In [9]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import math
import nibabel as nib
import numpy as np
from torch.utils.data import Dataset
import os
from tqdm import tqdm

In [10]:
class MRIDataset(Dataset):
    """Custom Dataset for loading MRI data as 2D slices."""
    def __init__(self, file_list, K, transform=None):
        self.file_list = file_list
        self.K = K
        self.transform = transform
        self.slices = []
        
        for file in tqdm(file_list, desc="Loading files"):
            img = nib.load(file).get_fdata()
            for i in range(img.shape[2]):  # Assuming the third dimension is for slices
                self.slices.append((file, i))

    def __len__(self):
        return len(self.slices)

    def __getitem__(self, index):
        file, slice_index = self.slices[index]
        img = nib.load(file).get_fdata()[:, :, slice_index]
        
        img_list = []
        if self.transform:
            for _ in range(self.K):
                img_transformed = self.transform(Image.fromarray(img.astype(np.uint8)))
                img_list.append(img_transformed)
        else:
            img_list = [torch.from_numpy(img).float().unsqueeze(0) for _ in range(self.K)]
        
        return img_list, 0  # 0 is a dummy target


In [11]:
class Conv4(torch.nn.Module):
    """A simple 4 layers CNN."""
    def __init__(self):
        super(Conv4, self).__init__()
        self.feature_size = 64
        self.name = "conv4"

        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(8),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )

        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )

        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )

        self.layer4 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d(1)
        )

        self.flatten = torch.nn.Flatten()

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, torch.nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        h = self.layer1(x)
        h = self.layer2(h)
        h = self.layer3(h)
        h = self.layer4(h)
        h = self.flatten(h)
        return h

In [12]:
class RelationalReasoning(torch.nn.Module):
  """Self-Supervised Relational Reasoning.
  Essential implementation of the method, which uses
  the 'cat' aggregation function (the most effective),
  and can be used with any backbone.
  """
  def __init__(self, backbone, feature_size=64):
    super(RelationalReasoning, self).__init__()
    self.backbone = backbone
    self.relation_head = torch.nn.Sequential(
                             torch.nn.Linear(feature_size*2, 256),
                             torch.nn.BatchNorm1d(256),
                             torch.nn.LeakyReLU(),
                             torch.nn.Linear(256, 1))

  def aggregate(self, features, K):
    relation_pairs_list = list()
    targets_list = list()
    size = int(features.shape[0] / K)
    shifts_counter=1
    for index_1 in range(0, size*K, size):
      for index_2 in range(index_1+size, size*K, size):
        # Using the 'cat' aggregation function by default
        pos_pair = torch.cat([features[index_1:index_1+size], 
                              features[index_2:index_2+size]], 1)
        # Shuffle without collisions by rolling the mini-batch (negatives)
        neg_pair = torch.cat([
                     features[index_1:index_1+size], 
                     torch.roll(features[index_2:index_2+size], 
                     shifts=shifts_counter, dims=0)], 1)
        relation_pairs_list.append(pos_pair)
        relation_pairs_list.append(neg_pair)
        targets_list.append(torch.ones(size, dtype=torch.float32))
        targets_list.append(torch.zeros(size, dtype=torch.float32))
        shifts_counter+=1
        if(shifts_counter>=size): 
            shifts_counter=1 # avoid identity pairs
    relation_pairs = torch.cat(relation_pairs_list, 0)
    targets = torch.cat(targets_list, 0)
    return relation_pairs, targets

  def train(self, tot_epochs, train_loader):
    optimizer = torch.optim.Adam([
                  {'params': self.backbone.parameters()},
                  {'params': self.relation_head.parameters()}])                               
    BCE = torch.nn.BCEWithLogitsLoss()
    self.backbone.train()
    self.relation_head.train()
    for epoch in range(tot_epochs):
      # the real target is discarded (unsupervised)
      for i, (data_augmented, _) in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{tot_epochs}"):
        K = len(data_augmented) # tot augmentations
        x = torch.cat(data_augmented, 0)
        optimizer.zero_grad()              
        # forward pass (backbone)
        features = self.backbone(x) 
        # aggregation function
        relation_pairs, targets = self.aggregate(features, K)
        # forward pass (relation head)
        score = self.relation_head(relation_pairs).squeeze()        
        # cross-entropy loss and backward
        loss = BCE(score, targets)
        loss.backward()
        optimizer.step()            
        # estimate the accuracy
        predicted = torch.round(torch.sigmoid(score))
        correct = predicted.eq(targets.view_as(predicted)).sum()
        accuracy = (100.0 * correct / float(len(targets)))
        
        if i % 100 == 0:
          print(f'Batch [{i+1}/{len(train_loader)}] - Loss: {loss.item():.5f}; Accuracy: {accuracy.item():.2f}%')

