In [1]:
import os

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchvision import transforms, models
from tqdm import tqdm

from Model_A_Dataset import Model_A_Dataset

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
device

device(type='cuda')

## Load and pre-process data

In [4]:
# Load Data
df_train = pd.read_csv('../data/data_entries/miccai2023_nih-cxr-lt_labels_train.csv')
df_val = pd.read_csv('../data/data_entries/miccai2023_nih-cxr-lt_labels_val.csv')
df_test = pd.read_csv('../data/data_entries/miccai2023_nih-cxr-lt_labels_test.csv')

In [5]:
# Image directories
image_dir_train = '../data/train_images'
image_dir_test = '../data/test_images'

In [6]:
def get_valid_image_ids(df, image_dir):
    image_files = set(os.listdir(image_dir))
    return df[df['id'].isin(image_files)]['id']

In [7]:
# Filter valid images, only images where their IDs find in the image folder
valid_train_ids = get_valid_image_ids(df_train, image_dir_train)
valid_val_ids = get_valid_image_ids(df_val, image_dir_train)
valid_test_ids = get_valid_image_ids(df_test, image_dir_test)

df_train_valid = df_train[df_train['id'].isin(valid_train_ids)]
df_val_valid = df_val[df_val['id'].isin(valid_val_ids)]
df_test_valid = df_test[df_test['id'].isin(valid_test_ids)]

In [8]:
df_train_val = pd.concat([df_train_valid, df_val_valid], ignore_index=True)
assert df_train_val.shape[0] == df_train_valid.shape[0] + df_val_valid.shape[0]

In [9]:
# Drop subj_id
df_train_val = df_train_val.copy()
df_train_val.drop(columns=['subj_id'], inplace=True)

df_test_valid = df_test_valid.copy()
df_test_valid.drop(columns=['subj_id'], inplace=True)

In [10]:
# Check inconsistent rows ( row has No Finding 1 and other categories 1)
inconsistent_rows = df_train_val[(df_train_val['No Finding'] == 1) & (df_train_val.iloc[:, 1:-1].sum(axis=1) > 0)]
assert len(inconsistent_rows) == 0

In [11]:
category_mapping = {'No Finding': 0, 'Finding': 1}

In [12]:
# Define dataset transformations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [13]:
# Step 1: Subset of the dataset using stratification on 'No Finding'
df_half, _ = train_test_split(
    df_train_val, test_size=0.2, stratify=df_train_val['No Finding'], random_state=42
)

# Step 2: Split the subset into training (80%) and validation (20%) sets, also stratified on 'No Finding'
df_train, df_val = train_test_split(
    df_half, test_size=0.2, stratify=df_half['No Finding'], random_state=42
)

In [14]:
df_train.shape, df_val.shape

((55375, 21), (13844, 21))

In [15]:
# Create datasets and dataloaders
train_dataset = Model_A_Dataset(df_train, image_dir_train, category_mapping, transform=train_transform)
val_dataset = Model_A_Dataset(df_val, image_dir_train, category_mapping, transform=val_transform)
test_dataset = Model_A_Dataset(df_test_valid, image_dir_test, category_mapping, transform=val_transform)

In [16]:
batch_size=32

In [17]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [18]:
len(train_loader), len(val_loader), len(test_loader)

(1731, 433, 659)

In [19]:
img, label, img_id = train_dataset[1]
print(f"Image ID: {img_id}")
print(f"Label: {label}")
print(f"Image Shape: {img.shape}")

Image ID: 00022949_001.png
Label: 1
Image Shape: torch.Size([3, 224, 224])


In [20]:
df_train[df_train.id == img_id]

