In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## 📊 Exploratory Data Analysis (EDA)

In this section, we perform initial exploration of the dataset to understand its structure, completeness, and label distribution.  

---

### 1. Dataset Overview
We start by loading the training CSV file and previewing the first 10 rows.

- **Total Images** → Total number of X-ray images available.  
- **Total Patients** → Unique patient identifiers present in the dataset.  
- **Total Studies** → Number of unique studies conducted.  

This gives a quick sense of dataset size and coverage.

---

### 2. Handling Missing Values
- **Age** → Missing ages are imputed with the **median** age.  
- **Sex** → Missing sex values are filled with `"Unknown"`.  

This ensures no missing values remain in critical demographic columns.

---

### 3. Label Columns
We define 14 condition labels for classification:  

- **No Finding**, **Lung Opacity**, **Support Devices**, **Atelectasis**,  
- **Cardiomegaly**, **Pleural Effusion**, **Enlarged Cardiomediastinum**,  
- **Edema**, **Consolidation**, **Pneumonia**, **Fracture**,  
- **Lung Lesion**, **Pneumothorax**, **Pleural Other**.  

For each condition, we calculate:  
- **Count** → Number of images with the condition.  
- **Percent (%)** → Prevalence as a percentage of the dataset.  

This helps us understand class imbalance and disease prevalence.

---

### 4. Data Quality Checks
To ensure data integrity, we check for the following:

- **Duplicate Images** → Verifies if any images are repeated.  
- **Duplicate Patients** → Expected, since a patient can have multiple images.  
- **Invalid Age Values** → Counts number of negative ages (should be zero).  

---

### 5. Outputs
- **Summary metrics** (images, patients, studies).  
- **Prevalence table** showing condition counts and percentages.  
- **Data quality reports** on duplicates and invalid values.  

This forms the foundation of our dataset understanding before moving into deeper analysis and modeling.


In [None]:
train_df = pd.read_csv('/kaggle/input/grand-xray-slam-division-b/train2.csv')
train_df.head(10)

In [None]:
train_df.info()

In [None]:
# Summarize key metrics
total_images = len(train_df)
total_patients = train_df['Patient_ID'].nunique()
total_studies = train_df['Study'].nunique()
print(f"Total Images: {total_images}")
print(f"Total Patients: {total_patients}")
print(f"Total Studies: {total_studies}")

In [None]:
train_df.isnull().sum()

In [None]:
train_df['Age'] = train_df['Age'].fillna(train_df['Age'].median())
train_df['Sex'] = train_df['Sex'].fillna('Unknown')

In [None]:
train_df.isnull().sum()

In [None]:
# Define the 14 condition columns
label_columns = ['No Finding', 'Lung Opacity', 'Support Devices', 'Atelectasis',
                 'Cardiomegaly', 'Pleural Effusion', 'Enlarged Cardiomediastinum',
                 'Edema', 'Consolidation', 'Pneumonia', 'Fracture', 'Lung Lesion',
                 'Pneumothorax', 'Pleural Other']

# Calculate counts and percentages for each condition
label_counts = train_df[label_columns].sum()
label_percentages = (label_counts / total_images * 100).round(2)
prevalence_df = pd.DataFrame({
    'Condition': label_counts.index,
    'Count': label_counts.values,
    'Percent (%)': label_percentages.values
})

# Display prevalence table
print("Label Prevalence:")
print(prevalence_df)

In [None]:
# Check for duplicate Image_Names
duplicate_images = train_df['Image_name'].duplicated().sum()
print(f"Duplicated Image_Name entries: {duplicate_images}")

# Check for duplicate Patient_IDs (expected due to multiple images per patient)
duplicate_patients = total_images - total_patients
print(f"Duplicated Patient_ID entries: {duplicate_patients}")

# Check for invalid Age values
invalid_ages = train_df['Age'].dropna()
invalid_ages = invalid_ages[invalid_ages < 0].count()
print(f"Invalid Age values (<0): {invalid_ages}")

## 🧩 Train–Validation Split

To ensure robust evaluation, we carefully split the dataset into **training** and **validation** subsets.  

---

### 1. Why Grouped Splitting?
A key challenge in medical imaging datasets is **patient-level data leakage**.  
If images from the same patient appear in both training and validation sets, the model may learn patient-specific features instead of generalizable disease patterns.  

✅ To prevent this, we use **GroupShuffleSplit** with `Patient_ID` as the grouping variable.  
This guarantees that all images from a single patient are restricted to **either training or validation**, never both.  

