In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time
from PIL import Image
import numpy as np
import random
import os

In [2]:
def set_seed(seed=42):
    """
    Set random seeds for reproducibility across multiple libraries.
    
    Args
    ----
    seed (int): 
        Seed value to use
    """
    # Set PyTorch seed
    torch.manual_seed(seed)
    
    # Set CUDA seed (if available)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # for multi-GPU setups
        
    
    # Set NumPy seed
    np.random.seed(seed)
    
    # Set Python's random seed
    random.seed(seed)
    
    # Set environment variable for some PyTorch operations
    os.environ['PYTHONHASHSEED'] = str(seed)

# Example usage
set_seed(42)  # or any other seed value

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the model architecture in a function that returns a fresh DenseNet instance
def create_densenet(growth_rate=12, block_config=(6, 12, 8), 
                    num_init_features=32, bn_size=4, drop_rate=0.3, num_classes=10):
    
    # Inner classes for this specific model instance
    class DenseLayer(nn.Module):
        def __init__(self, in_channels, growth_rate, bn_size, drop_rate):
            super(DenseLayer, self).__init__()
            # BN-ReLU-Conv(1x1)
            self.bn1 = nn.BatchNorm2d(in_channels)
            self.relu1 = nn.ReLU(inplace=True)
            self.conv1 = nn.Conv2d(in_channels, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
            
            # BN-ReLU-Conv(3x3)
            self.bn2 = nn.BatchNorm2d(bn_size * growth_rate)
            self.relu2 = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
            
            self.drop_rate = drop_rate
            
        def forward(self, x):
            new_features = self.conv1(self.relu1(self.bn1(x)))
            new_features = self.conv2(self.relu2(self.bn2(new_features)))
            if self.drop_rate > 0:
                new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
            return torch.cat([x, new_features], 1)

    class DenseBlock(nn.Module):
        def __init__(self, num_layers, in_channels, growth_rate, bn_size, drop_rate):
            super(DenseBlock, self).__init__()
            self.layers = nn.ModuleList()
            for i in range(num_layers):
                self.layers.add_module('denselayer%d' % (i + 1),
                                      DenseLayer(in_channels + i * growth_rate, growth_rate, bn_size, drop_rate))
                
        def forward(self, x):
            features = x
            for layer in self.layers:
                features = layer(features)
            return features

    class Transition(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(Transition, self).__init__()
            self.bn = nn.BatchNorm2d(in_channels)
            self.relu = nn.ReLU(inplace=True)
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False)
            self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
            
        def forward(self, x):
            x = self.bn(x)
            x = self.relu(x)
            x = self.conv(x)
            x = self.pool(x)
            return x

    class DenseNet(nn.Module):
        def __init__(self):
            super(DenseNet, self).__init__()
            
            # First convolution
            self.features = nn.Sequential(
                nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False),
                nn.BatchNorm2d(num_init_features),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=2, stride=2)
            )
            
            # Each denseblock
            num_features = num_init_features
            for i, num_layers in enumerate(block_config):
                # Add a dense block
                block = DenseBlock(
                    num_layers=num_layers,
                    in_channels=num_features,
                    growth_rate=growth_rate,
                    bn_size=bn_size,
                    drop_rate=drop_rate
                )
                self.features.add_module('denseblock%d' % (i + 1), block)
                num_features = num_features + num_layers * growth_rate
                
                # Add a transition layer between dense blocks (except after the last block)
                if i != len(block_config) - 1:
                    trans = Transition(in_channels=num_features, out_channels=num_features // 2)
                    self.features.add_module('transition%d' % (i + 1), trans)
                    num_features = num_features // 2
            
            # Final batch norm
            self.features.add_module('norm5', nn.BatchNorm2d(num_features))
            self.features.add_module('relu5', nn.ReLU(inplace=True))
            
            # Linear layer
            self.classifier = nn.Linear(num_features, num_classes)
            
            # Initialize weights
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    nn.init.constant_(m.bias, 0)
        
        def forward(self, x):
            features = self.features(x)
            out = F.adaptive_avg_pool2d(features, (1, 1))
            out = torch.flatten(out, 1)
            out = self.classifier(out)
            return out
    
    # Create and return a new model instance
    return DenseNet()


# 
# model2 = create_densenet(growth_rate=12, block_config=(6, 12, 8), num_classes=10)
# state_dict2 = torch.load("../DenseNET/model_id_2.pth")
# model2.load_state_dict(state_dict2)
# model2.eval()
# 
# model3 = create_densenet(growth_rate=12, block_config=(6, 12, 8), num_classes=10)
# state_dict3 = torch.load("../DenseNET/model_id_3.pth")
# model3.load_state_dict(state_dict3)
# model3.eval()

In [15]:
model_rotation = create_densenet(growth_rate=12, block_config=(6, 12, 8), num_classes=10)
state_dict1 = torch.load("../basic_CNN/model_id_12_DenseNET.pth")
model_rotation.load_state_dict(state_dict1)
model_rotation.eval()
model_rotation = model_rotation.to(device)

model_flip = create_densenet(growth_rate=12, block_config=(6, 12, 8), num_classes=10)
state_dict2 = torch.load("../DenseNET/model_id_15.pth")
model_flip.load_state_dict(state_dict2)
model_flip.eval()
model_flip = model_flip.to(device)

model_brightness_contrast = create_densenet(growth_rate=12, block_config=(6, 12, 8), num_classes=10)
state_dict3 = torch.load("../DenseNET/model_id_14.pth")
model_brightness_contrast.load_state_dict(state_dict3)
model_brightness_contrast.eval()
model_brightness_contrast = model_brightness_contrast.to(device)

In [13]:
image_path = "../../data/raw/train/airplane/cifar10-train-10008.png"
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

criterion = nn.CrossEntropyLoss()
trainset_raw = torchvision.datasets.ImageFolder('../../data/raw/train/', transform=preprocess)
trainloader_raw = torch.utils.data.DataLoader(trainset_raw, batch_size=32,
                                          shuffle=True, num_workers=2)
valset = torchvision.datasets.ImageFolder('../../data/raw/valid/', transform=preprocess)
valloader = torch.utils.data.DataLoader(valset, batch_size=32, shuffle=False, num_workers=2)

testset = torchvision.datasets.ImageFolder('../../data/raw/test/', transform=preprocess)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)

