# Initial tests of ViT for VALDO dataset

In [1]:
# !pip install transformers torch torchvision torchinfo
# %pip install accelerate -U
# %pip install transformers[torch]

#### Imports

In [2]:
import os
import cv2
import torch
import evaluate
import nibabel as nib
import numpy as np
import albumentations as A
import torchvision.transforms as T
import pandas as pd
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from albumentations import Compose, Normalize, Resize 
from albumentations.pytorch import ToTensorV2
from transformers import ViTFeatureExtractor, ViTForMaskedImageModeling, ViTImageProcessor
from transformers import AutoImageProcessor, AutoModel
from sklearn.model_selection import train_test_split

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.11 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


### Import the dataset

In [3]:
relative_path = 'VALDO_Dataset/Task2'
current_directory = os.getcwd()

parent_directory = os.path.abspath(os.path.join(current_directory, '../../'))

# Combine the path
testing_label = os.path.join(parent_directory, relative_path)

# Dataset folders 

folders = [case for case in os.listdir(testing_label) if os.path.isdir(os.path.join(testing_label, case))]

# Initialize the cohorts 

cases = {"cohort1": [], "cohort2": [], "cohort3": []}
# Print the list of folders
for folder in folders:
    if "sub-1" in folder:
        cases["cohort1"].append(folder)
    elif "sub-2" in folder:
        cases["cohort2"].append(folder)
    else:
        cases["cohort3"].append(folder)

In [4]:
# print(cases)

In [5]:
# Dividee the cases according to their cohorts 

In [6]:
cohort1_labels = []
cohort1_ids = []
for case in cases["cohort1"]:
    label = f"{testing_label}\\{case}\\{case}_space-T2S_CMB.nii.gz"
    id = f"{testing_label}\\{case}\\{case}_space-T2S_desc-masked_T2S.nii.gz"
    cohort1_labels.append(label)
    cohort1_ids.append(id)
# print("Label:", cohort1_labels, cohort1_labels.__len__())
# print("Ids:", cohort1_ids, cohort1_ids.__len__())

cohort2_labels = []
cohort2_ids = []
for case in cases["cohort2"]:
    label = f"{testing_label}\\{case}\\{case}_space-T2S_CMB.nii.gz"
    id = f"{testing_label}\\{case}\\{case}_space-T2S_desc-masked_T2S.nii.gz"
    cohort2_labels.append(label)
    cohort2_ids.append(id)
# print("Label:", cohort2_labels, cohort2_labels.__len__())
# print("Ids:", cohort2_ids, cohort2_ids.__len__())

cohort3_labels = []
cohort3_ids = []
for case in cases["cohort3"]:
    label = f"{testing_label}\\{case}\\{case}_space-T2S_CMB.nii.gz"
    id = f"{testing_label}\\{case}\\{case}_space-T2S_desc-masked_T2S.nii.gz"
    cohort3_labels.append(label)
    cohort3_ids.append(id)
# print("Label:", cohort3_labels, cohort3_labels.__len__())
# print("Ids:", cohort3_ids, cohort3_ids.__len__())

all_labels = cohort1_labels + cohort2_labels + cohort3_labels
all_ids = cohort1_ids + cohort2_ids + cohort3_ids

In [7]:
print(all_labels.__len__())

72


In [8]:
print(all_ids.__len__())

72


# Customized Dataset class for VALDO 

**This dataset class is different from the final project since ViT outputs the segmented mask**

