In [None]:
import pathlib
import numpy as np
from thunderpack import ThunderDB

In [None]:
VERSION = "v0.1"

from typing import List, Tuple
from sklearn.model_selection import train_test_split
from pydantic import validate_arguments


@validate_arguments
def data_splits(
    values: List[str], 
    splits: Tuple[float, float, float, float], 
    seed: int
) -> Tuple[List[str], List[str], List[str], List[str]]:

    if len(set(values)) != len(values):
        raise ValueError(f"Duplicate entries found in values")

    # if (s := sum(splits)) != 1.0:
    #     raise ValueError(f"Splits must add up to 1.0, got {splits}->{s}")

    train_size, cal_size, val_size, test_size = splits
    values = sorted(values)
    # First get the size of the test splut
    traincalval, test = train_test_split(values, test_size=test_size, random_state=seed)
    # Next size of the val split
    val_ratio = val_size / (train_size + cal_size + val_size)
    traincal, val = train_test_split(traincalval, test_size=val_ratio, random_state=seed)
    # Next size of the cal split
    cal_ratio = cal_size / (train_size + cal_size)
    train, cal = train_test_split(traincal, test_size=cal_ratio, random_state=seed)

    assert sorted(train + cal + val + test) == values, "Missing Values"

    return (train, cal, val, test)


def thunderify(proc_root, dst):
    datacenters = proc_root.iterdir() 

    # Train Calibration Val Test
    splits_ratio = (0.6, 0.1, 0.2, 0.1)
    splits_seed = 42

    for dc in datacenters:
        dc_proc_dir = proc_root / dc.name
        dc_dst = dst / dc.name
        with ThunderDB.open(str(dc_dst), "c") as db:
            subjects = []
            num_annotators = []
            subj_list = dc_proc_dir.iterdir()
            for subj in subj_list:
                key = subj.name
                print(subj)
                img_dir = subj / "image.npy"
                img = np.load(img_dir) 
                mask_list = list(subj.glob("observer*"))
                mask_dict = {}
                for mask_dir in mask_list:
                    seg_dir = mask_dir / "label.npy"
                    seg = np.load(seg_dir)
                    mask_dict[mask_dir.name] = seg
                db[key] = {
                    "image": img,
                    "masks": mask_dict
                }
                subjects.append(key)
                print(len(mask_list))
                num_annotators.append(len(mask_list))

            subjects, num_annotators = zip(*sorted(zip(subjects, num_annotators)))
            splits = data_splits(subjects, splits_ratio, splits_seed)
            splits = dict(zip(("train", "cal", "val", "test"), splits))
            db["_subjects"] = subjects
            db["_splits"] = splits
            db["_splits_kwarg"] = {
                "ratio": splits_ratio, 
                "seed": splits_seed
                }
            attrs = dict(
                dataset="WMH",
                version=VERSION,
                group=dc.name,
                modality="FLAIR",
                resolution=256,
            )
            db["_num_annotators"] = num_annotators 
            db["_subjects"] = subjects
            db["_samples"] = subjects
            db["_splits"] = splits
            db["_attrs"] = attrs

In [None]:
root = pathlib.Path("/storage/vbutoi/datasets/WMH/processed")
dst = pathlib.Path("/storage/vbutoi/datasets/WMH/thunder_wmh")
thunderify(root, dst)