You can run this notebook directly on Google Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DaniAffCH/Vessel-Geometric-Transformers/blob/main/main.ipynb)

In [1]:
import sys

COLAB_RUNTIME = 'google.colab' in sys.modules
if COLAB_RUNTIME:
    !git init
    !git remote add origin https://github.com/DaniAffCH/Vessel-Geometric-Transformers.git
    !git pull origin main
    !pip install -q -r requirements.txt
else: # Development mode, setting precommit checks 
    !pip install -r requirements.txt
    !pre-commit autoupdate
    !pre-commit install


[https://github.com/psf/black] already up to date!
[https://github.com/pycqa/isort] already up to date!
[https://github.com/PyCQA/flake8] already up to date!
[https://github.com/pre-commit/mirrors-mypy] updating v1.11.2 -> v1.12.0
pre-commit installed at .git/hooks/pre-commit


Loading the configuration

In [2]:
from src.utils import load_config
import os

config_path = os.path.join("config","config.yaml")
config = load_config(config_path)

---

Loading the dataset

In [3]:
from src.data import VesselDataModule
from src.utils.data_analysis import data_info

data = VesselDataModule(config.dataset)
data_info(data)

Processing...
Loading /home/neverorfrog/code/Vessel-Geometric-Transformers/dataset/stead/bifurcating/raw/database.hdf5: 100%|██████████| 1999/1999 [00:06<00:00, 288.50it/s]
Loading /home/neverorfrog/code/Vessel-Geometric-Transformers/dataset/stead/single/raw/database.hdf5: 100%|██████████| 2000/2000 [00:04<00:00, 463.57it/s]


: 

Data distribution

In [None]:
import seaborn as sns
sns.countplot(data.data.label)

In [None]:
from src.utils.definitions import Feature, Category
from src.utils.data_analysis import plot_data

wss, labels = data.extract_feature(Feature.WSS)
pos, labels = data.extract_feature(Feature.POS)
pressure, labels = data.extract_feature(Feature.PRESSURE)
face, labels = data.extract_feature(Feature.FACE)
plot_data(pos, labels, Category, "Position")
plot_data(wss, labels, Category, "Wall Shear Stress")
plot_data(pressure, labels, Category, "Pressure")
plot_data(face, labels, Category, "Face")

---

Performing equivariance check

In [None]:
from src.lib.geometricAlgebraElements import GeometricAlgebraBase
from src.test.test_equivariance import TestEquivariance
import unittest

dl = data.train_dataloader()

batch = next(iter(dl)).data[0]
batch = batch.view(-1, GeometricAlgebraBase.GA_size)[:10]
TestEquivariance.INPUT_DATA = batch

suite = unittest.TestSuite()
suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestEquivariance))
test_runner = unittest.TextTestRunner(verbosity=0)
restResult = test_runner.run(suite)

---

Baseline

In [None]:
from src.models import BaselineTransformer
from src.utils.hpo import baseline_hpo

baseline_hpo(config, data) # Hyperparameter optimization: writes the config file with the best hyperparameters

In [None]:
from src.trainer import VesselTrainer

model = BaselineTransformer(config.baseline)
trainer = VesselTrainer(config.trainer)
trainer.fit(model, data)

In [None]:
trainer.test(model, data)

---

Gatr

In [None]:
from src.models import Gatr
from src.utils.hpo import gatr_hpo

gatr_hpo(config, data) # Hyperparameter optimization: writes the config file with the best hyperparameters

In [None]:
from src.trainer import VesselTrainer

model = Gatr(config.gatr)
trainer = VesselTrainer(config.trainer)
trainer.fit(model, data)

In [9]:
trainer.test(model, data)

----