In [16]:
import os
import time

import pandas as pd
import numpy as np

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split

from torchvision import transforms
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from torchinfo import summary

import warnings
warnings.filterwarnings('ignore')

from tqdm import tqdm

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [18]:
def build_image_path_map(image_root):
    image_path_map = {}
    for folder in os.listdir(image_root):
        subfolder = os.path.join(image_root, folder, "images")
        if not os.path.isdir(subfolder):
            continue
        for img_file in os.listdir(subfolder):
            if img_file.endswith(".png"):
                full_path = os.path.join(subfolder, img_file)
                image_path_map[img_file] = full_path
    return image_path_map

# Example
image_root = "/kaggle/input/data/"
image_path_map = build_image_path_map(image_root)

In [19]:
df = pd.read_csv("/kaggle/input/data/Data_Entry_2017.csv")
df['Finding Labels'] = df['Finding Labels'].str.split('|')

mlb = MultiLabelBinarizer()
df['labels'] = mlb.fit_transform(df['Finding Labels']).tolist()

In [20]:
df['Patient Gender'] = df['Patient Gender'].map({'M': 0, 'F': 1})
df['View Position'] = df['View Position'].map({'PA': 0, 'AP': 1})

df['patient_data'] = df[['Patient Age', 'Patient Gender', 'View Position']].values.tolist()

In [21]:
df.head()

Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],Unnamed: 11,labels,patient_data
0,00000001_000.png,[Cardiomegaly],0,1,58,0,0,2682,2749,0.143,0.143,,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[58, 0, 0]"
1,00000001_001.png,"[Cardiomegaly, Emphysema]",1,1,58,0,0,2894,2729,0.143,0.143,,"[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[58, 0, 0]"
2,00000001_002.png,"[Cardiomegaly, Effusion]",2,1,58,0,0,2500,2048,0.168,0.168,,"[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]","[58, 0, 0]"
3,00000002_000.png,[No Finding],0,2,81,0,0,2500,2048,0.171,0.171,,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]","[81, 0, 0]"
4,00000003_000.png,[Hernia],0,3,81,1,0,2582,2991,0.143,0.143,,"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]","[81, 1, 0]"


In [22]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5172546, 0.5172546, 0.5172546],
                         std=[0.23124999, 0.23124999, 0.23124999])
])

