In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc

# Paths
""" Define your own Paths
dataset_root = 
test_path = 
model_save_path = 
os.makedirs(model_save_path, exist_ok=True)
"""

# Define HybridConv-Net for 3 classes
class HybridConvNet(nn.Module):
    def __init__(self, num_classes=3):
        super(HybridConvNet, self).__init__()

        # Load pretrained models
        self.densenet = models.densenet121(pretrained=True)
        self.mobilenet = models.mobilenet_v3_large(pretrained=True)
        self.efficientnet = models.efficientnet_b0(pretrained=True)

        # Remove classifiers
        self.densenet.classifier = nn.Identity()
        self.mobilenet.classifier = nn.Identity()
        self.efficientnet.classifier = nn.Identity()

        # Get concatenated feature size
        self._feature_dim = self._get_concat_feature_dim()

        # Fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(self._feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )


    def _get_concat_feature_dim(self):
        dummy_input = torch.randn(1, 3, 224, 224)
        with torch.no_grad():
            d_out = self.densenet(dummy_input)
            m_out = self.mobilenet(dummy_input)
            e_out = self.efficientnet(dummy_input)
        return d_out.shape[1] + m_out.shape[1] + e_out.shape[1]

    def forward(self, x):
        d_out = self.densenet(x)
        m_out = self.mobilenet(x)
        e_out = self.efficientnet(x)
        x_concat = torch.cat((d_out, m_out, e_out), dim=1)
        return self.fc(x_concat)

# Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load full dataset (train+val)
full_dataset = datasets.ImageFolder(root=dataset_root, transform=transform)

# Split: 90% train, 10% val
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

# Load test dataset separately (no split)
test_dataset = datasets.ImageFolder(root=test_path, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model, loss, optimizer
model = HybridConvNet(num_classes=3).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)

# For tracking accuracy
train_acc_list = []
val_acc_list = []

# Training loop
epochs = 30
for epoch in range(epochs):
    model.train()
    correct_train, total_train = 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        _, predicted = outputs.max(1)
        correct_train += predicted.eq(labels).sum().item()
        total_train += labels.size(0)

    train_accuracy = 100. * correct_train / total_train
    train_acc_list.append(train_accuracy)

    # Validation accuracy
    model.eval()
    correct_val, total_val = 0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct_val += predicted.eq(labels).sum().item()
            total_val += labels.size(0)

    val_accuracy = 100. * correct_val / total_val
    val_acc_list.append(val_accuracy)

    print(f"Epoch {epoch+1}/{epochs} - Train Acc: {train_accuracy:.2f}% - Validation Acc: {val_accuracy:.2f}%")

# Plot Training and Validation Accuracy vs Epochs BEFORE Testing
plt.figure(figsize=(7,5))
plt.plot(range(1, epochs+1), train_acc_list, label='Training Accuracy', marker='o')
plt.plot(range(1, epochs+1), val_acc_list, label='Validation Accuracy', marker='s')
plt.xlabel("Epochs")
plt.ylabel("Accuracy (%)")
plt.title("Training vs Validation Accuracy")
plt.legend()
plt.grid()
plt.show()

# Save model
torch.save(model.state_dict(), os.path.join(model_save_path, "HybridConvNet_3class.pth"))
print(" Model training complete and saved!")

# Testing
model.eval()
correct_test, total_test = 0, 0
all_labels, all_preds, all_probs = [], [], []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        _, predicted = outputs.max(1)

        correct_test += predicted.eq(labels).sum().item()
        total_test += labels.size(0)

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

test_accuracy = 100. * correct_test / total_test
print(f" Final Test Accuracy: {test_accuracy:.2f}%")

# Confusion Matrix
all_labels, all_preds, all_probs = np.array(all_labels), np.array(all_preds), np.array(all_probs)
class_names = full_dataset.classes  # ['Chickenpox', 'Measles', 'Monkeypox'] or similar
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(7,6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix - Test Set")
plt.show()

# Classification Report
report = classification_report(all_labels, all_preds, target_names=class_names)
print(f"\nClassification Report:\n{report}")

# ROC Curve for multi-class (One-vs-Rest)
plt.figure(figsize=(8,6))
for i, cls in enumerate(class_names):
    binary_labels = (all_labels == i).astype(int)
    plt_prob = all_probs[:, i]

    fpr, tpr, _ = roc_curve(binary_labels, plt_prob)
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f'{cls} (AUC = {roc_auc:.2f})')

plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve - Test Set")
plt.legend()
plt.grid()
plt.show()
