In [None]:
# Import all necessary libraries first
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models, datasets
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report, f1_score

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Image settings
IMG_SIZE = 240
BATCH_SIZE = 32
NUM_WORKERS = 4

class GaussianBlur:
    """Custom transform for Gaussian Blur"""
    def __call__(self, img):
        np_img = np.array(img)
        blurred = cv2.GaussianBlur(np_img, (3,3), 0)
        return transforms.functional.to_tensor(blurred)

# Define core preprocessing (applied to all splits)
core_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    GaussianBlur(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Additional augmentation for training only
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.RandomAffine(0, translate=(0.2, 0.2), scale=(0.8, 1.2)),
    GaussianBlur(),  # Same Gaussian blur as in core transforms
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms are just the core transforms
val_transforms = core_transforms
test_transforms = core_transforms

# Load datasets
def load_data():
    train_dataset = datasets.ImageFolder(
        root='/kaggle/input/chest-xray/chest_xray_jj_811/train',
        transform=train_transforms
    )
    
    val_dataset = datasets.ImageFolder(
        root='/kaggle/input/chest-xray/chest_xray_jj_811/val',
        transform=val_transforms
    )
    
    test_dataset = datasets.ImageFolder(
        root='/kaggle/input/chest-xray/chest_xray_jj_811/test',
        transform=val_transforms
    )
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS
    )
    
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = load_data()

In [None]:
class ChestXrayModel(nn.Module):
    def __init__(self):
        super(ChestXrayModel, self).__init__()
        # Load pretrained EfficientNetB1
        self.efficientnet = models.efficientnet_b1(pretrained=True)
        
        # Freeze base model parameters
        for param in self.efficientnet.parameters():
            param.requires_grad = False
            
        # Modify classifier
        num_ftrs = self.efficientnet.classifier[1].in_features
        self.efficientnet.classifier = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(num_ftrs, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, 3)
        )
        
    def forward(self, x):
        return self.efficientnet(x)

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in tqdm(loader, desc='Training'):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predictions = torch.max(outputs, 1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
        
    return running_loss / len(loader), correct / total

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc='Validation'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predictions = torch.max(outputs, 1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    
    return running_loss / len(loader), correct / total

# Initialize history dictionary for both phases
history = {
    'phase1': {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    },
    'phase2': {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }
}
# Load data
train_loader, val_loader, test_loader = load_data()

# Initialize model
model = ChestXrayModel().to(device)
criterion = nn.CrossEntropyLoss()

# First training phase with frozen base
print("Phase 1: Training with frozen base model...")
optimizer = optim.Adam(model.efficientnet.classifier.parameters(), lr=0.001)

for epoch in range(20):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Save metrics for phase 1
    history['phase1']['train_loss'].append(train_loss)
    history['phase1']['val_loss'].append(val_loss)
    history['phase1']['train_acc'].append(train_acc)
    history['phase1']['val_acc'].append(val_acc)
    
    print(f'Epoch {epoch+1}/20:')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}\n')

# Unfreeze last few layers for fine-tuning
print("\nPhase 2: Fine-tuning model...")
for param in model.efficientnet.features[-3:].parameters():
    param.requires_grad = True

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001)

for epoch in range(20):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    
    # Save metrics for phase 2
    history['phase2']['train_loss'].append(train_loss)
    history['phase2']['val_loss'].append(val_loss)
    history['phase2']['train_acc'].append(train_acc)
    history['phase2']['val_acc'].append(val_acc)
    
    print(f'Fine-tuning Epoch {epoch+1}/12:')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}\n')

# Save model and complete history
torch.save({
    'model_state_dict': model.state_dict(),
    'history': history
}, 'chest_xray_model_complete.pth')

