In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18, mobilenet_v2, vit_b_16
from tqdm.auto import tqdm
import time
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import gc
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, classification_report

In [None]:
# Set environment variable for memory optimization
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Configure reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Function to prepare model
def prepare_model(model_name, num_classes, dataset_name):
    """
    Prepare model with the correct number of classes
    
    Args:
        model_name: model name ('resnet', 'mobilenet', 'vit')
        num_classes: number of classes
        
    Returns:
        model: prepared model
    """
    if model_name == 'resnet':
        model = resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
        model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)  # For single-channel images
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        
    elif model_name == 'mobilenet':
        model = mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.DEFAULT)
        model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)  # For single-channel images
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        
    elif model_name == 'vit':
        model = vit_b_16(weights=torchvision.models.ViT_B_16_Weights.DEFAULT)
        model.conv_proj = nn.Conv2d(1, 768, kernel_size=16, stride=16)  # For single-channel images
        model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
        
        
    else:
        raise ValueError(f"Unknown model: {model_name}")
    
    # For SVHN (RGB), we need to change the first layer
    if dataset_name == 'svhn':  # Explicitly check for SVHN dataset
        if model_name == 'resnet':
            model.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        elif model_name == 'mobilenet':
            model.features[0][0] = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        elif model_name == 'vit':
            model.conv_proj = nn.Conv2d(3, 768, kernel_size=16, stride=16)
    
    return model.to(device)