---

### 2. Split Details
- **Split ratio** → 80% training, 20% validation.  
- **Random state** fixed at `42` for reproducibility.  
- Splitting performed only once (`n_splits=1`).  

---

### 3. Dataset Sizes
After the split:  
- **Train set size** → Number of images used for model training.  
- **Validation set size** → Number of images held out for unbiased performance evaluation.  

This ensures that the model’s validation accuracy reflects real-world generalization ability rather than memorization.

---


In [None]:
import os, random
from collections import Counter
import numpy as np
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms, datasets
import matplotlib.pyplot as plt

from sklearn.model_selection import GroupShuffleSplit

# Important: split by Patient_ID so same patient never leaks into train+val
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, val_idx = next(gss.split(train_df, groups=train_df["Patient_ID"]))

train_df_split = train_df.iloc[train_idx].reset_index(drop=True)
val_df_split   = train_df.iloc[val_idx].reset_index(drop=True)

print(f"Train size: {len(train_df_split)} | Val size: {len(val_df_split)}")


## 🗂️ Custom Dataset Class: `ChestXrayDataset`

To train deep learning models on medical images, we define a **custom PyTorch Dataset** that handles loading images, applying transformations, and returning labels in a structured format.

---

### 1. Purpose
The dataset class:
- Loads **X-ray images** directly from the dataset directory.  
- Ensures consistent preprocessing (e.g., grayscale conversion, resizing, normalization).  
- Returns both **image tensors** and their associated **multi-label targets**.  

This provides a clean pipeline for training and validation.

---

### 2. Key Components

#### **Initialization (`__init__`)**
- Accepts a dataframe (`dataframe`) containing metadata and image names.  
- Stores image directory (`img_dir`) for locating files.  
- Applies optional transformations (`transform`) for augmentation and normalization.  
- Extracts label columns (`label_columns`) and stores them as a NumPy float32 array.  

#### **Dataset Length (`__len__`)**
- Returns the number of samples (rows in the dataframe).  

#### **Fetching a Sample (`__getitem__`)**
- Retrieves an image name and constructs its full path.  
- Loads the image with **PIL** and converts it to **grayscale** (`"L"`), since chest X-rays are single-channel.  
- Applies transformations if provided.  
- Returns:
  - `(image, labels)` → when labels exist (train/validation).  
  - `(image, img_name)` → for test set (no labels).  

---

### 3. Advantages
- 🔄 **Reusability** → Works for train, validation, and test splits.  
- 🩻 **Medical image ready** → Handles grayscale conversion properly.  
- ⚡ **Integration** → Compatible with PyTorch `DataLoader` for efficient batching and shuffling.  

This dataset class forms the backbone of our training pipeline, ensuring images and labels are consistently processed.


In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os

class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, img_dir, transform=None, label_columns=None):
        self.df = dataframe.reset_index(drop=True)   # store as self.df
        self.img_dir = img_dir
        self.transform = transform
        self.label_columns = label_columns
        
        # Store labels for easy access
        if self.label_columns is not None:
            self.targets = self.df[self.label_columns].values.astype(np.float32)
        else:
            self.targets = None

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.loc[idx, "Image_name"]
        img_path = os.path.join(self.img_dir, img_name)
        
        # Load image
        image = Image.open(img_path).convert("L")  # X-ray = grayscale
        if self.transform:
            image = self.transform(image)
        
        # Labels (for train/val only)
        if self.targets is not None:
            labels = torch.tensor(self.targets[idx], dtype=torch.float32)
            return image, labels
        else:
            return image, img_name  # for test set


## ⚙️ Data Pipeline Setup

Now that we have our custom dataset class, we define the **data preprocessing pipeline** and prepare dataloaders for training and validation.

---

### 1. Image Transformations
We use **torchvision transforms** to preprocess and augment images before feeding them into the model.

- **Train Transformations**
  - `Resize(224, 224)` → standardize input size for CNNs.  
  - `RandomHorizontalFlip()` → simulates left/right orientation changes.  
  - `RandomRotation(8°)` → introduces slight rotational variation.  
  - `ColorJitter()` → adds minor brightness/contrast shifts.  
  - `Grayscale(num_output_channels=3)` → converts single-channel X-ray to **3-channel grayscale** (so pretrained CNNs like ResNet can process them).  
  - `ToTensor()` → converts to PyTorch tensor.  
  - `Normalize(mean, std)` → applies ImageNet normalization.  

