In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        (os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
# Run this cell if packages missing (Kaggle/Colab)
!pip install timm==0.9.2 albumentations==1.4.3 opencv-python scikit-image


Collecting timm==0.9.2
  Downloading timm-0.9.2-py3-none-any.whl.metadata (68 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.5/68.5 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting albumentations==1.4.3
  Downloading albumentations-1.4.3-py3-none-any.whl.metadata (37 kB)
Collecting scikit-learn>=1.3.2 (from albumentations==1.4.3)
  Downloading scikit_learn-1.8.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting numpy>=1.24.4 (from albumentations==1.4.3)
  Downloading numpy-2.2.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.7->timm==0.9.2)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.7->timm==0

In [3]:
# Cell 2
import os, sys, time
from glob import glob
from pathlib import Path
from typing import List, Tuple, Optional
import numpy as np
import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

from skimage.filters import threshold_otsu




In [4]:
# Cell 3 - EDIT THESE
CONFIG = {
    'DATA_DIR': '/kaggle/input/x-ray-images/images',  # change to your dataset root
    'IMG_EXTS': ('.png', '.jpg', '.jpeg'),
    'IMG_SIZE': 256,
    'NUM_CLASSES': 4,   # include background, e.g., 0=bg,1=class1,2=class2,3=class3
    'BATCH_SIZE': 8,
    'EPOCHS': 20,
    'LR': 1e-4,
    'DEVICE': 'cuda' if torch.cuda.is_available() else 'cpu',
    'BACKBONE': 'swin_tiny_patch4_window7_224',  # smaller by default
    'PRETRAINED': True,
    'CHECKPOINT_DIR': './checkpoints',
    'MODE': 'auto_mask'   # we will autogenerate masks
}
os.makedirs(CONFIG['CHECKPOINT_DIR'], exist_ok=True)
# Where to save generated masks
GENERATED_MASKS_DIR = os.path.join(CONFIG['DATA_DIR'], 'masks_auto')
os.makedirs(GENERATED_MASKS_DIR, exist_ok=True)


OSError: [Errno 30] Read-only file system: '/kaggle/input/x-ray-images/images/masks_auto'

In [None]:
# Cell 4
def collect_images(root_dir: str, exts=('.png','.jpg','.jpeg')) -> List[str]:
    p = Path(root_dir)
    files = []
    for ext in exts:
        files += [str(x) for x in p.rglob(f'*{ext}')]
    files = sorted(files)
    return files

def read_image(path: str, size: int) -> np.ndarray:
    img = cv2.imread(path)
    if img is None:
        raise FileNotFoundError(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (size, size), interpolation=cv2.INTER_LINEAR)
    return img


In [None]:
# Cell 5
from sklearn.cluster import KMeans

def generate_pseudo_mask(img: np.ndarray, num_classes: int) -> np.ndarray:
    """
    Returns HxW mask with integer labels in [0, num_classes-1].
    Uses Otsu for coarse foreground/background, then kmeans on color for multiclass.
    """
    h, w, _ = img.shape
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    try:
        t = threshold_otsu(gray)
        bin_mask = (gray > t).astype(np.uint8)
    except Exception:
        _, bin_mask = cv2.threshold(gray, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)

    # clean
    kernel = np.ones((5,5), np.uint8)
    clean = cv2.morphologyEx(bin_mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
    clean = cv2.morphologyEx(clean, cv2.MORPH_CLOSE, kernel)

    if num_classes == 2:
        return clean.astype(np.uint8)

    # For multi-class, run KMeans on RGB of foreground pixels + assign labels
    Z = img.reshape(-1, 3).astype(np.float32)
    K = num_classes
    # if dataset small, KMeans can be slow; use sample
    sample_idx = np.random.choice(Z.shape[0], min(20000, Z.shape[0]), replace=False)
    km = KMeans(n_clusters=K, n_init=3, random_state=42)
    km.fit(Z[sample_idx])

    labels = km.predict(Z).reshape(h, w).astype(np.uint8)

    # choose cluster with lowest mean intensity as background
    centers = km.cluster_centers_
    center_means = centers.mean(axis=1)
    bg_cluster = int(np.argmin(center_means))

    remapped = np.zeros_like(labels)
    cur = 1
    for c in range(K):
        if c == bg_cluster:
            remapped[labels==c] = 0
        else:
            remapped[labels==c] = cur
            cur += 1
            if cur >= num_classes:
                # wrap if more clusters than expected
                cur = num_classes - 1
    return remapped.astype(np.uint8)

def save_mask(mask: np.ndarray, out_path: str):
    # save mask as single-channel PNG with values 0..(C-1)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    # ensure dtype uint8
    cv2.imwrite(out_path, mask.astype(np.uint8))


In [None]:
# Cell 6
images = collect_images(CONFIG['DATA_DIR'], CONFIG['IMG_EXTS'])
print(f"Found {len(images)} images")

# If images are inside e.g., /Training/glioma/..., they will be included.
# We'll generate masks into GENERATED_MASKS_DIR preserving filenames.
for i, img_path in enumerate(images):
    fname = os.path.basename(img_path)
    out_mask = os.path.join(GENERATED_MASKS_DIR, fname)
    if os.path.exists(out_mask):
        if i % 200 == 0:
            print(f"[{i}] mask exists, skip: {fname}")
        continue
    img = read_image(img_path, CONFIG['IMG_SIZE'])
    mask = generate_pseudo_mask(img, CONFIG['NUM_CLASSES'])
    save_mask(mask, out_mask)
    if i % 200 == 0:
        print(f"[{i}] saved mask: {fname}")
print("Mask generation finished. Masks saved to:", GENERATED_MASKS_DIR)


In [None]:
# Cell 7
class MedicalSegDataset(Dataset):
    def __init__(self, image_paths: List[str], mask_dir: str, img_size:int=256, transform=None):
        self.image_paths = image_paths
        self.mask_dir = mask_dir
        self.img_size = img_size
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        fname = os.path.basename(img_path)
        mask_path = os.path.join(self.mask_dir, fname)

        img = read_image(img_path, self.img_size)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            # fallback to zeros
            mask = np.zeros((self.img_size, self.img_size), dtype=np.uint8)

        if self.transform:
            aug = self.transform(image=img, mask=mask)
            img = aug['image']
            mask = aug['mask']
        else:
            img = torch.from_numpy(img.transpose(2,0,1)).float() / 255.0
            mask = torch.from_numpy(mask).long()

        return img, mask


In [None]:
# Cell 8
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self,x): return self.conv(x)

class UpConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            ConvBlock(in_ch, out_ch)
        )
    def forward(self,x): return self.up(x)

