In [1]:
# ===============================
# tpu_resnet_multilabel_training.py
# ===============================

import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from tqdm import tqdm

import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

# -------------------------
# Label columns
# -------------------------
LABEL_COLUMNS = [
    'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum',
    'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion',
    'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices'
]

# -------------------------
# Dataset (3-channel images with augmentation)
# -------------------------
class XRayDataset(Dataset):
    def __init__(self, df, image_dir, img_size=(512,512), augment=False):
        self.df = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.img_size = img_size
        self.augment = augment

        if augment:
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.485, 0.485], [0.229, 0.229, 0.229])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize(img_size),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.485, 0.485], [0.229, 0.229, 0.229])
            ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_dir, row['Image_name'])
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if img is None:
            img = np.zeros((self.img_size[0], self.img_size[1], 3), dtype=np.uint8)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transform(img)
        labels = row[LABEL_COLUMNS].values.astype(np.float32)
        return img, torch.tensor(labels, dtype=torch.float32)

# -------------------------
# Weighted Focal Loss
# -------------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        if self.alpha is not None:
            BCE_loss = self.alpha.to(inputs.device) * BCE_loss
        F_loss = (1-pt)**self.gamma * BCE_loss
        if self.reduction == 'mean':
            return F_loss.mean()
        else:
            return F_loss.sum()

# -------------------------
# Hybrid Loss (Focal + BCE)
# -------------------------
class HybridLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, focal_weight=0.5):
        super().__init__()
        self.focal = FocalLoss(alpha=alpha, gamma=gamma)
        self.bce = nn.BCEWithLogitsLoss(pos_weight=alpha)
        self.focal_weight = focal_weight

    def forward(self, inputs, targets):
        focal_loss = self.focal(inputs, targets)
        bce_loss = self.bce(inputs, targets)
        return self.focal_weight * focal_loss + (1 - self.focal_weight) * bce_loss

# -------------------------
# Main Training
# -------------------------
if __name__ == "__main__":
    print("================================================")
    # -------------------------
    # Paths
    # -------------------------
    TRAIN_CSV = "/kaggle/input/grand-xray-slam-division-b/train2.csv"
    TRAIN_DIR = "/kaggle/input/grand-xray-slam-division-b/train2"

    # -------------------------
    # Load CSV and convert labels to numeric
    # -------------------------
    train_df = pd.read_csv(TRAIN_CSV)
    train_df[LABEL_COLUMNS] = train_df[LABEL_COLUMNS].apply(pd.to_numeric, errors='coerce').fillna(0)

    # -------------------------
    # Compute class weights
    # -------------------------
    pos_counts = train_df[LABEL_COLUMNS].sum()
    neg_counts = len(train_df) - pos_counts
    class_weights = torch.tensor((neg_counts / (pos_counts + 1e-6)).values, dtype=torch.float32)

    # -------------------------
    # Dataset and DataLoader
    # -------------------------
    dataset = XRayDataset(train_df, TRAIN_DIR, img_size=(224,224), augment=True)
    loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)  # adjusted for 512x512

    # -------------------------
    # TPU Device and Loader
    # -------------------------
    device = xm.xla_device()
    loader = pl.MpDeviceLoader(loader, device)
    print("Using TPU device:", device)

    # -------------------------
    # Model
    # -------------------------
    from torchvision.models import resnet152, ResNet152_Weights
    model = resnet152(weights=ResNet152_Weights.IMAGENET1K_V2)
    model.fc = nn.Linear(2048, len(LABEL_COLUMNS))
    model.to(device)

    # -------------------------
    # Optimizer & Loss
    # -------------------------
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    criterion = HybridLoss(alpha=class_weights, focal_weight=0.6)

    # -------------------------
    # Training loop
    # -------------------------
    num_epochs = 5

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        tqdm_loader = tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")

        for imgs, labels in tqdm_loader:
            imgs, labels = imgs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            xm.optimizer_step(optimizer)

            total_loss += loss.item()
            tqdm_loader.set_postfix({"loss": total_loss/(tqdm_loader.n+1)})

        print(f"Epoch {epoch+1} completed. Avg Loss: {total_loss/len(loader):.4f}")

        import numpy as np
        import torch
        from sklearn.metrics import roc_auc_score
        all_labels = torch.cat(tuple(labels), dim=0).cpu().numpy()
        all_preds = torch.cat(tuple(outputs), dim=0).detach().cpu().numpy()
        try:
            auc = roc_auc_score(all_labels, all_preds, average='macro')
        except ValueError:
            auc = float('nan')

        # Print results
        print(f"\nEpoch {epoch + 1}/{num_epochs} Summary:")
        print(f"Avg Loss: {total_loss / len(loader):.4f}")
        print(f"ROC AUC: {auc:.4f}\n")

    print("Training completed on TPU cores!")






  device = xm.xla_device()


Using TPU device: xla:0


E0000 00:00:1759972506.158885      10 common_lib.cc:648] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:238


Downloading: "https://download.pytorch.org/models/resnet152-f82ba261.pth" to /root/.cache/torch/hub/checkpoints/resnet152-f82ba261.pth


100%|██████████| 230M/230M [00:00<00:00, 256MB/s] 
Epoch 1/5: 100%|██████████| 1696/1696 [31:57<00:00,  1.13s/batch, loss=0.517] 


Epoch 1 completed. Avg Loss: 0.5170

Epoch 1/5 Summary:
Avg Loss: 0.5170
ROC AUC: 0.8761



Epoch 2/5: 100%|██████████| 1696/1696 [20:55<00:00,  1.35batch/s, loss=0.465]


