## 2 constrasts model

This notebook load, preprocess the data and train a first modèle to predict if a 2 image is T1w or T2w.
The Notebook form helps running and testing fast before coding the final structure

In [151]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import nibabel as nib
import pandas as pd
from monai.data import Dataset, DataLoader, CacheDataset
import torchvision.models as models
from monai.transforms import (
    Compose,
    RandScaleCrop,
    RandFlipd,
    RandRotate90d,
    RandRotate,
    RandShiftIntensityd,
    ToTensord,
    RandSpatialCrop,
)
import os
import nibabel as nib
import json
import pandas as pd
from sklearn.model_selection import train_test_split
import monai




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [51]:
# this cell aims at extracting the list of path relevant for the first model test which takes T1w T2w adn DWI as input

base_dir="data//data-multi-subject//"

desired_extension = ".json"

# Initialize lists to store the relative paths for T1w, T2w, and DWI files
t1w_file_paths = []
t2w_file_paths = []

print("Searching for T1w, T2w, and DWI files in", base_dir, "...")

# Traverse the directory structure
for root, dirs, files in os.walk(base_dir):
    # Exclude the "derivatives" subfolder
    if "derivatives" in dirs:
        dirs.remove("derivatives")
    for file in files:
        # Check if the file name contains the desired names
        if "T1w" in file and file.endswith(desired_extension):
            # Get the relative path of the T1w file
            relative_path = os.path.relpath(os.path.join(root, file), base_dir)
            # Remove the file extension
            relative_path = os.path.splitext(relative_path)[0] + ".nii.gz"
            # Append the relative path to the T1w file paths list
            t1w_file_paths.append(relative_path)
        elif "T2w" in file and file.endswith(desired_extension):
            # Get the relative path of the T2w file
            relative_path = os.path.relpath(os.path.join(root, file), base_dir)
            # Remove the file extension
            relative_path = os.path.splitext(relative_path)[0] + ".nii.gz"
            # Append the relative path to the T2w file paths list
            t2w_file_paths.append(relative_path)

t1w_file_paths = t1w_file_paths[:20]
t2w_file_paths = t2w_file_paths[:20]

print("Found", len(t1w_file_paths), "T1w files and", len(t2w_file_paths), "T2w files.")

Searching for T1w, T2w, and DWI files in data//data-multi-subject// ...
Found 20 T1w files and 20 T2w files.


In [136]:
# Define a function to load image data
def load_image(image_path):
    img = nib.load(image_path)
    img_data = img.get_fdata()
    return img_data

def From_3D_to_2Ds(image, label):
    #from each 3D image we will extract 10 2D images, from random different views
    data_2D=[]
    selected_perax=[[],[],[]] #list of the selected indexes for each view
    for i in range(10):
        # randomly select a view
        view = np.random.choice([0, 1, 2]) # 0: sagittal, 1: coronal, 2: axial
        # randomly select an index for wich we will extract the 2D image among the index which were not yet selected
        index = np.random.choice([i for i in range(image.shape[view]) if i not in selected_perax[view]])
        if view == 0:
            image_2D = np.array([image[index, :, :]])
        elif view == 1:
            image_2D = np.array([image[:, index, :]])
        else:
            image_2D = np.array([image[:, :, index]])
        selected_perax[view].append(index)
        data_2D.append({"image" : image_2D, "label" : label})
    return(data_2D)


In [145]:
# split the data into training and validation sets

# build a dataset with a colmn "file path" wich contiains the paths listed in both t1w_file_paths and t2w_file_paths
path_data = pd.DataFrame({"image_path" : t1w_file_paths + t2w_file_paths, "label" : len(t1w_file_paths) * [[1,0]] + len(t2w_file_paths) * [[0,1]]})
print(path_data.head())

train_data, val_data = train_test_split(path_data, test_size=0.2, random_state=0)

# load the 3D images
train_data_2D = []
val_data_2D = []
for index, row in train_data.iterrows():
    image_path = os.path.join(base_dir, row["image_path"])
    image = load_image(image_path)
    train_data_2D += From_3D_to_2Ds(image, row["label"])
for index, row in val_data.iterrows():
    image_path = os.path.join(base_dir, row["image_path"])
    image = load_image(image_path)
    val_data_2D += From_3D_to_2Ds(image, row["label"])

# Save the 2D images to data/preprocessed_2D_images/{train/val}/{label}_idx.png
preprocessed_dir = "data//preprocessed_2D_images//"
if not os.path.exists(preprocessed_dir):
    os.makedirs(preprocessed_dir)
    os.makedirs(os.path.join(preprocessed_dir, "train"))
    os.makedirs(os.path.join(preprocessed_dir, "val"))

