In [64]:
import os 
from omegaconf import OmegaConf
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule

torch.random.manual_seed(0)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [96]:
# Load decoder config
decoder_cfg = OmegaConf.load("../configs/config.yaml")

# Load encoder's training config
encoder_config_path = os.path.join(
    decoder_cfg.model_to_decode_path,
    ".hydra",
    "config.yaml"
)
encoder_cfg = OmegaConf.load(encoder_config_path)

# Override dataset_folder which conains a mistake
encoder_cfg["dataset_folder"] = decoder_cfg["dataset_folder"]

# Only resolve the dataset part to avoid errors in unrelated keys
region = list(encoder_cfg.dataset.keys())[0]
print("Region:", region)

dataset_info = OmegaConf.to_container(encoder_cfg.dataset[region], resolve=True)

# Access relevant dataset info

print("Numpy path:", dataset_info['numpy_all'])
print("Subjects path:", dataset_info['subjects_all'])
print("subject column name:", dataset_info['subject_column_name'])
print("Input size:", dataset_info['input_size'])

Region: CINGULATE_left
Numpy path: /neurospin/dico/data/deep_folding/current/datasets/UkBioBank40/crops/2mm/CINGULATE./mask/Lskeleton.npy
Subjects path: /neurospin/dico/data/deep_folding/current/datasets/UkBioBank40/crops/2mm/CINGULATE./mask/Lskeleton_subject.csv
subject column name: Subject
Input size: (1, 18, 41, 38)


In [94]:
skels = np.load(dataset_info['numpy_all'])
list_sub = pd.read_csv(dataset_info['subjects_all'])
print(region, skels.shape, len(list_sub))

CINGULATE_left (42433, 18, 41, 38, 1) 42433


In [46]:
rootpath = os.path.join(decoder_cfg["model_to_decode_path"])
train_path = os.path.join(rootpath, config["train_csv"])
val_test_path = os.path.join(rootpath, config["val_test_csv"])

In [47]:
train_data = pd.read_csv(train_path)
a = pd.read_csv(val_test_path)
a['IID'] = a['ID'].apply(lambda x : int(x[4:]))
val_data = a[a['IID']%2 ==0].drop('IID', axis=1)
test_data = a[a['IID']%2 ==1].drop('IID', axis=1)

#train_data = train_data.drop('ID', axis=1)
#val_data = val_data.drop('ID', axis=1)
#test_data = test_data.drop('ID', axis=1)

print(f"Train set size: {len(train_data)}") 
print(f"Validation set size: {len(val_data)}") 
print(f"Test set size: {len(test_data)}")

Train set size: 38190
Validation set size: 2092
Test set size: 2151


In [48]:
train_data.head(3)

Unnamed: 0,ID,dim1,dim2,dim3,dim4,dim5,dim6,dim7,dim8,dim9,...,dim23,dim24,dim25,dim26,dim27,dim28,dim29,dim30,dim31,dim32
0,sub-1000021,11.229898,29.822453,-4.026312,3.358199,0.231886,1.422046,-18.252083,12.020048,39.133064,...,-12.145392,-5.654875,24.603487,25.662783,-27.461916,-21.059357,-16.466547,17.850504,-43.04543,-16.672508
1,sub-1000325,-31.725147,-0.39503,-0.939265,1.984952,47.066414,-37.844696,11.580789,8.266258,-1.891401,...,-14.643673,-34.058014,20.73292,-26.947844,6.144578,13.84668,15.867406,-57.067223,-28.45412,-0.502026
2,sub-1000575,-13.318561,42.696262,-25.401484,0.247414,-8.130686,0.259071,15.793086,2.559356,-76.482475,...,3.452085,12.126092,-33.715843,-16.900974,-16.233204,-17.635584,-20.436428,-15.893841,-1.953091,-7.96151