In [None]:
def load_test_data(test_dir='/kaggle/input/chest-xray/chest_xray_jj_811/test', batch_size=32):
    """Load the test dataset with proper transforms"""
    # Define core preprocessing (same as training, but without augmentation)
    test_transforms = transforms.Compose([
        transforms.Resize((240, 240)),
        GaussianBlur(),  # Same Gaussian blur as in training
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Load test dataset
    test_dataset = datasets.ImageFolder(
        root=test_dir,
        transform=test_transforms
    )

    # Create test dataloader
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,  # Don't shuffle test data
        num_workers=4
    )

    return test_loader
    
def evaluate_model(model_path, test_loader, device):
    """Complete model evaluation on test set"""
    # Load the trained model
    checkpoint = torch.load(model_path)
    model = ChestXrayModel()  # Make sure this matches your training architecture
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()  # Set to evaluation mode
    
    # Get predictions on test set
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():  # No gradient computation needed for testing
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)  # Remove squeeze since we have 3 outputs
            _, predictions = torch.max(outputs, 1)  # Get class with highest probability
            
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    # Convert to numpy arrays
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    # Calculate metrics
    # Use 'macro' average for multi-class
    f1 = f1_score(all_labels, all_predictions, average='macro')
    conf_matrix = confusion_matrix(all_labels, all_predictions)
    class_report = classification_report(all_labels, all_predictions)
    
    # Get class names from test loader
    class_names = test_loader.dataset.classes
    
    # Plot confusion matrix with class labels
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names)
    plt.title('Confusion Matrix on Test Set')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    # Print metrics
    print("\nTest Set Metrics:")
    print(f"Macro F1 Score: {f1:.4f}")
    print("\nClassification Report:")
    print(class_report)
    
    # Plot training history if available
    if 'history' in checkpoint:
        plot_complete_training_history(checkpoint['history'])
    
    return {
        'f1_score': f1,
        'confusion_matrix': conf_matrix,
        'predictions': all_predictions,
        'true_labels': all_labels,
        'class_names': class_names
    }

# Add this function to plot per-class metrics
def plot_per_class_metrics(results):
    """Plot per-class precision, recall, and F1 scores"""
    class_names = results['class_names']
    report = classification_report(results['true_labels'], 
                                 results['predictions'], 
                                 output_dict=True)
    
    # Extract per-class metrics
    metrics = []
    for class_name in class_names:
        metrics.append({
            'Class': class_name,
            'Precision': report[str(class_names.index(class_name))]['precision'],
            'Recall': report[str(class_names.index(class_name))]['recall'],
            'F1-Score': report[str(class_names.index(class_name))]['f1-score']
        })
    
    # Create DataFrame for plotting
    df_metrics = pd.DataFrame(metrics)
    
    # Plot
    plt.figure(figsize=(12, 6))
    x = np.arange(len(class_names))
    width = 0.25
    
    plt.bar(x - width, df_metrics['Precision'], width, label='Precision')
    plt.bar(x, df_metrics['Recall'], width, label='Recall')
    plt.bar(x + width, df_metrics['F1-Score'], width, label='F1-Score')
    
    plt.ylabel('Score')
    plt.title('Per-class Performance Metrics')
    plt.xticks(x, class_names, rotation=45)
    plt.legend()
    plt.tight_layout()
    plt.show()

def plot_complete_training_history(history):
    """Plot training history metrics for both phases"""
    plt.figure(figsize=(15, 5))
    
    # Combine histories from both phases
    train_loss = history['phase1']['train_loss'] + history['phase2']['train_loss']
    val_loss = history['phase1']['val_loss'] + history['phase2']['val_loss']
    train_acc = history['phase1']['train_acc'] + history['phase2']['train_acc']
    val_acc = history['phase1']['val_acc'] + history['phase2']['val_acc']
    
    # Create epoch numbers for x-axis
    epochs = range(1, len(train_loss) + 1)
    phase1_end = len(history['phase1']['train_loss'])
    
    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, label='Train Loss')
    plt.plot(epochs, val_loss, label='Validation Loss')
    plt.axvline(x=phase1_end, color='r', linestyle='--', 
                label='Start of Fine-tuning')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_acc, label='Train Accuracy')
    plt.plot(epochs, val_acc, label='Validation Accuracy')
    plt.axvline(x=phase1_end, color='r', linestyle='--', 
                label='Start of Fine-tuning')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# Usage:
test_loader = load_test_data()
results = evaluate_model('chest_xray_model_complete.pth', test_loader, device)
plot_per_class_metrics(results)