In [3]:
# Import
import torch
import os

from dataclasses import dataclass

from helper.global_variables import TRAIN_YAML_PATH
from helper.models import evaluate_model, train_model, get_model
from helper.datasets import get_dataset_loader
from helper.general import ensure_directory_exists
from helper.config import TrainConfig

In [None]:
# Training Process Class
@dataclass
class TrainProcess:
    epochs: int
    model_name: str
    model_path: str

In [5]:
# Training Functions
def train_model_wrapper(
    tp: TrainProcess,
) -> None:
    """Train and save a single model."""
    
    print(f"Training model {tp.model_name}")

    # Get Dataset preset
    dataset_prefix = tp.model_name.split("-")[0]
    train_loader, test_loader = get_dataset_loader(dataset_prefix)
    
    # Initialize and train model
    model = get_model(tp.model_name)
    model = train_model(model, train_loader, tp.epochs)
    
    # Evaluate
    accuracy = evaluate_model(model, test_loader)
    print(f'-> Test Accuracy ({tp.model_name}): {accuracy:.2f}%')
    
    # Save model
    ensure_directory_exists(os.path.dirname(tp.model_path))
    torch.save(model.state_dict(), tp.model_path)
    
    print(f"-> Model saved to {tp.model_path}")
    print("-" * 30)
    
    
def get_train_process_list(train_config: TrainConfig) -> list[TrainProcess]:
    train_processes: list[TrainProcess] = []
    
    for model_name, model_path in train_config.get_models():    
        train_process = TrainProcess(train_config.epochs, model_name, model_path)
        train_processes.append(train_process)
        
    return train_processes


In [6]:
# Main Training Loop
train_config: TrainConfig = TrainConfig(TRAIN_YAML_PATH)
tp_list = get_train_process_list(train_config)

for tp in tp_list:
    train_model_wrapper(tp)
    torch.cuda.empty_cache()

Training model mnist-dense8x20


Train Model ...: 100%|██████████| 938/938 [00:05<00:00, 171.24it/s]


Epoch [1/5], Loss: 1.3109


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 190.29it/s]


Epoch [2/5], Loss: 0.7582


Train Model ...: 100%|██████████| 938/938 [00:05<00:00, 175.80it/s]


Epoch [3/5], Loss: 0.5400


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 188.95it/s]


Epoch [4/5], Loss: 0.4256


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 218.29it/s]


Epoch [5/5], Loss: 0.3510


Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 251.90it/s]


-> Test Accuracy (mnist-dense8x20): 91.64%
Directory already exists: models
-> Model saved to models/mnist-dense8x20.pth
------------------------------
Training model mnist-dense10x100


Train Model ...: 100%|██████████| 938/938 [00:05<00:00, 159.94it/s]


Epoch [1/5], Loss: 0.7002


Train Model ...: 100%|██████████| 938/938 [00:05<00:00, 176.55it/s]


Epoch [2/5], Loss: 0.2576


Train Model ...: 100%|██████████| 938/938 [00:05<00:00, 179.57it/s]


Epoch [3/5], Loss: 0.2013


Train Model ...: 100%|██████████| 938/938 [00:05<00:00, 178.00it/s]


Epoch [4/5], Loss: 0.1671


Train Model ...: 100%|██████████| 938/938 [00:05<00:00, 156.41it/s]


Epoch [5/5], Loss: 0.1531


Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 268.81it/s]


-> Test Accuracy (mnist-dense10x100): 95.65%
Directory already exists: models
-> Model saved to models/mnist-dense10x100.pth
------------------------------
Training model mnist-dense


Train Model ...: 100%|██████████| 938/938 [00:03<00:00, 255.93it/s]


Epoch [1/5], Loss: 0.4095


Train Model ...: 100%|██████████| 938/938 [00:03<00:00, 234.64it/s]


Epoch [2/5], Loss: 0.1610


Train Model ...: 100%|██████████| 938/938 [00:03<00:00, 245.12it/s]


Epoch [3/5], Loss: 0.1222


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 211.95it/s]


Epoch [4/5], Loss: 0.0993


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 220.98it/s]


Epoch [5/5], Loss: 0.0856


Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 292.62it/s]


-> Test Accuracy (mnist-dense): 96.73%
Directory already exists: models
-> Model saved to models/mnist-dense.pth
------------------------------
Training model mnist-conv