In [9]:
class VALDODataset(Dataset):
    def __init__(self, img_paths, mask_paths, transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.transform = transform
        self.cmb_counts = self.count_cmb_per_image(self.mask_paths)
        
        assert len(self.img_paths) == len(
            self.mask_paths), "Number of images and masks should be same"
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        try:
            img_path = self.img_paths[idx]
            mask_path = self.mask_paths[idx]
            cmb_count = self.cmb_counts[idx]

            # Load the Image 
            img = nib.load(img_path).get_fdata()
            img = (img / np.max(img) * 255).astype(np.uint8)

            # Load the Mask annotations
            ann = nib.load(mask_path).get_fdata()
            ann = (ann / np.max(ann) * 255).astype(np.uint8)

            slices = []
            targets = []

            for i in range(img.shape[2]):
                img_slice = img[:, :, i]
                ann_slice = ann[:, :, i]

                # Convert single-channel to three-channel 
                img_slice = cv2.merge([img_slice] *3)

                # Augment both image and annotation slice together 
                img_slice_aug = self.transform(img_slice)
                ann_slice_aug = self.transform(ann_slice)
                # Convert the mask into tensor
                target = torch.tensor(ann_slice_aug, dtype=torch.long)

                slices.append(img_slice_aug)
                targets.append(target)
            
            return slices, targets, img_path, cmb_count
            # return slices
        except Exception as e:
            print(f"Error processing index {idx}: {e}")
            raise

    def extract_bounding_boxes(self, mask):
        # Extract bounding boxes from mask
        boxes = []
        contours, _ = cv2.findContours(
            mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for cnt in contours:
            x, y, w, h = cv2.boundingRect(cnt)
            boxes.append([(x-(w/2.5)), (y-(h/2.5)), ((w+x) + (w/3)), ((h+y) + (h/3))])
            # boxes.append([x, y, x +     w, y + h])
        return boxes

    def count_cmb_per_image(self, segmented_images):
        cmb_counts = []
        for img_path in segmented_images:
            img = nib.load(img_path)
            data = img.get_fdata()
            slice_cmb_counts = [self.extract_bounding_boxes(
                (data[:, :, i] > 0).astype(np.uint8)) for i in range(data.shape[2])]
            total_cmb_count = sum(len(contours)
                                  for contours in slice_cmb_counts)
            cmb_counts.append(total_cmb_count)
        return cmb_counts

### Transformations used in the dataset

In [10]:
# transform = Compose(
#     [
#         A.Resize(height=512, width = 512, p=1.0),
#     ],
    
#     )

In [11]:
transform = T.Compose(
    [
        T.ToPILImage(),         # Convert to PIL image  s
        T.Resize((512, 512)),   # Resize to the input size required by the ViT model
        T.ToTensor(),           # Convert to PyTorch tensor and scale to [0, 1]s
    ]
)

## Collate for each batch

This is used to return the slices, targets, and img_ids during each iteration in the dataloader

In [12]:
def collate_fn(batch):
    slices = []
    targets = []
    img_paths = []
    cmb_counts = []

    for item in batch:
        item_slices, item_targets, item_img_path, item_cmb_counts = item
        slices.extend(item_slices)
        targets.extend(item_targets)
        img_paths.append(item_img_path)
        cmb_counts.append(item_cmb_counts)

    slices = [torch.stack(tuple(slice_set)) for slice_set in slices]

    return slices, targets, img_paths,


## loading all cohorts to the dataset 

In [13]:
dataset = VALDODataset(
    img_paths=all_ids, mask_paths=all_labels, transform=transform)

pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
INFO:nibabel.global:pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


# Balancing the dataset for the numbers of CMBs


In [None]:
has_cmb = [1 if count > 0 else 0 for count in dataset.cmb_counts]

df_dataset = pd.DataFrame({
    'MRI Scans': dataset.img_paths,
    'Segmented Masks': dataset.mask_paths,
    'CMB Count': dataset.cmb_counts,
    'Has CMB': has_cmb
})


## Loading ViT feature Extractor and Model 

In [None]:
# # import model
# model_id = 'google/vit-base-patch16-224-in21k'
# feature_extractor = ViTFeatureExtractor.from_pretrained(
#     model_id
# )

**Documentation for this model:** https://huggingface.co/learn/computer-vision-course/en/unit3/vision-transformers/vision-transformers-for-image-segmentation

In [None]:
# Load the feature extractor and the model
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTForMaskedImageModeling.from_pretrained('google/vit-base-patch16-224-in21k')


### Sanity Check
_A good sanity check before launching the training is to compute the loss on one sample_


In [None]:
feature_extractor   

In [None]:
# example = feature_extractor(images=dataset.__getitem__(1), return_tensors="pt")
# example

# Training


In [None]:
train_df, val_df = train_test_split(
    df_dataset, test_size=0.2, stratify=df_dataset['Has CMB'], random_state=42)

In [None]:
train_dataset = VALDODataset(train_df['MRI Scans'].tolist(
), train_df['Segmented Masks'].tolist(), transform=transform)
val_dataset = VALDODataset(val_df['MRI Scans'].tolist(
), val_df['Segmented Masks'].tolist(), transform=transform)

In [None]:
print(f"Training dataset size: {len(train_dataset)}")

In [None]:
print(f"Validation dataset size: {len(val_dataset)}")

In [None]:
# print('Target - ', train_dataset.__getitem__(0)[0])
slice, target, img_path, cmb_count = train_dataset.__getitem__(0)
print('Slice - ', slice)
print('Target - ', target)
print('Image Path - ', img_path)
print('CMB Count - ', cmb_count)


In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)
val_dataset = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn)

#### Device will be CUDA

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
metric = evaluate.load("mean_iou")
metric

In [None]:
preprocessor = ViTImageProcessor(
    ignore_index=0,
    do_reduce_labels=False,
    do_resize=False,
    do_rescale=False,
    do_normalize=False,
)

### Transfer Learning

Transfer learning involves **freezing** certain layers of the model that have already been trained on a large dataset and fine-tuning other layers on a new (smaller) dataset. To do transfer learning with MaskFormer, the developers proposed the following approach:

