In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score
import torchvision.models as models
import matplotlib.pyplot as plt

# Check device compatibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the updated LeNet-5 model
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5)  # Change from 1 to 3 channels
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 53 * 53, 120)  # Adjust the size based on the new input dimensions
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 16 * 53 * 53)  # Adjust the size based on the new input dimensions
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load the pre-trained models
alexnet = models.alexnet(weights='IMAGENET1K_V1')
vggnet = models.vgg16(weights='IMAGENET1K_V1')
resnet = models.resnet18(weights='IMAGENET1K_V1')
googlenet = models.googlenet(weights='IMAGENET1K_V1')
xception = models.mobilenet_v3_large(weights='IMAGENET1K_V1')  # Closest to Xception in PyTorch
senet = models.squeezenet1_0(weights='IMAGENET1K_V1')  # SqueezeNet as a proxy for SENet

# Define models dictionary
models_dict = {
    'LeNet-5': LeNet5(),
    'AlexNet': alexnet,
    'VGGNet': vggnet,
    'ResNet': resnet,
    'GoogLeNet': googlenet,
    'Xception': xception,
    'SENet': senet
}

# Move models to the appropriate device
for model_name, model in models_dict.items():
    models_dict[model_name] = model.to(device)

# Define the data transformations for each dataset
transform_mnist = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_cifar = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the datasets
mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_mnist)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_mnist)
fmnist_train = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_mnist)
fmnist_test = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_mnist)
cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cifar)
cifar10_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)

# Define DataLoaders
train_loaders = {
    'MNIST': DataLoader(mnist_train, batch_size=64, shuffle=True),
    'FMNIST': DataLoader(fmnist_train, batch_size=64, shuffle=True),
    'CIFAR-10': DataLoader(cifar10_train, batch_size=64, shuffle=True),
}

test_loaders = {
    'MNIST': DataLoader(mnist_test, batch_size=64, shuffle=False),
    'FMNIST': DataLoader(fmnist_test, batch_size=64, shuffle=False),
    'CIFAR-10': DataLoader(cifar10_test, batch_size=64, shuffle=False),
}

# Define the training function
def train_model(model, dataloader, criterion, optimizer, epochs=10):
    model.train()  # Set the model to training mode
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in dataloader:
            # Move inputs and labels to the appropriate device
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Print statistics
            running_loss += loss.item()

        print(f'Epoch {epoch + 1}/{epochs}, Loss: {running_loss/len(dataloader)}')
    return model

# Define the evaluation function
def evaluate_model(model, dataloader):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    accuracy = correct / total
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    return accuracy, precision, recall, f1

# Training and evaluating models on datasets
criterion = nn.CrossEntropyLoss()
epochs = 10
results = {}

for model_name, model in models_dict.items():
    for dataset_name in train_loaders.keys():
        print(f"Training {model_name} on {dataset_name}...")

        # Clone the model to avoid reusing the same model instance
        model_copy = model

        # Use the correct optimizer
        optimizer = optim.Adam(model_copy.parameters(), lr=0.001)

        # Train the model
        model_copy = train_model(model_copy, train_loaders[dataset_name], criterion, optimizer, epochs=epochs)

        # Evaluate the model
        acc, precision, recall, f1 = evaluate_model(model_copy, test_loaders[dataset_name])

        # Store the results
        results[(model_name, dataset_name)] = {
            'accuracy': acc,
            'precision': precision,
            'recall': recall,
            'f1_score': f1
        }

        print(f"Results for {model_name} on {dataset_name}:")
        print(f"Accuracy: {acc}, Precision: {precision}, Recall: {recall}, F1-score: {f1}\n")

# Example function to plot comparison
def plot_comparison(results, metric='accuracy'):
    fig, ax = plt.subplots(figsize=(12, 8))
    for (model_name, dataset_name), metrics in results.items():
        ax.bar(f"{model_name}-{dataset_name}", metrics[metric])

    ax.set_xlabel('Model-Dataset')
    ax.set_ylabel(metric.capitalize())
    ax.set_title(f'Model Performance Comparison by {metric.capitalize()}')
    plt.xticks(rotation=45, ha="right")
    plt.show()

# Plot the comparison for accuracy
plot_comparison(results, metric='accuracy')