- **Validation Transformations**
  - Only resizing, grayscale conversion, tensor conversion, and normalization.  
  - ❌ No random augmentations → ensures consistent validation results.

---

### 2. Dataset Creation
We wrap the preprocessed data into our custom `ChestXrayDataset`:

- **`train_ds`** → Training split with augmentations.  
- **`val_ds`** → Validation split with minimal preprocessing.  

This makes the dataset ready for PyTorch `DataLoader`.

---

### 3. Dataloaders
We create efficient data pipelines:

- **`train_loader`**
  - `batch_size = BATCH_SIZE`  
  - `shuffle=True` → ensures batches are randomized each epoch.  
  - `num_workers=4` & `pin_memory=True` → speed up data loading.  

- **`val_loader`**
  - `shuffle=False` → validation set order is fixed.  
  - Same parallelization optimizations as training.

---

### 4. Handling Class Imbalance
Chest X-ray datasets are **highly imbalanced** (e.g., "No Finding" dominates).  
To address this imbalance in **multi-label classification**:

- Count positives (`pos_counts`) and negatives (`neg_counts`) for each condition.  
- Compute **`pos_weight = neg_counts / pos_counts`**.  
- This weight is passed to **`BCEWithLogitsLoss`**, giving rare diseases higher importance during training.  

---

### 5. Device Setup
- Automatically detects **GPU (`cuda`)** if available, otherwise falls back to CPU.  
- Prints device information (including GPU name when applicable).  

✅ This ensures efficient training and fair handling of imbalanced medical labels.


In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms

# ---- Transforms ----
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.ColorJitter(brightness=0.08, contrast=0.08),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# ---- Datasets ----
train_ds = ChestXrayDataset(train_df_split, "path/to/images", transform=train_transform, label_columns=label_columns)
val_ds   = ChestXrayDataset(val_df_split, "path/to/images", transform=val_transform, label_columns=label_columns)

# ---- Dataloaders ----
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# ---- Class imbalance for BCEWithLogitsLoss ----
pos_counts = train_df_split[label_columns].sum().values
neg_counts = len(train_df_split) - pos_counts
pos_weight = torch.tensor(neg_counts / pos_counts, dtype=torch.float32)

# ---- Device ----
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)
if DEVICE.type == "cuda":
    print("CUDA device:", torch.cuda.get_device_name(0))

## 🏋️ Model Training Pipeline

With the dataset and dataloaders ready, we now move to **model definition, training, and fine-tuning**.

---

### 1. Model Architecture
We use a **ResNet-50** backbone pretrained on ImageNet and replace its classification head with a custom multi-label head:

- **Dropout** → reduces overfitting.  
- **Hidden layer + ReLU** → adds non-linearity.  
- **Final Linear layer** → outputs predictions for **14 conditions**.  

This allows transfer learning while adapting the network to chest X-ray classification.

---

### 2. Loss Function
We use **`BCEWithLogitsLoss`** (binary cross-entropy with sigmoid), suitable for **multi-label classification**.  
- Weighted with **`pos_weight`** to handle strong class imbalance.  
- Ensures rare conditions contribute fairly to the gradient updates.

---

### 3. Training Strategy
We adopt a **two-stage training** process:

#### **Stage 1 — Train Head Only**
- Backbone (ResNet-50 convolutional layers) **frozen**.  
- Only the new classification head is trained.  
- Optimizer → `Adam`, learning rate = **1e-3**.  
- LR Scheduler → `ReduceLROnPlateau` for adaptive learning rate adjustment.  

#### **Stage 2 — Fine-Tune Backbone**
- All layers **unfrozen** for full network training.  
- Optimizer → `AdamW`, learning rate = **1e-4** (smaller to avoid catastrophic forgetting).  
- LR Scheduler → same as Stage 1.  

This strategy stabilizes training and avoids destroying pretrained weights.

---

### 4. Training Loop
We define a reusable function **`run_epoch`**:

- **Mode switching** → `train()` for training, `eval()` for validation.  
- **Mixed precision option (`use_amp`)** for faster training on GPUs.  
- Tracks:
  - **Loss** → mean BCE loss across batches.  
  - **AUC (ROC-AUC score)** → primary evaluation metric for multi-label tasks.  

---

### 5. Monitoring & Logging
For each epoch, we log:  
- Training and validation **loss**.  
- Training and validation **AUC**.  
- **Time per epoch** for efficiency tracking.  

