In [None]:
!pip install torch torchvision transformers datasets huggingface_hub opencv-python matplotlib tqdm


Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)
  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)
  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting dill

## Kaggle API and Daset download

In [None]:
!pip install kaggle
!mkdir ~/.kaggle
!echo '{"username":"aritraoreo","key":"95aecea54e1d1dd30d7af8f683e88778"}' > ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.json



In [None]:
!kaggle datasets download -d mateuszbuda/lgg-mri-segmentation
!unzip lgg-mri-segmentation.zip -d LGG_Dataset

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: LGG_Dataset/lgg-mri-segmentation/kaggle_3m/TCGA_DU_7294_19890104/TCGA_DU_7294_19890104_9_mask.tif  
  inflating: LGG_Dataset/lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_1.tif  
  inflating: LGG_Dataset/lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_10.tif  
  inflating: LGG_Dataset/lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_10_mask.tif  
  inflating: LGG_Dataset/lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_11.tif  
  inflating: LGG_Dataset/lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_11_mask.tif  
  inflating: LGG_Dataset/lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_12.tif  
  inflating: LGG_Dataset/lgg-mri-segmentation/kaggle_3m/TCGA_DU_7298_19910324/TCGA_DU_7298_19910324_12_mask.tif  
  inflating: LGG_Dataset/lgg-mri-segmentation

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import cv2
import pandas as pd
from glob import glob
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor

In [None]:
def create_df(data_dir):
  images_path=[]
  masks_paths=glob(f'{data_dir}/*/*')
  for i in masks_paths:
    images_path.append(i.replace('_mask',''))
  df=pd.DataFrame({'images':images_path,'masks':masks_paths})
  return df

data_dir = "/content/LGG_Dataset/kaggle_3m"
df = create_df(data_dir)

In [None]:
train_df,val_df=train_test_split(df,test_size=0.2,random_state=42)

In [None]:
target_size=512

train_transform=A.Compose([
    A.Resize(target_size,target_size),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()

])

val_transform=A.Compose([
    A.Resize(target_size,target_size),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

In [None]:
class LGGSegmentationDataset(Dataset):
  def __init__(self,df,transforms):
    self.df=df.reset_index(drop=True)
    self.transforms=transforms

  def __len__(self):
    return len(self.df)

  def __getitem__(self,idx):
    image=cv2.imread(self.df.loc[idx,'images'])
    image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)

    mask=cv2.imread(self.df.loc[idx,'masks'],0)//255
    augmented=self.transforms(image=image,mask=mask)

    return augmented['image'],augmented['mask'].unsqueeze(0).float()

In [None]:
train_dataset=LGGSegmentationDataset(train_df,train_transform)
val_dataset=LGGSegmentationDataset(val_df,val_transform)

In [None]:
train_loader=DataLoader(train_dataset,batch_size=16,shuffle=True)
val_loader=DataLoader(val_dataset,batch_size=16,shuffle=False)

In [None]:
num_classes=1
model=SegformerForSemanticSegmentation.from_pretrained('nvidia/segformer-b0-finetuned-ade-512-512',num_labels=num_classes,
                                                       ignore_mismatched_sizes=True)

model.config.semantic_loss_ignore_index=255
model.config.num_labels=1
model.decode_head.classifier=nn.Conv2d(256,1,kernel_size=1)
model.config.id2label={0:'background',1:'tumor'}
model.config.id2label={"background":0,"tumor":1}

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/6.88k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/15.0M [00:00<?, ?B/s]

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([1]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([1, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
print(model.config.num_labels)

2


In [None]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

In [None]:
class BalancedBCEDiceLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.0]).to(device))

    def forward(self, inputs, targets):

        if inputs.shape[2:] != targets.shape[2:]:
            inputs = F.interpolate(inputs, size=targets.shape[2:], mode='bilinear')

        bce_loss = self.bce(inputs, targets)


        probs = torch.sigmoid(inputs)
        intersection = (probs * targets).sum()
        dice_loss = 1 - (2. * intersection + 1e-6) / (probs.sum() + targets.sum() + 1e-6)

        return bce_loss + dice_loss

criterion = BalancedBCEDiceLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2, factor=0.5)

In [None]:
def calculate_metrics(preds, masks):
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()


    batch_dice, batch_iou, count = 0, 0, 0
    for pred, mask in zip(preds, masks):

        pred_flat = pred.view(-1)
        mask_flat = mask.view(-1)


        if mask_flat.sum() == 0:
            continue


        intersection = (pred_flat * mask_flat).sum()
        union = (pred_flat + mask_flat).sum() - intersection

        dice = (2. * intersection) / (pred_flat.sum() + mask_flat.sum() + 1e-8)
        iou = intersection / (union + 1e-8)

        batch_dice += dice.item()
        batch_iou += iou.item()
        count += 1

    return batch_dice/max(1,count), batch_iou/max(1,count)

In [None]:
best_dice = 0
early_stop_counter = 0

for epoch in range(100):

    model.train()
    train_loss = 0
    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/20"):
        images, masks = images.to(device), masks.to(device)

        outputs = model(pixel_values=images).logits
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()


    model.eval()
    val_dice, val_iou = 0, 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(pixel_values=images).logits
            outputs = F.interpolate(outputs, size=masks.shape[2:], mode='bilinear')

            batch_dice, batch_iou = calculate_metrics(outputs, masks)
            val_dice += batch_dice
            val_iou += batch_iou


    train_loss /= len(train_loader)
    val_dice /= len(val_loader)
    val_iou /= len(val_loader)


    scheduler.step(val_dice)


    print(f"\nEpoch {epoch+1}:")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Dice: {val_dice:.4f}")
    print(f"Val IoU: {val_iou:.4f}")
    print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")

    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), 'best_model.pth')
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= 3:
            print("Early stopping triggered")
            break


Epoch 1/20:  46%|████▌     | 180/393 [1:59:32<2:19:36, 39.33s/it]

In [None]:
def plot_sample(image, mask, pred):
    plt.figure(figsize=(18, 6))

    plt.subplot(1, 3, 1)
    plt.imshow(image.cpu().permute(1, 2, 0))
    plt.title("Input Image")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(mask.squeeze().cpu(), cmap='gray')
    plt.title("Ground Truth")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(pred.squeeze().cpu(), cmap='gray')
    plt.title("Prediction")
    plt.axis('off')

    plt.show()


model.load_state_dict(torch.load('best_model.pth'))
model.eval()

with torch.no_grad():
    for i, (images, masks) in enumerate(val_loader):
        if i >= 3: break
        images, masks = images.to(device), masks.to(device)
        outputs = model(pixel_values=images).logits
        outputs = F.interpolate(outputs, size=masks.shape[2:], mode='bilinear')
        preds = torch.sigmoid(outputs) > 0.5
        plot_sample(images[0], masks[0], preds[0])