# Global

In [9]:
global_var = {
    # Data
    'data_path': "/content/vessel_data",
    'prefix': "https://drive.google.com/uc?export=",
    'bifurcating_data_id': "download&id=1bzXusjOMgh-5hnw6RDNc03UUktQixtL4",
    'single_data_id': "download&id=19AHhDU1UWBzpG33wlB1r7b6sH3lMVly5",
    'current_name': "bifurcating",

    # Datasets quality check
    'data_length': 2000,
    'data_keys': ['face', 'inlet_idcs', 'pos', 'pressure', 'wss'],

    # Geometric algebra
    'ga_dimension': 16,
}

# Imports

In [10]:
!pip install wget --quiet

In [11]:
import gzip
import h5py
import numpy as np
import os
import pickle as pkl
import re
import sys
import torch
import wget
import zipfile

from tqdm.notebook import tqdm

# Dataset

## Download dataset

In [None]:
def bar_progress(current,total,width = 80):
    """
        Simple bar to track the download list of links
        associated with patches in the dataset

        Args:
            current: current byte number
            total: total byte number
            width: visual length of the bar
    """
    progress = current / total * 100
    progress_string = f"{progress} % [{current} / {total}] bytes"
    name = global_var['current_name']
    description = f"[{name}] Recovering links: " + progress_string
    sys.stdout.write("\r" + description)
    sys.stdout.flush()

In [None]:
def get_links_list(list_link,dataset_name):
    """
        Download the file with the links referred to
        the dataset patches

        Args:
            list_link: link from where to download the file
            dataset_name: final dataset name used for log

        Returns:
            links: list of the zipped dataset patches
    """
    global_var['current_name'] = dataset_name
    list_name = wget.download(list_link,bar = bar_progress)
    links = pkl.load(open(list_name, 'rb'))
    os.remove(list_name)
    os.makedirs(dataset_name, exist_ok = True)

    return links

In [None]:
def unzip_data_patches(links,dataset_name,debug):
    """
        Unzip the downloaded file patches to obtain the respective HDF5
        file within them, cleaning up excess files when finished

        Args:
            links: list with the links referred to file to download
            dataset_name: final dataset name used for log
            debug: boolean that allow testing the function

        Returns:
            hdf5_patches: list of HDF5 files
    """
    cnt = 0 if debug else None
    hdf5_patches = []

    tqdm_desc = f"[{dataset_name}] Downloading and unzipping data patches"
    for link in tqdm(links,desc = tqdm_desc):
        name = wget.download(link)
        with zipfile.ZipFile(name, 'r') as zip_ref:
            zip_ref.extractall("/content/")
            hdf5_patches.append(name.replace(".zip",".hdf5"))
            os.remove(name)
            if debug:
                cnt +=1
            if debug and cnt > 3:
                break
    return hdf5_patches

In [None]:
def compose_dataset(dataset_file,dataset_name,hdf5_patches):
    """
        Merge several HDF5 files into one,

        Args:
            dataset_file: final dataset path
            dataset_name: final dataset name
            hdf5_patches: list of file to merge
    """
    with h5py.File(dataset_file, 'w') as file_dest:
        tqdm_desc = f"[{dataset_name}] Composing HDF5 full dataset"
        for file_origine in tqdm(hdf5_patches,desc = tqdm_desc):
            percorso_file_origine = os.path.join("/content/", file_origine)

            with h5py.File(percorso_file_origine, 'r') as file_orig:
                for gruppo_nome, gruppo in file_orig.items():
                    gruppo_dest = file_dest.create_group(gruppo_nome)

                    for data_name, dataset in gruppo.items():
                        gruppo_dest.create_dataset(data_name, data=dataset[()])

    for file in hdf5_patches:
        percorso_file = os.path.join("/content/", file)
        os.remove(percorso_file)

