In [1]:
import os
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

# -------------------------------
# 1️⃣ Dataset Class
# -------------------------------
class CXRDataset(Dataset):
    def __init__(self, csv_path, image_dir, transform=None, clinical_cols=None, label_col="label"):
        self.image_dir = image_dir
        self.transform = transform
        self.clinical_cols = clinical_cols
        self.label_col = label_col

        # Load CSV
        df = pd.read_csv(csv_path)

        # Ensure clinical columns are numeric
        for col in clinical_cols:
            df[col] = pd.to_numeric(df[col], errors="coerce")  # convert non-numeric → NaN
        df = df.dropna(subset=clinical_cols + [label_col])     # drop rows missing clinical data or label

        # Only keep images that exist in image_dir
        df["image_path"] = df["image_id"].apply(lambda x: os.path.join(image_dir, x))
        df = df[df["image_path"].apply(os.path.exists)].reset_index(drop=True)

        # Store cleaned dataframe
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # ---- Load Image ----
        img = Image.open(row["image_path"]).convert("RGB")
        if self.transform:
            img = self.transform(img)

        # ---- Clinical features ----
        # Convert to numeric first, then fill NaN and cast
        clinical_vals = pd.to_numeric(row[self.clinical_cols], errors="coerce") \
                           .fillna(0).to_numpy(dtype="float32")
        clinical = torch.tensor(clinical_vals, dtype=torch.float)

        # ---- Label ----
        label = torch.tensor(row[self.label_col], dtype=torch.long)

        return img, clinical, label

# -------------------------------
# 2️⃣ Model: CNN + Clinical MLP + Fusion
# -------------------------------
class CXRClinicalFusionModel(nn.Module):
    def __init__(self, num_tabular_features, num_classes):
        super().__init__()
        # CNN backbone
        base = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.cnn_encoder = nn.Sequential(*list(base.children())[:-1])
        self.cnn_dim = base.fc.in_features

        # MLP for clinical features
        self.mlp = nn.Sequential(
            nn.Linear(num_tabular_features, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Linear(64, 32),
            nn.ReLU()
        )

        # Fusion classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.cnn_dim + 32, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, image, clinical):
        img_feat = self.cnn_encoder(image).view(image.size(0), -1)
        tab_feat = self.mlp(clinical)
        fused = torch.cat([img_feat, tab_feat], dim=1)
        logits = self.classifier(fused)
        return logits

# -------------------------------
# 3️⃣ Image Transforms
# -------------------------------
train_tf = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(5),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])
val_tf = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

# -------------------------------
# 4️⃣ Dataset and Dataloader
# -------------------------------
clinical_cols = ["age", "gender"]  # example columns
label_col = "label"
img_dir = r"C:\Users\sureb\dataset\images\train"
img_dir1= r"C:\Users\sureb\dataset\images\val"

train_data = CXRDataset(
    csv_path=  r"C:\Users\sureb\dataset\train.csv",
    image_dir=img_dir,
    transform=train_tf,
    clinical_cols=clinical_cols,
    label_col=label_col
)
val_data = CXRDataset(
    csv_path= r"C:\Users\sureb\dataset\val.csv",
    image_dir=img_dir1,
    transform=val_tf,
    clinical_cols=clinical_cols,
    label_col=label_col
)

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(val_data, batch_size=16, shuffle=False)

# -------------------------------
# 5️⃣ Training Setup
# -------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

model = CXRClinicalFusionModel(
    num_tabular_features=len(clinical_cols),
    num_classes=3
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# -------------------------------
# 6️⃣ Training / Validation Loops
# -------------------------------
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, clinical, labels in loader:
        images, clinical, labels = images.to(device), clinical.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images, clinical)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
    return total_loss / total, correct / total

def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for images, clinical, labels in loader:
            images, clinical, labels = images.to(device), clinical.to(device), labels.to(device)
            outputs = model(images, clinical)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
    return total_loss / total, correct / total

# -------------------------------
# 7️⃣ Training Loop
# -------------------------------
num_epochs = 10
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion)
    print(f"Epoch {epoch+1}/{num_epochs} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")


Epoch 1/10 | Train Loss: 0.9055 | Train Acc: 0.5846 | Val Loss: 0.9611 | Val Acc: 0.5622
Epoch 2/10 | Train Loss: 0.7174 | Train Acc: 0.7065 | Val Loss: 1.0542 | Val Acc: 0.5514
Epoch 3/10 | Train Loss: 0.6238 | Train Acc: 0.7515 | Val Loss: 1.6763 | Val Acc: 0.3730
Epoch 4/10 | Train Loss: 0.5212 | Train Acc: 0.8012 | Val Loss: 1.4521 | Val Acc: 0.5622
Epoch 5/10 | Train Loss: 0.4257 | Train Acc: 0.8225 | Val Loss: 1.5937 | Val Acc: 0.4757
Epoch 6/10 | Train Loss: 0.3574 | Train Acc: 0.8627 | Val Loss: 1.6661 | Val Acc: 0.4378
Epoch 7/10 | Train Loss: 0.3283 | Train Acc: 0.8722 | Val Loss: 1.3682 | Val Acc: 0.5622
Epoch 8/10 | Train Loss: 0.2637 | Train Acc: 0.9053 | Val Loss: 1.8565 | Val Acc: 0.4703
Epoch 9/10 | Train Loss: 0.2254 | Train Acc: 0.9136 | Val Loss: 2.3903 | Val Acc: 0.3459
Epoch 10/10 | Train Loss: 0.1789 | Train Acc: 0.9325 | Val Loss: 2.1019 | Val Acc: 0.5514
