In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
from dataclasses import dataclass
import numpy as np
from torch.optim.sgd import SGD
from torch.optim.lr_scheduler import ExponentialLR
import tqdm
import wandb

@dataclass
class TrainConfig:
    """
    Dataclass for storing training configuration.
    
    Attributes:
        lr (float): Learning rate for the optimizer.
        eval_every (int): Frequency of evaluation during training.
        total_iterations (int): Total number of training iterations.
        scheduler_type (str): Type of learning rate scheduler.
        model_type (str): Type of model architecture.
        optimizer_type (str): Type of optimizer to use.
        gamma (float): Decay factor for ExponentialLR scheduler.
    """
    lr: float = 0.1
    eval_every: int = 10
    total_iterations: int = 3000
    scheduler_type: str = "none"
    model_type: str = "batch_norm"
    optimizer_type: str = "sgd"
    gamma: float = 0.99  # Used for ExponentialLR if scheduler_type is "exp"

def set_seed(seed: int) -> None:
    """
    Set seed for reproducibility.

    Args:
        seed (int): Seed value.
    """
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

def get_datasets() -> tuple:
    """
    Load the FashionMNIST dataset.

    Returns:
        tuple: Train and test datasets.
    """
    train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=ToTensor())
    test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=ToTensor())
    return train_dataset, test_dataset

