# Vision Transformer for VALDO Dataset

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler, RandomSampler

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import nibabel as nib
import numpy as np
import pandas as pd
import cv2

from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from transformers import ViTForMaskedImageModeling, ViTFeatureExtractor


In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

### Class for VALDO Dataset 

In [None]:
class VALDODataset(Dataset):
    def __init__(self, cases, masks, transform):
        self.cases = cases
        self.masks = masks
        self.transform = transform
        self.cmb_counts = self.count_cmb_per_image(self.masks)

        assert len(self.cases) == len(self.masks), 'Cases and masks must have the same length'
    
    def __len__(self):
        return len(self.cases)
    
    def __getitem__(self, idx):
        try:
            case = self.cases[idx]
            mask = self.masks[idx]

            slices = []
            masks = []

        
            s, m = self.transform(mri_image_path=case, segmentation_mask_path=mask)
            if s is None or m is None:
                raise ValueError(f"Transform returned None for {case} and {mask}")
            
            
            slices.append(s)
            masks.append(m)
            
            return slices, masks, case, self.cmb_counts[idx]

        except Exception as e:
            print(f'Error loading image: {e}')
            return None, None, None, None
    
    def extract_bounding_boxes(self, mask):
        # Extract bounding boxes from mask
        boxes = []
        contours, _ = cv2.findContours(
            mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for cnt in contours:
            x, y, w, h = cv2.boundingRect(cnt)
            boxes.append([(x-(w/2.5)), (y-(h/2.5)), ((w+x) + (w/3)), ((h+y) + (h/3))])
            # boxes.append([x, y, x +     w, y + h])
        return boxes

    def count_cmb_per_image(self, segmented_images):
        cmb_counts = []
        for img_path in segmented_images:
            img = nib.load(img_path)
            data = img.get_fdata()
            slice_cmb_counts = [self.extract_bounding_boxes(
                (data[:, :, i] > 0).astype(np.uint8)) for i in range(data.shape[2])]
            total_cmb_count = sum(len(contours)
                                  for contours in slice_cmb_counts)
            cmb_counts.append(total_cmb_count)
        return cmb_counts


## Transform 

In [None]:
def load_nifti(file_path):
    try:
        nifti = nib.load(file_path)
        data = nifti.get_fdata()
        print(f"Loaded NIfTI data shape: {data.shape}")
        return data
    except Exception as e:
        print(f"Error loading NIfTI file {file_path}: {e}")
        return None

class NiftiToTensorTransform:
    def __init__(self, target_shape=(256, 256), in_channels=1):
        self.target_shape = target_shape
        self.in_channels = in_channels
        self.transform = A.Compose([
            A.Resize(height=target_shape[0], width=target_shape[1], p=1.0, always_apply=True),
            ToTensorV2()
        ], is_check_shapes=False)  # Disable shape checking if you are sure about your data consistency

    def convert_to_binary_mask(self, segmentation_mask):
        binary_mask = (segmentation_mask > 0).astype(np.uint8)
        return binary_mask

    def __call__(self, mri_image_path, segmentation_mask_path):
        try:
            # Load the images
            mri_image = load_nifti(mri_image_path)
            segmentation_mask = load_nifti(segmentation_mask_path)
            dim = nib.load(mri_image_path).header['dim'][0]

            if mri_image is None or segmentation_mask is None:
                raise ValueError("Failed to load NIfTI files or data is None.")

            # Convert multi-label mask to binary mask
            binary_mask = self.convert_to_binary_mask(segmentation_mask)

            if mri_image.shape[0] != dim:
                # If the number of channels is not equal to dim, adjust it
                mri_image = np.repeat(mri_image, dim, axis=0)
            # Apply transformations to the entire volume
            augmented = self.transform(image=mri_image, mask=binary_mask)
            image = augmented['image']
            mask = augmented['mask']

            # # Ensure the number of channels matches the expected input channels
            # if image.shape[0] != self.in_channels:
            #     raise ValueError(f"Expected {self.in_channels} input channels, but got {image.shape[0]} channels. Channels should be {self.in_channels}. MRI Image is {mri_image}")

            
            # Debugging prints
            print(f"Image shape after transform: {image.shape}")
            print(f"Mask shape after transform: {mask.shape}")
            print(f"Unique values in the transformed mask: {torch.unique(mask)}")

            return image, mask
        except Exception as e:
            print(f"Error in __call__ with {mri_image_path} and {segmentation_mask_path}: {e}")
            return None, None


In [None]:
transform = NiftiToTensorTransform()

## Train the model

### Dataloader 

In [None]:
testing_label_relative = '../VALDO_Dataset/Task2'
current_directory = os.getcwd()

two_directories_up = os.path.abspath(os.path.join(current_directory, "../"))

# Combine the current directory with the relative path
testing_label_absolute = os.path.join(
    two_directories_up, testing_label_relative)

folders = [item for item in os.listdir(testing_label_absolute) if os.path.isdir(
    os.path.join(testing_label_absolute, item))]

cases = {"cohort1": [], "cohort2": [], "cohort3": []}
# Print the list of folders
for folder in folders:
    if "sub-1" in folder:
        cases["cohort1"].append(folder)
    elif "sub-2" in folder:
        cases["cohort2"].append(folder)
    else:
        cases["cohort3"].append(folder)

In [None]:
cohort1_labels = []
cohort1_ids = []
for case in cases["cohort1"]:
    label = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_CMB.nii.gz"
    id = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_desc-masked_T2S.nii.gz"
    cohort1_labels.append(label)
    cohort1_ids.append(id)
# print("Label:", cohort1_labels, cohort1_labels.__len__())
# print("Ids:", cohort1_ids, cohort1_ids.__len__())

cohort2_labels = []
cohort2_ids = []
for case in cases["cohort2"]:
    label = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_CMB.nii.gz"
    id = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_desc-masked_T2S.nii.gz"
    cohort2_labels.append(label)
    cohort2_ids.append(id)
# print("Label:", cohort2_labels, cohort2_labels.__len__())
# print("Ids:", cohort2_ids, cohort2_ids.__len__())

cohort3_labels = []
cohort3_ids = []
for case in cases["cohort3"]:
    label = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_CMB.nii.gz"
    id = f"{testing_label_absolute}\\{case}\\{case}_space-T2S_desc-masked_T2S.nii.gz"
    cohort3_labels.append(label)
    cohort3_ids.append(id)
# print("Label:", cohort3_labels, cohort3_labels.__len__())
# print("Ids:", cohort3_ids, cohort3_ids.__len__())

all_labels = cohort1_labels + cohort2_labels + cohort3_labels
all_ids = cohort1_ids + cohort2_ids + cohort3_ids

# print(all_labels[0])
# print(all_ids[0])

## Collate for each batch

This is used to return the slices, targets, and img_ids during each iteration in the dataloader


In [None]:
def collate_fn(batch):
    slices = []
    targets = []
    img_paths = []
    cmb_counts = []

    for item in batch:
        if item is not None:  # Skip None items
            item_slices, item_targets, item_img_path, item_cmb_counts = item
            slices.extend(item_slices)
            targets.extend(item_targets)
            img_paths.append(item_img_path)
            cmb_counts.append(item_cmb_counts)

    if slices:
        cases = torch.stack(slices, dim=0)
        masks = torch.stack(targets, dim=0)
        return cases, masks, img_paths, cmb_counts
    else:
        return None, None, [], []

In [None]:
# dataset = VALDODataset(
#     cases=all_ids, masks=all_labels, transform=transform)

dataset = VALDODataset(
    cases=cohort1_ids, masks=cohort1_labels, transform=transform)

In [None]:
has_cmb = [1 if count > 0 else 0 for count in dataset.cmb_counts]

df_dataset = pd.DataFrame({
    'MRI Scans': dataset.cases,
    'Segmented Masks': dataset.masks,
    'CMB Count': dataset.cmb_counts,
    'Has CMB': has_cmb
})

# df_dataset

In [None]:
train_df, val_df = train_test_split(
    df_dataset, test_size=0.2, stratify=df_dataset['Has CMB'], random_state=42)

In [None]:
print(train_df['MRI Scans'].values)

In [None]:
print(val_df)

In [None]:
train_dataset = VALDODataset(train_df['MRI Scans'].tolist(
), train_df['Segmented Masks'].tolist(), transform=transform)
val_dataset = VALDODataset(val_df['MRI Scans'].tolist(
), val_df['Segmented Masks'].tolist(), transform=transform)

In [None]:
print(train_dataset[0])

In [None]:
print(val_dataset[0])

## Global Variables

In [None]:
in_channels = 1
out_channels = 35
embed_dim = 35
image_size = 256
patch_size = (1,32,32)
num_layers = 12
num_heads = 7
mlp_dim = 3072
num_classes = 2
num_epochs = 10
batch_size = 1
lr = 1e-50
num_workers = 0


# Cuda

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

### Setup patch embeddings

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim, img_size):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.img_size = image_size
        self.projection = nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, padding=(0,0,0), dilation=(1,1,1), groups=1, bias=True)  # Corrected padding and dilation
    
    def forward(self, x):
        # Check and move 'self.projection' weights to the same device as 'x'
        if x.device != self.projection.weight.device:
            self.projection = self.projection.to(x.device)
        
        x = self.projection(x)
        
        print(f"Shape after projection: {x.shape}")
        # Calculate the number of patches
        num_patches = x.shape[0] * x.shape[1]# Assuming x is of shape [batch_size, embed_dim, H, W, D] for 3D
        x = x.flatten(2)  # Flatten dimensions H, W, D into one dimension
        x = x.transpose(1, 2)  # Swap dimensions to [batch_size, num_patches, embed_dim]
        
        return x, num_patches