In [None]:
def quality_check(dataset_file,dataset_name):
    """
        Check that the final dataset meets the standards

        Args:
            dataset_file: final dataset path
            dataset_name: final dataset name used for log

        Returns:
            (Boolean): logical result of the check
    """
    print(f"[{dataset_name}]\033[1m Dataset quality check \033[0m")

    with h5py.File(dataset_file, 'r') as dataset:
        # Check correct length
        len_condition = (len(dataset) == global_var['data_length'])
        if len_condition:
            error_str = "All samples present in the dataset \u2714"
            print(f"[{dataset_name}] " + error_str)
        else:
            missing_samples = global_var['data_length'] - len(dataset)
            error_str = "Not all samples are in the dataset \u2718"
            print(f"[{dataset_name}] " + error_str)
            if missing_samples > 1:
                print(f"[{dataset_name}] \t -> {missing_samples} is missing")
            else:
                print(f"[{dataset_name}] \t -> {missing_samples} are missing")
            return False

        # Check correct keys
        keys_condition = True
        for sample in list(dataset.keys()):
            correct_keys = list(dataset[sample].keys())
            current_keys = global_var['data_keys']
            keys_condition = keys_condition and (correct_keys == current_keys)
            if not keys_condition:
                problematic_sample = sample
                break
        if keys_condition:
            print(f"[{dataset_name}] All samples keys are correct \u2714")
        else:
            print(f"[{dataset_name}] Error in samples keys \u2718")
            print(f"[{dataset_name}] \t -> Check {problematic_sample}")
            return False

        # Check correct ordering
        ordering_condition = True
        numbers = [re.search(r'_\d+', s).group()[1:] for s in dataset.keys()]
        for n in range(len(numbers)):
            counter = f'{n:04d}'
            ordering_condition = ordering_condition and (numbers[n] == counter)
            if numbers[n] != counter:
                problematic_sample = counter
            break

        if ordering_condition:
            print(f"[{dataset_name}] Samples are ordered \u2714")
        else:
            print(f"[{dataset_name}] Samples are NOT ordered \u2718")
            print(f"[{dataset_name}] \t -> Check {counter}")
            return False

    return True

In [None]:
def download_dataset(list_link, dataset_name, debug):
    """
        Download the dataset divided into several patches,
        assemble it and check that the operation is successful

        Args:
            list_link: link from where to download the file
            dataset_name: final dataset name used for log
            debug: boolean that allow testing the function
    """
    dataset_file = dataset_name + "/" + dataset_name + ".hdf5"

    links = get_links_list(list_link,dataset_name)
    hdf5_patches = unzip_data_patches(links,dataset_name,debug)
    compose_dataset(dataset_file,dataset_name,hdf5_patches)
    check = quality_check(dataset_file,dataset_name)

    if check:
        success_string = "Dataset correctly downloaded \033[0m\u2714"
        print(f"[{dataset_name}] \u2714\033[1m " + success_string)
    else:
        error_string = "Dataset download failed \033[0m\u2718"
        print(f"[{dataset_name}] \u2718\033[1m " + error_string)

In [None]:
bifurcating_link = global_var['prefix'] + global_var["bifurcating_data_id"]
download_dataset(
    list_link = bifurcating_link,
    dataset_name = "bifurcating",
    debug = False
)

[bifurcating] Recovering links: 100.0 % [8708 / 8708] bytes

[bifurcating] Downloading and unzipping data patches:   0%|          | 0/100 [00:00<?, ?it/s]

[bifurcating] Composing HDF5 full dataset:   0%|          | 0/100 [00:00<?, ?it/s]