class GenericModel(nn.Module):
    """
    A generic neural network model for classification.

    Args:
        num_classes (int): Number of output classes.
        model_type (str): Type of model architecture.
    """
    def __init__(self, num_classes=10, model_type="batch_norm"):
        super().__init__()
        hidden_dim = 512
        self.model_type = model_type
        self.net = nn.Sequential(
            nn.Linear(in_features=28*28, out_features=hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(num_features=hidden_dim) if model_type == "batch_norm" else nn.Identity(),
            nn.Linear(in_features=hidden_dim, out_features=num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor.
        Returns:
            torch.Tensor: Output tensor.
        """
        x = x.reshape((-1, 28*28))
        return self.net(x)

def get_optimizer(model: nn.Module, config: TrainConfig) -> optim.Optimizer:
    """
    Get the optimizer based on configuration.

    Args:
        model (nn.Module): Model instance.
        config (TrainConfig): Training configuration.
    Returns:
        optim.Optimizer: Optimizer instance.
    """
    if config.optimizer_type == "sgd":
        return SGD(model.parameters(), lr=config.lr)
    elif config.optimizer_type == "adam":
        return optim.Adam(model.parameters(), lr=config.lr)
    else:
        raise ValueError(f"Unknown optimizer type: {config.optimizer_type}")

def get_scheduler(optimizer: optim.Optimizer, config: TrainConfig):
    """
    Get the learning rate scheduler based on configuration.

    Args:
        optimizer (optim.Optimizer): Optimizer instance.
        config (TrainConfig): Training configuration.
    Returns:
        optim.lr_scheduler: Scheduler instance or None.
    """
    if config.scheduler_type == "exp":
        return ExponentialLR(optimizer, gamma=config.gamma)
    elif config.scheduler_type == "none":
        return None
    else:
        raise ValueError(f"Unknown scheduler type: {config.scheduler_type}")

def calculate_accuracy(y_pred: torch.Tensor, y_true: torch.Tensor) -> float:
    """
    Calculate the accuracy of predictions.

    Args:
        y_pred (torch.Tensor): Predicted outputs.
        y_true (torch.Tensor): True labels.
    Returns:
        float: Accuracy value.
    """
    _, predicted = torch.max(y_pred, 1)
    correct = (predicted == y_true).float().sum()
    accuracy = correct / y_true.shape[0]
    return accuracy.item()

def train_loop(
    model: nn.Module,
    X_train: torch.Tensor,
    y_train: torch.Tensor,
    X_val: torch.Tensor,
    y_val: torch.Tensor,
    config: TrainConfig,
    run_name: str | None = None,
    group: str | None = None,
    best_metrics: dict = None
) -> None:
    """
    Main training loop for the model.

    Args:
        model (nn.Module): Model to be trained.
        X_train (torch.Tensor): Training data.
        y_train (torch.Tensor): Training labels.
        X_val (torch.Tensor): Validation data.
        y_val (torch.Tensor): Validation labels.
        config (TrainConfig): Training configuration.
        run_name (str | None): WandB run name.
        group (str | None): Group name for wandb to group runs.
        best_metrics (dict): Dictionary to store the best test accuracy and corresponding configurations.
    """
    wandb.init(
        project="model_train_fashion_mnist",
        notes="version2",
        name=run_name,
        group=group,
        config=config
    )
    optimizer = get_optimizer(model, config)
    scheduler = get_scheduler(optimizer, config)
    model.to(device).train()

    for i in tqdm.trange(config.total_iterations):
        optimizer.zero_grad()
        loss = F.cross_entropy(model(X_train.to(device)), y_train.to(device))
        loss.backward()
        optimizer.step()
        
        metrics = {"iteration": i, "loss_train": loss.detach().cpu().item()}

        if (i + 1) % config.eval_every == 0:
            with torch.no_grad():
                model.to(device).eval()
                loss_val = F.cross_entropy(model(X_val.to(device)), y_val.to(device))
                model.train()
                metrics.update({"loss_val": loss_val.detach().cpu().item()})

        if scheduler:
            scheduler.step()
            metrics.update({"lr": scheduler.get_last_lr()[0]})
        else:
            metrics.update({"lr": config.lr})

        wandb.log(metrics)
    
    wandb.finish()

    # Evaluate test accuracy after training
    with torch.no_grad():
        model.to(device).eval()
        test_accuracy = calculate_accuracy(model(X_val.to(device)), y_val.to(device))
    
    # Update best_metrics if current run's test accuracy is the best
    if best_metrics["best_accuracy"] < test_accuracy:
        best_metrics["best_accuracy"] = test_accuracy
        best_metrics["best_lr"] = config.lr
        best_metrics["best_total_iterations"] = config.total_iterations

# Main execution
seed = 0
set_seed(seed)
train_dataset, test_dataset = get_datasets()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

X_train = train_dataset.data.float().to(device)
y_train = train_dataset.targets.to(device)
X_test = test_dataset.data.float().to(device)
y_test = test_dataset.targets.to(device)

num_classes = 10  # Number of classes in FashionMNIST

learning_rates = [0.01, 0.1, 1.0, 2.0]
total_iterations_list = [100, 130, 500, 1000]
group_name = "lr_iterations_comparison"

best_metrics = {"best_accuracy": 0.0, "best_lr": 0.0, "best_total_iterations": 0}

for lr in learning_rates:
    for total_iterations in total_iterations_list:
        config = TrainConfig(eval_every=20, lr=lr, total_iterations=total_iterations, scheduler_type="exp", model_type="batch_norm", optimizer_type="sgd")
        model = GenericModel(num_classes=num_classes, model_type=config.model_type)
        run_name = f"lr_{lr}_iterations_{total_iterations}"
        train_loop(model, X_train, y_train, X_test, y_test, config=config, run_name=run_name, group=group_name, best_metrics=best_metrics)

        accuracy = calculate_accuracy(model(X_test.to(device)), y_test.to(device))
        print(f"Run: {run_name}, Test Accuracy: {accuracy:.4f}")

print(f"Best Test Accuracy: {best_metrics['best_accuracy']:.4f} with lr={best_metrics['best_lr']} and total_iterations={best_metrics['best_total_iterations']}")


Using device: cuda


100%|██████████| 100/100 [00:03<00:00, 27.03it/s]


0,1
iteration,▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███
loss_train,█▆▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▅▃▂▁
lr,███▇▇▇▆▆▆▆▆▆▆▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁

0,1
iteration,99.0
loss_train,0.62582
loss_val,0.64527
lr,0.00366


Run: lr_0.01_iterations_100, Test Accuracy: 0.7899


100%|██████████| 130/130 [00:04<00:00, 26.70it/s]


0,1
iteration,▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
loss_train,█▇▅▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▅▃▂▁▁
lr,██▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▁▁▁▁▁

0,1
iteration,129.0
loss_train,0.60745
loss_val,0.63669
lr,0.00271


Run: lr_0.01_iterations_130, Test Accuracy: 0.7925


100%|██████████| 500/500 [00:19<00:00, 26.04it/s]


0,1
iteration,▁▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████
loss_train,█▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▅▄▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,█▆▆▅▅▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
iteration,499.0
loss_train,0.55945
loss_val,0.58343
lr,7e-05


Run: lr_0.01_iterations_500, Test Accuracy: 0.8073


100%|██████████| 1000/1000 [00:37<00:00, 26.55it/s]


0,1
iteration,▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇██
loss_train,█▇▇▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▅▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,█▇▄▄▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
iteration,999.0
loss_train,0.56371
loss_val,0.58734
lr,0.0


Run: lr_0.01_iterations_1000, Test Accuracy: 0.8078


100%|██████████| 100/100 [00:03<00:00, 27.02it/s]


0,1
iteration,▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇████
loss_train,█▆▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▄▂▂▁
lr,████▇▇▇▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁

0,1
iteration,99.0
loss_train,0.37255
loss_val,0.41627
lr,0.0366


Run: lr_0.1_iterations_100, Test Accuracy: 0.8551


100%|██████████| 130/130 [00:04<00:00, 27.12it/s]


0,1
iteration,▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇█
loss_train,█▆▇▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▅▃▂▁▁
lr,███▇▇▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁

0,1
iteration,129.0
loss_train,0.36247
loss_val,0.41691
lr,0.02708


Run: lr_0.1_iterations_130, Test Accuracy: 0.8543


100%|██████████| 500/500 [00:18<00:00, 26.54it/s]


0,1
iteration,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇██
loss_train,█▆▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▅▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,█▇▆▆▅▅▅▅▄▄▄▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
iteration,499.0
loss_train,0.33425
loss_val,0.39144
lr,0.00066


Run: lr_0.1_iterations_500, Test Accuracy: 0.8628


100%|██████████| 1000/1000 [00:37<00:00, 26.49it/s]


0,1
iteration,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇█
loss_train,█▆▆▄▄▄▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▅▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,█▆▅▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
iteration,999.0
loss_train,0.33209
loss_val,0.38695
lr,0.0


Run: lr_0.1_iterations_1000, Test Accuracy: 0.8630


100%|██████████| 100/100 [00:03<00:00, 27.36it/s]


0,1
iteration,▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
loss_train,▃▂█▅▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▂▁▁▁
lr,███▇▇▇▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
iteration,99.0
loss_train,0.34796
loss_val,0.43367
lr,0.36603


Run: lr_1.0_iterations_100, Test Accuracy: 0.8495


100%|██████████| 130/130 [00:04<00:00, 27.62it/s]


0,1
iteration,▁▁▁▁▁▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇████
loss_train,▅█▅▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▃▂▁▁▁
lr,██▇▇▇▇▆▆▆▅▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
iteration,129.0
loss_train,0.34138
loss_val,0.43178
lr,0.27075


Run: lr_1.0_iterations_130, Test Accuracy: 0.8575


100%|██████████| 500/500 [00:18<00:00, 27.14it/s]


0,1
iteration,▁▁▁▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇█
loss_train,█▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,█▇▇▅▅▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
iteration,499.0
loss_train,0.29299
loss_val,0.37236
lr,0.00657


Run: lr_1.0_iterations_500, Test Accuracy: 0.8715


100%|██████████| 1000/1000 [00:36<00:00, 27.20it/s]


0,1
iteration,▁▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇███
loss_train,█▆▅▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,█▆▅▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
iteration,999.0
loss_train,0.28033
loss_val,0.43588
lr,4e-05


Run: lr_1.0_iterations_1000, Test Accuracy: 0.8741


100%|██████████| 100/100 [00:03<00:00, 27.55it/s]


0,1
iteration,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇██
loss_train,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▃▁▁▁
lr,████▇▇▇▇▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁

0,1
iteration,99.0
loss_train,0.38172
loss_val,0.43921
lr,0.73206


Run: lr_2.0_iterations_100, Test Accuracy: 0.8440


100%|██████████| 130/130 [00:04<00:00, 27.57it/s]


0,1
iteration,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
loss_train,▂█▆▄▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▅▃▂▁▁
lr,███▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁

0,1
iteration,129.0
loss_train,0.36109
loss_val,0.42719
lr,0.54151


Run: lr_2.0_iterations_130, Test Accuracy: 0.8553


100%|██████████| 500/500 [00:18<00:00, 27.17it/s]


0,1
iteration,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇████
loss_train,▂█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▃▂▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
lr,█▇▇▇▇▆▅▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
iteration,499.0
loss_train,0.32005
loss_val,0.47733
lr,0.01314


Run: lr_2.0_iterations_500, Test Accuracy: 0.8675


100%|██████████| 1000/1000 [00:37<00:00, 26.98it/s]


0,1
iteration,▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇████
loss_train,██▆▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss_val,█▆▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,█▇▆▆▆▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
iteration,999.0
loss_train,0.30578
loss_val,0.4208
lr,9e-05


Run: lr_2.0_iterations_1000, Test Accuracy: 0.8687
Best Test Accuracy: 0.8741 with lr=1.0 and total_iterations=1000