Unnamed: 0,id,Atelectasis,Cardiomegaly,Consolidation,Edema,Effusion,Emphysema,Fibrosis,Hernia,Infiltration,...,Nodule,Pleural Thickening,Pneumonia,Pneumothorax,Pneumoperitoneum,Pneumomediastinum,Subcutaneous Emphysema,Tortuous Aorta,Calcification of the Aorta,No Finding
68218,00022949_001.png,0,0,0,0,1,0,0,0,1,...,0,0,1,0,0,0,0,0,0,0


## Define train, validate and test

In [21]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    for images, labels, _ in tqdm(dataloader, desc="Training"):
        images = images.to(device)
        labels = labels.to(device)
        
        # Ensure labels are of type float32 and reshape to [batch_size, 1]
        labels = labels.view(-1, 1).float()
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Collect predictions and labels for dynamic threshold calculation
        all_preds.append(outputs.detach().cpu().numpy())
        all_labels.append(labels.cpu().numpy())

    # Flatten predictions and labels
    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Find the best threshold based on F1 score
    thresholds = np.arange(0.4, 1.0, 0.025)
    best_threshold = 0.5
    best_f1 = 0
    best_accuracy = 0

    for threshold in thresholds:
        preds = (all_preds > threshold).astype(float)
        f1 = f1_score(all_labels, preds, average='binary')
        accuracy = accuracy_score(all_labels, preds)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
            best_accuracy = accuracy

    avg_loss = total_loss / len(dataloader)

    return avg_loss, best_accuracy, best_f1, best_threshold,


In [22]:
def validate(model, dataloader, criterion, device, threshold):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels, _ in tqdm(dataloader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            
            # Ensure labels are of type float32 and reshape to [batch_size, 1]
            labels = labels.view(-1, 1).float()
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            preds = (outputs > threshold).float()
            correct += (preds == labels).all(dim=1).sum().item()
            total += labels.size(0)

            # Collect predictions and labels for F1 score calculation
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    f1 = f1_score(all_labels, all_preds, average='binary')
    
    return avg_loss, f1

In [23]:
def test(model, dataloader, device, threshold):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels, _ in tqdm(dataloader, desc="Testing"):
            images, labels = images.to(device), labels.to(device)
            labels = labels.view(-1, 1).float()
            outputs = model(images)
            preds = (outputs > threshold).float().cpu().numpy()
            all_preds.append(preds)
            all_labels.append(labels.cpu().numpy())

    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    f1 = f1_score(all_labels, all_preds, average='binary')

    return all_preds, all_labels, f1

## Define the model: train, validate and test it

In [24]:
model = models.resnet152(weights='IMAGENET1K_V2') 
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(), 
    nn.Dropout(0.4),
    nn.Linear(512, 256),
    nn.BatchNorm1d(256),
    nn.ReLU(), 
    nn.Dropout(0.4),
    nn.Linear(256, 1),
    nn.Sigmoid()
)

model.to(device)
print(device)

cuda


In [25]:
# Define optimizer, loss, and metrics
criterion = nn.BCELoss()  # Binary Cross-Entropy Loss for binary classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler with a gentler decay
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

In [26]:
# Early stopping and model saving
best_f1 = float("inf")
patience = 7
counter = 0

best_model_path = "best_model_a_v3.pth"

# Training loop
num_epochs = 100

In [28]:
# Training and validation
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss, train_acc, train_f1, best_threshold = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_f1 = validate(model, val_loader, criterion, device, best_threshold)

    print(f"Train Loss: {train_loss:.4f}, F1: {train_f1:.4f}, Accuracy: {train_acc:.4f}, Threshold: {best_threshold}")
    print(f"Val Loss: {val_loss:.4f}, F1: {val_f1:.4f}")

    scheduler.step()
    
    # Early stopping logic
    if val_loss < best_f1:
        print("Validation F1 improved. Saving model...")
        best_f1 = val_loss
        torch.save(model.state_dict(), best_model_path)
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered.")
            break

    print('-' * 150)

Epoch 1/100


Training: 100%|██████████| 1731/1731 [06:35<00:00,  4.38it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.61it/s]


