## 2 constrasts model

This notebook load, preprocess the data and train a first modèle to predict if a 2 image is T1w or T2w.
The Notebook form helps running and testing fast before coding the final structure

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import nibabel as nib
import pandas as pd
from monai.data import Dataset, DataLoader, CacheDataset
import torchvision.models as models
from monai.transforms import (
    Compose,
    RandScaleCrop,
    RandFlip,
    RandRotate90,
    RandRotate,
    RandShiftIntensity,
    ToTensor,
    RandSpatialCrop,
    LoadImage,
    SqueezeDim,
    RandRotate,
)
import os
import nibabel as nib
import json
import pandas as pd
from sklearn.model_selection import train_test_split
import monai




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [50]:
# this cell aims at extracting the list of path relevant for the first model test which takes T1w T2w adn DWI as input

base_dir="data//data-multi-subject//"

desired_extension = ".json"

# Initialize lists to store the relative paths for T1w, T2w, and DWI files
t1w_file_paths = []
t2w_file_paths = []

print("Searching for T1w, T2w, and DWI files in", base_dir, "...")

# Traverse the directory structure
for root, dirs, files in os.walk(base_dir):
    # Exclude the "derivatives" subfolder
    if "derivatives" in dirs:
        dirs.remove("derivatives")
    for file in files:
        # Check if the file name contains the desired names
        if "T1w" in file and file.endswith(desired_extension):
            # Get the relative path of the T1w file
            relative_path = os.path.relpath(os.path.join(root, file), base_dir)
            # Remove the file extension
            relative_path = os.path.splitext(base_dir + relative_path)[0] + ".nii.gz"
            # Append the relative path to the T1w file paths list
            t1w_file_paths.append(relative_path)
        elif "T2w" in file and file.endswith(desired_extension):
            # Get the relative path of the T2w file
            relative_path = os.path.relpath(os.path.join(root, file), base_dir)
            # Remove the file extension
            relative_path = os.path.splitext(relative_path)[0] + ".nii.gz"
            # Append the relative path to the T2w file paths list
            t2w_file_paths.append(base_dir + relative_path)

#t1w_file_paths = t1w_file_paths[:20]
#t2w_file_paths = t2w_file_paths[:20]

print("Found", len(t1w_file_paths), "T1w files and", len(t2w_file_paths), "T2w files.")

Searching for T1w, T2w, and DWI files in data//data-multi-subject// ...
Found 267 T1w files and 267 T2w files.


In [97]:
# split the data into training and validation sets

# build a dataset with a colmn "file path" wich contiains the paths listed in both t1w_file_paths and t2w_file_paths
path_data = pd.DataFrame({"image_path" : t1w_file_paths + t2w_file_paths, "labels" : len(t1w_file_paths) * [0] + len(t2w_file_paths) * [1]})

train_data, val_data = train_test_split(path_data, test_size=0.2, random_state=0)
train_data.reset_index(drop=True, inplace=True)
val_data.reset_index(drop=True, inplace=True)



data//data-multi-subject//sub-ucl04\anat\sub-ucl04_T1w.nii.gz


In [98]:
# Define a custom dataset class
class Dataset_2D(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.data = {"paths" : paths, "labels" : labels}
        self.transform = transform
        self.length = len(self.data["paths"])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        path = self.data["paths"][index]
        label = [0,1] if self.data["labels"][index] else [1,0]
        if self.transform:
            image = self.transform(path)
            dim_to_squeeze = int(np.argmin(image.shape[1:])) + 1
            image = SqueezeDim(dim=dim_to_squeeze)(image)
            # Convert to tensor
            image = ToTensor()(image)
        return image, label
    
# use monai to define the transforms for data augmentation
# perform the following transformations : rotation (random between +3° and -3°), flipping (random between 0°,  90 °, 180° and 270°), cropping (Random size, random place) and shifting (random shift)

train_transforms = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
        RandRotate90(prob=0.5),
        RandFlip(prob=0.5),
        RandShiftIntensity(offsets=0.1, prob=0.5),
        RandRotate(range_x=3, range_y=3, range_z=3, prob=0.2),
        RandSpatialCrop(np.array([1, 1, 1]),  max_roi_size = np.random.choice(np.array([-1,-1,1]),3, replace=False), random_size=True, random_center=True),
    ]
)

