## 2 constrasts model

This notebook load, preprocess the data and train a first modèle to predict if a 2 image is T1w or T2w.
The Notebook form helps running and testing fast before coding the final structure

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import nibabel as nib
import pandas as pd
from monai.data import Dataset, DataLoader, CacheDataset
import torchvision.models as models
from monai.transforms import (
    Compose,
    RandScaleCrop,
    RandFlip,
    RandRotate90,
    RandRotate,
    RandShiftIntensity,
    ToTensor,
    RandSpatialCrop,
    LoadImage,
    SqueezeDim,
    RandRotate,
)
import os
import nibabel as nib
import json
import pandas as pd
from sklearn.model_selection import train_test_split
import monai




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# this cell aims at extracting the list of path relevant for the first model test which takes T1w T2w adn DWI as image

base_dir="data//data-multi-subject//"

desired_extension = ".json"

# Initialize lists to store the relative paths for T1w, T2w, and DWI files
t1w_file_paths = []
t2w_file_paths = []

print("Searching for T1w, T2w, and DWI files in", base_dir, "...")

# Traverse the directory structure
for root, dirs, files in os.walk(base_dir):
    # Exclude the "derivatives" subfolder
    if "derivatives" in dirs:
        dirs.remove("derivatives")
    for file in files:
        # Check if the file name contains the desired names
        if "T1w" in file and file.endswith(desired_extension):
            # Get the relative path of the T1w file
            relative_path = os.path.relpath(os.path.join(root, file), base_dir)
            # Remove the file extension
            relative_path = os.path.splitext(base_dir + relative_path)[0] + ".nii.gz"
            # Append the relative path to the T1w file paths list
            t1w_file_paths.append(relative_path)
        elif "T2w" in file and file.endswith(desired_extension):
            # Get the relative path of the T2w file
            relative_path = os.path.relpath(os.path.join(root, file), base_dir)
            # Remove the file extension
            relative_path = os.path.splitext(relative_path)[0] + ".nii.gz"
            # Append the relative path to the T2w file paths list
            t2w_file_paths.append(base_dir + relative_path)

#t1w_file_paths = t1w_file_paths[:20]
#t2w_file_paths = t2w_file_paths[:20]

print("Found", len(t1w_file_paths), "T1w files and", len(t2w_file_paths), "T2w files.")

Searching for T1w, T2w, and DWI files in data//data-multi-subject// ...
Found 267 T1w files and 267 T2w files.


In [3]:
# split the data into training and validation sets

# build a dataset with a colmn "file path" wich contiains the paths listed in both t1w_file_paths and t2w_file_paths
path_data = pd.DataFrame({"image_path" : t1w_file_paths + t2w_file_paths, "labels" : len(t1w_file_paths) * [0] + len(t2w_file_paths) * [1]})

train_data, val_data = train_test_split(path_data, test_size=0.2, random_state=0)
train_data.reset_index(drop=True, inplace=True)
val_data.reset_index(drop=True, inplace=True)



In [68]:
# Define a custom dataset class
class Dataset_2D(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.data = {"paths" : paths, "labels" : labels}
        self.transform = transform
        self.length = len(self.data["paths"])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        path = self.data["paths"][index]
        label = [0,1] if self.data["labels"][index] else [1,0]
        if self.transform:
            image = self.transform(path)
            dim_to_squeeze = int(np.random.choice([0,1,2]))
            roi_min = np.array([15, 15, 15])
            roi_max = np.array([-1, -1, -1])
            roi_min[dim_to_squeeze] = 1
            roi_max[dim_to_squeeze] = 1
            image = RandSpatialCrop(roi_min,  max_roi_size = roi_max, random_size=True, random_center=True)(image)
            image = SqueezeDim(dim=dim_to_squeeze + 1)(image)
            # Convert to tensor
            image = ToTensor()(image)
            # add a dimension to the image, for exemple [1, 256, 256] -> [1, 1, 256, 256]
            image = image.unsqueeze(0)

        # convert label list to tensor with shape [1,2]
        label = torch.tensor([label])
        return image, label
    
# use monai to define the transforms for data augmentation
# perform the following transformations : rotation (random between +3° and -3°), flipping (random between 0°,  90 °, 180° and 270°), cropping (Random size, random place) and shifting (random shift)

train_transforms = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
        RandRotate90(prob=0.5),
        RandFlip(prob=0.5),
        RandShiftIntensity(offsets=0.1, prob=0.5),
        RandRotate(range_x=3, range_y=3, range_z=3, prob=0.2),
        
    ]
)

