In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.models as models
import random
import numpy as np
from sklearn.metrics import roc_auc_score

from EnsembleBench.frameworks.pytorchUtility import (
    plurality_voting,
    majority_voting

)

You can download the weight files for both MNIST and FashionMNIST from the following Google Drive Folder. https://drive.google.com/drive/folders/1OZfwj9iVruzta9VAhNzVqEgh-jOXKHhb?usp=sharing

In [None]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed) if torch.cuda.is_available() else None
random.seed(seed)
np.random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data = 'fashMNIST'
voting = 'soft'

In [None]:
model_names = ['resnet18', 'resnet34']


def load_model(model_name):
    model = getattr(models, model_name)(pretrained=True)
    if 'resnet' in model_name or 'resnext' in model_name:
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 10)
    elif 'vgg' in model_name or 'alexnet' in model_name:
        model.classifier[6] = nn.Linear(4096, 10)
    elif 'shufflenet' in model_name:
        model.fc = nn.Linear(1024, 10)
    elif 'mnasnet' in model_name:
        model.classifier[1] = nn.Linear(1280, 10)
    elif 'densenet' in model_name:
        num_features = model.classifier.in_features
        model.classifier = nn.Linear(num_features, 10)
    elif 'squeezenet' in model_name:
        model.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1), stride=(1,1))
        model.num_classes = 10
    elif 'mobilenet' in model_name:
        num_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_features, 10)
    elif 'googlenet' in model_name or 'inception' in model_name:
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 10)
    elif 'efficientnet' in model_name:
        num_features = model._fc.in_features
        model._fc = nn.Linear(num_features, 10)
    elif 'convnext' in model_name:
        num_features = model.classifier[2].in_features
        model.classifier[2] = nn.Linear(num_features, 10)
        
    return model.to(device)

In [None]:
transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])

if data == 'fashMNIST':
    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
elif data == 'MNIST':
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
else:
    print("data not selected")

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
model_paths = {
    'resnet18': r'./resnet18_best_model.pth',
    'resnet34': r'./resnet34_best_model.pth'
}

models_ensemble = {}
for model_name in model_names:
    model = load_model(model_name)
    model.load_state_dict(torch.load(model_paths[model_name]))
    model.eval()  
    models_ensemble[model_name] = model

In [None]:
correct = 0
total = 0
all_labels = []
all_predictions = []
all_scores = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        batch_predictions = []
        batch_scores = []
        for model_name in model_names:
            model = models_ensemble[model_name]
            outputs = model(inputs)
            batch_predictions.append(outputs.cpu().numpy())
            batch_scores.append(torch.softmax(outputs, dim=1).cpu().numpy())

        predictionVectorsStack = torch.stack([torch.tensor(pred) for pred in batch_predictions])
        scoresVectorsStack = torch.stack([torch.tensor(score) for score in batch_scores])

        if voting == 'soft':
            final_predictions = torch.argmax(scoresVectorsStack.mean(dim=0), dim=1).to(device)
        elif voting == 'plural':
            final_predictions = torch.tensor(plurality_voting(predictionVectorsStack)).to(device)
        elif voting == 'major':
            final_predictions = torch.tensor(majority_voting(predictionVectorsStack)).to(device)
        else:
            print('voting method not selected')

        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(final_predictions.cpu().numpy())
        all_scores.extend(scoresVectorsStack.mean(dim=0).cpu().numpy())

        correct += final_predictions.eq(labels).sum().item()
        total += labels.size(0)

test_accuracy = 100 * correct / total
print(f'Ensemble Test Accuracy with Plurality Voting: {test_accuracy:.2f}%')

all_labels = np.array(all_labels)
all_scores = np.array(all_scores)
if len(np.unique(all_labels)) == 2:
    auc_score = roc_auc_score(all_labels, all_scores[:, 1])
else:
    auc_score = roc_auc_score(all_labels, all_scores, multi_class='ovr')
print(f'Ensemble AUC Score: {auc_score:.4f}')