val_transforms = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
        RandSpatialCrop(np.array([1, 1, 1]),  max_roi_size = np.random.choice(np.array([-1,-1,1]),3, replace=False), random_size=True, random_center=True),
    ]
)

# Create the custom datasets
train_dataset = Dataset_2D(
    paths=train_data['image_path'],
    labels=train_data['labels'],
    transform=train_transforms,
)

val_dataset = Dataset_2D(
    paths=val_data['image_path'],
    labels=val_data['labels'],
    transform=val_transforms,
)


In [99]:
# get an image
for i in range(13, train_dataset.length):
    print("Index:", i)
    image, label = train_dataset[i]
    print(image.shape, label)


# plot the image
import matplotlib.pyplot as plt
plt.imshow(image[0, :, :])
plt.show()


Index: 13
torch.Size([1, 50, 206]) [0, 1]
Index: 14
torch.Size([1, 189, 224]) [1, 0]
Index: 15
torch.Size([1, 34, 362]) [0, 1]
Index: 16
torch.Size([1, 5, 125]) [0, 1]
Index: 17
torch.Size([1, 125, 35]) [1, 0]
Index: 18
torch.Size([1, 44, 81]) [0, 1]
Index: 19
torch.Size([1, 92, 58]) [1, 0]
Index: 20
torch.Size([1, 29, 149]) [0, 1]
Index: 21
torch.Size([1, 122, 106]) [1, 0]
Index: 22
torch.Size([1, 28, 123]) [0, 1]
Index: 23
torch.Size([1, 36, 247]) [0, 1]
Index: 24
torch.Size([1, 61, 12]) [0, 1]
Index: 25
torch.Size([1, 261, 161]) [0, 1]
Index: 26
torch.Size([1, 6, 126]) [0, 1]
Index: 27
torch.Size([1, 134, 117]) [1, 0]
Index: 28
torch.Size([1, 52, 320]) [1, 0]
Index: 29
torch.Size([1, 36, 61]) [1, 0]
Index: 30
torch.Size([1, 3, 264]) [1, 0]
Index: 31
torch.Size([1, 118, 205]) [1, 0]
Index: 32
torch.Size([1, 134, 255]) [1, 0]
Index: 33
torch.Size([1, 214, 225]) [1, 0]
Index: 34
torch.Size([1, 178, 190]) [1, 0]
Index: 35
torch.Size([1, 55, 188]) [0, 1]
Index: 36
torch.Size([1, 60, 307]

tensor([[ 0.0000e+00, -1.2637e-02,  7.9990e-01, -3.6361e+00],
        [-8.0000e-01,  0.0000e+00,  0.0000e+00,  4.4262e+00],
        [ 0.0000e+00,  7.9990e-01,  1.2637e-02, -7.4166e+01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],
       dtype=torch.float64).


torch.Size([1, 46, 21]) [0, 1]
Index: 138
torch.Size([1, 55, 7]) [0, 1]
Index: 139
torch.Size([1, 164, 203]) [1, 0]
Index: 140
torch.Size([1, 110, 56]) [1, 0]
Index: 141
torch.Size([1, 76, 125]) [1, 0]
Index: 142
torch.Size([1, 19, 303]) [1, 0]
Index: 143
torch.Size([1, 10, 24]) [0, 1]
Index: 144
torch.Size([1, 288, 273]) [0, 1]
Index: 145
torch.Size([1, 266, 60]) [0, 1]
Index: 146
torch.Size([1, 72, 242]) [1, 0]
Index: 147
torch.Size([1, 302, 15]) [0, 1]
Index: 148
torch.Size([1, 61, 145]) [0, 1]
Index: 149
torch.Size([1, 55, 151]) [0, 1]
Index: 150
torch.Size([1, 52, 242]) [0, 1]
Index: 151
torch.Size([1, 86, 247]) [1, 0]
Index: 152
torch.Size([1, 173, 316]) [1, 0]
Index: 153
torch.Size([1, 214, 160]) [1, 0]
Index: 154
torch.Size([1, 26, 185]) [1, 0]
Index: 155
torch.Size([1, 87, 159]) [0, 1]
Index: 156
torch.Size([1, 45, 122]) [0, 1]
Index: 157
torch.Size([1, 108, 190]) [1, 0]
Index: 158
torch.Size([1, 215, 63]) [0, 1]
Index: 159
torch.Size([1, 213, 207]) [0, 1]
Index: 160
torch.Siz

tensor([[ 0.0000e+00, -1.8870e-02, -7.9978e-01, -9.2078e+00],
        [ 8.0000e-01,  0.0000e+00,  0.0000e+00, -1.1801e+02],
        [ 0.0000e+00, -7.9978e-01,  1.8870e-02,  5.3359e+01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],
       dtype=torch.float64).


torch.Size([1, 74, 16]) [0, 1]
Index: 397
torch.Size([1, 112, 1]) [1, 0]
Index: 398
torch.Size([1, 11, 103]) [0, 1]
Index: 399
torch.Size([1, 71, 151]) [1, 0]
Index: 400
torch.Size([1, 138, 73]) [1, 0]
Index: 401
torch.Size([1, 61, 119]) [0, 1]
Index: 402
torch.Size([1, 39, 144]) [0, 1]
Index: 403
torch.Size([1, 127, 43]) [1, 0]
Index: 404
torch.Size([1, 167, 183]) [1, 0]
Index: 405
torch.Size([1, 191, 53]) [1, 0]
Index: 406
torch.Size([1, 88, 165]) [1, 0]
Index: 407
torch.Size([1, 128, 155]) [0, 1]
Index: 408
torch.Size([1, 50, 130]) [0, 1]
Index: 409
torch.Size([1, 118, 192]) [1, 0]
Index: 410
torch.Size([1, 219, 48]) [1, 0]
Index: 411
torch.Size([1, 78, 318]) [1, 0]
Index: 412
torch.Size([1, 51, 100]) [0, 1]
Index: 413
torch.Size([1, 47, 42]) [0, 1]
Index: 414


In [37]:
# Build the training and validation datasets
train_dataset = Dataset_2D(train_data['image_path'], train_data['label'], transform=train_transforms)
#val_dataset = Dataset_2D(val_data['image_path'], val_data['label'], transform=val_transforms)


# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=10)

N = len(train_dataset.data)
for i in range(10):
    rd_index = np.random.randint(0, N)
    # print the rd_index-th image
    print(train_dataset[rd_index])

TypeError: dim must be None or a int but is int64.

In [50]:


class ResNet18SingleChannel(nn.Module):
    # Define the ResNet18 model with a single input channel and an output value between 0 and 1
    def __init__(self, num_classes=2):
        super(ResNet18SingleChannel, self).__init__()
        # Load the pre-trained ResNet18 model
        resnet = models.resnet18(pretrained=True)
        # Modify the first convolutional layer to take a single channel input
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3, bias=False)
        # Modify the final fully connected layer to output a single value

        self.resnet = resnet

        #final fc to go from [batch_size, 1000] to [batch_size, num_classes]
        self.fc = nn.Linear(1000, num_classes)

    def forward(self, x):
        x = self.resnet(x)
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x

model = ResNet18SingleChannel(num_classes=2).to(device)

#output = model.forward(torch.randn(3, 1, 49, 29))
#print(output)



tensor([[0.5767, 0.5020],
        [0.8802, 0.8488],
        [0.0636, 0.6604]], grad_fn=<SigmoidBackward0>)


NameError: name 'train_dataset' is not defined

In [None]:
# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in val_loader:
        images, labels = batch
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels.argmax(dim=1)).sum().item()
print(f"Validation accuracy: {correct / total}")
