# Chest X-Ray Training Pipeline (Robust V3)

**Updates:** Improved column matching logic (V3) to catch 'image' vs 'image_index'.

In [None]:
import os
import sys
import zipfile
import glob
from pathlib import Path
import pandas as pd
import torch

# 1. MOUNT DRIVE
try:
    from google.colab import drive
    print("üîå Connecting to Google Drive...")
    drive.mount('/content/drive')
    IS_COLAB = True
except ImportError:
    IS_COLAB = False
    print("üíª Running Locally")

# 2. AUTO-DISCOVERY 
def find_file(filename, search_path):
    print(f"üîç Searching for '{filename}' in {search_path}...")
    matches = sorted(list(Path(search_path).rglob(filename)))
    if matches:
        print(f"   ‚úÖ Found: {matches[0]}")
        return matches[0]
    return None

# 3. SETUP PATHS
if IS_COLAB:
    SEARCH_ROOT = "/content/drive/My Drive"
    WORK_DIR = Path("/content/work")
else:
    SEARCH_ROOT = ".."
    WORK_DIR = Path("./temp_work")

os.makedirs(WORK_DIR, exist_ok=True)

ZIP_PATH = find_file("images-224.zip", SEARCH_ROOT)
if not ZIP_PATH: raise FileNotFoundError("Please upload images-224.zip to Drive")

IMAGE_DIR = WORK_DIR / "images-224"
if not IMAGE_DIR.exists():
    print("‚è≥ Unzipping...")
    with zipfile.ZipFile(ZIP_PATH, 'r') as z:
        z.extractall(WORK_DIR)
    print("‚úÖ Unzip Done.")

CSV_PATH = find_file("Data_Entry_2017.csv", SEARCH_ROOT)
if not CSV_PATH: CSV_PATH = find_file("Data_Entry_2017.csv", WORK_DIR)
if not CSV_PATH: raise FileNotFoundError("Please upload Data_Entry_2017.csv to Drive")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ Device: {device}")

In [None]:
# 4. ROBUST DATA PROCESSING
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import numpy as np

print("üìä Loading CSV...")
df = pd.read_csv(CSV_PATH)

# --- DEFENSIVE COLUMN CLEANING ---
# 1. Normalize all columns to lower_snake_case
df.columns = df.columns.str.strip().str.lower().str.replace(' ', '_')
print(f"   Columns detected: {list(df.columns)}")

# 2. Identify Key Columns dynamically (Improved Logic)
def match_col(cols, keywords):
    for c in cols:
        if all(k in c for k in keywords): return c
    return None

# Priority 1: 'image' AND 'index'. Priority 2: just 'image'
img_col = match_col(df.columns, ['image', 'index']) or match_col(df.columns, ['image'])
# Priority 1: 'patient' AND 'id'. Priority 2: just 'patient'
id_col  = match_col(df.columns, ['patient', 'id']) or match_col(df.columns, ['patient']) 
lbl_col = match_col(df.columns, ['label']) or match_col(df.columns, ['finding'])

if not img_col or not id_col or not lbl_col:
    raise ValueError(f"Could not identify columns automatically. Found: {list(df.columns)}")

print(f"   Mappings: Image='{img_col}', ID='{id_col}', Label='{lbl_col}'")

# 3. Process Labels
df[lbl_col] = df[lbl_col].astype(str).str.split('|')
mlb = MultiLabelBinarizer()
encoded = mlb.fit_transform(df[lbl_col])
classes = mlb.classes_
print(f"   Classes Found ({len(classes)}): {classes}")

# 4. Construct Final DataFrame
df_enc = pd.DataFrame(encoded, columns=classes)
# Robust concat: Use the dynamically found column names
df = pd.concat([df[[img_col, id_col]], df_enc], axis=1)
# Now standardize names for the Dataset class
df.rename(columns={img_col: 'image', id_col: 'patientid'}, inplace=True)

# 5. Dataset Class
class ChestXRayDataset(Dataset):
    def __init__(self, df, img_dir, tf=None):
        self.df = df
        self.img_dir = img_dir
        self.tf = tf
        self.img_names = df['image'].values
        # Drop metadata to keep only One-Hot labels
        self.labels = df.drop(['image', 'patientid'], axis=1).values.astype('float32')

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        name = self.img_names[idx]
        path = os.path.join(self.img_dir, name)
        try:
            img = Image.open(path).convert("RGB")
            if self.tf: img = self.tf(img)
            return img, torch.tensor(self.labels[idx])
        except:
            # Fail gracefully, but maybe print warning only one time
            return torch.zeros((3,224,224)), torch.tensor(self.labels[idx])

# 6. Split & Loaders
pats = df['patientid'].unique()
train_p, test_p = train_test_split(pats, test_size=0.15, random_state=42)
train_p, val_p  = train_test_split(train_p, test_size=0.15, random_state=42)

train_df = df[df['patientid'].isin(train_p)]
val_df   = df[df['patientid'].isin(val_p)]

print(f"   Train: {len(train_df)}, Val: {len(val_df)}")

train_tf = transforms.Compose([transforms.Resize((224,224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
val_tf   = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

train_loader = DataLoader(ChestXRayDataset(train_df, IMAGE_DIR, train_tf), batch_size=32, shuffle=True, num_workers=2)
val_loader   = DataLoader(ChestXRayDataset(val_df, IMAGE_DIR, val_tf), batch_size=32, shuffle=False, num_workers=2)

print("‚úÖ Pipeline Ready.")

In [None]:
# 5. TRAINING LOOP
class ResNet50(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.base = models.resnet50(pretrained=True)
        self.base.fc = nn.Linear(self.base.fc.in_features, n_classes)
    def forward(self, x): return self.base(x)

model = ResNet50(len(classes)).to(device)
crit = nn.BCEWithLogitsLoss()
opt = optim.Adam(model.parameters(), lr=1e-4)

EPOCHS = 10
best_loss = float('inf')

print("üî• Starting Training...")

for ep in range(EPOCHS):
    model.train()
    train_loss = 0.0
    loop = tqdm(train_loader, desc=f"Epoch {ep+1}/{EPOCHS}")
    
    for imgs, lbls in loop:
        imgs, lbls = imgs.to(device), lbls.to(device)
        opt.zero_grad()
        out = model(imgs)
        loss = crit(out, lbls)
        loss.backward()
        opt.step()
        train_loss += loss.item()
        loop.set_postfix(loss=loss.item())
        
    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for imgs, lbls in val_loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            val_loss += crit(model(imgs), lbls).item()
    val_loss /= len(val_loader)
    
    print(f"   Train Loss: {train_loss/len(train_loader):.4f} | Val Loss: {val_loss:.4f}")
    
    # Save Best
    if val_loss < best_loss:
        best_loss = val_loss
        if IS_COLAB:
            save_path = f"/content/drive/My Drive/xray_best_model.pth"
            torch.save(model.state_dict(), save_path)
            print(f"   üíæ Best Model Saved to: {save_path}")