## 2 constrasts model

This notebook load, preprocess the data and train a first modèle to predict if a 2 image is T1w or T2w.
The Notebook form helps running and testing fast before coding the final structure

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import nibabel as nib
import pandas as pd
from monai.data import Dataset, DataLoader, CacheDataset
from monai.transforms import (
    Compose,
    RandScaleCrop,
    RandFlipd,
    RandRotate90d,
    RandRotate,
    RandShiftIntensityd,
    ToTensord,
    
)
import os
import nibabel as nib
import json
import pandas as pd
from sklearn.model_selection import train_test_split
import monai

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# this cell aims at extracting the list of path relevant for the first model test which takes T1w T2w adn DWI as input

base_dir="data//data-multi-subject//"

desired_extension = ".json"

# Initialize lists to store the relative paths for T1w, T2w, and DWI files
t1w_file_paths = []
t2w_file_paths = []

print("Searching for T1w, T2w, and DWI files in", base_dir, "...")

# Traverse the directory structure
for root, dirs, files in os.walk(base_dir):
    # Exclude the "derivatives" subfolder
    if "derivatives" in dirs:
        dirs.remove("derivatives")
    for file in files:
        # Check if the file name contains the desired names
        if "T1w" in file and file.endswith(desired_extension):
            # Get the relative path of the T1w file
            relative_path = os.path.relpath(os.path.join(root, file), base_dir)
            # Remove the file extension
            relative_path = os.path.splitext(relative_path)[0] + ".nii.gz"
            # Append the relative path to the T1w file paths list
            t1w_file_paths.append(relative_path)
        elif "T2w" in file and file.endswith(desired_extension):
            # Get the relative path of the T2w file
            relative_path = os.path.relpath(os.path.join(root, file), base_dir)
            # Remove the file extension
            relative_path = os.path.splitext(relative_path)[0] + ".nii.gz"
            # Append the relative path to the T2w file paths list
            t2w_file_paths.append(relative_path)

t1w_file_paths = t1w_file_paths[:20]
t2w_file_paths = t2w_file_paths[:20]

print("Found", len(t1w_file_paths), "T1w files and", len(t2w_file_paths), "T2w files.")

Searching for T1w, T2w, and DWI files in data//data-multi-subject// ...
Found 20 T1w files and 20 T2w files.


In [3]:
# Define a function to load image data
def load_image(image_path):
    img = nib.load(image_path)
    img_data = img.get_fdata()
    return img_data

def From_3D_to_2Ds(image, label):
    #from each 3D image we will extract 10 2D images, from random different views
    data_2D=[]
    selected_perax=[[],[],[]] #list of the selected indexes for each view
    for i in range(10):
        # randomly select a view
        view = np.random.choice([0, 1, 2]) # 0: sagittal, 1: coronal, 2: axial
        # randomly select an index for wich we will extract the 2D image among the index which were not yet selected
        index = np.random.choice([i for i in range(image.shape[view]) if i not in selected_perax[view]])
        # choose randomly if the image is cropped or not  

        if view == 0:
            image_2D = image[index, :, :]
        elif view == 1:
            image_2D = image[:, index, :]
        else:
            image_2D = image[:, :, index]
        selected_perax[view].append(index)
        data_2D.append({"image" : image_2D, "label" : label})
    return(data_2D)


In [4]:
# split the data into training and validation sets

# build a dataset with a colmn "file path" wich contiains the paths listed in both t1w_file_paths and t2w_file_paths
path_data = pd.DataFrame({"image_path" : t1w_file_paths + t2w_file_paths, "label" : len(t1w_file_paths) * [[1,0]] + len(t2w_file_paths) * [[0,1]]})
print(path_data.head())

train_data, val_data = train_test_split(path_data, test_size=0.2, random_state=0)

# load the 3D images
train_data_2D = []
val_data_2D = []
for index, row in train_data.iterrows():
    image_path = os.path.join(base_dir, row["image_path"])
    image = load_image(image_path)
    train_data_2D += From_3D_to_2Ds(image, row["label"])
for index, row in val_data.iterrows():
    image_path = os.path.join(base_dir, row["image_path"])
    image = load_image(image_path)
    val_data_2D += From_3D_to_2Ds(image, row["label"])

# Shuffle the data
np.random.shuffle(train_data_2D)
np.random.shuffle(val_data_2D)



                            image_path   label
0  sub-amu01\anat\sub-amu01_T1w.nii.gz  [1, 0]
1  sub-amu02\anat\sub-amu02_T1w.nii.gz  [1, 0]
2  sub-amu03\anat\sub-amu03_T1w.nii.gz  [1, 0]
3  sub-amu04\anat\sub-amu04_T1w.nii.gz  [1, 0]
4  sub-amu05\anat\sub-amu05_T1w.nii.gz  [1, 0]


In [9]:
# Define a custom dataset class
class Dataset_2D(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_item = self.data[index]
        image = data_item["image"]
        label = data_item["label"]
        if self.transform:
            image = self.transform(image)
        return image, label
    
# use monai to define the transforms for data augmentation
# perform the following transformations : rotation (random between +3° and -3°), flipping (random between 0°,  90 °, 180° and 270°), cropping (Random size, random place) and shifting (random shift)

train_transforms = Compose(
    [
        RandRotate90d(keys=["image"], prob=0.5),
        RandFlipd(keys=["image"], prob=0.5),
        RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
        RandScaleCrop([np.random.uniform(0.5, 1), np.random.uniform(0.5,1)], random_center=True),
        ToTensord(keys=["image"]),
    ]
)


In [10]:
# Build the training and validation datasets
train_dataset = Dataset_2D(train_data_2D, transform=train_transforms)
val_dataset = Dataset_2D(val_data_2D, transform=train_transforms)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=10)

In [7]:



model = monai.networks.nets.UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)