In [16]:
val_error = 0
correct = 0
with torch.no_grad():
    model_brightness_contrast.eval()
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_brightness_contrast(images)
        val_error = val_error + criterion(outputs, labels) * images.size(0)
        correct += (torch.argmax(outputs, 1) == labels).float().sum().item()
        val_error = val_error / len(valloader.dataset)
    print(f'epoch NONE TEST error: {val_error}, acc: {correct/len(valloader.dataset)}')

epoch NONE TEST error: 0.00010590700549073517, acc: 0.7426222222222222


In [17]:
val_error = 0
correct = 0
with torch.no_grad():
    model_rotation.eval()
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_rotation(images)
        val_error = val_error + criterion(outputs, labels) * images.size(0)
        correct += (torch.argmax(outputs, 1) == labels).float().sum().item()
        val_error = val_error / len(valloader.dataset)
    print(f'epoch NONE TEST error: {val_error}, acc: {correct/len(valloader.dataset)}')

epoch NONE TEST error: 4.198949318379164e-05, acc: 0.7332888888888889


In [18]:
val_error = 0
correct = 0
with torch.no_grad():
    model_flip.eval()
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_flip(images)
        val_error = val_error + criterion(outputs, labels) * images.size(0)
        correct += (torch.argmax(outputs, 1) == labels).float().sum().item()
        val_error = val_error / len(valloader.dataset)
    print(f'epoch NONE TEST error: {val_error}, acc: {correct/len(valloader.dataset)}')

epoch NONE TEST error: 6.321733962977305e-05, acc: 0.7505222222222222


# Soft voting

In [19]:
def ensemble_predict(image, model1, model2, model3):
    # Get predictions from both models
    with torch.no_grad():
        output1 = model1(image)
        output2 = model2(image)
        output3 = model3(image)
    
    # For classification (voting)
    if output1.shape[1] > 1:  # Multi-class
        # Average the probabilities
        avg_output = (output1 + output2 + output3) / 3
        # Or use hard voting
        # pred1 = output1.argmax(dim=1)
        # pred2 = output2.argmax(dim=1)
        # Use most common prediction
    
    # For regression
    else:
        avg_output = (output1 + output2 + output3) / 3
        
    return avg_output