In [None]:
# Function to prepare datasets
def prepare_datasets(dataset_name, batch_size=32):  # Reduced batch size
    """
    Prepare datasets for training and evaluation
    
    Args:
        dataset_name: dataset name ('emnist', 'kmnist', 'svhn')
        batch_size: batch size
        
    Returns:
        train_loader: training data loader
        val_loader: validation data loader
        test_loader: test data loader
        num_classes: number of classes in the dataset
    """
    if dataset_name == 'emnist':
        # EMNIST dataset (letters)
        transform = transforms.Compose([
            transforms.Resize((224, 224)),  # Resize all images to 224x224
            transforms.ToTensor(),
            transforms.Normalize((0.1751,), (0.3331,))  # Normalization for EMNIST
        ])
        
        train_dataset = torchvision.datasets.EMNIST(
            root='./data', split='letters', train=True, download=True, transform=transform
        )
        test_dataset = torchvision.datasets.EMNIST(
            root='./data', split='letters', train=False, download=True, transform=transform
        )

        train_dataset.targets -= 1
        test_dataset.targets -= 1
        
        # Split training dataset into training and validation
        train_size = int(0.8 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
        
        num_classes = 26  # A-Z (26 classes)
        
    elif dataset_name == 'kmnist':
        # KMNIST dataset (Kuzushiji-MNIST)
        transform = transforms.Compose([
            transforms.Resize((224, 224)),  # Resize all images to 224x224
            transforms.ToTensor(),
            transforms.Normalize((0.1918,), (0.3483,))  # Normalization values for KMNIST
        ])
        
        # Load KMNIST
        train_dataset = torchvision.datasets.KMNIST(
            root='./data', train=True, download=True, transform=transform
        )
        test_dataset = torchvision.datasets.KMNIST(
            root='./data', train=False, download=True, transform=transform
        )
        
        # Split training dataset into training and validation
        train_size = int(0.8 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
        
        num_classes = 10  # KMNIST has 10 classes
        
    elif dataset_name == 'svhn':
        # SVHN dataset
        transform = transforms.Compose([
            transforms.Resize((224, 224)),  # Resize all images to 224x224
            transforms.ToTensor(),
            transforms.Normalize((0.4380, 0.4440, 0.4730), (0.1751, 0.1771, 0.1744))  # Normalization for SVHN
        ])
        
        train_dataset = torchvision.datasets.SVHN(
            root='./data', split='train', download=True, transform=transform
        )
        test_dataset = torchvision.datasets.SVHN(
            root='./data', split='test', download=True, transform=transform
        )
        
        # Split training dataset into training and validation
        train_size = int(0.8 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
        
        num_classes = 10  # Digits 0-9
        
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    # Prepare data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return train_loader, val_loader, test_loader, num_classes

In [None]:
# Function to calculate metrics (accuracy, precision, recall, f1-score)
def calculate_metrics(y_true, y_pred, num_classes):
    """
    Calculate evaluation metrics
    
    Args:
        y_true: true labels
        y_pred: predicted labels
        num_classes: number of classes
        
    Returns:
        metrics: dictionary with metrics (accuracy, precision, recall, f1-score)
    """
    # Convert tensors to numpy arrays if needed
    if isinstance(y_true, torch.Tensor):
        y_true = y_true.cpu().numpy()
    if isinstance(y_pred, torch.Tensor):
        y_pred = y_pred.cpu().numpy()
    
    # Calculate metrics
    metrics = {}
    
    # Accuracy
    metrics['accuracy'] = (y_true == y_pred).mean() * 100
    
    # Precision, Recall, and F1-Score (with different averaging methods)
    metrics['precision_micro'] = precision_score(y_true, y_pred, average='micro', zero_division=0) * 100
    metrics['recall_micro'] = recall_score(y_true, y_pred, average='micro', zero_division=0) * 100
    metrics['f1_micro'] = f1_score(y_true, y_pred, average='micro', zero_division=0) * 100
    
    metrics['precision_macro'] = precision_score(y_true, y_pred, average='macro', zero_division=0) * 100
    metrics['recall_macro'] = recall_score(y_true, y_pred, average='macro', zero_division=0) * 100
    metrics['f1_macro'] = f1_score(y_true, y_pred, average='macro', zero_division=0) * 100
    
    metrics['precision_weighted'] = precision_score(y_true, y_pred, average='weighted', zero_division=0) * 100
    metrics['recall_weighted'] = recall_score(y_true, y_pred, average='weighted', zero_division=0) * 100
    metrics['f1_weighted'] = f1_score(y_true, y_pred, average='weighted', zero_division=0) * 100
    
    # Get per class metrics
    per_class_metrics = classification_report(y_true, y_pred, output_dict=True)
    
    # Store confusion matrix
    metrics['confusion_matrix'] = confusion_matrix(y_true, y_pred)
    
    return metrics, per_class_metrics

In [None]:
# Function to count trainable parameters in a model
def count_trainable_parameters(model):
    """
    Count the number of trainable parameters in a model

    Args:
        model: PyTorch model

    Returns:
        total_params: number of trainable parameters
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
# Function to compute model size in megabytes
def get_model_size_mb(model):
    """
    Calculate the memory footprint of the model in megabytes (MB)

    Args:
        model: PyTorch model

    Returns:
        size_mb: float, total size of model parameters and buffers in MB
    """
    param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
    size_mb = (param_size + buffer_size) / 1024 ** 2
    return size_mb

In [None]:
# Function for training and evaluating the model with memory optimization
def train_and_evaluate(model, train_loader, val_loader, test_loader, epochs=5, lr=0.001):
    """
    Train and evaluate the model
    
    Args:
        model: model to train
        train_loader: training data loader
        val_loader: validation data loader
        test_loader: test data loader
        epochs: number of epochs
        lr: learning rate
        
    Returns:
        results: dictionary with results (loss, accuracy, time, precision, recall, f1)
    """
    
    # Log model statistics
    print(f"\nModel statistics:")
    print(f"- Trainable parameters: {count_trainable_parameters(model):,}")
    print(f"- Model size: {get_model_size_mb(model):.2f} MB")

    if device.type == 'cuda':
        torch.cuda.reset_peak_memory_stats()

    # Setup loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Get number of classes
    try:
        num_classes = model.fc.out_features  # For ResNet
    except:
        try:
            num_classes = model.classifier[1].out_features  # For MobileNet
        except:
            try:
                num_classes = model.heads.head.out_features  # For ViT
            except:
                num_classes = model.head.out_features
    
    # Track results
    results = {
        'train_loss': [],
        'val_loss': [],
        'val_acc': [],
        'val_precision': [],
        'val_recall': [],
        'val_f1': [],
        'test_loss': None,
        'test_acc': None,
        'test_precision': None,
        'test_recall': None,
        'test_f1': None,
        'training_time': 0,
        'class_metrics': None,
        'confusion_matrix': None
    }
    
    # Start training time
    start_time = time.time()
    
    # Training loop
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]") as train_bar:
            for inputs, labels in train_bar:
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Zero gradients
                optimizer.zero_grad()
                
                # Forward pass
                outputs = model(inputs)
                
                # Calculate loss
                loss = criterion(outputs, labels)
                
                # Backward pass and optimization
                loss.backward()
                optimizer.step()
                
                # Update statistics
                train_loss += loss.item()
                train_bar.set_postfix(loss=train_loss/len(train_bar))
                
                # Clear memory
                del inputs, labels, outputs, loss
                if device.type == 'cuda':
                    torch.cuda.empty_cache()
        
        
        if device.type == 'cuda':
            current_mem = torch.cuda.memory_allocated() / 1024**2
            peak_mem = torch.cuda.max_memory_allocated() / 1024**2
            print(f"Epoch {epoch+1} Train VRAM usage: Current {current_mem:.2f} MB, Peak {peak_mem:.2f} MB")
            torch.cuda.reset_peak_memory_stats()
        else:
            mem_cpu = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2
            print(f"Epoch {epoch+1} Train CPU RAM usage: {mem_cpu:.2f} MB")

        # Calculate average loss on training set
        avg_train_loss = train_loss / len(train_loader)
        results['train_loss'].append(avg_train_loss)
        
        # Validation
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_preds = []
        
        with torch.no_grad():
            with tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]") as val_bar:
                for inputs, labels in val_bar:
                    inputs, labels = inputs.to(device), labels.to(device)
                    
                    # Forward pass
                    outputs = model(inputs)
                    
                    # Calculate loss
                    loss = criterion(outputs, labels)
                    
                    # Update statistics
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    
                    # Collect labels and predictions for metric calculation
                    all_labels.extend(labels.cpu().numpy())
                    all_preds.extend(predicted.cpu().numpy())
                    
                    val_bar.set_postfix(loss=val_loss/len(val_bar))
                    
                    # Clear memory
                    del inputs, labels, outputs, loss
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()
        
        
        if device.type == 'cuda':
            current_mem = torch.cuda.memory_allocated() / 1024**2
            peak_mem = torch.cuda.max_memory_allocated() / 1024**2
            print(f"Epoch {epoch+1} Val VRAM usage: Current {current_mem:.2f} MB, Peak {peak_mem:.2f} MB")
            torch.cuda.reset_peak_memory_stats()
        else:
            mem_cpu = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2
            print(f"Epoch {epoch+1} Val CPU RAM usage: {mem_cpu:.2f} MB")

        # Calculate average loss and metrics on validation set
        avg_val_loss = val_loss / len(val_loader)
        
        # Calculate validation metrics
        val_metrics, _ = calculate_metrics(np.array(all_labels), np.array(all_preds), num_classes)
        
        # Store validation results
        results['val_loss'].append(avg_val_loss)
        results['val_acc'].append(val_metrics['accuracy'])
        results['val_precision'].append(val_metrics['precision_weighted'])
        results['val_recall'].append(val_metrics['recall_weighted'])
        results['val_f1'].append(val_metrics['f1_weighted'])
        
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        print(f"Val Metrics: Acc: {val_metrics['accuracy']:.2f}%, Precision: {val_metrics['precision_weighted']:.2f}%, " +
              f"Recall: {val_metrics['recall_weighted']:.2f}%, F1: {val_metrics['f1_weighted']:.2f}%")
        
        # Run garbage collection
        gc.collect()
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    
    # Total training time
    results['training_time'] = time.time() - start_time
    print(f"Training completed in {results['training_time']:.2f} seconds")
    
    # Evaluate on test set
    model.eval()
    test_loss = 0.0
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        with tqdm(test_loader, desc="Testing") as test_bar:
            for inputs, labels in test_bar:
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(inputs)
                
                # Calculate loss
                loss = criterion(outputs, labels)
                
                # Update statistics
                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                
                # Collect labels and predictions for metric calculation
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(predicted.cpu().numpy())
                
                test_bar.set_postfix(loss=test_loss/len(test_bar))
                
                # Clear memory
                del inputs, labels, outputs, loss
                if device.type == 'cuda':
                    torch.cuda.empty_cache()
    
    # Calculate average loss and metrics on test set
    results['test_loss'] = test_loss / len(test_loader)
    
    # Calculate test metrics
    test_metrics, class_metrics = calculate_metrics(np.array(all_labels), np.array(all_preds), num_classes)
    
    # Store test results
    results['test_acc'] = test_metrics['accuracy']
    results['test_precision'] = test_metrics['precision_weighted']
    results['test_recall'] = test_metrics['recall_weighted']
    results['test_f1'] = test_metrics['f1_weighted']
    results['confusion_matrix'] = test_metrics['confusion_matrix']
    results['class_metrics'] = class_metrics
    
    print(f"Test Loss: {results['test_loss']:.4f}")

    print(f"Test Metrics: Acc: {results['test_acc']:.2f}%, Precision: {results['test_precision']:.2f}%, " +
          f"Recall: {results['test_recall']:.2f}%, F1: {results['test_f1']:.2f}%")

    if device.type == 'cuda':
        peak_memory_mb = torch.cuda.max_memory_allocated() / 1024**2
        print(f"Peak VRAM used: {peak_memory_mb:.2f} MB")
        results['peak_vram'] = peak_memory_mb
    else:
        results['peak_vram'] = None

    return results

In [None]:
# Function to compare models on a given dataset
def compare_models_on_dataset(dataset_name, models=['resnet', 'mobilenet', 'vit'], 
                             epochs=3, batch_size=32, lr=0.001):  # Reduced batch size
    """
    Compare models on a given dataset
    
    Args:
        dataset_name: dataset name ('emnist', 'kmnist', 'svhn')
        models: list of models to compare
        epochs: number of epochs
        batch_size: batch size
        lr: learning rate
        
    Returns:
        comparison_results: dictionary with comparison results
    """
    print(f"\n{'='*20} Comparing models on dataset {dataset_name.upper()} {'='*20}\n")
    
    # Prepare datasets
    train_loader, val_loader, test_loader, num_classes = prepare_datasets(dataset_name, batch_size)
    
    # Evaluate each model
    comparison_results = {}
    
    for model_name in models:
        print(f"\n{'-'*10} Model: {model_name.upper()} {'-'*10}")
        
        # Prepare model
        model = prepare_model(model_name, num_classes, dataset_name)
        
        # Train and evaluate
        results = train_and_evaluate(model, train_loader, val_loader, test_loader, epochs, lr)
        
        # Save results
        comparison_results[model_name] = results
        
        # Clear memory
        del model
        gc.collect()
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    
    return comparison_results

In [None]:
# Function to plot confusion matrix
def plot_confusion_matrix(conf_matrix, class_names, title='Confusion Matrix'):
    """
    Plot confusion matrix
    
    Args:
        conf_matrix: confusion matrix
        class_names: list of class names
        title: title for the plot
    """
    plt.figure(figsize=(10, 8))
    plt.imshow(conf_matrix, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    
    # Display values in cells
    thresh = conf_matrix.max() / 2.
    for i in range(conf_matrix.shape[0]):
        for j in range(conf_matrix.shape[1]):
            plt.text(j, i, format(conf_matrix[i, j], 'd'),
                     horizontalalignment="center",
                     color="white" if conf_matrix[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
    return plt

In [None]:
# Function to plot comparison results
def plot_comparison_results(all_results):
    """
    Plot comparison charts
    
    Args:
        all_results: dictionary with comparison results for all datasets
    """
    datasets = list(all_results.keys())
    models = list(all_results[datasets[0]].keys())
    
    # Prepare data for comparison tables
    test_acc_data = {dataset: {model: all_results[dataset][model]['test_acc'] for model in models} for dataset in datasets}
    test_precision_data = {dataset: {model: all_results[dataset][model]['test_precision'] for model in models} for dataset in datasets}
    test_recall_data = {dataset: {model: all_results[dataset][model]['test_recall'] for model in models} for dataset in datasets}
    test_f1_data = {dataset: {model: all_results[dataset][model]['test_f1'] for model in models} for dataset in datasets}
    training_time_data = {dataset: {model: all_results[dataset][model]['training_time'] for model in models} for dataset in datasets}
    
    # Plot metrics charts
    plt.figure(figsize=(20, 15))
    
    # Test accuracy chart
    plt.subplot(3, 2, 1)
    df_test_acc = pd.DataFrame(test_acc_data)
    df_test_acc.plot(kind='bar', ax=plt.gca())
    plt.title('Test Accuracy (%)')
    plt.xlabel('Model')
    plt.ylabel('Accuracy (%)')
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Test precision chart
    plt.subplot(3, 2, 2)
    df_test_precision = pd.DataFrame(test_precision_data)
    df_test_precision.plot(kind='bar', ax=plt.gca())
    plt.title('Test Precision (Weighted) (%)')
    plt.xlabel('Model')
    plt.ylabel('Precision (%)')
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Test recall chart
    plt.subplot(3, 2, 3)
    df_test_recall = pd.DataFrame(test_recall_data)
    df_test_recall.plot(kind='bar', ax=plt.gca())
    plt.title('Test Recall (Weighted) (%)')
    plt.xlabel('Model')
    plt.ylabel('Recall (%)')
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Test F1 chart
    plt.subplot(3, 2, 4)
    df_test_f1 = pd.DataFrame(test_f1_data)
    df_test_f1.plot(kind='bar', ax=plt.gca())
    plt.title('Test F1 Score (Weighted) (%)')
    plt.xlabel('Model')
    plt.ylabel('F1 Score (%)')
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Training time chart
    plt.subplot(3, 2, 5)
    df_time = pd.DataFrame(training_time_data)
    df_time.plot(kind='bar', ax=plt.gca())
    plt.title('Training Time (seconds)')
    plt.xlabel('Model')
    plt.ylabel('Time (s)')
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Validation metrics during training
    plt.subplot(3, 2, 6)
    for dataset in datasets:
        for model in models:
            plt.plot(all_results[dataset][model]['val_f1'], label=f"{dataset}-{model} (F1)")
    plt.title('Validation F1 Score During Training')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score (%)')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('model_comparison_metrics.png')  # Save the figure
    plt.show()
    
    
    # Plot confusion matrices for all models
    for dataset in datasets:
        for model in models:
            if 'confusion_matrix' in all_results[dataset][model]:
                conf_matrix = all_results[dataset][model]['confusion_matrix']

                # Get class names based on dataset
                if dataset == 'emnist':
                    class_names = [chr(i + ord('A')) for i in range(26)]  # A-Z
                elif dataset in ['kmnist', 'svhn']:
                    class_names = [str(i) for i in range(10)]  # 0-9
                else:
                    class_names = [str(i) for i in range(conf_matrix.shape[0])]

                # Plot and save confusion matrix
                plt_cm = plot_confusion_matrix(
                    conf_matrix,
                    class_names,
                    f'Confusion Matrix for {model.upper()} on {dataset.upper()}'
                )
                plt_cm.savefig(f'{dataset}_{model}_confusion_matrix.png')
                plt_cm.show()


    # Create comparison table
    comparison_table = {}
    
    for dataset in datasets:
        for model in models:
            result = all_results[dataset][model]
            key = f"{dataset}-{model}"
            comparison_table[key] = {
                'Test Accuracy (%)': result['test_acc'],
                'Test Precision (%)': result['test_precision'],
                'Test Recall (%)': result['test_recall'],
                'Test F1 Score (%)': result['test_f1'],
                'Training Time (s)': result['training_time']
            }
    
    df_comparison = pd.DataFrame(comparison_table).T
    print("\nComparison table of all models on all datasets:")
    print(df_comparison)
    
    return df_comparison

In [None]:
# Main function to run the comparison
def run_models_comparison(datasets=['emnist', 'kmnist', 'svhn'], 
                          models=['resnet', 'mobilenet', 'vit'],
                          epochs=3, batch_size=32, lr=0.001):  # Reduced batch size
    """
    Run model comparison on all datasets
    
    Args:
        datasets: list of datasets to compare
        models: list of models to compare
        epochs: number of epochs
        batch_size: batch size
        lr: learning rate
    
    Returns:
        df_comparison: table with comparison results
    """
    # Save results for all datasets
    all_results = {}
    
    for dataset in datasets:
        # Compare models on current dataset
        results = compare_models_on_dataset(dataset, models, epochs, batch_size, lr)
        all_results[dataset] = results
        
        # Clear memory between datasets
        gc.collect()
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    
    # Plot comparison charts
    df_comparison = plot_comparison_results(all_results)
    
    # Save detailed per-class metrics to CSV files
    for dataset in datasets:
        for model in models:
            if 'class_metrics' in all_results[dataset][model]:
                class_metrics_df = pd.DataFrame(all_results[dataset][model]['class_metrics'])
                class_metrics_df.to_csv(f"{dataset}_{model}_class_metrics.csv")
    
    return df_comparison

In [None]:
# Run a smaller comparison to save memory
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

gc.collect()
if device.type == 'cuda':
    torch.cuda.empty_cache()

df_results = run_models_comparison(
    datasets=['emnist'],
    models=['resnet', 'mobilenet', 'vit'],
    epochs=15,
    batch_size=32
)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

gc.collect()
if device.type == 'cuda':
    torch.cuda.empty_cache()

df_results = run_models_comparison(
    datasets=['kmnist'],
    models=['resnet', 'mobilenet', 'vit'],
    epochs=15,
    batch_size=32
)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

gc.collect()
if device.type == 'cuda':
    torch.cuda.empty_cache()

df_results = run_models_comparison(
    datasets=['svhn'],
    models=['resnet', 'mobilenet', 'vit'],
    epochs=15,
    batch_size=32
)