You can run this notebook directly on Google Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DaniAffCH/Vessel-Geometric-Transformers/blob/main/main.ipynb)

In [None]:
import sys

COLAB_RUNTIME = 'google.colab' in sys.modules
if COLAB_RUNTIME:
    !git init
    !git remote add origin https://github.com/DaniAffCH/Vessel-Geometric-Transformers.git
    !git pull origin main
    !pip install -q -r requirements.txt
else: # Development mode, setting precommit checks 
    !pip install -r requirements.txt
    !pre-commit autoupdate
    !pre-commit install


Loading the configuration

In [2]:
from src.utils import load_config
import os
from config import DatasetConfig, TrainerConfig, BaselineConfig, GatrConfig

config_path = os.path.join("config","config.yaml")

config = load_config(config_path)
dataset_config: DatasetConfig = config.dataset
trainer_config: TrainerConfig = config.trainer
baseline_config: BaselineConfig = config.baseline
gatr_config: GatrConfig = config.gatr

---

Loading the dataset

In [None]:
from src.data import VesselDataModule
from src.utils.data_analysis import data_info

data = VesselDataModule(dataset_config)
data_info(data)

Data distribution

In [None]:
import seaborn as sns
sns.countplot(data.data.label)

In [None]:
from src.utils.definitions import Feature, Category
from src.utils.data_analysis import plot_data

wss, labels = data.extract_feature(Feature.WSS)
pos, labels = data.extract_feature(Feature.POS)
pressure, labels = data.extract_feature(Feature.PRESSURE)
face, labels = data.extract_feature(Feature.FACE)
plot_data(pos, labels, Category, "Position")
plot_data(wss, labels, Category, "Wall Shear Stress")
plot_data(pressure, labels, Category, "Pressure")
plot_data(face, labels, Category, "Face")

---

Performing equivariance check

In [None]:
from src.lib.geometricAlgebraElements import GeometricAlgebraBase
from src.test.test_equivariance import TestEquivariance
import unittest

dl = data.train_dataloader()

batch = next(iter(dl)).data[0]
batch = batch.view(-1, GeometricAlgebraBase.GA_size)[:10]
TestEquivariance.INPUT_DATA = batch

suite = unittest.TestSuite()
suite.addTests(unittest.TestLoader().loadTestsFromTestCase(TestEquivariance))
test_runner = unittest.TextTestRunner(verbosity=0)
restResult = test_runner.run(suite)

---

Baseline (trainin and testin)

In [None]:
from src.trainer import VesselTrainer
from src.models import BaselineTransformer
from src.utils.hpo import baseline_hpo

trainer = VesselTrainer(trainer_config)
model = BaselineTransformer(baseline_config)
baseline_hpo(config, model, data)
trainer.fit(model, data)

In [None]:
trainer.test(model, data)

---

Gatr

In [None]:
from src.trainer import VesselTrainer
from src.models import Gatr
from src.utils.hpo import gatr_hpo

model = Gatr(gatr_config)
gatr_hpo(config, model, data)
trainer = VesselTrainer(trainer_config)
trainer.fit(model, data)

In [9]:
trainer.test(model, data)

----