**Freezing Components**: The Backbone and the Pixel Decoder will be frozen. Their pre-trained weights capture universal features applicable across different datasets and domains.

**Training Components**: The _Transformer Decoder_ and _MLP_ will be fine-tuned. This process customizes the segment embeddings and classification layers.

In [None]:
# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# If you only want to freeze some layers, you can selectively set requires_grad
# For example, freezing the encoder layers:
for param in model.vit.encoder.parameters():
    param.requires_grad = False

# Now, fine-tune only the last layers, such as the head
for param in model.vit.layernorm.parameters():
    param.requires_grad = True
for name, param in model.named_parameters():
    if 'layernorm' in name or 'heads' in name:
        param.requires_grad = True

# Check that the parameters are frozen/unfrozen correctly
for name, param in model.named_parameters():
    print(f"{name}: {'requires_grad=True' if param.requires_grad else 'requires_grad=False'}")

### Evaluates the given model using the specified dataloader and computes the mean Intersection over Union (IoU).
----
Args:
- model (MaskFormerForInstanceSegmentation): The trained model to be evaluated.
- dataloader (DataLoader): DataLoader containing the dataset for evaluation.
- preprocessor (AutoImageProcessor): The preprocessor used for post-processing the model outputs.
- metric (Any): Metric instance used for calculating IoU.
- id2label (dict): Dictionary mapping class ids to their corresponding labels.
- max_batches (int, optional): Maximum number of batches to evaluate. If None, evaluates on the entire validation dataset.

Returns:
float: The mean IoU calculated over the specified number of batches.

In [None]:
epochs = 1
learning_rate = 1e-5
log_interval = 100
id2label = {0: 'background', 1: 'CMB'}

In [None]:
def evaluate_model(
        model: ViTForMaskedImageModeling, 
        dataloader: DataLoader,
        preprocessor: AutoImageProcessor,
        metric: any,
        id2label: dict, 
        max_batches=None,
):
        model.eval()
        running_iou = 0
        num_batches = 0 
        with torch.no_grad():
                for idx, batch in enumerate(tqdm(dataloader)):
                        slices, targets = batch
                        # Unpack the tuple (assuming batch is a tuple of two tensors)
                        if max_batches and idx >= max_batches:
                                break
                        pixel_values = slices.to(device)
                        outputs = model(pixel_values=pixel_values)

                        original_images = targets
                        target_sizes = [
                                (image.shape[0], image.shape[1]) for image in original_images
                        ]

                        predicted_segmentation_maps = (
                                preprocessor.post_process_panoptic_segmentation(
                                        outputs, target_sizes=target_sizes
                                )
                        )

                        ground_truth_segmentation_maps = targets
                        metric.add_batch(
                                predictions=predicted_segmentation_maps,
                                references=ground_truth_segmentation_maps,
                        )

                running_iou += metric.compute(num_labels=len(id2label), ignore_index=0)[
                        "mean_iou"
                ]
                num_batches += 1
                mean_iou = running_iou / num_batches
                return mean_iou

def train_model(
        model: ViTForMaskedImageModeling, 
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        preprocessor: AutoImageProcessor,
        metric: any,
        num_epochs: int,
        learning_rate: float,
        log_interval: int = 100,
        id2label: dict = None,
):
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        print(f"Current epoch: {epoch+1}/{num_epochs}")
        model.train()

        running_loss = 0
        num_samples = 0

        for idx, batch in enumerate(tqdm(train_dataloader)):
            try:
                optimizer.zero_grad()
                
                slices, targets = batch
                outputs = model(
                    pixel_values=slices.to(device),
                    labels=targets.to(device),
                )

                loss = outputs.loss
                loss.backward()

                batch_size = slices.size(0)
                running_loss += loss.item()
                num_samples += batch_size

                if idx % log_interval == 0 and idx:
                    print(f"Current loss: {running_loss / num_samples}")

                optimizer.step()
            except Exception as e:
                print(f"Error in batch {idx}: {e}")
                # continue  # Skip this batch and continue

        val_mean_iou = evaluate_model(
            model=model,
            dataloader=val_dataloader,
            preprocessor=preprocessor,
            metric=metric,
            id2label=dict,
        )
        print(f"Validation mIoU: {val_mean_iou}")


In [None]:
# for idx, batch in enumerate(train_dataloader):
#     slices, targets, img_paths = batch
#     print("Slices" , slices)
#     print("targets", targets)
#     print("Path", img_paths)
#     break

In [None]:
train_model(
    model, 
    train_dataloader,
    val_dataset,
    preprocessor,
    metric,
    num_epochs=epochs,
    log_interval=log_interval,
    learning_rate=learning_rate,
    id2label=id2label,
)