In [None]:
!pip install opendatasets --q
import opendatasets as od
od.download('https://www.kaggle.com/datasets/briscdataset/brisc2025')

In [None]:
import torch
import torchvision
import torch.functional as F

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
import os

from tqdm.auto import tqdm

from torch.utils.data import Dataset, DataLoader
from torch import nn
"""
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
"""
import torchvision.transforms as TT

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
torch.manual_seed(24)

In [None]:
def dice_coefficient_calcuatior(logits: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5):
  if threshold is not None:
    preds = torch.where(torch.sigmoid(logits) > threshold, 1, 0).float()

    preds = preds.view(-1)
    targets = targets.view(-1)

    intersection = torch.sum(preds * targets)
    union = torch.sum(preds) + torch.sum(targets)

    dice_score = (2 * intersection + 1e-8) / (union + 1e-8)

  else:
    preds = torch.sigmoid(logits)

    preds = preds.view(-1)
    targets = targets.view(-1)

    intersection = torch.sum(preds * targets)
    union = torch.sum(preds) + torch.sum(targets)

    dice_score = (2 * intersection + 1e-8) / (union + 1e-8)

  return dice_score

In [None]:
BCE = nn.BCEWithLogitsLoss() #we are defining it here so when we do not need to define it for every batch when calculating loss

In [None]:
def UNet_Loss(logits: torch.Tensor, targets: torch.Tensor):
  return BCE(logits, targets) + (1 - dice_coefficient_calcuatior(logits, targets, threshold=None))

In [None]:
class Swish(nn.Module):
  def forward(self, X):
    return X * torch.sigmoid(X)

In [None]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.DoubleConv = nn.Sequential(
        nn.Conv2d(in_channels,
                  out_channels,
                  kernel_size=3,
                  stride=1,
                  padding=1,
                  bias=False),

        nn.BatchNorm2d(out_channels),

        Swish(), # Activation function

        nn.Conv2d(out_channels,
                  out_channels,
                  kernel_size=3,
                  stride=1,
                  padding=1,
                  bias=False),

        nn.BatchNorm2d(out_channels),

        Swish()
    )

  def forward(self, X):
    return self.DoubleConv(X)

