<a href="https://colab.research.google.com/github/airoprojects/vessel-classification/blob/main/code/VesselFormer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --upgrade gdown

Collecting gdown
  Downloading gdown-5.1.0-py3-none-any.whl (17 kB)
Installing collected packages: gdown
  Attempting uninstall: gdown
    Found existing installation: gdown 4.7.3
    Uninstalling gdown-4.7.3:
      Successfully uninstalled gdown-4.7.3
Successfully installed gdown-5.1.0


In [2]:
import gdown

In [3]:
# https://drive.google.com/drive/u/0/folders/1o34kRpwSsGAXhDw4K8eOXHce6JTIOQ18
output_file = "single.zip"  # Replace "data_file.ext" with the desired output filename and extension
gdown.download('https://drive.google.com/uc?id=1nzSdInny5VCi7PPT5EjsB8i1_QLbeOMu', output_file)

Downloading...
From (original): https://drive.google.com/uc?id=1nzSdInny5VCi7PPT5EjsB8i1_QLbeOMu
From (redirected): https://drive.usercontent.google.com/download?id=1nzSdInny5VCi7PPT5EjsB8i1_QLbeOMu&confirm=t&uuid=9e3d5be4-0c54-4245-a28f-9855cb88a08c
To: /content/single.zip
100%|██████████| 623M/623M [00:12<00:00, 49.6MB/s]


'single.zip'

In [4]:
# https://drive.google.com/file/d/1RYCJyMXccTWCDMN8c_xrZMXu47gaSpvg/view?usp=sharing
file_id = ""  # Replace this with your file's ID
output_file = "bifurcating.zip"  # Replace "data_file.ext" with the desired output filename and extension
gdown.download('https://drive.google.com/uc?id=1RYCJyMXccTWCDMN8c_xrZMXu47gaSpvg', output_file)

Downloading...
From (original): https://drive.google.com/uc?id=1RYCJyMXccTWCDMN8c_xrZMXu47gaSpvg
From (redirected): https://drive.usercontent.google.com/download?id=1RYCJyMXccTWCDMN8c_xrZMXu47gaSpvg&confirm=t&uuid=231ed069-084c-463c-bfa6-237139f6a31a
To: /content/bifurcating.zip
100%|██████████| 1.18G/1.18G [00:22<00:00, 53.5MB/s]


'bifurcating.zip'

In [5]:
!unzip single.zip
!unzip bifurcating.zip

Archive:  single.zip
   creating: single/
   creating: single/raw/
  inflating: single/raw/README       
  inflating: single/raw/database.hdf5  
  inflating: single/raw/CC-BY.svg    
  inflating: single/raw/vtk_demo.py  
  inflating: single/raw/md5_sum      
  inflating: single/raw/licence_CC-BY  
Archive:  bifurcating.zip
   creating: bifurcating/
   creating: bifurcating/raw/
  inflating: bifurcating/raw/database.hdf5  
  inflating: bifurcating/raw/md5_sum  
  inflating: bifurcating/raw/CC-BY.svg  
  inflating: bifurcating/raw/README  
  inflating: bifurcating/raw/vtk_demo.py  
  inflating: bifurcating/raw/licence_CC-BY  


In [6]:
!pip install -q torch_geometric vtk

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.0/92.0 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [7]:
import random

import h5py
import torch
from torch_geometric.data import Data

import vtk
from vtk.util.numpy_support import numpy_to_vtk, numpy_to_vtkIdTypeArray
import numpy as np


def main():
    vtu_writer = VTUWriter()
    vtu_writer("random_sample.vtu", hdf5_to_pyg("database.hdf5", sample_id=random.randint(0, 1998)))


def hdf5_to_pyg(path_to_hdf5, sample_id):

    sample_name = f"sample_{sample_id:04d}"
    with h5py.File(path_to_hdf5, 'r') as f:

        # PyG Data object
        data = Data(
            wss=torch.from_numpy(f[sample_name]['wss'][()]),
            pos=torch.from_numpy(f[sample_name]['pos'][()]),
            face=torch.from_numpy(f[sample_name]['face'][()].T).long(),  # transpose to match PyG convention
            inlet_index=torch.from_numpy(f[sample_name]['inlet_idcs'][()])  # name "_index" for correct batching
        )

    return data


# Polygon data represented as PyG Data object to VTU file
class VTUWriter():
    def __init__(self):
        self.vtu_writer = vtk.vtkXMLUnstructuredGridWriter()

    def __call__(self, path_to_file, data):

        self.vtu_writer.SetFileName(path_to_file)
        self.vtu_writer.SetInputData(self.pyg_to_vtk(data))

        self.vtu_writer.Update()

    def pyg_to_vtk(self, data):
        vtk_unstructured_grid = vtk.vtkUnstructuredGrid()

        vtk_points = vtk.vtkPoints()
        vtk_points.SetData(numpy_to_vtk(data.pos.numpy()))

        vtk_cell_array = vtk.vtkCellArray()
        vtk_cell_array.SetCells(data.face.shape[-1], numpy_to_vtkIdTypeArray(self.serialise_simplices(data.face)))

        vtk_unstructured_grid.SetPoints(vtk_points)
        vtk_unstructured_grid.SetCells(vtk.VTK_POLYGON, vtk_cell_array)

        vtk_unstructured_grid = self.add_point_data(vtk_unstructured_grid, data)

        return vtk_unstructured_grid

    @staticmethod
    def serialise_simplices(simplices):
        simplices = simplices.t().numpy()  # (3, N) to (N, 3)

        simplices = np.concatenate((
            np.full(simplices.shape[0], simplices.shape[1])[:, None],
            simplices
        ), axis=-1)

        return simplices.ravel()

    def add_point_data(self, vtk_unstructured_grid, data):
        for key, value in {**self.parse_point_data(data), **self.parse_point_indices(data)}.items():

            array = numpy_to_vtk(value)
            array.SetName(key)

            vtk_unstructured_grid.GetPointData().AddArray(array)

        return vtk_unstructured_grid

    @staticmethod
    def parse_point_data(data):
        return {key: value for key, value in data if value.size(0) == data.pos.size(0) and key != 'pos'}

    @staticmethod
    def parse_point_indices(data):
        point_mask_dict = {}

        for key, value in data:
            if "_index" in key and key != 'edge_index':

                point_mask = np.zeros(data.pos.size(0), dtype='i4')
                point_mask[value] = 1

                point_mask_dict[key.replace("_index", "")] = point_mask

        return point_mask_dict


In [17]:
%cd /content

/content


In [18]:
%cd bifurcating/raw/

/content/bifurcating/raw


In [15]:
%cd single/raw/

/content/single/raw


In [19]:
main()