In [100]:
class LatentTargetDataset(Dataset):
    def __init__(self, latent_csv_path, target_npy_path, subjects_all_path, subject_list):
        self.subject_list = subject_list

        # Load latent vectors
        latent_df = pd.read_csv(latent_csv_path)
        latent_df = latent_df[latent_df["ID"].isin(subject_list)]

        # Enforce the order in subject_list
        latent_df = latent_df.set_index("ID").loc[subject_list].reset_index()
        self.latents = latent_df.drop(columns=["ID"])  # keep only the vector columns

        # Load subject order from the .csv (matching the .npy)
        subjects_all_df = pd.read_csv(subjects_all_path)

        # Create mapping from subject ID to index in the .npy array
        subject_to_index = {subj: idx for idx, subj in enumerate(subjects_all_df["Subject"])}

        # For each subject in subject_list, get its corresponding index in .npy file
        try:
            self.indices = [subject_to_index[subj] for subj in subject_list]
        except KeyError as e:
            raise ValueError(f"Subject {e.args[0]} not found in subjects_all list")

        # Load target volumes
        self.targets = np.load(target_npy_path)

        # Safety check
        assert len(self.latents) == len(self.indices), "Mismatch between latent vectors and volume indices"

    def __len__(self):
        return len(self.latents)

    def __getitem__(self, idx):
        latent_vector = torch.tensor(self.latents.iloc[idx].values, dtype=torch.float32)
        target_volume = torch.tensor(self.targets[self.indices[idx]], dtype=torch.float32).permute(3, 0, 1, 2)
        return latent_vector, target_volume


In [101]:
class DataModule_Learning(LightningDataModule):
    def __init__(self, config, dataset_info):
        super().__init__()
        self.config = config
        self.dataset_info = dataset_info

    def setup(self, stage=None):

        # Full paths to CSV and NPY files
        train_path = os.path.join(self.config["model_to_decode_path"], self.config["train_csv"])
        val_test_path = os.path.join(self.config["model_to_decode_path"], self.config["val_test_csv"])
        print(self.config.keys())
        subjects_all = os.path.join(self.config["dataset_folder"], self.dataset_info['subjects_all'])
        target_npy_path = os.path.join(self.config["dataset_folder"], self.dataset_info['numpy_all'])

        # Read splits
        train_data = pd.read_csv(train_path)
        val_test_data = pd.read_csv(val_test_path)
        val_test_data['IID'] = val_test_data['ID'].apply(lambda x: int(x[4:]))

        # Split into validation and test
        train_subjects = train_data['ID'].tolist()
        val_subjects = val_test_data[val_test_data['IID'] % 2 == 0]['ID'].tolist()
        test_subjects = val_test_data[val_test_data['IID'] % 2 == 1]['ID'].tolist()

        # Instantiate datasets
        self.dataset_train = LatentTargetDataset(
            latent_csv_path=train_path,
            target_npy_path=target_npy_path,
            subjects_all_path=subjects_all,
            subject_list=train_subjects
        )
        self.dataset_val = LatentTargetDataset(
            latent_csv_path=val_test_path,
            target_npy_path=target_npy_path,
            subjects_all_path=subjects_all,
            subject_list=val_subjects
        )
        self.dataset_test = LatentTargetDataset(
            latent_csv_path=val_test_path,
            target_npy_path=target_npy_path,
            subjects_all_path=subjects_all,
            subject_list=test_subjects
        )

    def train_dataloader(self):
        return DataLoader(self.dataset_train,
                          batch_size=self.config["batch_size"],
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.dataset_val,
                          batch_size=self.config["batch_size"],
                          shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.dataset_test,
                          batch_size=self.config["batch_size"],
                          shuffle=False)

In [104]:
datamodule = DataModule_Learning(decoder_cfg, dataset_info)
datamodule.setup()
train_loader = datamodule.train_dataloader()

for latent_vector, target_volume in train_loader:
    print(latent_vector.shape, target_volume.shape)
    break

dict_keys(['model_to_decode_path', 'dataset_folder', 'train_csv', 'val_test_csv', 'output_dir', 'batch_size', 'learning_rate', 'num_epochs', 'latent_dim', 'dropout', 'decoder_type', 'encoder_depth', 'block_depth', 'filters', 'last_kernel_size', 'activation'])
torch.Size([32, 32]) torch.Size([32, 1, 18, 41, 38])


In [110]:
from dataloader import DataModule_Learning

# Load configs
decoder_cfg = OmegaConf.load("../configs/config.yaml")
encoder_config_path = os.path.join(decoder_cfg.model_to_decode_path, ".hydra", "config.yaml")
encoder_cfg = OmegaConf.load(encoder_config_path)
encoder_cfg["dataset_folder"] = decoder_cfg["dataset_folder"]
region = list(encoder_cfg.dataset.keys())[0]
dataset_info = OmegaConf.to_container(encoder_cfg.dataset[region], resolve=True)

# Instantiate and setup the datamodule
dm = DataModule_Learning(decoder_cfg, dataset_info)
dm.setup()

# Get the data loaders
train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()
test_loader = dm.test_dataloader()

# Example loop
for latent_vector, target_volume in train_loader:
    print(latent_vector.shape, target_volume.shape)
    break  # Just show one batch

torch.Size([32, 32]) torch.Size([32, 1, 16, 37, 37])
