In [1]:
import torch
import pandas as pd

from typing import Dict, List
from helper.config import TrainConfig
from helper.global_variables import TRAIN_YAML_PATH
from helper.models import get_model, evaluate_model, get_model_size
from helper.datasets import get_dataset_loader

In [2]:
# Main Functions
@dataclass
class TrainAccResult:
    dataset: str
    model: str
    size: int
    epochs: int
    accuracy: float

def eval_wrapper(
    model_name: str,
    dataset_name: str,
    model_path: str,
    epochs: int
) -> Dict:

    """Evaluate a single model and return results."""
    # Load and prepare model
    model = get_model(model_name)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    model_size = get_model_size(model)

    # Get data and evaluate
    _, test_loader = get_dataset_loader(dataset_name)
    accuracy = evaluate_model(model, test_loader)

    return TrainAccResult(dataset_name, 
                          model_name, 
                          model_size, 
                          epochs, 
                          accuracy)

def evaluate_model_accuracies(train_config: TrainConfig) -> None:
    results: List[TrainAccResult] = []

    # Evaluate each model
    for model_name, model_path in train_config.get_models():
        
        # Parse config key
        dataset_name = model_name.split("-")[0]
        
        # Evaluate model
        try:
            result = eval_wrapper(
                model_name=model_name,
                dataset_name=dataset_name,
                model_path=model_path,
                epochs=train_config.epochs
            )
            results.append(result)
            
        except Exception as e:
            print(f"Error evaluating {model_name} on {dataset_name}: {str(e)}")

    # Save and display results
    return pd.DataFrame(results)

NameError: name 'dataclass' is not defined

In [3]:
# Load configuration
train_config: TrainConfig = TrainConfig(TRAIN_YAML_PATH)

df = evaluate_model_accuracies(train_config)
df.to_csv("models/model-eval-results.csv", index=False)

print("\nEvaluation Results:")
print(df)

  model.load_state_dict(torch.load(model_path))
Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 239.14it/s]
Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 264.37it/s]
Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 258.08it/s]
Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 202.12it/s]
Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 267.42it/s]
Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 269.12it/s]


Files already downloaded and verified
Files already downloaded and verified


Evaluate Model ...: 100%|██████████| 157/157 [00:04<00:00, 34.96it/s]


Files already downloaded and verified
Files already downloaded and verified


Evaluate Model ...: 100%|██████████| 157/157 [00:05<00:00, 29.63it/s]



Evaluation Results:
   Dataset               Model     Size  Epochs  Accuracy (%)
0    mnist     mnist-dense8x20    19270       5         91.22
1    mnist   mnist-dense10x100   180510       5         95.47
2    mnist         mnist-dense   576810       5         97.28
3    mnist          mnist-conv    28534       5         98.63
4   fmnist        fmnist-dense   576810       5         87.42
5   fmnist         fmnist-conv    28534       5         88.65
6  cifar10   cifar10-mobilenet  2542856       5         87.74
7  cifar10  cifar10-squeezenet  1235496       5         79.00
