In [None]:
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from PIL import Image
from sklearn.metrics import multilabel_confusion_matrix, classification_report, roc_curve, auc
import os

from PIL import Image, ImageOps
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from sklearn.model_selection import train_test_split
import json

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

In [None]:
BASE_PATH = "path_to_directory"
dataset = pd.read_csv("path_to_directory")
len(dataset)

In [None]:
class ResNet50(nn.Module):
    def __init__(self, out_size):
        super(ResNet50, self).__init__()
        # Use the latest ImageNet weights for ResNet50
        self.resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        num_ftrs = self.resnet50.fc.in_features
        self.resnet50.fc = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()  # Assuming you're doing a binary classification, adjust as needed
        )

    def forward(self, x):
        return self.resnet50(x)

In [None]:
N_LABELS = 14
# Load the saved model
best_model = ResNet50(out_size=N_LABELS)
best_model.load_state_dict(torch.load('path_to_best_model.pth'))
best_model = best_model.to(device)
best_model.eval()  

In [None]:
# Load the JSON file with test image paths
with open('lateral_test.json', 'r') as f:
    test_image_paths = json.load(f)['lateral_images']
# Prepare image transformation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Prepare arrays for true and predicted labels
y_true = []
y_pred = []
y_pred_proba = []

# Classes
classes = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 
           'Fracture', 'Lung Lesion', 'Lung Opacity', 'No Finding', 'Pleural Effusion', 
           'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

# Create a directory to save processed images
os.makedirs('processed_images', exist_ok=True)

In [None]:
# Process each image
for i, img_path in enumerate(test_image_paths):
    full_path = f"{BASE_PATH}/{img_path}"
    
    # Get true labels
    true_labels = dataset[dataset['lateral_image'] == img_path][classes].values[0]
    y_true.append(true_labels)
    
    # Predict labels
    img = Image.open(full_path).convert('RGB')
    
    # Save the processed image
    save_path = f"processed_images/image_{i}.jpg"
    img.save(save_path)
    
    img_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        output = best_model(img_tensor)
    pred_proba = output.cpu().numpy()[0]
    pred_labels = (pred_proba > 0.5).astype(float)
    y_pred.append(pred_labels)
    y_pred_proba.append(pred_proba)

# Convert to numpy arrays
y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_pred_proba = np.array(y_pred_proba)

# Compute multilabel confusion matrices
mcm = multilabel_confusion_matrix(y_true, y_pred)

In [None]:
# Plot confusion matrices for each class
def plot_multilabel_confusion_matrix(confusion_mtx, class_names):
    num_classes = confusion_mtx.shape[0]
    ncols = 3  # Set the number of columns for the plot
    nrows = (num_classes + ncols - 1) // ncols  # Calculate the number of rows needed

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 4, nrows * 4))
    axes = axes.flatten()  # Flatten the 2D array of axes for easy iteration

    for i in range(num_classes):
        ax = axes[i]
        ax.matshow(confusion_mtx[i], cmap=plt.cm.Blues, alpha=0.5)
        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(class_names[i])

        # Set x and y axis ticks to show "Positive" first and "Negative" second
        ax.set_xticks([0, 1])
        ax.set_xticklabels(['Positive', 'Negative'])  # Positive first
        ax.set_yticks([0, 1])
        ax.set_yticklabels(['Positive', 'Negative'])  # Positive first

        # Show the counts
        for j in range(confusion_mtx[i].shape[0]):
            for k in range(confusion_mtx[i].shape[1]):
                ax.text(k, j, confusion_mtx[i][j, k], ha='center', va='center')

    # Hide any unused subplots
    for i in range(num_classes, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.savefig('confusion_matrices.png')
    plt.close()

In [None]:
# Call the function to plot the confusion matrices
plot_multilabel_confusion_matrix(mcm, classes)

# Print and save classification report
report = classification_report(y_true, y_pred, target_names=classes)
print(report)
with open('classification_report.txt', 'w') as f:
    f.write(report)

# Compute overall metrics
precision = []
recall = []
f1_score = []

for cm in mcm:
    tn, fp, fn, tp = cm.ravel()
    precision.append(tp / (tp + fp) if (tp + fp) > 0 else 0)
    recall.append(tp / (tp + fn) if (tp + fn) > 0 else 0)
    f1_score.append(2 * precision[-1] * recall[-1] / (precision[-1] + recall[-1]) if (precision[-1] + recall[-1]) > 0 else 0)

# Plot overall metrics
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(classes))
width = 0.25

ax.bar(x - width, precision, width, label='Precision')
ax.bar(x, recall, width, label='Recall')
ax.bar(x + width, f1_score, width, label='F1-score')

ax.set_ylabel('Scores')
ax.set_title('Precision, Recall, and F1-score for each class')
ax.set_xticks(x)
ax.set_xticklabels(classes, rotation=45, ha='right')
ax.legend()

plt.tight_layout()
plt.savefig('overall_metrics.png')
plt.close()

# Plot ROC curves
plt.figure(figsize=(10, 8))

for i, class_name in enumerate(classes):
    fpr, tpr, _ = roc_curve(y_true[:, i], y_pred_proba[:, i])
    roc_auc = auc(fpr, tpr)
    
    plt.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.savefig('roc_curves.png')
plt.close()

# Calculate and save average AUC
avg_auc = np.mean([auc(roc_curve(y_true[:, i], y_pred_proba[:, i])[0], roc_curve(y_true[:, i], y_pred_proba[:, i])[1]) for i in range(len(classes))])
with open('average_auc.txt', 'w') as f:
    f.write(f"Average AUC: {avg_auc:.3f}")