In [3]:
import os
import io
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from google.cloud import storage, bigquery
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from monai.transforms import Compose, ScaleIntensity, Resize, EnsureType
from monai.networks.nets import DenseNet121
from captum.attr import LayerGradCam

# Configuration
PROJECT_ID = 'pneumonia-detection-2026'
LOCAL_DATA_DIR = "./data/nih_images"
os.makedirs(LOCAL_DATA_DIR, exist_ok=True)

# Clients
storage_client = storage.Client(project=PROJECT_ID)
bq_client = bigquery.Client(project=PROJECT_ID)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"‚úÖ Environment Ready. Running on: {device}")

‚úÖ Environment Ready. Running on: cuda


In [4]:
def get_metadata_from_bigquery(limit=200):
    # This query extracts the filename and creates a binary label for Pneumonia
    query = f"""
        SELECT
            image_id,
            IF('Pneumonia' IN UNNEST(SPLIT(finding_labels, '|')), 1, 0) as label,
            patient_id
        FROM `search.chc-nih-chest-xray.nih_chest_xray.metadata`
        LIMIT {limit}
    """
    return bq_client.query(query).to_dataframe()

# Pull data and perform Patient-level split to prevent data leakage
df = get_metadata_from_bigquery(limit=500)
unique_patients = df['patient_id'].unique()
np.random.shuffle(unique_patients)
split_idx = int(len(unique_patients) * 0.8)

train_pts = unique_patients[:split_idx]
train_df = df[df['patient_id'].isin(train_pts)].reset_index(drop=True)
test_df = df[~df['patient_id'].isin(train_pts)].reset_index(drop=True)

print(f"üìä Training images: {len(train_df)} | Testing images: {len(test_df)}")

Forbidden: 403 Access Denied: Table search:chc-nih-chest-xray.nih_chest_xray.metadata: User does not have permission to query table search:chc-nih-chest-xray.nih_chest_xray.metadata, or perhaps it does not exist.; reason: accessDenied, message: Access Denied: Table search:chc-nih-chest-xray.nih_chest_xray.metadata: User does not have permission to query table search:chc-nih-chest-xray.nih_chest_xray.metadata, or perhaps it does not exist.

Location: US
Job ID: 6c028dd7-9105-4e88-b5cb-8b09628e5501


In [None]:
def download_to_local(dataframe):
    bucket = storage_client.bucket("gcs-public-data--healthcare-nih-chest-xray")
    count = 0
    for img_id in dataframe['image_id']:
        local_path = os.path.join(LOCAL_DATA_DIR, img_id)
        if not os.path.exists(local_path):
            blob = bucket.blob(f"png/{img_id}")
            blob.download_to_filename(local_path)
            count += 1
    print(f"‚úÖ Local Cache Updated. Downloaded {count} new images.")

# Sync local WSL storage with GCS references
download_to_local(df)

In [None]:
class NIHDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe
        self.transform = transform

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(LOCAL_DATA_DIR, row['image_id'])

        try:
            image = Image.open(img_path).convert('RGB')
            # MONAI transforms expect (C, H, W)
            image_np = np.array(image).astype(np.float32).transpose(2, 0, 1)
            if self.transform:
                image_np = self.transform(image_np)
            return image_np, torch.tensor(row['label'], dtype=torch.long)
        except Exception as e:
            return torch.zeros((3, 224, 224)), torch.tensor(row['label'], dtype=torch.long)

transforms = Compose([ScaleIntensity(), Resize((224, 224)), EnsureType()])
train_loader = DataLoader(NIHDataset(train_df, transforms), batch_size=8, shuffle=True)
test_loader = DataLoader(NIHDataset(test_df, transforms), batch_size=8, shuffle=False)

print("‚úÖ Local Dataloaders Initialized.")

In [None]:
# Calculate class weights for imbalance
neg, pos = np.bincount(train_df['label'])
weights = torch.tensor([1.0, neg/pos], dtype=torch.float).to(device)

model = DenseNet121(spatial_dims=2, in_channels=3, out_channels=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(weight=weights)

print(f"‚úÖ Model ready with Class Weights: {weights.tolist()}")

In [None]:
model.train()
for epoch in range(5): # Increase for real training
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 5 == 0:
            print(f"Epoch {epoch} | Step {i} | Loss: {loss.item():.4f}")

print("üèÅ Training Finished.")

In [None]:
def visualize_prediction(model, loader):
    model.eval()
    images, labels = next(iter(loader))
    img_tensor = images[0:1].to(device)

    lgc = LayerGradCam(model, model.features)
    attr = lgc.attribute(img_tensor, target=labels[0].item())
    attr_upsampled = LayerGradCam.interpolate(attr, (224, 224))

    img_display = img_tensor[0].cpu().numpy().transpose(1, 2, 0)
    img_display = (img_display - img_display.min()) / (img_display.max() - img_display.min() + 1e-8)

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img_display)
    plt.title(f"X-Ray (Label: {labels[0].item()})")

    plt.subplot(1, 2, 2)
    plt.imshow(img_display)
    plt.imshow(attr_upsampled.cpu().detach().numpy().squeeze(), cmap='jet', alpha=0.4)
    plt.title("Grad-CAM Heatmap")
    plt.show()

visualize_prediction(model, test_loader)