val_transforms = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
    ]
)

# Create the custom datasets
train_dataset = Dataset_2D(
    paths=train_data['image_path'],
    labels=train_data['labels'],
    transform=train_transforms,
)

val_dataset = Dataset_2D(
    paths=val_data['image_path'],
    labels=val_data['labels'],
    transform=val_transforms,
)


In [69]:
# get an image
i = np.random.randint(0, train_dataset.length)
print("Index:", i)
image, label = train_dataset[i]
print(image.shape, label)


Index: 143
torch.Size([1, 1, 55, 301]) tensor([[0, 1]])


In [70]:
class ResNet18SingleChannel(nn.Module):
    # Define the ResNet18 model with a single image channel and an output value between 0 and 1
    def __init__(self, num_classes=2):
        super(ResNet18SingleChannel, self).__init__()
        # Load the pre-trained ResNet18 model
        resnet = models.resnet18(pretrained=True)
        # Modify the first convolutional layer to take a single channel input
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False)
        # Modify the final fully connected layer to output a single value

        self.resnet = resnet

        #final fc to go from [batch_size, 1000] to [batch_size, num_classes]
        self.fc = nn.Linear(1000, num_classes)

    def forward(self, x):
        x = self.resnet(x)
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x

model = ResNet18SingleChannel(num_classes=2).to(device)

image, label = train_dataset[0]
print(image)
output = model.forward(image.to(device))
print(output)



metatensor([[[[1.0714, 2.0714, 3.0714,  ..., 2.0714, 2.0714, 3.0714],
          [6.0714, 2.0714, 1.0714,  ..., 0.0714, 2.0714, 3.0714],
          [3.0714, 1.0714, 3.0714,  ..., 1.0714, 2.0714, 1.0714],
          ...,
          [3.0714, 2.0714, 2.0714,  ..., 3.0714, 2.0714, 3.0714],
          [2.0714, 4.0714, 4.0714,  ..., 3.0714, 3.0714, 3.0714],
          [6.0714, 6.0714, 2.0714,  ..., 2.0714, 3.0714, 3.0714]]]])
metatensor([[0.5073, 0.5722]], grad_fn=<AliasBackward0>)


In [72]:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


# Define the training loop
def training_one_epoch(model):
    model.train()
    running_loss = 0.0
    queue_line = np.arange(train_dataset.length)
    np.random.shuffle(queue_line) 
    for i in queue_line:
        image, label = train_dataset[i]
        image, label = image.to(device), label.to(device)
        optimizer.zero_grad()
        outputs = model(image)
        loss = criterion(outputs, label.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return model, running_loss / len(train_dataset)

# Define the number of epochs
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1} / {num_epochs}")
    model, train_loss = training_one_epoch(model)
    print(f"Epoch {epoch + 1} training loss: {train_loss}")

#save model
torch.save(model.state_dict(), "model.pth")



Epoch 1 / 10
Training...
Model in training mode
Training on image  355
torch.Size([1, 1, 24, 296])
Training on image  353
torch.Size([1, 1, 125, 98])
Training on image  231
torch.Size([1, 1, 26, 246])
Training on image  1
torch.Size([1, 1, 51, 124])
Training on image  95