[bifurcating][1m Dataset quality check [0m
[bifurcating] All samples present in the dataset ✔
[bifurcating] All samples keys are correct ✔
[bifurcating] Samples are ordered ✔
[bifurcating] ✔[1m Dataset correctly downloaded [0m✔


In [None]:
single_link = global_var['prefix'] + global_var["single_data_id"]
download_dataset(
    list_link = single_link,
    dataset_name = "single",
    debug = False
)

[single] Recovering links: 100.0 % [8708 / 8708] bytes

[single] Downloading and unzipping data patches:   0%|          | 0/100 [00:00<?, ?it/s]

[single] Composing HDF5 full dataset:   0%|          | 0/100 [00:00<?, ?it/s]

[single][1m Dataset quality check [0m
[single] All samples present in the dataset ✔
[single] All samples keys are correct ✔
[single] Samples are ordered ✔
[single] ✔[1m Dataset correctly downloaded [0m✔


## Embedding in geometric algebra

In [12]:
def get_pos_mv(pos):
    mv = torch.zeros(pos.shape[0], global_var['ga_dimension'])

    mv[..., 14] = 1 # homogeneous coordinates
    mv[..., 11] = pos[..., 0] # x-coordinate of pos
    mv[..., 12] = pos[..., 1] # y-coordinate of pos
    mv[..., 13] = pos[..., 2] # z-coordinate of pos

    return mv

In [13]:
def get_face_mv(face):
    mv = torch.zeros(face.shape[0], global_var['ga_dimension'])

    mv[..., 2] = face[..., 0]
    mv[..., 3] = face[..., 1]
    mv[..., 4] = face[..., 2]

    return mv

In [14]:
def get_wss_mv(wss):
    mv = torch.zeros(wss.shape[0], global_var['ga_dimension'])

    mv[..., 0] = 1 # homogeneous coordinates

    mv[..., 5] = wss[..., 0]
    mv[..., 6] = wss[..., 1]
    mv[..., 7] = wss[..., 2]

    return mv

In [15]:
def get_inlet_mv(inlet):
    mv = torch.zeros(inlet.shape[0], global_var['ga_dimension'])

    mv[..., 0] = inlet

    return mv

In [8]:
def get_pressure_mv(pressure):
    mv = torch.zeros(pressure.shape[0], global_var['ga_dimension'])

    mv[..., 0] = pressure

    return mv

In [17]:
def embed_data(data_path,dataset_name):
    tqdm_desc = f"[{dataset_name}] Embedding data in geometric algebra"

    with h5py.File(data_path, 'r') as dataset:
        for sample_idx in tqdm(list(dataset.keys()),desc = tqdm_desc):
            sample = dataset[sample_idx]
            pkl_path = dataset_name + "/" + sample_idx + ".pkl"

            # 'pos' property modeled as a point
            input_pos = torch.tensor(np.array(sample['pos']))
            pos_mv = get_pos_mv(input_pos)

            # 'face' property modeled as a plane
            input_pos = torch.tensor(np.array(sample['face']))
            face_mv = get_face_mv(input_pos)

            # 'wss' property modeled as a translation
            input_wss = torch.tensor(np.array(sample['wss']))
            wss_mv = get_wss_mv(input_wss)

            # 'inlet' modeled as a scalar
            input_inlet = torch.tensor(np.array(sample['inlet_idcs']))
            inlet_mv = get_inlet_mv(input_inlet)

            # 'pressure' modeled as a scalar
            input_pressure = torch.tensor(np.array(sample['pressure']))
            pressure_mv = get_pressure_mv(input_pressure)

            total_mv = torch.concatenate(
                [pos_mv, face_mv, wss_mv, inlet_mv, pressure_mv]
            )
            with open(pkl_path, 'wb') as file:
                pkl.dump(total_mv, file)

In [18]:
bifurcating_embedded = embed_data(
    data_path = "/content/bifurcating/bifurcating.hdf5",
    dataset_name = "bifurcating"
)

[bifurcating] Embedding data in geometric algebra:   0%|          | 0/2000 [00:00<?, ?it/s]

In [19]:
single_embedded = embed_data(
    data_path = "/content/single/single.hdf5",
    dataset_name = "single"
)

[single] Embedding data in geometric algebra:   0%|          | 0/2000 [00:00<?, ?it/s]