This allows us to monitor model convergence and early stop if necessary.

---

✅ With this setup, the model benefits from **transfer learning**, balanced training for rare conditions, and careful fine-tuning of the backbone for maximum performance.


In [None]:
import os
import time
import copy
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score
from torchvision import models

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---- Hyperparameters ----
HEAD_EPOCHS   = 5          # train classifier head first
FT_EPOCHS     = 12         # then fine-tune backbone
BATCH_SIZE    = 32         # don't leave None, set explicitly
LR_HEAD       = 1e-3       # higher LR for head
LR_FT         = 1e-4       # lower LR for fine-tuning
WEIGHT_DECAY  = 1e-4       # regularization
PATIENCE      = 3          # early stopping patience
MIN_LR        = 1e-7       # minimum learning rate for scheduler

def build_model(num_classes=14, dropout=0.3):
    model = models.resnet50(pretrained=True)
    in_features = model.fc.in_features
    
    # Better head: hidden layer + ReLU + Dropout
    model.fc = nn.Sequential(
        nn.Dropout(dropout),
        nn.Linear(in_features, in_features // 2),
        nn.ReLU(),
        nn.Dropout(dropout/2),
        nn.Linear(in_features // 2, num_classes)
    )
    return model
# Build model
model = build_model(num_classes=14, dropout=0.3).to(DEVICE)
# Loss with pos_weight for multi-label
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(DEVICE))  # pos_weight must match labels shape### LOSS

# Stage 1: freeze backbone except final fc
for name, param in model.named_parameters():
    param.requires_grad = False
for name, param in model.fc.named_parameters():
    param.requires_grad = True

opt_head = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR_HEAD, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    opt_head, mode='min', factor=0.5, patience=PATIENCE, min_lr=MIN_LR
)

def run_epoch(model, loader, optimizer=None, train=False, device=DEVICE, use_amp=False):
    if train:
        model.train()
    else:
        model.eval()
    losses = []
    all_labels = []
    all_probs = []
    scaler = torch.cuda.amp.GradScaler() if use_amp and device.type == 'cuda' else None
    loop = tqdm(loader, desc='Train' if train else 'EVAL', leave= False)
    for imgs, labels in loop:
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        with torch.set_grad_enabled(train):
            if scaler:
                with torch.cuda.amp.autocast():
                    logits = model(imgs)
                    loss = criterion(logits, labels)

            else:
                logits = model(imgs)
                loss = criterion(logits, labels)
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            all_probs.append(probs)
            all_labels.append(labels.detach().cpu().numpy())

            if train:
                if scaler:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                else:
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()

        losses.append(loss.item())
        loop.set_postfix(loss=np.mean(losses))
    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)
    try:
        auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        auc = float('nan')

    return np.mean(losses), auc

### STAGE - 1 HEAD TRAINING
for epoch in range(1,HEAD_EPOCHS+1):
    t0 = time.time()
    train_loss, train_auc = run_epoch(model, train_loader, optimizer=opt_head, train=True, use_amp=False)
    val_loss, val_auc = run_epoch(model, val_loader, train=False, use_amp=False)
    scheduler.step(val_loss)
    print(f"Epoch {epoch}/{HEAD_EPOCHS}  train_loss={train_loss:.4f} train_auc={train_auc:.4f}  val_loss={val_loss:.4f} val_auc={val_auc:.4f}  time={(time.time()-t0):.1f}s")

## UNFREEZE BACKBONE AND FINE TUNEEEEEE

for param in model.parameters():
    param.requires_grad = True

opt_ft = optim.AdamW(model.parameters(), lr=LR_FT, weight_decay=WEIGHT_DECAY)
scheduler_ft = optim.lr_scheduler.ReduceLROnPlateau(
    opt_ft, mode='min', factor=0.5, patience=PATIENCE, min_lr=MIN_LR
)

for epoch in range(1,FT_EPOCHS+1):
    t0 = time.time()
    train_loss, train_auc = run_epoch(model, train_loader, optimizer=opt_ft, train=True, use_amp=False)
    val_loss, val_auc = run_epoch(model, val_loader, train=False, use_amp=False)
    scheduler_ft.step(val_loss)
    print(f"FT Epoch {epoch}/{FT_EPOCHS}  train_loss={train_loss:.4f} train_auc={train_auc:.4f}  val_loss={val_loss:.4f} val_auc={val_auc:.4f}  time={(time.time()-t0):.1f}s")