In [None]:
class UNet(nn.Module):
  def __init__(self, in_channels=1, num_classes=1, model_channels=[64, 128, 256, 512, 1024]):
    super().__init__()

    self.in_channels = in_channels
    self.model_channels = model_channels

    self.ups = nn.ModuleList()
    self.downs = nn.ModuleList()

    self.max_pool = nn.MaxPool2d(kernel_size=2,
                                 stride=2)

    for n_channels in model_channels:
      self.downs.append(
          DoubleConv(n_channels//2, n_channels) if n_channels != 64 else DoubleConv(in_channels, n_channels)
      )

      if n_channels != model_channels[4]:
        self.downs.append(
            self.max_pool
        )
      if n_channels != model_channels[0]:
        self.ups.append(
            DoubleConv(n_channels, n_channels//2)
        )

      if n_channels != model_channels[0]:
        self.ups.append(
            nn.ConvTranspose2d(n_channels,
                              n_channels//2,
                              kernel_size=2,
                              stride=2)
            )

    self.final_conv = nn.Conv2d(in_channels=model_channels[0],
                                out_channels=num_classes,
                                kernel_size=1,
                                stride=1,
                                padding=0)

  def forward(self, X):
    out = X
    connections = []
    for i, part in enumerate(self.downs):
      out = part(out)
      if i % 2 == 0 and i < len(self.downs) - 1:
        connections.append(out)

    k = len(connections) - 1

    for i, part in enumerate(reversed(self.ups)):
      if i % 2 == 0:
        out = part(out)
        out = torch.cat((out, connections[k]), dim=1)
        k-=1
      else:
        out = part(out)

    out = self.final_conv(out)

    return out

In [None]:
train_transform = torchvision.transforms.Compose([
    TT.Resize((256, 256)),
    TT.RandomInvert(p=0.3),
    TT.ColorJitter(brightness=0.4, contrast=0.2),
    TT.ToTensor(),
    TT.Normalize(mean=(0.5,),
                std=(0.5,))
])

test_transform = torchvision.transforms.Compose([
    TT.Resize((256, 256)),
    TT.ToTensor(),
    TT.Normalize(mean=(0.5,),
                std=(0.5,))
])

In [None]:
class BrainTumorDataset(Dataset):
  def __init__(self, imgs_path, masks_path, transform_mode):
    self.imgs_path = imgs_path
    self.masks_path = masks_path
    self.transform_mode = transform_mode

    self.image_files = sorted(os.listdir(imgs_path))
    self.mask_files = sorted(os.listdir(masks_path))

  def __len__(self):
    return len(self.image_files)

  def __getitem__(self, idx):
    img_path, mask_path = os.path.join(self.imgs_path, self.image_files[idx]), os.path.join(self.masks_path, self.mask_files[idx])

    image = Image.open(img_path)
    label = Image.open(mask_path)

    image = image.convert("L")

    if self.transform_mode.upper() == 'TRAIN':
      image = train_transform(image)
    elif self.transform_mode.upper() == 'TEST':
      image = test_transform(image)
    else:
      raise RuntimeError("Please specify the transform_mode as either 'train' or 'test'")

    to_tensor = TT.Compose([TT.Resize((256, 256), interpolation=TT.InterpolationMode.NEAREST), TT.ToTensor()])
    label = to_tensor(label)

    return image, label

In [None]:
train_img_path = '/content/brisc2025/brisc2025/segmentation_task/train/images'
train_mask_path = '/content/brisc2025/brisc2025/segmentation_task/train/masks'

test_img_path = '/content/brisc2025/brisc2025/segmentation_task/test/images'
test_mask_path = '/content/brisc2025/brisc2025/segmentation_task/test/masks'

In [None]:
train_ds = BrainTumorDataset(train_img_path, train_mask_path, 'train')
test_ds = BrainTumorDataset(test_img_path, test_mask_path, 'test')

In [None]:
train_dl = DataLoader(train_ds,
                      batch_size=16,
                      shuffle=True,
                      drop_last=True,
                      num_workers=2,
                      pin_memory=True)

test_dl = DataLoader(test_ds,
                     batch_size=16,
                     shuffle=False,
                     drop_last=False,
                     num_workers=2,
                     pin_memory=True)

In [None]:
model = UNet().to(device)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=8, gamma=0.5)

In [15]:
SAVE_MODEL = False
LOAD_MODEL = False

assert not(SAVE_MODEL == True and LOAD_MODEL == True), "'SAVE_MODEL' and 'LOAD_MODEL' cannot be equal to 'True' at the same time"
assert isinstance(SAVE_MODEL, bool) and isinstance(LOAD_MODEL, bool), "'SAVE_MODEL' and 'LOAD_MODEL' must be booleans"

In [None]:
if not LOAD_MODEL:
  EPOCHS = 64

  best_dice_score = -1.0

  for epoch in tqdm(range(EPOCHS)):
    losses = []
    dice_scores = []

    model.train()
    for n_batch, (X ,y) in enumerate(train_dl):
      X, y = X.to(device), y.to(device)

      preds = model(X)
      loss = UNet_Loss(preds, y)

      with torch.no_grad():
        dice_score = dice_coefficient_calcuatior(preds, y, threshold=0.5)

      losses.append(loss.item())
      dice_scores.append(dice_score.item())

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    lr_scheduler.step()


    avg_dice_score = sum(dice_scores) / len(dice_scores)
    avg_loss = sum(losses) / len(losses)
    print('====================')
    print(f"Epoch: {epoch + 1}")
    print(f"Dice Score: {avg_dice_score:.2f} | Loss: {avg_loss:.4f}")

    ###################################################################
    if SAVE_MODEL:
      checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'lr_scheduler_state_dict': lr_scheduler.state_dict() if lr_scheduler else None,
        'loss': avg_loss,
        'dice_score': avg_dice_score
      }

      if avg_dice_score > best_dice_score:
        best_dice_score = avg_dice_score
        torch.save(checkpoint, 'checkpoint.pth')

else:
  assert os.path.exists('checkpoint.pth'), "Please upload your the config as 'checkpoint.pth'"
  checkpoint = torch.load('checkpoint.pth', map_location=device, weights_only=False)

  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
random_image_path = '/content/brisc2025/brisc2025/segmentation_task/train/images/brisc2025_train_00024_gl_ax_t1.jpg'
random_image_mask_path = '/content/brisc2025/brisc2025/segmentation_task/train/masks/brisc2025_train_00024_gl_ax_t1.png'

random_image = Image.open(random_image_path)
random_mask = Image.open(random_image_mask_path)

random_image = random_image.convert('L')
random_mask = random_mask.convert('L')

