In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from methods import load_data, NeuralNet, loader_eval, load_data_with_validation

In [3]:
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
input_size = 784
hidden_size = 500
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

In [4]:
val_errors_per_seed = []
test_errors_per_seed = []

for seed in [0, 1, 2, 3, 4]:
    torch.manual_seed(seed)
    train_loader, test_loader, validation_loader = load_data_with_validation()
    val_errors = []
    test_errors = []
    
    # Fully connected neural network
    model = NeuralNet(input_size, hidden_size, num_classes).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Train the model
    for epoch in range(num_epochs):
        # Train the model
        for i, (images, labels) in enumerate(train_loader):
            images = images.reshape(-1, input_size).to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Validation eval
        val_loss = loader_eval(validation_loader,model, criterion)
        val_errors.append(val_loss)

        
        # Test eval
        test_loss = loader_eval(test_loader, model, criterion)
        test_errors.append(test_loss)
        
    val_errors_per_seed.append(val_errors)
    test_errors_per_seed.append(test_errors)

find best validation error

In [5]:
val_errors_per_seed = np.array(val_errors_per_seed)
test_errors_per_seed = np.array(test_errors_per_seed)

In [6]:
min_validation_error = np.min(val_errors_per_seed)
min_val_error_idx = np.where(val_errors_per_seed == min_validation_error)
i, j = min_val_error_idx
min_pair = [min_validation_error, test_errors_per_seed[i[0]][j[0]]]

In [7]:
print(f'minimum pair {min_pair}')

minimum pair [0.10306332632899284, 0.2823956326271097]