def ensemble_predict_two_models(image, model1, model2):
    # Get predictions from both models
    with torch.no_grad():
        output1 = model1(image)
        output2 = model2(image)
    
    # For classification (voting)
    if output1.shape[1] > 1:  # Multi-class
        # Average the probabilities
        avg_output = (output1 + output2) / 2
        # Or use hard voting
        # pred1 = output1.argmax(dim=1)
        # pred2 = output2.argmax(dim=1)
        # Use most common prediction
    
    # For regression
    else:
        avg_output = (output1 + output2) / 2
        
    return avg_output

In [20]:
val_error = 0
correct = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = ensemble_predict(images, model_brightness_contrast, model_rotation, model_flip)
        val_error = val_error + criterion(outputs, labels) * images.size(0)
        correct += (torch.argmax(outputs, 1) == labels).float().sum().item()
        val_error = val_error / len(valloader.dataset)
    print(f'epoch NONE TEST error: {val_error}, acc: {correct/len(valloader.dataset)}')

epoch NONE TEST error: 5.0692764489213005e-05, acc: 0.7813777777777777


In [21]:
val_error = 0
correct = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = ensemble_predict_two_models(images, model_brightness_contrast, model_rotation)
        val_error = val_error + criterion(outputs, labels) * images.size(0)
        correct += (torch.argmax(outputs, 1) == labels).float().sum().item()
        val_error = val_error / len(valloader.dataset)
    print(f'epoch NONE TEST error: {val_error}, acc: {correct/len(valloader.dataset)}')

epoch NONE TEST error: 5.498751124832779e-05, acc: 0.7667555555555555


In [22]:
val_error = 0
correct = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = ensemble_predict_two_models(images, model_brightness_contrast, model_flip)
        val_error = val_error + criterion(outputs, labels) * images.size(0)
        correct += (torch.argmax(outputs, 1) == labels).float().sum().item()
        val_error = val_error / len(valloader.dataset)
    print(f'epoch NONE TEST error: {val_error}, acc: {correct/len(valloader.dataset)}')

epoch NONE TEST error: 7.193325291154906e-05, acc: 0.7772111111111111


In [23]:
val_error = 0
correct = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = ensemble_predict_two_models(images, model_rotation, model_flip)
        val_error = val_error + criterion(outputs, labels) * images.size(0)
        correct += (torch.argmax(outputs, 1) == labels).float().sum().item()
        val_error = val_error / len(valloader.dataset)
    print(f'epoch NONE TEST error: {val_error}, acc: {correct/len(valloader.dataset)}')

epoch NONE TEST error: 4.0082122723106295e-05, acc: 0.7670333333333333


# Confusion matrix

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
import os
import glob
from torch.utils.data import DataLoader

def accuracy(y_true: List[int], y_pred: List[int]) -> float:
    """Calculate the accuracy of predictions."""
    correct = 0
    for true, pred in zip(y_true, y_pred):
        if true == pred:
            correct += 1
    return correct / len(y_true)

def confusion_matrix(y_true: List[int], y_pred: List[int]):
    """Calculate confusion matrix."""
    n_classes = max(max(y_true), max(y_pred)) + 1
    confusion = np.zeros((n_classes, n_classes))

    for true, pred in zip(y_true, y_pred):
        confusion[true, pred] += 1

    return confusion

def micro_f1_score(y_true: List[int], y_pred: List[int]) -> float:
    """Calculate micro-averaged F1 score."""
    confusion = confusion_matrix(y_true, y_pred)

    tp = np.sum(np.diag(confusion))
    fp = np.sum(confusion, axis=0) - np.diag(confusion)
    fn = np.sum(confusion, axis=1) - np.diag(confusion)

    total_tp = tp
    total_fp = np.sum(fp)
    total_fn = np.sum(fn)

    precision = total_tp / (total_tp + total_fp) if total_tp + total_fp > 0 else 0.0
    recall = total_tp / (total_tp + total_fn) if total_tp + total_fn > 0 else 0.0

    if precision + recall == 0:
        return 0.0

    return 2 * (precision * recall) / (precision + recall)

def macro_f1_score(y_true: List[int], y_pred: List[int]) -> float:
    """Calculate macro-averaged F1 score."""
    confusion = confusion_matrix(y_true, y_pred)

    tp = np.diag(confusion)
    fp = np.sum(confusion, axis=0) - tp
    fn = np.sum(confusion, axis=1) - tp

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)

    precision = np.nan_to_num(precision)
    recall = np.nan_to_num(recall)

    f1 = 2 * (precision * recall) / (precision + recall + 1e-10)  # Add epsilon to prevent division by zero
    return np.mean(f1)