Train Model ...: 100%|██████████| 938/938 [00:05<00:00, 179.22it/s]


Epoch [1/5], Loss: 0.3364


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 221.43it/s]


Epoch [2/5], Loss: 0.0879


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 210.54it/s]


Epoch [3/5], Loss: 0.0627


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 201.91it/s]


Epoch [4/5], Loss: 0.0498


Train Model ...: 100%|██████████| 938/938 [00:05<00:00, 173.49it/s]


Epoch [5/5], Loss: 0.0437


Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 237.41it/s]


-> Test Accuracy (mnist-conv): 98.44%
Directory already exists: models
-> Model saved to models/mnist-conv.pth
------------------------------
Training model fmnist-dense


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 229.90it/s]


Epoch [1/5], Loss: 0.5728


Train Model ...: 100%|██████████| 938/938 [00:03<00:00, 241.66it/s]


Epoch [2/5], Loss: 0.3975


Train Model ...: 100%|██████████| 938/938 [00:03<00:00, 299.88it/s]


Epoch [3/5], Loss: 0.3538


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 207.87it/s]


Epoch [4/5], Loss: 0.3237


Train Model ...: 100%|██████████| 938/938 [00:03<00:00, 292.29it/s]


Epoch [5/5], Loss: 0.3070


Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 259.60it/s]


-> Test Accuracy (fmnist-dense): 87.36%
Directory already exists: models
-> Model saved to models/fmnist-dense.pth
------------------------------
Training model fmnist-conv


Train Model ...: 100%|██████████| 938/938 [00:05<00:00, 169.96it/s]


Epoch [1/5], Loss: 0.6816


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 192.16it/s]


Epoch [2/5], Loss: 0.4195


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 222.36it/s]


Epoch [3/5], Loss: 0.3577


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 194.71it/s]


Epoch [4/5], Loss: 0.3210


Train Model ...: 100%|██████████| 938/938 [00:04<00:00, 192.26it/s]


Epoch [5/5], Loss: 0.2962


Evaluate Model ...: 100%|██████████| 157/157 [00:00<00:00, 259.70it/s]


-> Test Accuracy (fmnist-conv): 88.48%
Directory already exists: models
-> Model saved to models/fmnist-conv.pth
------------------------------
Training model cifar10-mobilenet
Files already downloaded and verified
Files already downloaded and verified


Train Model ...: 100%|██████████| 782/782 [00:38<00:00, 20.46it/s]


Epoch [1/5], Loss: 0.4705


Train Model ...: 100%|██████████| 782/782 [00:37<00:00, 20.96it/s]


Epoch [2/5], Loss: 0.2248


Train Model ...: 100%|██████████| 782/782 [00:37<00:00, 20.85it/s]


Epoch [3/5], Loss: 0.1644


Train Model ...: 100%|██████████| 782/782 [00:37<00:00, 20.79it/s]


Epoch [4/5], Loss: 0.1321


Train Model ...: 100%|██████████| 782/782 [00:38<00:00, 20.49it/s]


Epoch [5/5], Loss: 0.1142


Evaluate Model ...: 100%|██████████| 157/157 [00:04<00:00, 32.21it/s]


-> Test Accuracy (cifar10-mobilenet): 90.89%
Directory already exists: models
-> Model saved to models/cifar10-mobilenet-v3-small.pth
------------------------------
Training model cifar10-squeezenet
Files already downloaded and verified
Files already downloaded and verified


Train Model ...: 100%|██████████| 782/782 [00:46<00:00, 16.92it/s]


Epoch [1/5], Loss: 1.6009


Train Model ...: 100%|██████████| 782/782 [00:45<00:00, 17.19it/s]


Epoch [2/5], Loss: 1.0133


Train Model ...: 100%|██████████| 782/782 [00:45<00:00, 17.10it/s]


Epoch [3/5], Loss: 0.7826


Train Model ...: 100%|██████████| 782/782 [00:47<00:00, 16.44it/s]


Epoch [4/5], Loss: 0.6580


Train Model ...: 100%|██████████| 782/782 [00:46<00:00, 16.85it/s]


Epoch [5/5], Loss: 0.5694


Evaluate Model ...: 100%|██████████| 157/157 [00:05<00:00, 29.56it/s]


-> Test Accuracy (cifar10-squeezenet): 79.48%
Directory already exists: models
-> Model saved to models/cifar10-squeezenet.pth
------------------------------
