In [1]:
import torch
import torch.nn as nn
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import os
import pandas as pd
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
class MultiLabelResNet50(nn.Module):
    def __init__(self, num_classes, hidden_size):
        super(MultiLabelResNet50, self).__init__()
        
        # Load pre-trained ResNet50
        self.base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        
        # Modify the fully connected layer for multi-label classification
        self.base_model.fc = nn.Sequential(
            nn.Linear(self.base_model.fc.in_features, hidden_size),  # 512, New intermediate layer
            nn.ReLU(),
            nn.Dropout(0.5),  # Dropout to prevent overfitting
            nn.Linear(hidden_size, num_classes),  # Output layer
            nn.Sigmoid()  # Sigmoid for multi-label classification (soften the data)
            #nn.Tanh()  #This is between -1 and 1

           # nn.Linear(self.base_model.fc.in_features, num_classes),
           # nn.Sigmoid()  # Sigmoid activation for multi-label classification
        )

    def forward(self, x):
        return self.base_model(x)

class MultiLabelResNet50_2(nn.Module):
    def __init__(self, num_classes):
        super(MultiLabelResNet50_2, self).__init__()
        
        # Load pre-trained ResNet50
        self.base_model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        
        # Modify the fully connected layer for multi-label classification
        self.base_model.fc = nn.Sequential(
            nn.Linear(self.base_model.fc.in_features, 256),  # New intermediate layer. ##512 --> 256
            nn.ReLU(),
            nn.Dropout(0.6),  # Dropout to prevent overfitting ##0.5 --> 0.6
            nn.Linear(256, num_classes),  # Output layer
            nn.Sigmoid()  # Sigmoid for multi-label classification (soften the data)
            #nn.Tanh()  #This is between -1 and 1

           # nn.Linear(self.base_model.fc.in_features, num_classes),
           # nn.Sigmoid()  # Sigmoid activation for multi-label classification
        )

    def forward(self, x):
        return self.base_model(x)