Epoch 2 completed. Avg Loss: 0.4649

Epoch 2/5 Summary:
Avg Loss: 0.4649
ROC AUC: 0.9179



Epoch 3/5: 100%|██████████| 1696/1696 [21:24<00:00,  1.32batch/s, loss=0.448]


Epoch 3 completed. Avg Loss: 0.4479

Epoch 3/5 Summary:
Avg Loss: 0.4479
ROC AUC: 0.8935



Epoch 4/5: 100%|██████████| 1696/1696 [21:43<00:00,  1.30batch/s, loss=0.437]  


Epoch 4 completed. Avg Loss: 0.4366

Epoch 4/5 Summary:
Avg Loss: 0.4366
ROC AUC: 0.9688



Epoch 5/5: 100%|██████████| 1696/1696 [21:26<00:00,  1.32batch/s, loss=0.427]

Epoch 5 completed. Avg Loss: 0.4268

Epoch 5/5 Summary:
Avg Loss: 0.4268
ROC AUC: 0.9532

Training completed on TPU cores!





In [2]:
# =========================
# Inference & Submission
# =========================
if __name__ == "__main__":
    # -------------------------
    # Test Set Path
    # -------------------------
    TEST_DIR = "/kaggle/input/grand-xray-slam-division-b/test2"

    # -------------------------
    # Create Test Dataset
    # -------------------------
    class TestDataset(Dataset):
        def __init__(self, image_dir, img_size=(224,224)):
            self.image_dir = image_dir
            self.img_size = img_size
            self.images = sorted([f for f in os.listdir(image_dir) if f.endswith(".jpg")])
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.485, 0.485], [0.229, 0.229, 0.229])
            ])

        def __len__(self):
            return len(self.images)

        def __getitem__(self, idx):
            img_name = self.images[idx]
            img_path = os.path.join(self.image_dir, img_name)
            img = cv2.imread(img_path, cv2.IMREAD_COLOR)
            if img is None:
                img = np.zeros((self.img_size[0], self.img_size[1], 3), dtype=np.uint8)
            img = cv2.resize(img, self.img_size)
            img = img.astype(np.float32) / 255.0
            img = self.transform(img)
            return img_name, img

    # -------------------------
    # Test Loader
    # -------------------------
    test_dataset = TestDataset(TEST_DIR)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
    test_loader = pl.MpDeviceLoader(test_loader, device)

    # -------------------------
    # Inference
    # -------------------------
    model.eval()
    submission = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Inference"):
            img_names, imgs = batch
            imgs = imgs.to(device)
            outputs = model(imgs)
            probs = torch.sigmoid(outputs).cpu().numpy()

            for name, prob in zip(img_names, probs):
                row = [name] + prob.tolist()
                submission.append(row)

    # -------------------------
    # Save Submission
    # -------------------------
    submission_df = pd.DataFrame(submission, columns=["Image_name"] + LABEL_COLUMNS)
    SUBMISSION_CSV = "/kaggle/working/submission.csv"
    submission_df.to_csv(SUBMISSION_CSV, index=False)
    print(f"Submission file saved to {SUBMISSION_CSV}")

Inference: 100%|██████████| 749/749 [09:46<00:00,  1.28it/s]


Submission file saved to /kaggle/working/submission.csv


In [3]:
submission_df

Unnamed: 0,Image_name,Atelectasis,Cardiomegaly,Consolidation,Edema,Enlarged Cardiomediastinum,Fracture,Lung Lesion,Lung Opacity,No Finding,Pleural Effusion,Pleural Other,Pneumonia,Pneumothorax,Support Devices
0,00000002_001_001.jpg,0.588338,0.533304,0.420442,0.383140,0.618749,0.535954,0.318367,0.631151,0.220034,0.499770,0.295218,0.389012,0.324956,0.480835
1,00000002_001_002.jpg,0.537321,0.480841,0.522963,0.443063,0.541917,0.624062,0.493039,0.627037,0.130759,0.477923,0.347875,0.489386,0.430935,0.577663
2,00000002_002_001.jpg,0.477501,0.475001,0.377278,0.369730,0.495254,0.528144,0.330540,0.522241,0.155403,0.426558,0.301524,0.359055,0.359818,0.495589
3,00000008_001_001.jpg,0.549936,0.436892,0.395875,0.345480,0.540365,0.595596,0.374730,0.550680,0.058653,0.463184,0.240875,0.338732,0.613156,0.759055
4,00000008_002_001.jpg,0.522037,0.371948,0.347818,0.280848,0.438411,0.601109,0.321890,0.461097,0.049066,0.445307,0.206535,0.317135,0.650601,0.741227
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
47922,20009239_009_000.jpg,0.459558,0.054689,0.335621,0.038043,0.019514,0.183478,0.425327,0.383629,0.290678,0.400904,0.490137,0.216664,0.536702,0.019602
47923,20009239_010_000.jpg,0.404337,0.032144,0.279932,0.060307,0.002514,0.017291,0.339331,0.346398,0.410858,0.458241,0.361558,0.111223,0.404965,0.001718
47924,20009239_011_000.jpg,0.442121,0.021986,0.230720,0.033964,0.000618,0.005011,0.340831,0.304567,0.444894,0.350364,0.309063,0.063857,0.453756,0.000833
47925,20009239_012_000.jpg,0.431989,0.047798,0.256906,0.052510,0.006022,0.039037,0.372614,0.347557,0.383492,0.396977,0.401870,0.138498,0.480332,0.004249
