<a href="https://colab.research.google.com/github/Advait-git123/ct/blob/main/Untitled3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q kagglehub segmentation-models-pytorch albumentations opencv-python matplotlib


In [None]:
import torch, os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
import kagglehub, os

dataset_path = kagglehub.dataset_download("orvile/cpaisd-acute-ischemic-stroke-dataset")
print("Dataset path:", dataset_path)
print("Folders:", os.listdir(dataset_path))


In [38]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset

TRAIN_ROOT = os.path.join(dataset_path, "dataset", "train")

class StrokeDataset(Dataset):
    def __init__(self, root, max_cases=500):
        self.samples = []

        # recursively search for image.npz
        for root_dir, dirs, files in os.walk(root):
            if "image.npz" in files and "mask.npz" in files:
                img_path = os.path.join(root_dir, "image.npz")
                mask_path = os.path.join(root_dir, "mask.npz")
                self.samples.append((img_path, mask_path))

        print("Total found:", len(self.samples))

        if max_cases:
            self.samples = self.samples[:max_cases]

        print("Using cases:", len(self.samples))

    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        img_path, mask_path = self.samples[idx]

        img_npz = np.load(img_path)
        mask_npz = np.load(mask_path)

        img = img_npz[img_npz.files[0]]
        mask = mask_npz[mask_npz.files[0]]

        img = img.astype("float32")
        mask = mask.astype("float32")

        # ---------- Safe normalization ----------
        min_val = img.min()
        max_val = img.max()

        if max_val - min_val < 1e-5:
            img = np.zeros_like(img)
        else:
            img = (img - min_val) / (max_val - min_val)

        # ---------- Clean mask ----------
        mask = (mask > 0).astype("float32")

        # ---------- Convert to model format ----------
        img = np.stack([img, img, img], axis=0)     # 3-channel
        mask = np.expand_dims(mask, 0)

        # ensure no NaNs
        img = np.nan_to_num(img)
        mask = np.nan_to_num(mask)

        return torch.tensor(img), torch.tensor(mask)



In [39]:
from torch.utils.data import random_split, DataLoader

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8)


In [40]:
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="efficientnet-b0",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1
).to(device)


In [41]:
from torch.utils.data import DataLoader

dataset = StrokeDataset(TRAIN_ROOT, max_cases=500)
loader = DataLoader(dataset, batch_size=8, shuffle=True)

dice = smp.losses.DiceLoss(mode="binary")
focal = smp.losses.FocalLoss(mode="binary")

optimizer = torch.optim.Adam(model.parameters(),5e-5)


Total found: 8376
Using cases: 500


In [42]:
def dice_score(pred, target, eps=1e-6):
    pred = (pred > 0.5).float()
    intersection = (pred * target).sum()
    return (2 * intersection + eps) / (pred.sum() + target.sum() + eps)


In [None]:
best_dice = 0

for epoch in range(10):

    # ---------- TRAIN ----------
    model.train()
    loss_sum = 0

    for i,(imgs,masks) in enumerate(train_loader):
        imgs,masks = imgs.to(device), masks.to(device)

        preds = model(imgs)
        loss = 0.5*dice(preds,masks) + 0.5*focal(preds,masks)

        if torch.isnan(loss):
            continue

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        loss_sum += loss.item()

    print("Epoch",epoch,"Train Loss:",loss_sum/len(train_loader))


    # ---------- VALIDATION ----------
    model.eval()
    dice_total = 0

    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = torch.sigmoid(model(imgs))
            dice_total += dice_score(preds, masks).item()

    val_dice = dice_total / len(val_loader)
    print("Epoch",epoch,"Validation Dice:",val_dice)


    # ---------- SAVE BEST MODEL ----------
    if val_dice > best_dice:
        best_dice = val_dice
        torch.save(model.state_dict(), "best_model.pth")
        print("Saved new best model")


In [None]:
!nvidia-smi


In [None]:
import matplotlib.pyplot as plt

model.eval()

img,mask = dataset[5]
with torch.no_grad():
    pred = torch.sigmoid(model(img.unsqueeze(0).to(device)))[0,0].cpu().numpy()

plt.figure(figsize=(10,3))
plt.subplot(1,3,1); plt.title("Image"); plt.imshow(img[0],cmap='gray')
plt.subplot(1,3,2); plt.title("GT"); plt.imshow(mask[0],cmap='gray')
plt.subplot(1,3,3); plt.title("Prediction"); plt.imshow(pred,cmap='gray')
plt.show()