### Setup Position Embedding

In [None]:
class PositionEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, image_size):
        super(PositionEmbedding, self).__init__()
        self.patch_size = patch_size
        self.image_size = image_size
        self.embed_dim = embed_dim
        
        # Calculate number of patches
        num_patches_x = image_size // patch_size[1]
        num_patches_y = image_size // patch_size[2]
        self.num_patches = num_patches_x * num_patches_y
        
        # Learnable positional embeddings
        self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches, embed_dim))
        
        # Class token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
    
    def forward(self, x):
        if not isinstance(x, torch.Tensor):
            raise TypeError(f"Expected input x to be a tensor, but got {type(x)}")
        
        batch_size, num_patches, embed_dim = x.size()
        
        # Ensure the dimensions of x match the expected number of patches
        if num_patches != self.num_patches:
            raise ValueError(f"Number of patches ({num_patches}) does not match the expected number of patches ({self.num_patches}).")
        
        # Move class token and positional embeddings to the same device as x
        cls_tokens = self.cls_token.expand(x.size()[0], -1, -1).to(x.device)
        position_embedding = self.position_embedding.to(x.device)
        
        # Concatenate class token with x
        print("CLS_Tokens", cls_tokens.shape, "\n", "Input", x[0:,:,:].shape)
        x = torch.cat((cls_tokens, x[0:,:,:]), dim=1)  # Concatenate along the num_patches dimension, excluding the first cls_token
        print("Size of x after cat:", x.size())
        
        # Add positional embeddings
        x = x + position_embedding