def weighted_f1_score(y_true: List[int], y_pred: List[int]) -> float:
    """Calculate weighted F1 score."""
    confusion = confusion_matrix(y_true, y_pred)

    tp = np.diag(confusion)
    fp = np.sum(confusion, axis=0) - tp
    fn = np.sum(confusion, axis=1) - tp

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)

    precision = np.nan_to_num(precision)
    recall = np.nan_to_num(recall)

    weights = np.sum(confusion, axis=1) / np.sum(confusion)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-10)  # Add epsilon to prevent division by zero

    return np.sum(weights * f1)

def f1_score(y_true: List[int], y_pred: List[int], average: str = "macro") -> float:
    """Calculate F1 score with different averaging methods."""
    if average == "micro":
        return micro_f1_score(y_true, y_pred)
    elif average == "macro":
        return macro_f1_score(y_true, y_pred)
    elif average == "weighted":
        return weighted_f1_score(y_true, y_pred)
    else:
        raise ValueError("Invalid average type. Choose from 'micro', 'macro', or 'weighted'")

In [None]:
def predict_with_model(model, dataloader, device):
    """
    Get predictions from a PyTorch model using a dataloader
    
    Args:
        model: The PyTorch model to use for predictions
        dataloader: DataLoader containing the test data
        device: Device to run inference on ('cuda' or 'cpu')
        
    Returns:
        y_true: List of true labels
        y_pred: List of predicted labels
    """
    model.eval()  # Set model to evaluation mode
    y_true = []
    y_pred = []
    
    with torch.no_grad():  # No need to track gradients
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            # Add batch results to our lists
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    
    return y_true, y_pred

def plot_confusion_matrix(cm, classes, title='Confusion Matrix', filename='confusion_matrix.png'):
    """
    Plot confusion matrix heatmap and save it to a file.
    
    Parameters:
    -----------
    cm : array-like
        Confusion matrix array
    classes : list
        List of class names
    title : str, default='Confusion Matrix'
        Title of the plot
    filename : str, default='confusion_matrix.png'
        Filename to save the plot
    """
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    # Add text annotations in the cells
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            # Format as integer if the value is an integer, otherwise use float format
            if cm[i, j] == int(cm[i, j]):
                plt.text(j, i, format(int(cm[i, j])),
                        ha="center", va="center",
                        color="white" if cm[i, j] > thresh else "black")
            else:
                plt.text(j, i, format(cm[i, j], '.1f'),
                        ha="center", va="center",
                        color="white" if cm[i, j] > thresh else "black")
    
    # Improve layout and save the figure
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()

def evaluate_model(model, dataloader, device, classes, model_name="Model"):
    """
    Evaluate a model and print performance metrics
    
    Args:
        model: The PyTorch model to evaluate
        dataloader: DataLoader containing the test data
        device: Device to run inference on ('cuda' or 'cpu')
        classes: List of class names
        model_name: Name to display in the output
    """
    y_true, y_pred = predict_with_model(model, dataloader, device)
    
    print(f"=== {model_name} Evaluation ===")
    print(f"Accuracy: {accuracy(y_true, y_pred):.4f}")
    print(f"Micro F1 Score: {micro_f1_score(y_true, y_pred):.4f}")
    print(f"Macro F1 Score: {macro_f1_score(y_true, y_pred):.4f}")
    print(f"Weighted F1 Score: {weighted_f1_score(y_true, y_pred):.4f}")
    
    cm = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm, classes, title=f'{model_name} Confusion Matrix', filename=f'{model_name} Confusion Matrix')
    
    return {
        'accuracy': accuracy(y_true, y_pred),
        'micro_f1': micro_f1_score(y_true, y_pred),
        'macro_f1': macro_f1_score(y_true, y_pred),
        'weighted_f1': weighted_f1_score(y_true, y_pred),
        'confusion_matrix': cm,
        'y_true': y_true,
        'y_pred': y_pred
    }