class MultiLabelDenseNet121(nn.Module):
    def __init__(self, num_classes, hidden_size):
        super(MultiLabelDenseNet121, self).__init__()

        # Load pre-trained DenseNet-121
        self.base_model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        
        # Replace the classifier with a custom head
        self.base_model.classifier = nn.Sequential(
            nn.Linear(self.base_model.classifier.in_features, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, num_classes),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.base_model(x)

In [9]:
model_pe = MultiLabelResNet50(num_classes=1, hidden_size=512).to(device)
model_pe.load_state_dict(torch.load('models/best_pe1_model.pth'))
model_pe.eval()

model_cm = MultiLabelResNet50(num_classes=1, hidden_size=512).to(device)
model_cm.load_state_dict(torch.load('models/best_cm0_model.pth'))
model_cm.eval()

model_ecm = MultiLabelResNet50(num_classes=1, hidden_size=512).to(device)
model_ecm.load_state_dict(torch.load('models/best_ec1_model.pth'))
model_ecm.eval()

model_lo = MultiLabelResNet50(num_classes=1, hidden_size=256).to(device)
model_lo.load_state_dict(torch.load('amb_models/lo_partial/epoch_3.pth'))
model_lo.eval()

model_lo_fr = MultiLabelResNet50(num_classes=1, hidden_size=256).to(device)
model_lo_fr.load_state_dict(torch.load('amb_models/lo_partial_frontal/epoch_6.pth'))
model_lo_fr.eval()

model_lo_l = MultiLabelResNet50(num_classes=1, hidden_size=256).to(device)
model_lo_l.load_state_dict(torch.load('amb_models/lo_partial_lateral/epoch_4.pth'))
model_lo_l.eval()

model_lo_fr1 = MultiLabelResNet50(num_classes=1, hidden_size=256).to(device)
model_lo_fr1.load_state_dict(torch.load('amb_models/lo_partial_frontal_2/epoch_7.pth'))
model_lo_fr1.eval()

model_lo_l1 = MultiLabelResNet50(num_classes=1, hidden_size=256).to(device)
model_lo_l1.load_state_dict(torch.load('amb_models/lo_partial_lateral_2/epoch_8.pth'))
model_lo_l1.eval()

model_fr = MultiLabelResNet50_2(num_classes=1).to(device)
model_fr.load_state_dict(torch.load('models/best_fr4_model.pth'))
model_fr.eval()

model_nf = MultiLabelResNet50(num_classes=1, hidden_size=256).to(device)
model_nf.load_state_dict(torch.load('amb_models/nf_partial/epoch_3.pth'))
model_nf.eval()

"""
model_sd = MultiLabelResNet50(num_classes=1, hidden_size=256).to(device)
model_sd.load_state_dict(torch.load('amb_models/sd_partial/epoch_2.pth'))
model_sd.eval()
"""

"""
model_pn = MultiLabelResNet50(num_classes=1, hidden_size=256).to(device)
model_pn.load_state_dict(torch.load('amb_models/pn_partial/epoch_6.pth'))
model_pn.eval()
"""

"\nmodel_pn = MultiLabelResNet50(num_classes=1, hidden_size=256).to(device)\nmodel_pn.load_state_dict(torch.load('amb_models/pn_partial/epoch_6.pth'))\nmodel_pn.eval()\n"

In [10]:
import torch
import os
import pandas as pd
from tqdm import tqdm

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories
dir_dict = {
    "frontal": "input_images/test_frontal",
    "lateral": "input_images/test_lateral"
}

# Column setup
columns = ["Id", "No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", 
           "Pneumonia", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices"]

average_values = {
    "No Finding": -0.734655,
    "Enlarged Cardiomediastinum": -0.275805,
    "Cardiomegaly": 0.190770,
    "Lung Opacity": 0.836288,
    "Pneumonia": 0.031183,
    "Pleural Other": 0.521795,
    "Fracture": 0.392374,
    "Support Devices": 0.888289
}

batch_size = 64
predictions = []

# Loop over both frontal and lateral directories
for view_type, test_dir in dir_dict.items():
    model_lo2 = model_lo_fr if view_type == "frontal" else model_lo_l
    model_lo3 = model_lo_fr1 if view_type == "frontal" else model_lo_l1
    file_list = [f for f in os.listdir(test_dir) if f.endswith(".pt")]
    
    batch = []
    batch_filenames = []

    for filename in tqdm(file_list, desc=f"Processing {view_type}"):
        image_path = os.path.join(test_dir, filename)
        image_tensor = torch.load(image_path).to(device)
        batch.append(image_tensor)
        batch_filenames.append(filename.split('.')[0])

        if len(batch) == batch_size or filename == file_list[-1]:
            input_batch = torch.stack(batch)

            with torch.no_grad():
                output_pe = model_pe(input_batch).cpu().numpy()
                output_cm = model_cm(input_batch).cpu().numpy()
                output_ecm = model_ecm(input_batch).cpu().numpy()
                output_lo = model_lo(input_batch).cpu().numpy()
                output_lo2 = model_lo2(input_batch).cpu().numpy()
                output_lo3 = model_lo3(input_batch).cpu().numpy()
                output_fr = model_fr(input_batch).cpu().numpy()
                output_nf = model_nf(input_batch).cpu().numpy()
                #output_sd = model_sd(input_batch).cpu().numpy()
                #output_pn = model_pn(input_batch).cpu().numpy()

            for i in range(len(batch)):
                pe_score = output_pe[i][0] * 2 - 1  # Rescale from [0,1] to [-1,1] if needed
                cm_score = output_cm[i][0] * 2 - 1
                ecm_score = output_ecm[i][0] * 2 - 1
                lo_score = output_lo[i][0] * 2 - 1
                lo_score2 = output_lo2[i][0] * 2 - 1
                lo_score3 = output_lo3[i][0] * 2 - 1
                fr_score = output_fr[i][0] * 2 - 1
                nf_score = output_nf[i][0] * 2 - 1
                #sd_score = output_sd[i][0] * 2 - 1
                #pn_score = output_pn[i][0] * 2 - 1

                row = [batch_filenames[i]]
                for col in columns[1:]:
                    if col == "Pleural Effusion":
                        row.append(pe_score)
                    elif col == "Cardiomegaly":
                        row.append(cm_score)
                    elif col == "Enlarged Cardiomediastinum":
                        row.append(ecm_score)
                    elif col == "Lung Opacity":
                        row.append((lo_score + lo_score2 + lo_score3)/3)
                    elif col == "Fracture":
                        row.append(fr_score)
                    elif col == "No Finding":
                        row.append(nf_score)
                    #elif col == "Support Devices":
                    #    row.append(sd_score)
                    #elif col == "Pneumonia":
                    #    row.append(pn_score)
                    else:
                        row.append(average_values.get(col, 0))
                predictions.append(row)

            batch = []
            batch_filenames = []

# Save all predictions
df_predictions = pd.DataFrame(predictions, columns=columns)
df_predictions = df_predictions.sort_values(by="Id")
df_predictions.to_csv('amb_test_predictions.csv', index=False)

print("✅ Predictions saved to 'amb_test_predictions.csv'")

Processing frontal:   0%|          | 0/19347 [00:00<?, ?it/s]

Processing frontal: 100%|██████████| 19347/19347 [09:58<00:00, 32.31it/s] 
Processing lateral: 100%|██████████| 3249/3249 [01:48<00:00, 29.90it/s]


✅ Predictions saved to 'amb_test_predictions.csv'


In [6]:
df_predictions = pd.DataFrame(predictions, columns=columns)
df_predictions = df_predictions.sort_values(by="Id", ascending=True)
df_predictions.head()

Unnamed: 0,Id,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Pneumonia,Pleural Effusion,Pleural Other,Fracture,Support Devices
811,100018,-0.960788,-0.551876,0.209412,0.77288,0.031183,-0.925029,0.521795,0.177863,0.888289
12203,100019,-0.965522,-0.620867,-0.716572,0.929377,0.031183,0.126851,0.521795,0.282742,0.888289
14779,100022,-0.996358,0.052508,-0.716739,0.944411,0.031183,0.740513,0.521795,0.602488,0.888289
12851,100023,-0.946412,0.035041,-0.613445,0.994,0.031183,0.949374,0.521795,0.29663,0.888289
6047,100053,-0.274712,0.094983,0.718571,0.930785,0.031183,-0.666675,0.521795,-0.006521,0.888289
