In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [1]:
# 📌 Setup
!pip install -q timm
!pip install -q einops
!pip install -q torchvision

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from einops import rearrange
import timm

ERROR: Exception:
Traceback (most recent call last):
  File "C:\Users\ksair\anaconda3\Lib\site-packages\pip\_vendor\urllib3\response.py", line 438, in _error_catcher
    yield
  File "C:\Users\ksair\anaconda3\Lib\site-packages\pip\_vendor\urllib3\response.py", line 561, in read
    data = self._fp_read(amt) if not fp_closed else b""
           ^^^^^^^^^^^^^^^^^^
  File "C:\Users\ksair\anaconda3\Lib\site-packages\pip\_vendor\urllib3\response.py", line 527, in _fp_read
    return self._fp.read(amt) if amt is not None else self._fp.read()
           ^^^^^^^^^^^^^^^^^^
  File "C:\Users\ksair\anaconda3\Lib\site-packages\pip\_vendor\cachecontrol\filewrapper.py", line 98, in read
    data: bytes = self.__fp.read(amt)
                  ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\ksair\anaconda3\Lib\http\client.py", line 479, in read
    s = self.fp.read(amt)
        ^^^^^^^^^^^^^^^^^
  File "C:\Users\ksair\anaconda3\Lib\socket.py", line 720, in readinto
    return self._sock.recv_into(b)
           ^

ModuleNotFoundError: No module named 'torch'

In [None]:
import zipfile
import os

zip_path = "/content/drive/MyDrive/brain_tumour.zip"  # Update this path as needed
extract_path = "/content/dataset"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)


In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class LungTumorDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))
        self.masks = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])
        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset = LungTumorDataset(
    image_dir="/content/dataset/images",
    mask_dir="/content/dataset/masks",
    transform=transform
)

from torch.utils.data import random_split

train_len = int(0.8 * len(dataset))
val_len = len(dataset) - train_len
train_set, val_set = random_split(dataset, [train_len, val_len])

train_loader = DataLoader(train_set, batch_size=4, shuffle=True)
val_loader = DataLoader(val_set, batch_size=4, shuffle=False)


In [None]:
# ✅ Deformable Convolution Placeholder (replace with actual if using mmcv or deformable conv support)
class DeformableConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)

    def forward(self, x):
        # NOTE: Actual deformable conv is not implemented here.
        return self.conv(x)

# ✅ Weight Generation Unit
class WeightGenUnit(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        w = self.fc(x)
        return x * w

# ✅ Double Conv Block
def double_conv(in_ch, out_ch):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
    )