def compare_models(models, model_names, dataloader, device, classes):
    """
    Compare multiple models side by side
    
    Args:
        models: List of PyTorch models to compare
        model_names: List of model names (for display)
        dataloader: DataLoader containing the test data
        device: Device to run inference on ('cuda' or 'cpu')
        classes: List of class names
    """
    results = []
    
    for model, name in zip(models, model_names):
        result = evaluate_model(model, dataloader, device, classes, name)
        results.append(result)
    
    # Create a comparison table
    print("\n=== Model Comparison ===")
    metrics = ['accuracy', 'micro_f1', 'macro_f1', 'weighted_f1']
    
    # Print header
    header = "Metric"
    for name in model_names:
        header += f" | {name}"
    print(header)
    print("-" * len(header))
    
    # Print each metric
    for metric in metrics:
        row = metric
        for result in results:
            row += f" | {result[metric]:.4f}"
        print(row)
    
    return results


# Define class names (replace with your actual class names)
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

# Evaluate individual models
evaluate_model(model_brightness_contrast, testloader, device, class_names, "model_brightness_contrast")
evaluate_model(model_rotation, testloader, device, class_names, "model_rotation")
evaluate_model(model_flip, testloader, device, class_names, "model_flip")

=== model_brightness_contrast Evaluation ===
Accuracy: 0.7426
Micro F1 Score: 0.7426
Macro F1 Score: 0.7402
Weighted F1 Score: 0.7402
=== model_rotation Evaluation ===
Accuracy: 0.7333
Micro F1 Score: 0.7333
Macro F1 Score: 0.7316
Weighted F1 Score: 0.7316
=== model_flip Evaluation ===
Accuracy: 0.7505
Micro F1 Score: 0.7505
Macro F1 Score: 0.7505
Weighted F1 Score: 0.7505


