In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        (os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import os
from glob import glob
from pathlib import Path
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    from skimage.filters import threshold_otsu
except:
    threshold_otsu = None


In [3]:
CONFIG = {
    'DATA_DIR': '/kaggle/input/x-ray-images/images',  # change this to your dataset folder
    'IMG_SIZE': 224,
    'NUM_CLASSES': 2,  # background + your object
    'BATCH_SIZE': 8,
    'EPOCHS': 30,
    'LR': 1e-4,
    'DEVICE': 'cuda' if torch.cuda.is_available() else 'cpu',
    'CHECKPOINT_DIR': './checkpoints'
}

os.makedirs(CONFIG['CHECKPOINT_DIR'], exist_ok=True)


In [4]:
def list_images(folder: str, ext: str = '.png'):
    pattern = str(Path(folder) / f"**/*{ext}")
    return sorted([p for p in glob(pattern, recursive=True)])

def read_image(path: str, size: int):
    img = cv2.imread(path)
    if img is None:
        raise FileNotFoundError(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (size, size))
    return img

def pseudo_mask_otsu(img: np.ndarray):
    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    if threshold_otsu:
        t = threshold_otsu(gray)
        mask = (gray > t).astype(np.uint8)
    else:
        _, mask = cv2.threshold(gray, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8))
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8))
    return mask


In [5]:
class MedicalSegDataset(Dataset):
    def __init__(self, image_paths, img_size=224):
        self.image_paths = image_paths
        self.img_size = img_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = read_image(img_path, self.img_size)
        mask = pseudo_mask_otsu(img)
        mask = cv2.resize(mask.astype(np.uint8), (self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)

        img = torch.from_numpy(img.transpose(2,0,1)).float() / 255.0
        mask = torch.from_numpy(mask).long()
        return img, mask


In [6]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class UpConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            ConvBlock(in_ch, out_ch)
        )
    def forward(self, x):
        return self.up(x)

class HybridUNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.enc1 = ConvBlock(3, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBlock(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = ConvBlock(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        # Transformer Bottleneck
        self.trans = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=512, nhead=8),
            num_layers=1
        )

        self.up4 = UpConv(512, 256)
        self.up3 = UpConv(256+256, 128)
        self.up2 = UpConv(128+128, 64)
        self.up1 = UpConv(64+64, 32)
        self.final = nn.Conv2d(32, num_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        # Transformer bottleneck
        b = e4.flatten(2).permute(2,0,1)
        b = self.trans(b)
        b = b.permute(1,2,0).view(e4.shape)

        d4 = self.up4(b)
        d3 = self.up3(torch.cat([d4, e3], dim=1))
        d2 = self.up2(torch.cat([d3, e2], dim=1))
        d1 = self.up1(torch.cat([d2, e1], dim=1))
        out = self.final(d1)
        return out


In [7]:
def dice_loss(pred, target, eps=1e-6):
    pred = F.softmax(pred, dim=1)
    target_onehot = F.one_hot(target, num_classes=pred.shape[1]).permute(0,3,1,2).float()
    inter = (pred * target_onehot).sum(dim=(2,3))
    denom = pred.sum(dim=(2,3)) + target_onehot.sum(dim=(2,3))
    dice = (2*inter + eps) / (denom + eps)
    return 1 - dice.mean()

def iou_score(pred, target, num_classes=2):
    pred_labels = pred.argmax(dim=1)
    ious = []
    for cls in range(num_classes):
        pred_c = (pred_labels == cls)
        target_c = (target == cls)
        inter = (pred_c & target_c).sum().item()
        union = (pred_c | target_c).sum().item()
        ious.append(inter / union if union > 0 else 1.0)
    return np.mean(ious)


In [8]:
images = list_images(CONFIG['DATA_DIR'], ext='.png')
n = len(images)
split = int(0.8 * n)
train_imgs = images[:split]
val_imgs = images[split:]

train_ds = MedicalSegDataset(train_imgs, img_size=CONFIG['IMG_SIZE'])
val_ds = MedicalSegDataset(val_imgs, img_size=CONFIG['IMG_SIZE'])

train_loader = DataLoader(train_ds, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=2)


In [9]:
device = CONFIG['DEVICE']
model = HybridUNet(num_classes=CONFIG['NUM_CLASSES']).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['LR'])

best_iou = 0.0

for epoch in range(1, CONFIG['EPOCHS']+1):
    model.train()
    total_loss = 0.0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = F.cross_entropy(logits, masks) + dice_loss(logits, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)

    # Validation
    model.eval()
    val_loss = 0.0
    val_iou = 0.0
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            logits = model(imgs)
            val_loss += (F.cross_entropy(logits, masks) + dice_loss(logits, masks)).item() * imgs.size(0)
            val_iou += iou_score(logits, masks, CONFIG['NUM_CLASSES']) * imgs.size(0)
    val_loss /= len(val_loader.dataset)
    val_iou /= len(val_loader.dataset)

    print(f"Epoch {epoch}/{CONFIG['EPOCHS']}  Train Loss: {total_loss/len(train_loader.dataset):.4f}  Val Loss: {val_loss:.4f}  Val IoU: {val_iou:.4f}")

    if val_iou > best_iou:
        best_iou = val_iou
        torch.save(model.state_dict(), os.path.join(CONFIG['CHECKPOINT_DIR'], 'best_model.pth'))




RuntimeError: input and target batch or spatial sizes don't match: target [8, 224, 224], input [8, 2, 448, 448]

In [None]:
def predict_and_visualize(model, img_path, img_size=224, device='cpu'):
    model.eval()
    img = read_image(img_path, img_size)
    inp = torch.from_numpy(img.transpose(2,0,1)).float().unsqueeze(0).to(device)/255.0
    with torch.no_grad():
        logits = model(inp)
        pred = logits.argmax(dim=1).squeeze(0).cpu().numpy()
    
    overlay = (pred / pred.max() * 255).astype(np.uint8)
    overlay = cv2.applyColorMap(overlay, cv2.COLORMAP_JET)
    base = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    blended = cv2.addWeighted(base, 0.6, overlay, 0.4, 0)
    return blended