### Transformer Encoder

In [None]:
class TransformerencoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerencoderLayer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        device = x.device
        
        # Move attention and feed_forward layers to the same device as x
        self.attention = self.attention.to(device)
        self.feed_forward = self.feed_forward.to(device)
        self.norm1 = self.norm1.to(device)
        self.norm2 = self.norm2.to(device)
        self.dropout = self.dropout.to(device)
        
        x = x.permute(1, 0, 2)
        attn_output, _ = self.attention(x, x, x)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        
        ff_output = self.feed_forward(x)
        x = x + self.dropout(ff_output)
        x = self.norm2(x)
        
        return x



### SegmentationHead

In [None]:
class SegmentationHead(nn.Module):
    def __init__(self, embed_dim, num_classes, image_size, patch_size):
        super(SegmentationHead, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = ((image_size // patch_size[1]) * (image_size // patch_size[2])) # Corrected num_patches calculation
        
        # Calculate side length of the patches
        side_length = int(self.num_patches ** 0.5)
        if side_length ** 2 != self.num_patches:
            raise ValueError("Number of patches is not a perfect square")
        
        # Define the Conv3d layer
        self.conv = nn.Conv3d(embed_dim, num_classes, kernel_size=(1, side_length, side_length), stride=(1, side_length, side_length), padding=(0, 0, 0))
    
    def forward(self, x):
        # Check and move 'self.conv' weights to the same device as 'x'
        if x.device != self.conv.weight.device:
            self.conv = self.conv.to(x.device)
        
        batch_size, num_patches, embed_dim = x.shape
        expected_patches = self.num_patches  # No additional token in this context

        if num_patches != expected_patches:
            raise ValueError(f"Expected {expected_patches} patches, but got {num_patches}")

        x = x.transpose(1, 2)  # Swap dimensions to [batch_size, embed_dim, num_patches]

        side_length = int(self.num_patches ** 0.5)
        if side_length ** 2 != self.num_patches:
            raise ValueError("Number of patches is not a perfect square")

        x = x.view(batch_size, embed_dim, side_length, side_length)  # Reshape to [batch_size, embed_dim, sqrt(num_patches), sqrt(num_patches)]
        
        x = self.conv(x)  # Apply convolution
        return x


### Form the ViT

In [None]:
class VisionTransformerSegmentation(nn.Module):
    def __init__(self, in_channels, patch_size, embed_dim, num_classes, image_size, num_heads, mlp_dim, num_layers, dropout=0.1):
        super(VisionTransformerSegmentation, self).__init__()
        self.patch_embedding = PatchEmbedding(in_channels, patch_size, embed_dim, image_size)
        self.position_embedding = PositionEmbedding(embed_dim, patch_size, image_size)
        self.transformer_encoder_layers = nn.ModuleList([
            TransformerencoderLayer(embed_dim, num_heads, mlp_dim, dropout) for _ in range(num_layers)
        ])
        self.segmentation_head = SegmentationHead(embed_dim, num_classes, image_size, patch_size)
    
    def forward(self, x):
        device = x.device
        x, num_patches = self.patch_embedding(x)
        max_seq_length = num_patches + 1  # Account for the class token
        print(f'Patch Embedding output shape: {x.shape}')
        x = self.position_embedding(x, max_seq_length)
        # print(f'Position Embedding output shape: {x.shape}')
        
        for layer in self.transformer_encoder_layers:
            x = layer(x)
            # print(f'Transformer Layer output shape: {x.shape}')
        
        # print('Segmenting ', x)
        x = self.segmentation_head(x)
        # print(f'Segmentation Head output shape: {x.shape}')
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=True)
        x = F.interpolate(x, scale_factor=1, mode='bilinear', align_corners=True)

        print(f'Final output shape: {x.shape}')
        # print(f'Final output: {x}')
        return x

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    sampler=RandomSampler(train_dataset),
    pin_memory=False,
    drop_last=False,  # drop last one for having same batch size
    num_workers=num_workers,
    collate_fn=collate_fn,
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
    sampler=SequentialSampler(val_dataset),
    pin_memory=False,
    collate_fn=collate_fn,
)

In [None]:
for c in train_loader:
    print(c)
    break

In [None]:
device

In [None]:
model = VisionTransformerSegmentation(in_channels, patch_size, embed_dim, num_classes, image_size, num_heads, mlp_dim, num_layers)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

### Loop the training 

In [None]:
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5)
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    epoch_loss = 0.0
    
    # Iterate over batches
    for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        images, masks = batch[0].float().to(device), batch[1].float().to(device)  # Ensure both images and masks are float
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        
        # print(f"Outputs shape: {outputs.shape}")  
        # print(f"Masks shape: {masks.shape}")
        
        # Reshape masks to match the batch size of outputs
        masks = masks.repeat(outputs.size(0), 1, 1, 1)  # Duplicate masks to match batch size
        
        # Permute outputs to match masks shape
        outputs = outputs.permute(0, 2, 3, 1)  
        # print(f"Outputs shape after permute: {outputs.shape}")     

        # Calculate loss
        loss = criterion(outputs, masks)
        epoch_loss += loss.item()
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
    
    # Print epoch loss
    avg_epoch_loss = epoch_loss / len(train_loader)

    # Adjust learning rate based on validation loss
    val_loss = 0.0
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        for val_batch in val_loader:
            val_images, val_masks = val_batch[0].float().to(device), val_batch[1].float().to(device)
            val_outputs = model(val_images)
            val_outputs = val_outputs.permute(0, 2, 3, 1)  # Permute to match masks shape
            val_masks = val_masks.repeat(val_outputs.size(0), 1, 1, 1)  # Duplicate masks to match batch size
            val_loss += criterion(val_outputs, val_masks).item()
    
    avg_val_loss = val_loss / len(val_loader)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_epoch_loss:.4f}')
    print(f'Validation Loss: {avg_val_loss:.4f}')
    scheduler.step(avg_val_loss)