{'accuracy': 0.7505222222222222,
 'micro_f1': np.float64(0.7505222222222222),
 'macro_f1': np.float64(0.7504965865073746),
 'weighted_f1': np.float64(0.7504965865073746),
 'confusion_matrix': array([[6.817e+03, 9.100e+01, 8.910e+02, 6.800e+01, 1.340e+02, 5.800e+01,
         6.700e+01, 8.600e+01, 6.830e+02, 1.050e+02],
        [1.470e+02, 7.128e+03, 6.900e+01, 4.900e+01, 3.700e+01, 7.300e+01,
         8.700e+01, 5.800e+01, 2.680e+02, 1.084e+03],
        [1.150e+02, 2.400e+01, 6.876e+03, 2.560e+02, 4.490e+02, 4.050e+02,
         6.230e+02, 1.140e+02, 1.200e+02, 1.800e+01],
        [3.000e+01, 2.300e+01, 4.770e+02, 5.402e+03, 6.010e+02, 1.232e+03,
         9.890e+02, 1.260e+02, 8.500e+01, 3.500e+01],
        [3.200e+01, 2.000e+01, 4.650e+02, 3.780e+02, 6.396e+03, 8.730e+02,
         2.510e+02, 4.400e+02, 1.100e+02, 3.500e+01],
        [2.900e+01, 5.400e+01, 4.810e+02, 1.253e+03, 7.240e+02, 5.592e+03,
         3.700e+02, 3.570e+02, 9.800e+01, 4.200e+01],
        [1.500e+01, 6.000e+00, 4.37

In [36]:
image_path = "../../data/raw/train/truck/cifar10-train-10138.png"
image = Image.open(image_path).convert('RGB')
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)  # Add batch dimension

input_batch = input_batch.to(device)

# Make prediction
with torch.no_grad():  # Disable gradient calculation
    output = ensemble_predict(input_batch, model_brightness_contrast, model_rotation, model_flip)

# Process the output based on your model type
# For classification:
_, predicted_class = torch.max(output, 1)
print(f"Predicted class: {predicted_class.item()}")

Predicted class: 9


# Evaluating ensemble

In [40]:
def predict_with_model(ensemble_func, model1, model2, model3, dataloader, device):
    """
    Get predictions from a PyTorch model using a dataloader
    
    Args:
        model: The PyTorch model to use for predictions
        dataloader: DataLoader containing the test data
        device: Device to run inference on ('cuda' or 'cpu')
        
    Returns:
        y_true: List of true labels
        y_pred: List of predicted labels
    """
    y_true = []
    y_pred = []
    
    with torch.no_grad():  # No need to track gradients
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = ensemble_func(images, model1, model2, model3)
            _, predicted = torch.max(outputs, 1)
            
            # Add batch results to our lists
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    
    return y_true, y_pred

def plot_confusion_matrix(cm, classes, title='Confusion Matrix', filename='confusion_matrix.png'):
    """
    Plot confusion matrix heatmap and save it to a file.
    
    Parameters:
    -----------
    cm : array-like
        Confusion matrix array
    classes : list
        List of class names
    title : str, default='Confusion Matrix'
        Title of the plot
    filename : str, default='confusion_matrix.png'
        Filename to save the plot
    """
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    
    # Add text annotations in the cells
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            # Format as integer if the value is an integer, otherwise use float format
            if cm[i, j] == int(cm[i, j]):
                plt.text(j, i, format(int(cm[i, j])),
                        ha="center", va="center",
                        color="white" if cm[i, j] > thresh else "black")
            else:
                plt.text(j, i, format(cm[i, j], '.1f'),
                        ha="center", va="center",
                        color="white" if cm[i, j] > thresh else "black")
    
    # Improve layout and save the figure
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()

def evaluate_model_ensemble(ensemble_func, model1, model2, model3, dataloader, device, classes, model_name="Model"):
    """
    Evaluate a model and print performance metrics
    
    Args:
        model: The PyTorch model to evaluate
        dataloader: DataLoader containing the test data
        device: Device to run inference on ('cuda' or 'cpu')
        classes: List of class names
        model_name: Name to display in the output
    """
    y_true, y_pred = predict_with_model(ensemble_func, model1, model2, model3, dataloader, device)
    
    print(f"=== {model_name} Evaluation ===")
    print(f"Accuracy: {accuracy(y_true, y_pred):.4f}")
    print(f"Micro F1 Score: {micro_f1_score(y_true, y_pred):.4f}")
    print(f"Macro F1 Score: {macro_f1_score(y_true, y_pred):.4f}")
    print(f"Weighted F1 Score: {weighted_f1_score(y_true, y_pred):.4f}")
    
    cm = confusion_matrix(y_true, y_pred)
    plot_confusion_matrix(cm, classes, title=f'{model_name} Confusion Matrix', filename=f'{model_name} Confusion Matrix')
    
    return {
        'accuracy': accuracy(y_true, y_pred),
        'micro_f1': micro_f1_score(y_true, y_pred),
        'macro_f1': macro_f1_score(y_true, y_pred),
        'weighted_f1': weighted_f1_score(y_true, y_pred),
        'confusion_matrix': cm,
        'y_true': y_true,
        'y_pred': y_pred
    }

evaluate_model_ensemble(ensemble_predict, model_brightness_contrast, model_rotation, model_flip, testloader, device, class_names, "Ensemble model")

=== Ensemble model Evaluation ===
Accuracy: 0.7814
Micro F1 Score: 0.7814
Macro F1 Score: 0.7804
Weighted F1 Score: 0.7804


{'accuracy': 0.7813777777777777,
 'micro_f1': np.float64(0.7813777777777777),
 'macro_f1': np.float64(0.7803803813463441),
 'weighted_f1': np.float64(0.780380381346344),
 'confusion_matrix': array([[7.697e+03, 9.800e+01, 2.630e+02, 5.400e+01, 7.100e+01, 4.200e+01,
         5.100e+01, 9.100e+01, 5.640e+02, 6.900e+01],
        [1.350e+02, 7.515e+03, 3.300e+01, 3.100e+01, 2.900e+01, 5.500e+01,
         3.500e+01, 7.800e+01, 2.580e+02, 8.310e+02],
        [2.530e+02, 2.200e+01, 6.877e+03, 2.430e+02, 4.280e+02, 4.110e+02,
         4.980e+02, 9.900e+01, 1.570e+02, 1.200e+01],
        [6.000e+01, 3.000e+01, 4.190e+02, 5.671e+03, 5.570e+02, 1.356e+03,
         6.560e+02, 1.370e+02, 8.800e+01, 2.600e+01],
        [6.300e+01, 2.000e+01, 3.590e+02, 3.540e+02, 6.657e+03, 7.130e+02,
         1.840e+02, 5.170e+02, 1.080e+02, 2.500e+01],
        [4.500e+01, 5.000e+01, 3.590e+02, 1.120e+03, 7.400e+02, 5.851e+03,
         2.450e+02, 4.750e+02, 9.300e+01, 2.200e+01],
        [2.200e+01, 1.200e+01, 4.170

In [None]:
# Compare all models
compare_models(
    [model1, model2, model3],
    ["Model 1", "Model 2", "Model 3"],
    testloader,
    device,
    class_names
)