# ✅ WDU-Net Block with WGU + Deformable Conv
class WDUBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = double_conv(in_ch, out_ch)
        self.wgu = WeightGenUnit(out_ch, out_ch)
        self.deform = DeformableConv2D(out_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        x = self.wgu(x)
        x = self.deform(x)
        return x

# ✅ WDU-Net Architecture
class WDUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1):
        super().__init__()
        self.enc1 = WDUBlock(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = WDUBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.bottleneck = WDUBlock(128, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = WDUBlock(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = WDUBlock(128, 64)

        self.final = nn.Conv2d(64, out_ch, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        b = self.bottleneck(self.pool2(e2))
        d2 = self.dec2(torch.cat([self.up2(b), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.final(d1)

# ✅ Focal Asymmetric Similarity Loss (FASL)
class FASLoss(nn.Module):
    def __init__(self, alpha=0.8, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.where(targets == 1, inputs, 1 - inputs)
        FAS = self.alpha * (1 - pt) ** self.gamma * BCE
        return FAS.mean()

# ✅ Dummy Training Loop
def train_step(model, dataloader, optimizer, criterion):
    model.train()
    for x, y in dataloader:
        x, y = x.cuda(), y.cuda()
        pred = model(x)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# ✅ Dummy Inference with Ensemble (if multiple models are trained)
def ensemble_inference(models, x):
    with torch.no_grad():
        preds = [torch.sigmoid(m(x.cuda())) for m in models]
        return torch.mean(torch.stack(preds), dim=0)

# Example Usage
# model = WDUNet().cuda()
# criterion = FASLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [None]:
model = WDUNet().cuda()
criterion = FASLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [None]:
from tqdm import tqdm

def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            imgs, masks = imgs.cuda(), masks.cuda()
            preds = model(imgs)
            loss = criterion(preds, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        print(f"Train Loss: {avg_train_loss:.4f}")

        evaluate_model(model, val_loader)

def evaluate_model(model, val_loader):
    model.eval()
    iou_total = 0
    f_measure_total = 0
    mae_total = 0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.cuda(), masks.cuda()
            preds = model(imgs)
            preds_sigmoid = torch.sigmoid(preds)
            preds_binary = (preds_sigmoid > 0.5).float()

            intersection = (preds_binary * masks).sum((1, 2, 3))
            union = ((preds_binary + masks) > 0).float().sum((1, 2, 3))
            iou = (intersection + 1e-6) / (union + 1e-6)
            iou_total += iou.mean().item()

            # Calculate F-measure and MAE for the batch
            f_beta_batch = f_measure(preds_binary, masks)
            f_measure_total += f_beta_batch.mean().item()

            mae_batch = mae(preds_sigmoid, masks)
            mae_total += mae_batch.mean().item()


    avg_iou = iou_total / len(val_loader)
    avg_f_measure = f_measure_total / len(val_loader)
    avg_mae = mae_total / len(val_loader)

    print(f"Validation IoU: {avg_iou:.4f}, F-measure: {avg_f_measure:.4f}, MAE: {avg_mae:.4f}")

In [None]:
train_model(model, train_loader, val_loader, criterion, optimizer, epochs=15)


Epoch 1/15: 100%|██████████| 613/613 [02:23<00:00,  4.26it/s]


Train Loss: 0.0134
Validation IoU: 0.1058


Epoch 2/15: 100%|██████████| 613/613 [02:26<00:00,  4.19it/s]


Train Loss: 0.0089
Validation IoU: 0.2863


Epoch 3/15: 100%|██████████| 613/613 [02:25<00:00,  4.21it/s]


Train Loss: 0.0077
Validation IoU: 0.3178


Epoch 4/15: 100%|██████████| 613/613 [02:25<00:00,  4.21it/s]


Train Loss: 0.0067
Validation IoU: 0.3985


Epoch 5/15: 100%|██████████| 613/613 [02:25<00:00,  4.20it/s]


Train Loss: 0.0063
Validation IoU: 0.4440


Epoch 6/15: 100%|██████████| 613/613 [02:25<00:00,  4.20it/s]


Train Loss: 0.0058
Validation IoU: 0.4775


Epoch 7/15: 100%|██████████| 613/613 [02:28<00:00,  4.13it/s]


Train Loss: 0.0053
Validation IoU: 0.4316


Epoch 8/15: 100%|██████████| 613/613 [02:25<00:00,  4.21it/s]


Train Loss: 0.0052
Validation IoU: 0.4586


Epoch 9/15: 100%|██████████| 613/613 [02:26<00:00,  4.20it/s]


Train Loss: 0.0049
Validation IoU: 0.5029


Epoch 10/15: 100%|██████████| 613/613 [02:25<00:00,  4.21it/s]


Train Loss: 0.0046
Validation IoU: 0.4976


Epoch 11/15: 100%|██████████| 613/613 [02:25<00:00,  4.20it/s]


Train Loss: 0.0044
Validation IoU: 0.4069


Epoch 12/15: 100%|██████████| 613/613 [02:25<00:00,  4.21it/s]


Train Loss: 0.0041
Validation IoU: 0.5191


Epoch 13/15: 100%|██████████| 613/613 [02:25<00:00,  4.21it/s]


Train Loss: 0.0040
Validation IoU: 0.5319


Epoch 14/15: 100%|██████████| 613/613 [02:25<00:00,  4.20it/s]


Train Loss: 0.0037
Validation IoU: 0.4926


Epoch 15/15: 100%|██████████| 613/613 [02:25<00:00,  4.22it/s]


Train Loss: 0.0036
Validation IoU: 0.5643


In [None]:
def f_measure(pred, gt, beta_squared=0.3):
    """
    Computes the Fβ-score between predicted and ground truth masks for a batch.

    Args:
        pred (torch.Tensor): Binary prediction mask (0 or 1), shape [N, C, H, W]
        gt (torch.Tensor): Binary ground truth mask (0 or 1), shape [N, C, H, W]
        beta_squared (float): Beta^2 parameter. Defaults to 0.3.

    Returns:
        torch.Tensor: Fβ-scores for each item in the batch
    """
    pred = pred.float()
    gt = gt.float()

    # Ensure shapes match (assuming single channel for now)
    if pred.shape[-2:] != gt.shape[-2:]:
        pred = F.interpolate(pred, size=gt.shape[-2:], mode='bilinear', align_corners=False)

    TP = (pred * gt).sum(dim=(-2, -1))
    precision = TP / (pred.sum(dim=(-2, -1)) + 1e-8)
    recall = TP / (gt.sum(dim=(-2, -1)) + 1e-8)

    f_beta = ((1 + beta_squared) * precision * recall) / (beta_squared * precision + recall + 1e-8)
    return f_beta

def mae(pred, gt):
    """
    Computes Mean Absolute Error between predicted and ground truth masks for a batch.

    Args:
        pred (torch.Tensor): Predicted mask, values in [0, 1], shape [N, C, H, W]
        gt (torch.Tensor): Ground truth mask, values in [0, 1], shape [N, C, H, W]

    Returns:
        torch.Tensor: MAE values for each item in the batch
    """
    pred = pred.float()
    gt = gt.float()

    # Ensure shapes match (assuming single channel for now)
    if pred.shape[-2:] != gt.shape[-2:]:
        pred = F.interpolate(pred, size=gt.shape[-2:], mode='bilinear', align_corners=False)


    return torch.abs(pred - gt).mean(dim=(-2, -1))

In [None]:
# Save
torch.save(model.state_dict(), "wdu_net.pth")

# Load
model.load_state_dict(torch.load("wdu_net.pth"))
model.eval()


WDUNet(
  (enc1): WDUBlock(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
    (wgu): WeightGenUnit(
      (fc): Sequential(
        (0): AdaptiveAvgPool2d(output_size=1)
        (1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (2): Sigmoid()
      )
    )
    (deform): DeformableConv2D(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (enc2): WDUBlock(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2