Train Loss: 0.6141, F1: 0.6272, Accuracy: 0.6581, Threshold: 0.4
Val Loss: 0.6060, F1: 0.6257
Validation F1 improved. Saving model...
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 2/100


Training: 100%|██████████| 1731/1731 [06:39<00:00,  4.33it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.57it/s]


Train Loss: 0.6070, F1: 0.6355, Accuracy: 0.6673, Threshold: 0.4
Val Loss: 0.6001, F1: 0.6476
Validation F1 improved. Saving model...
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 3/100


Training: 100%|██████████| 1731/1731 [06:38<00:00,  4.34it/s]
Validating: 100%|██████████| 433/433 [01:04<00:00,  6.71it/s]


Train Loss: 0.6011, F1: 0.6421, Accuracy: 0.6742, Threshold: 0.4
Val Loss: 0.6277, F1: 0.6477
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 4/100


Training: 100%|██████████| 1731/1731 [06:36<00:00,  4.37it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.57it/s]


Train Loss: 0.5958, F1: 0.6480, Accuracy: 0.6752, Threshold: 0.4
Val Loss: 0.6167, F1: 0.6166
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 5/100


Training: 100%|██████████| 1731/1731 [06:36<00:00,  4.37it/s]
Validating: 100%|██████████| 433/433 [01:06<00:00,  6.47it/s]


Train Loss: 0.5950, F1: 0.6490, Accuracy: 0.6783, Threshold: 0.4
Val Loss: 0.5969, F1: 0.6251
Validation F1 improved. Saving model...
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 6/100


Training: 100%|██████████| 1731/1731 [06:38<00:00,  4.34it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.65it/s]


Train Loss: 0.5884, F1: 0.6537, Accuracy: 0.6819, Threshold: 0.4
Val Loss: 0.5941, F1: 0.6274
Validation F1 improved. Saving model...
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 7/100


Training: 100%|██████████| 1731/1731 [06:37<00:00,  4.36it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.64it/s]


Train Loss: 0.5857, F1: 0.6568, Accuracy: 0.6879, Threshold: 0.4
Val Loss: 0.5883, F1: 0.6402
Validation F1 improved. Saving model...
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 8/100


Training: 100%|██████████| 1731/1731 [06:39<00:00,  4.33it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.59it/s]


Train Loss: 0.5823, F1: 0.6584, Accuracy: 0.6889, Threshold: 0.4
Val Loss: 0.5880, F1: 0.6395
Validation F1 improved. Saving model...
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 9/100


Training: 100%|██████████| 1731/1731 [06:37<00:00,  4.36it/s]
Validating: 100%|██████████| 433/433 [01:04<00:00,  6.67it/s]


Train Loss: 0.5807, F1: 0.6623, Accuracy: 0.6940, Threshold: 0.4
Val Loss: 0.5872, F1: 0.6604
Validation F1 improved. Saving model...
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 10/100


Training: 100%|██████████| 1731/1731 [06:38<00:00,  4.34it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.64it/s]


Train Loss: 0.5755, F1: 0.6669, Accuracy: 0.6959, Threshold: 0.4
Val Loss: 0.5887, F1: 0.6521
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 11/100


Training: 100%|██████████| 1731/1731 [06:35<00:00,  4.38it/s]
Validating: 100%|██████████| 433/433 [01:04<00:00,  6.67it/s]


Train Loss: 0.5713, F1: 0.6720, Accuracy: 0.6989, Threshold: 0.4
Val Loss: 0.5860, F1: 0.6537
Validation F1 improved. Saving model...
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 12/100


Training: 100%|██████████| 1731/1731 [06:37<00:00,  4.35it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.61it/s]


Train Loss: 0.5655, F1: 0.6742, Accuracy: 0.7010, Threshold: 0.4
Val Loss: 0.5882, F1: 0.6511
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 13/100