for i, data in enumerate(train_data_2D):
    image = data["image"]
    label = data["label"]
    image_path = os.path.join(preprocessed_dir, "train", f"{label}_{i}.npy")
    np.save(image_path, image)

for i, data in enumerate(val_data_2D):
    image = data["image"]
    label = data["label"]
    image_path = os.path.join(preprocessed_dir, "val", f"{label}_{i}.npy")
    np.save(image_path, image)







                            image_path   label
0  sub-amu01\anat\sub-amu01_T1w.nii.gz  [1, 0]
1  sub-amu02\anat\sub-amu02_T1w.nii.gz  [1, 0]
2  sub-amu03\anat\sub-amu03_T1w.nii.gz  [1, 0]
3  sub-amu04\anat\sub-amu04_T1w.nii.gz  [1, 0]
4  sub-amu05\anat\sub-amu05_T1w.nii.gz  [1, 0]
[[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]]
[[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]]
[[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]]
[[[ 79.  49.  52. ...  68.  48. 108.]
  [139.  81.  96. ... 102.  90. 174.]
  [101. 157. 115. ... 151. 159. 180.]
  ...
  [139. 106.  87. ... 102. 129. 118.]
  [ 26.  46.  56. ... 157. 183.  99.]
  [  1.  -1.  -2. ...  71. 110

In [172]:
# Define a custom dataset class
class Dataset_2D(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        path = self.paths[index]
        label = self.labels[index]
        if self.transform:
            image = self.transform(path)
        return image, label
    
# use monai to define the transforms for data augmentation
# perform the following transformations : rotation (random between +3° and -3°), flipping (random between 0°,  90 °, 180° and 270°), cropping (Random size, random place) and shifting (random shift)

train_transforms = Compose(
    [
        RandRotate90d(keys=["image"], prob=0.5),
        RandFlipd(keys=["image"], prob=0.5),
        RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
        RandScaleCrop([0.6, 0.5], random_center=True),
        ToTensord(keys=["image"]),
    ]
)


In [264]:
test_path = "data\data-multi-subject\sub-amu01\\anat\sub-amu01_T1w.nii.gz"
from monai.transforms import LoadImage

#load from a path the image and apply the transforms
test_transform = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
        RandSpatialCrop(np.array([1, 1, 1]), max_roi_size = np.random.choice(np.array([-1,-1,1]),3, replace=False), random_size=True, random_center=True),
    ]
)

image = test_transform(test_path)
print(image.shape)

torch.Size([1, 15, 1, 254])


In [260]:
import random as rd

A = np.array([-1,1,1])

# translate A to a random position
B = np.random.choice(A,3, replace=False)
print(B) 

[ 1  1 -1]


In [139]:
# Build the training and validation datasets
train_dataset = Dataset_2D(train_data['image_path'], train_data['label'], transform=train_transforms)
val_dataset = Dataset_2D(val_data['image_path'], val_data['label'], transform=val_transforms)


# Create data loaders
train_loader = DataLoader(train_dataset.data, batch_size=10, shuffle=True)
val_loader = DataLoader(val_dataset.data, batch_size=10)

N = len(train_dataset.data)
for i in range(10):
    rd_index = np.random.randint(0, N)
    print("shape of the " +str(i)+ "th image : " + str(train_dataset[i]['image'].shape))

RuntimeError: applying transform <monai.transforms.croppad.array.RandScaleCrop object at 0x000001E32FBDD6F0>

In [50]:


class ResNet18SingleChannel(nn.Module):
    # Define the ResNet18 model with a single input channel and an output value between 0 and 1
    def __init__(self, num_classes=2):
        super(ResNet18SingleChannel, self).__init__()
        # Load the pre-trained ResNet18 model
        resnet = models.resnet18(pretrained=True)
        # Modify the first convolutional layer to take a single channel input
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False)
        # Modify the final fully connected layer to output a single value

        self.resnet = resnet

        #final fc to go from [batch_size, 1000] to [batch_size, num_classes]
        self.fc = nn.Linear(1000, num_classes)

    def forward(self, x):
        x = self.resnet(x)
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x

model = ResNet18SingleChannel(num_classes=2).to(device)

#output = model.forward(torch.randn(3, 1, 49, 29))
#print(output)



tensor([[0.5767, 0.5020],
        [0.8802, 0.8488],
        [0.0636, 0.6604]], grad_fn=<SigmoidBackward0>)


NameError: name 'train_dataset' is not defined

In [None]:
# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in val_loader:
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels.argmax(dim=1)).sum().item()
print(f"Validation accuracy: {correct / total}")