In [23]:
class ChestXrayWithMetaDataset(Dataset):
    def __init__(self, dataframe, image_path_map, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.image_path_map = image_path_map
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Image
        img_name = row['Image Index']
        img_path = self.image_path_map.get(img_name)
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image {img_name} not found.")
        
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Labels
        labels = torch.tensor(row['labels'], dtype=torch.float32)

        # Metadata
        metadata = torch.tensor(row['patient_data'], dtype=torch.float32)
        metadata[0] = metadata[0] / 100.0  # Normalize age to 0–1

        return image, metadata, labels

In [24]:
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

train_dataset = ChestXrayWithMetaDataset(train_df, image_path_map, transform=transform)
test_dataset = ChestXrayWithMetaDataset(test_df, image_path_map, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, pin_memory=True, num_workers=2)

In [25]:
train_dataset.__getitem__(0)

(tensor([[[-2.1011, -2.1689, -2.1859,  ..., -0.4053, -0.1509,  0.0356],
          [-2.1181, -2.1689, -2.1859,  ..., -0.8971, -0.5579, -0.0492],
          [-2.1181, -2.1859, -2.2029,  ..., -1.2362, -0.9140, -0.3205],
          ...,
          [-2.1011, -2.1350, -2.1350,  ..., -0.3714, -0.1509, -0.1848],
          [-2.0333, -2.0672, -2.0672,  ..., -0.3544, -0.1340, -0.1679],
          [-1.9654, -1.9824, -1.9994,  ..., -0.3375, -0.1340, -0.1679]],
 
         [[-2.1011, -2.1689, -2.1859,  ..., -0.4053, -0.1509,  0.0356],
          [-2.1181, -2.1689, -2.1859,  ..., -0.8971, -0.5579, -0.0492],
          [-2.1181, -2.1859, -2.2029,  ..., -1.2362, -0.9140, -0.3205],
          ...,
          [-2.1011, -2.1350, -2.1350,  ..., -0.3714, -0.1509, -0.1848],
          [-2.0333, -2.0672, -2.0672,  ..., -0.3544, -0.1340, -0.1679],
          [-1.9654, -1.9824, -1.9994,  ..., -0.3375, -0.1340, -0.1679]],
 
         [[-2.1011, -2.1689, -2.1859,  ..., -0.4053, -0.1509,  0.0356],
          [-2.1181, -2.1689,

In [26]:
# from torch.utils.data import Subset
# import random

# subset_size = 5000
# total_size = len(train_dataset)

# random_indices = random.sample(range(total_size), subset_size)
# subset_dataset = Subset(train_dataset, random_indices)

In [27]:
class ChestXrayMultiInputCNN(nn.Module):
    def __init__(self, num_labels=15):
        super().__init__()

        # CNN for image
        self.image_conv = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),  # (B, 3, 224, 224) → (B, 16, 224, 224)
            nn.ReLU(),
            nn.MaxPool2d(2),                             # (B, 16, 112, 112)
            
            nn.Conv2d(16, 32, kernel_size=3, padding=1), # (B, 32, 112, 112)
            nn.ReLU(),
            nn.MaxPool2d(2),                             # (B, 32, 56, 56)
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1), # (B, 64, 56, 56)
            nn.ReLU(),
            nn.MaxPool2d(2)                              # (B, 64, 28, 28)
        )

        self.image_fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 28 * 28, 128),
            nn.ReLU()
        )

        # Small MLP for patient data (3 features)
        self.patient_fc = nn.Sequential(
            nn.Linear(3, 16),
            nn.ReLU()
        )

        # Combined output
        self.classifier = nn.Sequential(
            nn.Linear(128 + 16, 64),
            nn.ReLU(),
            nn.Linear(64, num_labels),
            nn.Sigmoid()  # For multi-label classification
        )

    def forward(self, image, patient_data):
        img_feat = self.image_conv(image)
        img_feat = self.image_fc(img_feat)

        patient_feat = self.patient_fc(patient_data)

        combined = torch.cat((img_feat, patient_feat), dim=1)
        output = self.classifier(combined)
        return output

In [28]:
learning_rate = 1e-3
epochs = 10

model = ChestXrayMultiInputCNN(num_labels=15)
model = model.to(device)

criterion = nn.BCELoss()  # Since we use sigmoid
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [29]:
summary(model)

Layer (type:depth-idx)                   Param #
ChestXrayMultiInputCNN                   --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       448
│    └─ReLU: 2-2                         --
│    └─MaxPool2d: 2-3                    --
│    └─Conv2d: 2-4                       4,640
│    └─ReLU: 2-5                         --
│    └─MaxPool2d: 2-6                    --
│    └─Conv2d: 2-7                       18,496
│    └─ReLU: 2-8                         --
│    └─MaxPool2d: 2-9                    --
├─Sequential: 1-2                        --
│    └─Flatten: 2-10                     --
│    └─Linear: 2-11                      6,422,656
│    └─ReLU: 2-12                        --
├─Sequential: 1-3                        --
│    └─Linear: 2-13                      64
│    └─ReLU: 2-14                        --
├─Sequential: 1-4                        --
│    └─Linear: 2-15                      9,280
│    └─ReLU: 2-16                        --
│    └─Li

In [None]:
for epoch in range(epochs):
    epoch_start_time = time.time()
    running_loss = 0.0

    model.train()
    print(f"\nEpoch [{epoch + 1}/{epochs}]")

    batch_loader = tqdm(train_loader, desc="Training", unit="batch")

    for batch_idx, batch in enumerate(batch_loader):
        images, patient_data, labels = batch
        images = images.to(device, non_blocking=True)
        patient_data = patient_data.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images, patient_data)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        avg_loss = running_loss / (batch_idx + 1)
        batch_loader.set_postfix({
            "Batch Loss": f"{loss.item():.4f}",
            "Avg Loss": f"{avg_loss:.4f}"
        })

    epoch_time = time.time() - epoch_start_time
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{epochs}] Average Loss: {epoch_loss:.4f} | Time: {epoch_time:.2f} sec ({epoch_time/60:.2f} min)")


Epoch [1/10]


Training: 100%|██████████| 2803/2803 [19:55<00:00,  2.34batch/s, Batch Loss=0.1615, Avg Loss=0.2032]


Epoch [1/10] Average Loss: 0.2032 | Time: 1195.55 sec (19.93 min)

Epoch [2/10]


Training: 100%|██████████| 2803/2803 [20:42<00:00,  2.26batch/s, Batch Loss=0.2083, Avg Loss=0.1972]


Epoch [2/10] Average Loss: 0.1972 | Time: 1242.06 sec (20.70 min)

Epoch [3/10]


Training: 100%|██████████| 2803/2803 [20:34<00:00,  2.27batch/s, Batch Loss=0.2416, Avg Loss=0.1906]


Epoch [3/10] Average Loss: 0.1906 | Time: 1234.46 sec (20.57 min)

Epoch [4/10]


Training: 100%|██████████| 2803/2803 [19:53<00:00,  2.35batch/s, Batch Loss=0.1677, Avg Loss=0.1799]


Epoch [4/10] Average Loss: 0.1799 | Time: 1193.55 sec (19.89 min)

Epoch [5/10]


Training: 100%|██████████| 2803/2803 [19:20<00:00,  2.41batch/s, Batch Loss=0.1880, Avg Loss=0.1631]


Epoch [5/10] Average Loss: 0.1631 | Time: 1160.71 sec (19.35 min)

Epoch [6/10]


Training:  54%|█████▎    | 1504/2803 [10:27<08:30,  2.55batch/s, Batch Loss=0.1150, Avg Loss=0.1379]