Training: 100%|██████████| 1731/1731 [06:35<00:00,  4.38it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.58it/s]


Train Loss: 0.5607, F1: 0.6813, Accuracy: 0.7058, Threshold: 0.4
Val Loss: 0.5962, F1: 0.6461
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 14/100


Training: 100%|██████████| 1731/1731 [06:35<00:00,  4.38it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.60it/s]


Train Loss: 0.5547, F1: 0.6856, Accuracy: 0.7122, Threshold: 0.4
Val Loss: 0.5919, F1: 0.6416
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 15/100


Training: 100%|██████████| 1731/1731 [06:35<00:00,  4.38it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.56it/s]


Train Loss: 0.5479, F1: 0.6919, Accuracy: 0.7168, Threshold: 0.4
Val Loss: 0.5953, F1: 0.6531
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 16/100


Training: 100%|██████████| 1731/1731 [06:37<00:00,  4.36it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.57it/s]


Train Loss: 0.5354, F1: 0.7009, Accuracy: 0.7267, Threshold: 0.4
Val Loss: 0.5988, F1: 0.6418
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 17/100


Training: 100%|██████████| 1731/1731 [06:35<00:00,  4.38it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.62it/s]


Train Loss: 0.5220, F1: 0.7117, Accuracy: 0.7355, Threshold: 0.4
Val Loss: 0.6006, F1: 0.6533
------------------------------------------------------------------------------------------------------------------------------------------------------
Epoch 18/100


Training: 100%|██████████| 1731/1731 [06:35<00:00,  4.38it/s]
Validating: 100%|██████████| 433/433 [01:05<00:00,  6.59it/s]

Train Loss: 0.5077, F1: 0.7243, Accuracy: 0.7460, Threshold: 0.4
Val Loss: 0.6166, F1: 0.6473
Early stopping triggered.





In [29]:
model.load_state_dict(torch.load(best_model_path, weights_only=True))

<All keys matched successfully>

In [30]:
all_preds, all_labels, test_f1 = test(model, test_loader, device, best_threshold)

Testing: 100%|██████████| 659/659 [01:38<00:00,  6.67it/s]


In [31]:
print(f"Test F1: {test_f1}")

Test F1: 0.7945465002198994


In [32]:
from sklearn.metrics import f1_score, precision_score, recall_score

# Calculate F1 score (micro and macro)
f1_micro = f1_score(all_labels, all_preds, average='micro')
f1_macro = f1_score(all_labels, all_preds, average='macro')

print(f"F1 Micro: {f1_micro}, F1 Macro: {f1_macro}")

# Precision and recall
precision = precision_score(all_labels, all_preds, average='macro')
recall = recall_score(all_labels, all_preds, average='macro')
print(f"Precision: {precision}, Recall: {recall}")

F1 Micro: 0.7119206868744367, F1 Macro: 0.6563385520221927
Precision: 0.7118190441523605, Recall: 0.6530530158371544


In [33]:
from sklearn.metrics import classification_report

print(classification_report(all_labels, all_preds, target_names=category_mapping))

              precision    recall  f1-score   support

  No Finding       0.71      0.41      0.52      8015
     Finding       0.71      0.90      0.79     13066

    accuracy                           0.71     21081
   macro avg       0.71      0.65      0.66     21081
weighted avg       0.71      0.71      0.69     21081



In [34]:
# Visualize a few predictions and their true labels
for i in range(10):  # Show 5 examples
    print(f"True: {all_labels[i]}, Predicted: {all_preds[i]}")


True: [0.], Predicted: [1.]
True: [1.], Predicted: [1.]
True: [1.], Predicted: [1.]
True: [1.], Predicted: [1.]
True: [1.], Predicted: [1.]
True: [1.], Predicted: [1.]
True: [1.], Predicted: [1.]
True: [1.], Predicted: [1.]
True: [0.], Predicted: [1.]
True: [1.], Predicted: [1.]