In [13]:
train_path = '../dataset/MICCAI_BraTS2020_TrainingData/'
modality_keys = ["flair"]


In [14]:
# Function to create a list of data dictionaries
def create_data_list(data_dir):
    data_list = []
    patients = os.listdir(data_dir)
    for patient in tqdm(patients, desc="Creating data list"):
        patient_dir = os.path.join(data_dir, patient)
        if os.path.isdir(patient_dir):
            data_dict = os.path.join(patient_dir, f"{patient}_flair.nii")
            data_list.append(data_dict)
    return data_list

In [15]:
# Hyper-parameters
K = 4
batch_size = 64
tot_epochs = 1
feature_size = 64

# Transformations for MRI slices
normalize = transforms.Normalize(mean=[0.5], std=[0.5])
train_transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])


In [16]:

backbone = Conv4()
model = RelationalReasoning(backbone, feature_size)

# Replace this with your actual list of .nii files
file_list = create_data_list(train_path)

Creating data list: 100%|██████████| 371/371 [00:00<00:00, 2403.43it/s]


In [17]:

train_set = MRIDataset(file_list=file_list, K=K, transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)


Loading files: 100%|██████████| 369/369 [04:54<00:00,  1.25it/s]


In [18]:


model.train(tot_epochs=tot_epochs, train_loader=train_loader)
torch.save(model.backbone.state_dict(), './backbone_mri.tar')


Epoch 1/1:   0%|          | 1/894 [00:08<2:01:20,  8.15s/it]

Batch [1/894] - Loss: 0.68923; Accuracy: 54.95%


Epoch 1/1:  11%|█▏        | 101/894 [13:23<1:32:39,  7.01s/it]

Batch [101/894] - Loss: 0.14401; Accuracy: 96.74%


Epoch 1/1:  22%|██▏       | 201/894 [24:49<1:20:15,  6.95s/it]

Batch [201/894] - Loss: 0.10423; Accuracy: 97.27%


Epoch 1/1:  34%|███▎      | 301/894 [36:00<1:04:40,  6.54s/it]

Batch [301/894] - Loss: 0.08755; Accuracy: 97.01%


Epoch 1/1:  45%|████▍     | 401/894 [47:22<58:20,  7.10s/it]  

Batch [401/894] - Loss: 0.07395; Accuracy: 97.79%


Epoch 1/1:  56%|█████▌    | 501/894 [57:57<30:25,  4.64s/it]  

Batch [501/894] - Loss: 0.05847; Accuracy: 98.44%


Epoch 1/1:  67%|██████▋   | 601/894 [1:06:05<23:21,  4.78s/it]

Batch [601/894] - Loss: 0.10885; Accuracy: 96.22%


Epoch 1/1:  78%|███████▊  | 701/894 [1:14:09<15:45,  4.90s/it]

Batch [701/894] - Loss: 0.04078; Accuracy: 98.83%


Epoch 1/1:  90%|████████▉ | 801/894 [1:22:18<07:19,  4.72s/it]

Batch [801/894] - Loss: 0.04672; Accuracy: 98.57%


Epoch 1/1: 100%|██████████| 894/894 [1:29:54<00:00,  6.03s/it]