class SwinGRUSegmenter(nn.Module):
    def __init__(self, backbone_name='swin_tiny_patch4_window7_224', pretrained=True, num_classes=4):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, features_only=True, out_indices=(0,1,2,3))
        feats = self.backbone.feature_info.channels()
        self.reduce4 = nn.Conv2d(feats[-1], 512, 1)
        self.reduce3 = nn.Conv2d(feats[-2], 256, 1)
        self.reduce2 = nn.Conv2d(feats[-3], 128, 1)
        self.reduce1 = nn.Conv2d(feats[-4], 64, 1)
        self.gru_in_dim = feats[-2]
        self.gru = nn.GRU(input_size=self.gru_in_dim, hidden_size=256, batch_first=True)
        self.gru_proj = nn.Linear(256, 256)
        self.dec4 = ConvBlock(512,256)
        self.up3 = UpConv(256+256, 128)
        self.up2 = UpConv(128+128, 64)
        self.up1 = UpConv(64+64, 32)
        self.final_conv = nn.Sequential(nn.Conv2d(32,32,3,padding=1), nn.ReLU(inplace=True), nn.Conv2d(32, CONFIG['NUM_CLASSES'],1))

    def forward(self,x):
        feats = self.backbone(x)
        s1,s2,s3,s4 = feats
        r4 = self.reduce4(s4); r3 = self.reduce3(s3); r2 = self.reduce2(s2); r1=self.reduce1(s1)
        d4 = self.dec4(r4)
        b,c,h,w = r3.shape
        tokens = r3.view(b,c,h*w).permute(0,2,1)
        gru_out, _ = self.gru(tokens)
        pooled = gru_out.mean(dim=1)
        pooled = self.gru_proj(pooled).unsqueeze(-1).unsqueeze(-1)
        pooled = pooled.expand(-1,-1,d4.shape[2],d4.shape[3])
        d4 = d4 + pooled
        u3 = F.interpolate(d4, scale_factor=2, mode='bilinear', align_corners=False)
        u3 = torch.cat([u3, r3], dim=1); u3 = self.up3(u3)
        u2 = F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False)
        u2 = torch.cat([u2, r2], dim=1); u2 = self.up2(u2)
        u1 = F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False)
        u1 = torch.cat([u1, r1], dim=1); u1 = self.up1(u1)
        out = self.final_conv(u1)
        return out


In [None]:
# Cell 9
def dice_loss(pred, target, eps=1e-6):
    pred = F.softmax(pred, dim=1)
    target_onehot = F.one_hot(target, num_classes=pred.shape[1]).permute(0,3,1,2).float()
    inter = (pred * target_onehot).sum(dim=(2,3))
    denom = pred.sum(dim=(2,3)) + target_onehot.sum(dim=(2,3))
    dice = (2*inter + eps)/(denom + eps)
    return 1 - dice.mean()