random_image = random_image.resize([256, 256])
random_mask = random_mask.resize([256, 256], Image.Resampling.NEAREST)

to_tensor = TT.Compose([TT.ToTensor()])
to_tensor_and_norm = TT.Compose([TT.ToTensor(), TT.Normalize(mean=(0.5,), std=(0.5,))])

X = to_tensor_and_norm(random_image).to(device)
y = to_tensor(random_mask).to(device)

model.eval()
with torch.no_grad():
  pred = model(X.unsqueeze(dim=0))
pred = pred.squeeze(dim=0)

train_dice_score = dice_coefficient_calcuatior(pred, y)

pred = torch.where(torch.sigmoid(pred) > 0.5, 1, 0)

In [None]:
def denorm(img):
  return img * 0.5 + 0.5

In [None]:
img_to_show = denorm(X)[0].cpu()
mask_to_show = y[0].cpu()
pred_to_show = pred[0].cpu()

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.imshow(img_to_show, cmap='gray')
plt.title('Brain Scan')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(mask_to_show, cmap='gray')
plt.title('True Brain Tumor Location')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(pred_to_show, cmap='gray')
plt.title('Predicted Brain Tumor Location')
plt.axis('off')

plt.show()

In [None]:
plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.imshow(img_to_show, cmap='gray')
plt.imshow(mask_to_show, cmap='Reds', alpha=0.2)
plt.title('True Brain Tumor Location')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(img_to_show, cmap='gray')
plt.imshow(pred_to_show, cmap='Reds', alpha=0.2)
plt.title('Predicted Brain Tumor Location')
plt.axis('off')

plt.tight_layout()

plt.figtext(0.67, -0.01, f"Dice Score: {train_dice_score:.2f}", fontsize=10)
plt.show()

In [None]:
model.eval()

test_dice_scores = []
test_losses = []

with torch.no_grad():
  for n_batch, (X, y) in enumerate(test_dl):
    X, y = X.to(device), y.to(device)

    preds = model(X)

    loss = UNet_Loss(preds, y)

    dice_score = dice_coefficient_calcuatior(preds, y)

    test_dice_scores.append(dice_score.item())
    test_losses.append(loss.item())

print('##################\n#  Test Results  #\n##################\n')
print(f'Dice Score: {np.mean(test_dice_scores):.2f} | Loss: {np.mean(test_losses):.4f}')

In [None]:
random_image_path = '/content/brisc2025/brisc2025/segmentation_task/test/images/brisc2025_test_00024_gl_ax_t1.jpg'
random_image_mask_path = '/content/brisc2025/brisc2025/segmentation_task/test/masks/brisc2025_test_00024_gl_ax_t1.png'

image = Image.open(random_image_path)
mask = Image.open(random_image_mask_path)

image = image.convert('L').resize([256, 256])
mask = mask.convert('L').resize([256, 256])

to_tensor_and_norm = TT.Compose([TT.ToTensor(), TT.Normalize(mean=(0.5,), std=(0.5,))])
to_tensor = TT.Compose([TT.ToTensor()])

image = to_tensor_and_norm(image)
mask = to_tensor(mask)

X = image.to(device)
y = mask.to(device)

pred = model(X.unsqueeze(dim=0))

test_dice_score = dice_coefficient_calcuatior(pred, y)

pred = torch.where(torch.sigmoid(pred) > 0.5, 1, 0)

In [None]:
image_to_show = denorm(image)[0].cpu().squeeze()
mask_to_show = mask[0].cpu().squeeze()
pred_to_show = pred[0].cpu().squeeze()

plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.imshow(image_to_show, cmap='gray')
plt.title('Brain Scan')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(mask_to_show, cmap='gray')
plt.title('True Brain Tumor Location')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(pred_to_show, cmap='gray')
plt.title('Predicted Brain Tumor Location')
plt.axis('off')

plt.show()

In [None]:
plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.imshow(image_to_show, cmap='grey')
plt.imshow(mask_to_show, cmap='Reds', alpha=0.2)
plt.title('True Brain Tumor Location')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(image_to_show, cmap='grey')
plt.imshow(pred_to_show, cmap='Reds', alpha=0.2)
plt.title('Predicted Brain Tumor Location')
plt.axis('off')

plt.tight_layout()

plt.figtext(0.67, -0.01, f"Dice Score: {test_dice_score:.2f}", fontsize=10)

plt.show()