tensor([[ 0.0000e+00, -5.2965e-03,  1.0000e+00, -2.0622e+01],
        [-1.0000e+00,  0.0000e+00,  0.0000e+00,  2.6243e+01],
        [ 0.0000e+00,  9.9999e-01,  5.2966e-03, -7.7048e+01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],
       dtype=torch.float64).


torch.Size([1, 1, 94, 56])
Training on image  36
torch.Size([1, 1, 23, 299])
Training on image  113
torch.Size([1, 1, 136, 302])
Training on image  181
torch.Size([1, 1, 178, 211])
Training on image  397
torch.Size([1, 1, 68, 282])
Training on image  272
torch.Size([1, 1, 59, 38])
Training on image  296
torch.Size([1, 1, 183, 187])
Training on image  135
torch.Size([1, 1, 254, 21])
Training on image  199
torch.Size([1, 1, 38, 304])
Training on image  78
torch.Size([1, 1, 110, 154])
Training on image  167
torch.Size([1, 1, 136, 299])
Training on image  121
torch.Size([1, 1, 34, 48])
Training on image  40
torch.Size([1, 1, 172, 214])
Training on image  212
torch.Size([1, 1, 49, 146])
Training on image  403
torch.Size([1, 1, 72, 102])
Training on image  269
torch.Size([1, 1, 204, 137])
Training on image  107
torch.Size([1, 1, 15, 180])
Training on image  168
torch.Size([1, 1, 112, 103])
Training on image  160
torch.Size([1, 1, 134, 302])
Training on image  81
torch.Size([1, 1, 187, 103])


tensor([[ 0.0000e+00,  2.8356e-02,  9.9959e-01,  6.1151e+01],
        [-1.0002e+00,  0.0000e+00,  0.0000e+00, -8.2081e+01],
        [ 0.0000e+00,  9.9982e-01, -2.8350e-02,  9.1249e+01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],
       dtype=torch.float64).


torch.Size([1, 1, 36, 30])
Training on image  261
torch.Size([1, 1, 30, 166])
Training on image  263
torch.Size([1, 1, 283, 18])
Training on image  6
torch.Size([1, 1, 177, 274])
Training on image  254
torch.Size([1, 1, 247, 246])
Training on image  104


tensor([[ 0.0000e+00, -2.2680e-02, -7.9968e-01,  6.0325e+00],
        [ 8.0000e-01,  0.0000e+00,  0.0000e+00, -5.2215e+01],
        [ 0.0000e+00, -7.9968e-01,  2.2680e-02,  8.1739e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],
       dtype=torch.float64).


torch.Size([1, 1, 113, 37])
Training on image  252
torch.Size([1, 1, 285, 24])
Training on image  141
torch.Size([1, 1, 71, 58])
Training on image  23
torch.Size([1, 1, 38, 62])
Training on image  27
torch.Size([1, 1, 210, 310])
Training on image  164
torch.Size([1, 1, 20, 316])
Training on image  415
torch.Size([1, 1, 233, 164])
Training on image  143
torch.Size([1, 1, 305, 115])
Training on image  408
torch.Size([1, 1, 57, 237])
Training on image  92
torch.Size([1, 1, 163, 204])
Training on image  22
torch.Size([1, 1, 44, 295])
Training on image  372
torch.Size([1, 1, 117, 275])
Training on image  10
torch.Size([1, 1, 18, 172])
Training on image  289
torch.Size([1, 1, 43, 143])
Training on image  198
torch.Size([1, 1, 224, 185])
Training on image  329
torch.Size([1, 1, 21, 434])
Training on image  30
torch.Size([1, 1, 54, 234])
Training on image  64
torch.Size([1, 1, 28, 64])
Training on image  48
torch.Size([1, 1, 35, 177])
Training on image  337
torch.Size([1, 1, 131, 26])
Training

tensor([[ 0.0000e+00,  1.8870e-02,  7.9978e-01, -5.3758e+00],
        [-8.0000e-01,  0.0000e+00,  0.0000e+00, -5.2409e+01],
        [ 0.0000e+00,  7.9978e-01, -1.8870e-02, -5.5562e+01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],
       dtype=torch.float64).


torch.Size([1, 1, 59, 33])
Training on image  190
torch.Size([1, 1, 237, 265])
Training on image  381
torch.Size([1, 1, 29, 215])
Training on image  236
torch.Size([1, 1, 52, 261])
Training on image  77
torch.Size([1, 1, 16, 197])
Training on image  220
torch.Size([1, 1, 186, 319])
Training on image  72
torch.Size([1, 1, 301, 161])
Training on image  387
torch.Size([1, 1, 191, 86])
Training on image  217
torch.Size([1, 1, 83, 276])
Training on image  145
torch.Size([1, 1, 64, 258])
Training on image  331
torch.Size([1, 1, 52, 124])
Training on image  368
torch.Size([1, 1, 90, 146])
Training on image  388
torch.Size([1, 1, 32, 181])
Training on image  187
torch.Size([1, 1, 37, 264])
Training on image  155
torch.Size([1, 1, 284, 88])
Training on image  184
torch.Size([1, 1, 19, 245])
Training on image  291
torch.Size([1, 1, 51, 228])
Training on image  14
torch.Size([1, 1, 288, 75])
Training on image  109
torch.Size([1, 1, 60, 225])
Training on image  152
torch.Size([1, 1, 140, 158])
Tra

tensor([[ 0.0000e+00,  3.4767e-02,  7.9924e-01, -4.3796e+00],
        [-8.0000e-01,  0.0000e+00,  0.0000e+00,  1.0099e+02],
        [ 0.0000e+00,  7.9924e-01, -3.4767e-02,  6.3333e+01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],
       dtype=torch.float64).


torch.Size([1, 1, 306, 34])
Training on image  256
torch.Size([1, 1, 44, 117])
Training on image  270
torch.Size([1, 1, 210, 94])
Training on image  377
torch.Size([1, 1, 152, 164])
Training on image  383
torch.Size([1, 1, 48, 169])
Training on image  74
torch.Size([1, 1, 37, 213])
Training on image  394
torch.Size([1, 1, 25, 236])
Training on image  140
torch.Size([1, 1, 36, 201])
Training on image  75
torch.Size([1, 1, 219, 35])
Training on image  55
torch.Size([1, 1, 122, 163])
Training on image  193
torch.Size([1, 1, 22, 45])
Training on image  112
torch.Size([1, 1, 166, 268])
Training on image  24
torch.Size([1, 1, 61, 220])
Training on image  400
torch.Size([1, 1, 313, 128])
Training on image  266
torch.Size([1, 1, 276, 48])
Training on image  319
torch.Size([1, 1, 303, 225])
Training on image  87
torch.Size([1, 1, 148, 125])
Training on image  399
torch.Size([1, 1, 89, 174])
Training on image  301
torch.Size([1, 1, 157, 159])
Training on image  25
torch.Size([1, 1, 247, 205])
Tr

tensor([[ 0.0000e+00,  1.2637e-02, -7.9990e-01,  1.9791e+01],
        [ 8.0000e-01,  0.0000e+00,  0.0000e+00,  8.6826e+01],
        [ 0.0000e+00, -7.9990e-01, -1.2637e-02,  1.1423e+02],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],
       dtype=torch.float64).


torch.Size([1, 1, 41, 51])
Training on image  423
torch.Size([1, 1, 25, 67])
Training on image  205
torch.Size([1, 1, 43, 113])
Training on image  210
torch.Size([1, 1, 22, 34])
Training on image  284
torch.Size([1, 1, 173, 21])
Training on image  389
torch.Size([1, 1, 34, 187])
Training on image  209
torch.Size([1, 1, 21, 140])
Training on image  250
torch.Size([1, 1, 182, 104])
Training on image  165
torch.Size([1, 1, 66, 310])
Training on image  67
torch.Size([1, 1, 165, 41])
Training on image  102
torch.Size([1, 1, 53, 283])
Training on image  357
torch.Size([1, 1, 185, 163])
Training on image  215
torch.Size([1, 1, 107, 225])
Training on image  318
torch.Size([1, 1, 23, 74])
Training on image  118
torch.Size([1, 1, 141, 158])
Training on image  290
torch.Size([1, 1, 155, 30])
Training on image  351
torch.Size([1, 1, 197, 31])
Training on image  376
torch.Size([1, 1, 53, 208])
Training on image  258
torch.Size([1, 1, 301, 103])
Training on image  57
torch.Size([1, 1, 42, 164])
Trai

tensor([[ 0.0000e+00, -2.2680e-02, -7.9968e-01, -6.2642e+00],
        [ 8.0000e-01,  0.0000e+00,  0.0000e+00,  5.4985e+01],
        [ 0.0000e+00, -7.9968e-01,  2.2680e-02,  1.3898e+02],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],
       dtype=torch.float64).


torch.Size([1, 1, 180, 16])
Training on image  201
torch.Size([1, 1, 85, 57])
Training on image  42
torch.Size([1, 1, 175, 314])
Training on image  346
torch.Size([1, 1, 274, 158])
Training on image  243
torch.Size([1, 1, 103, 25])
Training on image  204
torch.Size([1, 1, 32, 230])
Training on image  111
torch.Size([1, 1, 16, 145])
Training on image  96
torch.Size([1, 1, 63, 42])
Training on image  57
torch.Size([1, 1, 133, 137])
Training on image  259
torch.Size([1, 1, 64, 304])
Training on image  354
torch.Size([1, 1, 241, 305])
Training on image  421
torch.Size([1, 1, 100, 49])
Training on image  367
torch.Size([1, 1, 30, 165])
Training on image  373
torch.Size([1, 1, 43, 181])
Training on image  174
torch.Size([1, 1, 88, 188])
Training on image  103
torch.Size([1, 1, 203, 106])
Training on image  295
torch.Size([1, 1, 206, 185])
Training on image  234
torch.Size([1, 1, 125, 71])
Training on image  40
torch.Size([1, 1, 31, 115])
Training on image  145
torch.Size([1, 1, 42, 143])
Tra

tensor([[ 0.0000e+00, -2.8356e-02, -9.9960e-01,  6.3793e+01],
        [ 1.0002e+00,  0.0000e+00,  0.0000e+00,  6.0526e+01],
        [ 0.0000e+00, -9.9982e-01,  2.8350e-02,  1.4476e+02],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],
       dtype=torch.float64).


torch.Size([1, 1, 49, 160])
Training on image  211
torch.Size([1, 1, 18, 269])
Training on image  75
torch.Size([1, 1, 20, 248])
Training on image  423
torch.Size([1, 1, 37, 69])
Training on image  167
torch.Size([1, 1, 56, 311])
Training on image  185
torch.Size([1, 1, 171, 173])
Training on image  324
torch.Size([1, 1, 264, 318])
Training on image  45
torch.Size([1, 1, 30, 308])
Training on image  240
torch.Size([1, 1, 195, 83])
Training on image  150
torch.Size([1, 1, 22, 118])
Training on image  206
torch.Size([1, 1, 92, 18])
Training on image  221
torch.Size([1, 1, 245, 86])
Training on image  418
torch.Size([1, 1, 69, 265])
Training on image  15
torch.Size([1, 1, 153, 22])
Training on image  351
torch.Size([1, 1, 121, 149])
Training on image  59
torch.Size([1, 1, 99, 197])
Training on image  223
torch.Size([1, 1, 250, 63])
Training on image  133
torch.Size([1, 1, 294, 234])
Training on image  33
torch.Size([1, 1, 257, 314])
Training on image  228
torch.Size([1, 1, 79, 512])
Train