def iou_score(pred, target, num_classes):
    pred_labels = pred.argmax(dim=1)
    ious = []
    for cls in range(num_classes):
        pred_c = (pred_labels == cls)
        target_c = (target == cls)
        inter = (pred_c & target_c).sum().item()
        union = (pred_c | target_c).sum().item()
        if union == 0:
            ious.append(1.0)
        else:
            ious.append(inter/union)
    return np.mean(ious)

# transforms
train_tf = A.Compose([A.Resize(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']), A.HorizontalFlip(p=0.5), A.RandomBrightnessContrast(p=0.3), A.Normalize(), ToTensorV2()])
val_tf = A.Compose([A.Resize(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']), A.Normalize(), ToTensorV2()])

# collect images and create dataset objects
all_images = collect_images(CONFIG['DATA_DIR'], CONFIG['IMG_EXTS'])
# ensure masks are generated in GENERATED_MASKS_DIR
train_n = int(0.8 * len(all_images))
np.random.seed(42)
idxs = np.random.permutation(len(all_images))
train_imgs = [all_images[i] for i in idxs[:train_n]]
val_imgs = [all_images[i] for i in idxs[train_n:]]

train_ds = MedicalSegDataset(train_imgs, GENERATED_MASKS_DIR, img_size=CONFIG['IMG_SIZE'], transform=train_tf)
val_ds   = MedicalSegDataset(val_imgs, GENERATED_MASKS_DIR, img_size=CONFIG['IMG_SIZE'], transform=val_tf)

train_loader = DataLoader(train_ds, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=4, pin_memory=True)


In [None]:
# Cell 10
device = CONFIG['DEVICE']
model = SwinGRUSegmenter(backbone_name=CONFIG['BACKBONE'], pretrained=CONFIG['PRETRAINED'], num_classes=CONFIG['NUM_CLASSES']).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['LR'])

def train_one_epoch():
    model.train()
    total_loss = 0.0
    for imgs, masks in train_loader:
        imgs = imgs.to(device); masks = masks.to(device)
        logits = model(imgs)
        ce = F.cross_entropy(logits, masks)
        d = dice_loss(logits, masks)
        loss = ce + d
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item() * imgs.size(0)
    return total_loss / len(train_loader.dataset)

def validate_epoch():
    model.eval()
    total_loss = 0.0; total_iou = 0.0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs = imgs.to(device); masks = masks.to(device)
            logits = model(imgs)
            ce = F.cross_entropy(logits, masks)
            d = dice_loss(logits, masks)
            loss = ce + d
            total_loss += loss.item() * imgs.size(0)
            total_iou += iou_score(logits, masks, CONFIG['NUM_CLASSES']) * imgs.size(0)
    return total_loss / len(val_loader.dataset), total_iou / len(val_loader.dataset)

best_iou = 0.0
for epoch in range(1, CONFIG['EPOCHS']+1):
    t0 = time.time()
    train_loss = train_one_epoch()
    val_loss, val_iou = validate_epoch()
    t1 = time.time()
    print(f"Epoch {epoch}/{CONFIG['EPOCHS']} time {t1-t0:.1f}s train_loss {train_loss:.4f} val_loss {val_loss:.4f} val_iou {val_iou:.4f}")
    ckpt = {'epoch':epoch, 'model':model.state_dict(), 'optim':optimizer.state_dict(), 'val_iou':val_iou}
    torch.save(ckpt, os.path.join(CONFIG['CHECKPOINT_DIR'], f'epoch_{epoch}.pth'))
    if val_iou > best_iou:
        best_iou = val_iou
        torch.save(ckpt, os.path.join(CONFIG['CHECKPOINT_DIR'], f'best.pth'))
print("Training finished. Best val IoU:", best_iou)


In [None]:
# Cell 11
import matplotlib.pyplot as plt

def predict_mask(model, img_path):
    model.eval()
    img = read_image(img_path, CONFIG['IMG_SIZE'])
    inp = torch.from_numpy(img.transpose(2,0,1)).float().unsqueeze(0)/255.0
    inp = inp.to(device)
    with torch.no_grad():
        logits = model(inp)
        pred = logits.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
    return img, pred

# show 4 samples from val set
for i in range(4):
    p = val_imgs[i]
    img, pred = predict_mask(model, p)
    plt.figure(figsize=(8,4))
    plt.subplot(1,3,1); plt.imshow(img); plt.title("Image"); plt.axis('off')
    plt.subplot(1,3,2); plt.imshow(pred, cmap='jet'); plt.title("Pred mask"); plt.axis('off')
    # load generated ground-truth mask if exists
    gt = cv2.imread(os.path.join(GENERATED_MASKS_DIR, os.path.basename(p)), cv2.IMREAD_GRAYSCALE)
    plt.subplot(1,3,3); plt.imshow(gt, cmap='gray'); plt.title("Pseudo GT mask"); plt.axis('off')
    plt.show()
