You can run this notebook directly on Google Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DaniAffCH/Vessel-Geometric-Transformers/blob/main/main.ipynb)

In [1]:
import sys

COLAB_RUNTIME = 'google.colab' in sys.modules
# if COLAB_RUNTIME:
#     !git clone https://github.com/DaniAffCH/Vessel-Geometric-Transformers.git
#     !mv Vessel-Geometric-Transformers/* . 
#     !pip install -q -r requirements.txt
# else:
#     !pip install -q -r requirements.txt
#     !pre-commit autoupdate
#     !pre-commit install

Loading the configuration

In [2]:
from src.utils import load_config
import os
from config import DatasetConfig, TrainerConfig, BaselineConfig

config_path = os.path.join("config","config.yaml")

config = load_config(config_path)
dataset_config: DatasetConfig = config.dataset
trainer_config: TrainerConfig = config.trainer
baseline_config: BaselineConfig = config.baseline

Loading the dataset

In [3]:
from src.data import VesselDataModule

data = VesselDataModule(dataset_config)
print(f'Train size: {len(data.train_set)}')
print(f'Validation size: {len(data.val_set)}')
print(f'Test size: {len(data.test_set)}')
print(data.train_set[0])

Train size: 2999
Validation size: 599
Test size: 401
Data(pos=[19217, 3], wss=[19217, 3], pressure=[19217], face=[3, 38430], inlet_index=[921])


Testing

In [4]:
from src.lib import PointGeometricAlgebra, TranslationGeometricAlgebra, ScalarGeometricAlgebra, PlaneGeometricAlgebra

print(data.train_set[0].pos[0])
print(PointGeometricAlgebra.fromElement(data.train_set[0].pos[0].unsqueeze(0)))
print()

print(TranslationGeometricAlgebra.fromElement(data.train_set[1].wss)[:, :8])
print()

print(ScalarGeometricAlgebra.fromElement(data.train_set[0].pressure))
print()

print(PlaneGeometricAlgebra.fromElement(data.train_set[0].face.T))

tensor([-0.1435, -0.1550, -0.0041])
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000, -0.1435, -0.1550, -0.0041,  1.0000,  0.0000]])

tensor([[ 1.0000,  0.0000,  0.0000,  ..., -1.3098, -3.7878, -0.7568],
        [ 1.0000,  0.0000,  0.0000,  ..., -1.3070, -3.7272, -0.7964],
        [ 1.0000,  0.0000,  0.0000,  ..., -1.2395, -3.6962, -0.8062],
        ...,
        [ 1.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 1.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])

tensor([[133372.5312,      0.0000,      0.0000,  ...,      0.0000,
              0.0000,      0.0000],
        [133347.1875,      0.0000,      0.0000,  ...,      0.0000,
              0.0000,      0.0000],
        [133348.7500,      0.0000,      0.0000,  ...,      0.0000,
              0.0000,      0.0000],
        ...,
        [133410.9688,      0.0000,      0.

Training Loop

In [5]:
from src.trainer import VesselTrainer
from src.models import BaselineTransformer

model = BaselineTransformer(baseline_config)
trainer = VesselTrainer(trainer_config)
trainer.fit(model, data)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/neverorfrog/.miniconda3/envs/gatr/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/home/neverorfrog/.miniconda3/envs/gatr/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory /home/neverorfrog/code/deep-learning/gatr/ckpt exists and is not empty.
Restoring states from the checkpoint path at ckpt


IsADirectoryError: [Errno 21] Is a directory: '/home/neverorfrog/code/deep-learning/gatr/ckpt'