In [1]:
import torch
import os
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from network import CharacterClassifier
from tqdm import tqdm
from training_data import CombinedData
from PIL import Image
from matplotlib import pyplot as plt

hasy_train = CombinedData('HASY')
hasy_test = CombinedData('HASY', train=False)

print("Train data length: {0}".format(len(hasy_train.data)))
print("Test data length: {0}".format(len(hasy_test.data)))
print("Img Shape: {0}".format(hasy_train.data[0].shape))
print("Number of Labels: {0}".format(hasy_train.no_labels))

100%|██████████| 151241/151241 [00:00<00:00, 697768.85it/s]
100%|██████████| 60000/60000 [00:07<00:00, 7881.00it/s]
100%|██████████| 60000/60000 [00:00<00:00, 362894.99it/s]
100%|██████████| 16992/16992 [00:00<00:00, 623492.99it/s]
  8%|▊         | 812/10000 [00:00<00:01, 8116.37it/s]

No training data for 0. Skipping
No training data for 1. Skipping
No training data for 2. Skipping
No training data for 3. Skipping
No training data for 4. Skipping
No training data for 5. Skipping
No training data for 6. Skipping
No training data for 7. Skipping
No training data for 8. Skipping
No training data for 9. Skipping
No training data for +. Skipping


100%|██████████| 10000/10000 [00:01<00:00, 8259.05it/s]
100%|██████████| 10000/10000 [00:00<00:00, 363253.28it/s]

No training data for 0. Skipping
No training data for 1. Skipping
No training data for 2. Skipping
No training data for 3. Skipping
No training data for 4. Skipping
No training data for 5. Skipping
No training data for 6. Skipping
No training data for 7. Skipping
No training data for 8. Skipping
No training data for 9. Skipping
No training data for +. Skipping
Train data length: 60405
Test data length: 10045
Img Shape: torch.Size([1, 32, 32])
Number of Labels: 11





In [3]:
from torchvision import models
from torch.nn import Conv2d

train_loader = DataLoader(hasy_train, batch_size=16, shuffle=True)
test_loader = DataLoader(hasy_test, batch_size=16, shuffle=False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def calc_accuracy(model):
    accuracies = []
    for idx, [x_test, y_test] in enumerate(test_loader):
        x_test, y_test = x_test.to(device), y_test.to(device)
        test_pred = model(x_test)
        accuracy = 100 * torch.mean((torch.argmax(test_pred, dim=1) == y_test).float())
        accuracies.append(accuracy.item())
    return np.mean(accuracies) 

def print_stats(lr, beta1, beta2, decay, accuracy):
    print("Learning rate: {0}".format(lr))
    print("Beta 1: {0}".format(beta1))
    print("Beta 2: {0}".format(beta2))
    print("Weight decay: {0}".format(decay))
    print("Accuracy: {0}".format(accuracy))

learning_rates = [0.01, 0.001, 0.0001]
betas1 = [0.8, 0.85, 0.9, 0.95]
betas2 = [0.9, 0.925, 0.95, 0.99]
weight_decays = [0, 0.01, 0.001, 0.0001]

best_lr = 0
best_beta1 = 0
best_beta2 = 0
best_decay = 0
best_accuracy = 0

for lr in learning_rates:
    for beta1 in betas1:
        for beta2 in betas2:
            for decay in weight_decays:
                model = models.alexnet(num_classes=11)
                model.features[0] = Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
                model.to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(beta1, beta2), weight_decay=decay)
                criterion = nn.CrossEntropyLoss()
                for step, [x_train, y_train] in enumerate(tqdm(train_loader)):
                    x_train, y_train = x_train.to(device), y_train.to(device)
                    optimizer.zero_grad()
                    train_pred = model(x_train)
                    loss = criterion(train_pred, y_train)
                    loss.backward()
                    optimizer.step()
                accuracy = calc_accuracy(model)
                print_stats(lr, beta1, beta2, decay, accuracy)
                if accuracy > best_accuracy:
                    best_lr = lr
                    best_beta1 = beta1
                    best_beta2 = beta2
                    best_decay = decay
                    best_accuracy = accuracy
                    
print("Best hyperparameters:")
print_stats(best_lr, best_beta1, best_beta2, best_decay, best_accuracy)
                

100%|██████████| 3776/3776 [01:43<00:00, 36.46it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:55<00:00, 32.78it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0.01
Accuracy: 10.733096522130785


100%|██████████| 3776/3776 [01:54<00:00, 32.89it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0.001
Accuracy: 9.7554813820845


100%|██████████| 3776/3776 [01:55<00:00, 32.66it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0.0001
Accuracy: 11.34018250939193


100%|██████████| 3776/3776 [01:45<00:00, 35.90it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.925
Weight decay: 0
Accuracy: 10.27299730565138


100%|██████████| 3776/3776 [01:57<00:00, 32.13it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.925
Weight decay: 0.01
Accuracy: 10.233188388453927


100%|██████████| 3776/3776 [01:53<00:00, 33.21it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.925
Weight decay: 0.001
Accuracy: 9.77768250939193


100%|██████████| 3776/3776 [01:54<00:00, 32.99it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.925
Weight decay: 0.0001
Accuracy: 10.054048261065391


100%|██████████| 3776/3776 [01:43<00:00, 36.54it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.95
Weight decay: 0
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:52<00:00, 33.43it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.95
Weight decay: 0.01
Accuracy: 10.253092847052653


100%|██████████| 3776/3776 [01:52<00:00, 33.47it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.95
Weight decay: 0.001
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:52<00:00, 33.54it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.95
Weight decay: 0.0001
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:42<00:00, 36.90it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.99
Weight decay: 0
Accuracy: 10.233188388453927


100%|██████████| 3776/3776 [01:52<00:00, 33.64it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.99
Weight decay: 0.01
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:52<00:00, 33.62it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.99
Weight decay: 0.001
Accuracy: 10.429936305732484


100%|██████████| 3776/3776 [01:52<00:00, 33.61it/s]


Learning rate: 0.01
Beta 1: 0.8
Beta 2: 0.99
Weight decay: 0.0001
Accuracy: 9.864955904377494


100%|██████████| 3776/3776 [01:42<00:00, 36.95it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.9
Weight decay: 0
Accuracy: 10.233188388453927


100%|██████████| 3776/3776 [01:52<00:00, 33.65it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.9
Weight decay: 0.01
Accuracy: 10.153570554059023


100%|██████████| 3776/3776 [01:55<00:00, 32.67it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.9
Weight decay: 0.001
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:56<00:00, 32.40it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.9
Weight decay: 0.0001
Accuracy: 9.538829006207218


100%|██████████| 3776/3776 [01:44<00:00, 36.09it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.925
Weight decay: 0
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:55<00:00, 32.75it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.925
Weight decay: 0.01
Accuracy: 10.054048261065391


100%|██████████| 3776/3776 [01:55<00:00, 32.61it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.925
Weight decay: 0.001
Accuracy: 10.233188388453927


100%|██████████| 3776/3776 [01:53<00:00, 33.41it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.925
Weight decay: 0.0001
Accuracy: 10.066297159073459


100%|██████████| 3776/3776 [01:47<00:00, 35.18it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.95
Weight decay: 0
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:59<00:00, 31.63it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.95
Weight decay: 0.01
Accuracy: 10.633574229137153


100%|██████████| 3776/3776 [01:58<00:00, 31.97it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.95
Weight decay: 0.001
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:58<00:00, 31.91it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.95
Weight decay: 0.0001
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:49<00:00, 34.52it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.99
Weight decay: 0
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:57<00:00, 32.13it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.99
Weight decay: 0.01
Accuracy: 10.263045076352016


100%|██████████| 3776/3776 [01:53<00:00, 33.17it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.99
Weight decay: 0.001
Accuracy: 28.338590152704032


100%|██████████| 3776/3776 [01:54<00:00, 33.07it/s]


Learning rate: 0.01
Beta 1: 0.85
Beta 2: 0.99
Weight decay: 0.0001
Accuracy: 10.253092847052653


100%|██████████| 3776/3776 [01:43<00:00, 36.52it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.9
Weight decay: 0
Accuracy: 10.04409603176603


100%|██████████| 3776/3776 [01:54<00:00, 33.11it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.9
Weight decay: 0.01
Accuracy: 9.695768006288322


100%|██████████| 3776/3776 [01:54<00:00, 32.98it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.9
Weight decay: 0.001
Accuracy: 9.887157031684923


100%|██████████| 3776/3776 [01:53<00:00, 33.26it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.9
Weight decay: 0.0001
Accuracy: 24.347746203659447


100%|██████████| 3776/3776 [01:43<00:00, 36.54it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.925
Weight decay: 0
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:55<00:00, 32.60it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.925
Weight decay: 0.01
Accuracy: 10.27299730565138


100%|██████████| 3776/3776 [01:54<00:00, 33.04it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.925
Weight decay: 0.001
Accuracy: 10.27299730565138


100%|██████████| 3776/3776 [01:54<00:00, 33.10it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.925
Weight decay: 0.0001
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:44<00:00, 36.29it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.95
Weight decay: 0
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:54<00:00, 32.98it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.95
Weight decay: 0.01
Accuracy: 8.941695248245434


100%|██████████| 3776/3776 [01:54<00:00, 32.97it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.95
Weight decay: 0.001
Accuracy: 10.233188388453927


100%|██████████| 3776/3776 [01:53<00:00, 33.30it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.95
Weight decay: 0.0001
Accuracy: 71.32839295211112


100%|██████████| 3776/3776 [01:42<00:00, 36.87it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.99
Weight decay: 0
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:53<00:00, 33.34it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.99
Weight decay: 0.01
Accuracy: 10.312806222848831


100%|██████████| 3776/3776 [01:55<00:00, 32.76it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.99
Weight decay: 0.001
Accuracy: 9.85500367507813


100%|██████████| 3776/3776 [01:56<00:00, 32.51it/s]


Learning rate: 0.01
Beta 1: 0.9
Beta 2: 0.99
Weight decay: 0.0001
Accuracy: 10.755297651716099


100%|██████████| 3776/3776 [01:44<00:00, 36.07it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.9
Weight decay: 0
Accuracy: 10.153570554059023


100%|██████████| 3776/3776 [01:55<00:00, 32.78it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.9
Weight decay: 0.01
Accuracy: 77.00805366418923


100%|██████████| 3776/3776 [01:54<00:00, 33.03it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.9
Weight decay: 0.001
Accuracy: 40.489496577317546


100%|██████████| 3776/3776 [01:54<00:00, 32.96it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.9
Weight decay: 0.0001
Accuracy: 30.935356449929014


100%|██████████| 3776/3776 [01:43<00:00, 36.50it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.925
Weight decay: 0
Accuracy: 10.233188388453927


100%|██████████| 3776/3776 [01:54<00:00, 33.09it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.925
Weight decay: 0.01
Accuracy: 64.0112077445741


100%|██████████| 3776/3776 [01:54<00:00, 32.84it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.925
Weight decay: 0.001
Accuracy: 51.88250245258307


100%|██████████| 3776/3776 [01:54<00:00, 32.95it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.925
Weight decay: 0.0001
Accuracy: 55.22338927323651


100%|██████████| 3776/3776 [01:45<00:00, 35.78it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.95
Weight decay: 0
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:52<00:00, 33.47it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.95
Weight decay: 0.01
Accuracy: 9.695768006288322


100%|██████████| 3776/3776 [01:53<00:00, 33.21it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.95
Weight decay: 0.001
Accuracy: 52.008819209542246


100%|██████████| 3776/3776 [01:54<00:00, 32.98it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.95
Weight decay: 0.0001
Accuracy: 58.13939245794989


100%|██████████| 3776/3776 [01:45<00:00, 35.96it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.99
Weight decay: 0
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:55<00:00, 32.79it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.99
Weight decay: 0.01
Accuracy: 10.223236159154563


100%|██████████| 3776/3776 [01:54<00:00, 33.05it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.99
Weight decay: 0.001
Accuracy: 10.454434101748618


100%|██████████| 3776/3776 [01:53<00:00, 33.14it/s]


Learning rate: 0.01
Beta 1: 0.95
Beta 2: 0.99
Weight decay: 0.0001
Accuracy: 30.32138045426387


100%|██████████| 3776/3776 [01:44<00:00, 36.26it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0
Accuracy: 95.59116242038216


100%|██████████| 3776/3776 [01:55<00:00, 32.73it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0.01
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:54<00:00, 33.10it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0.001
Accuracy: 94.34483710367968


100%|██████████| 3776/3776 [01:52<00:00, 33.42it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0.0001
Accuracy: 95.92953821656052


100%|██████████| 3776/3776 [01:42<00:00, 36.81it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.925
Weight decay: 0
Accuracy: 88.63991303656512


100%|██████████| 3776/3776 [01:52<00:00, 33.46it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.925
Weight decay: 0.01
Accuracy: 9.695768006288322


100%|██████████| 3776/3776 [01:52<00:00, 33.62it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.925
Weight decay: 0.001
Accuracy: 97.08399681528662


100%|██████████| 3776/3776 [01:55<00:00, 32.61it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.925
Weight decay: 0.0001
Accuracy: 94.32722929936305


100%|██████████| 3776/3776 [01:45<00:00, 35.64it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.95
Weight decay: 0
Accuracy: 95.81776703998541


100%|██████████| 3776/3776 [01:54<00:00, 32.91it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.95
Weight decay: 0.01
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:55<00:00, 32.55it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.95
Weight decay: 0.001
Accuracy: 97.30294585987261


100%|██████████| 3776/3776 [01:54<00:00, 32.93it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.95
Weight decay: 0.0001
Accuracy: 94.50636942675159


100%|██████████| 3776/3776 [01:42<00:00, 36.75it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.99
Weight decay: 0
Accuracy: 90.85926017032307


100%|██████████| 3776/3776 [01:52<00:00, 33.66it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.99
Weight decay: 0.01
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:52<00:00, 33.46it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.99
Weight decay: 0.001
Accuracy: 97.73089171974522


100%|██████████| 3776/3776 [01:54<00:00, 32.96it/s]


Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.99
Weight decay: 0.0001
Accuracy: 97.0342356687898


100%|██████████| 3776/3776 [01:44<00:00, 36.12it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.9
Weight decay: 0
Accuracy: 92.93391719745223


100%|██████████| 3776/3776 [01:54<00:00, 32.95it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.9
Weight decay: 0.01
Accuracy: 9.695768006288322


100%|██████████| 3776/3776 [01:55<00:00, 32.81it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.9
Weight decay: 0.001
Accuracy: 96.35748407643312


100%|██████████| 3776/3776 [01:53<00:00, 33.20it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.9
Weight decay: 0.0001
Accuracy: 94.80493630573248


100%|██████████| 3776/3776 [01:44<00:00, 36.15it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.925
Weight decay: 0
Accuracy: 87.09731749516384


100%|██████████| 3776/3776 [01:55<00:00, 32.65it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.925
Weight decay: 0.01
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:53<00:00, 33.20it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.925
Weight decay: 0.001
Accuracy: 95.44187898089172


100%|██████████| 3776/3776 [01:54<00:00, 32.99it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.925
Weight decay: 0.0001
Accuracy: 96.28781847133757


100%|██████████| 3776/3776 [01:43<00:00, 36.55it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.95
Weight decay: 0
Accuracy: 91.64778296355229


100%|██████████| 3776/3776 [01:52<00:00, 33.61it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.95
Weight decay: 0.01
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:52<00:00, 33.61it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.95
Weight decay: 0.001
Accuracy: 97.68113057324841


100%|██████████| 3776/3776 [01:52<00:00, 33.52it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.95
Weight decay: 0.0001
Accuracy: 95.99920382165605


100%|██████████| 3776/3776 [01:42<00:00, 36.84it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.99
Weight decay: 0
Accuracy: 96.4945186323421


100%|██████████| 3776/3776 [01:52<00:00, 33.71it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.99
Weight decay: 0.01
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:52<00:00, 33.60it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.99
Weight decay: 0.001
Accuracy: 96.6062898089172


100%|██████████| 3776/3776 [01:52<00:00, 33.49it/s]


Learning rate: 0.001
Beta 1: 0.85
Beta 2: 0.99
Weight decay: 0.0001
Accuracy: 95.21297770700637


100%|██████████| 3776/3776 [01:42<00:00, 36.88it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.9
Weight decay: 0
Accuracy: 92.97372611464968


100%|██████████| 3776/3776 [01:51<00:00, 33.71it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.9
Weight decay: 0.01
Accuracy: 10.054048261065391


100%|██████████| 3776/3776 [01:52<00:00, 33.71it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.9
Weight decay: 0.001
Accuracy: 95.93949044585987


100%|██████████| 3776/3776 [01:52<00:00, 33.71it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.9
Weight decay: 0.0001
Accuracy: 95.40207006369427


100%|██████████| 3776/3776 [01:42<00:00, 36.86it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.925
Weight decay: 0
Accuracy: 88.81139761323382


100%|██████████| 3776/3776 [01:53<00:00, 33.40it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.925
Weight decay: 0.01
Accuracy: 9.837395885188108


100%|██████████| 3776/3776 [01:54<00:00, 33.11it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.925
Weight decay: 0.001
Accuracy: 96.10867834394904


100%|██████████| 3776/3776 [01:53<00:00, 33.35it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.925
Weight decay: 0.0001
Accuracy: 96.96457006369427


100%|██████████| 3776/3776 [01:43<00:00, 36.50it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.95
Weight decay: 0
Accuracy: 90.70461783439491


100%|██████████| 3776/3776 [01:53<00:00, 33.34it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.95
Weight decay: 0.01
Accuracy: 9.964478197371124


100%|██████████| 3776/3776 [01:53<00:00, 33.34it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.95
Weight decay: 0.001
Accuracy: 96.94466560509554


100%|██████████| 3776/3776 [01:53<00:00, 33.25it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.95
Weight decay: 0.0001
Accuracy: 96.765525477707


100%|██████████| 3776/3776 [01:43<00:00, 36.48it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.99
Weight decay: 0
Accuracy: 92.62539808917198


100%|██████████| 3776/3776 [01:54<00:00, 32.99it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.99
Weight decay: 0.01
Accuracy: 9.538829006207218


100%|██████████| 3776/3776 [01:54<00:00, 32.84it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.99
Weight decay: 0.001
Accuracy: 97.30294585987261


100%|██████████| 3776/3776 [01:55<00:00, 32.68it/s]


Learning rate: 0.001
Beta 1: 0.9
Beta 2: 0.99
Weight decay: 0.0001
Accuracy: 94.41679936305732


100%|██████████| 3776/3776 [01:46<00:00, 35.38it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.9
Weight decay: 0
Accuracy: 93.89698678520834


100%|██████████| 3776/3776 [01:53<00:00, 33.34it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.9
Weight decay: 0.01
Accuracy: 9.538829006207218


100%|██████████| 3776/3776 [01:53<00:00, 33.29it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.9
Weight decay: 0.001
Accuracy: 96.44705414012739


100%|██████████| 3776/3776 [01:53<00:00, 33.14it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.9
Weight decay: 0.0001
Accuracy: 95.86982484076434


100%|██████████| 3776/3776 [01:47<00:00, 35.07it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.925
Weight decay: 0
Accuracy: 92.08338437414473


100%|██████████| 3776/3776 [01:54<00:00, 33.01it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.925
Weight decay: 0.01
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:54<00:00, 33.07it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.925
Weight decay: 0.001
Accuracy: 96.54657643312102


100%|██████████| 3776/3776 [01:53<00:00, 33.16it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.925
Weight decay: 0.0001
Accuracy: 96.48686305732484


100%|██████████| 3776/3776 [01:43<00:00, 36.55it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.95
Weight decay: 0
Accuracy: 96.23805732484077


100%|██████████| 3776/3776 [01:53<00:00, 33.41it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.95
Weight decay: 0.01
Accuracy: 10.233188388453927


100%|██████████| 3776/3776 [01:53<00:00, 33.27it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.95
Weight decay: 0.001
Accuracy: 96.54657643312102


100%|██████████| 3776/3776 [01:53<00:00, 33.28it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.95
Weight decay: 0.0001
Accuracy: 96.67595541401273


100%|██████████| 3776/3776 [01:43<00:00, 36.46it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.99
Weight decay: 0
Accuracy: 95.55135350318471


100%|██████████| 3776/3776 [01:53<00:00, 33.17it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.99
Weight decay: 0.01
Accuracy: 11.298076923485775


100%|██████████| 3776/3776 [01:54<00:00, 32.93it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.99
Weight decay: 0.001
Accuracy: 97.5218949044586


100%|██████████| 3776/3776 [01:53<00:00, 33.37it/s]


Learning rate: 0.001
Beta 1: 0.95
Beta 2: 0.99
Weight decay: 0.0001
Accuracy: 96.25796178343948


100%|██████████| 3776/3776 [01:43<00:00, 36.52it/s]


Learning rate: 0.0001
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0
Accuracy: 96.36743630573248


100%|██████████| 3776/3776 [01:55<00:00, 32.63it/s]


Learning rate: 0.0001
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0.01
Accuracy: 92.8818593966733


 49%|████▉     | 1856/3776 [00:56<00:57, 33.21it/s]

KeyboardInterrupt: 

In [5]:
print("Best hyperparameters:")
print_stats(best_lr, best_beta1, best_beta2, best_decay, best_accuracy)

Best hyperparameters:
Learning rate: 0.001
Beta 1: 0.8
Beta 2: 0.99
Weight decay: 0.001
Accuracy: 97.73089171974522


In [7]:
model = models.alexnet(num_classes=11)
model.features[0] = Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.8, 0.99), weight_decay=0.001)
criterion = nn.CrossEntropyLoss()
for step, [x_train, y_train] in enumerate(tqdm(train_loader)):
    x_train, y_train = x_train.to(device), y_train.to(device)
    optimizer.zero_grad()
    train_pred = model(x_train)
    loss = criterion(train_pred, y_train)
    loss.backward()
    optimizer.step()
accuracy = calc_accuracy(model)
print_stats(lr, beta1, beta2, decay, accuracy)
if accuracy > best_accuracy:
    best_lr = lr
    best_beta1 = beta1
    best_beta2 = beta2
    best_decay = decay
    best_accuracy = accuracy


  0%|          | 0/3776 [00:00<?, ?it/s][A
  0%|          | 4/3776 [00:00<01:38, 38.20it/s][A
  0%|          | 7/3776 [00:00<01:47, 35.06it/s][A
  0%|          | 11/3776 [00:00<01:52, 33.41it/s][A
  0%|          | 15/3776 [00:00<01:56, 32.38it/s][A
  1%|          | 19/3776 [00:00<01:58, 31.77it/s][A
  1%|          | 22/3776 [00:00<02:00, 31.14it/s][A
  1%|          | 26/3776 [00:00<02:01, 30.98it/s][A
  1%|          | 30/3776 [00:00<02:01, 30.77it/s][A
  1%|          | 34/3776 [00:01<02:02, 30.67it/s][A
  1%|          | 38/3776 [00:01<02:02, 30.62it/s][A
  1%|          | 42/3776 [00:01<02:02, 30.43it/s][A
  1%|          | 46/3776 [00:01<02:02, 30.54it/s][A
  1%|▏         | 50/3776 [00:01<02:02, 30.44it/s][A
  1%|▏         | 54/3776 [00:01<02:02, 30.44it/s][A
  2%|▏         | 58/3776 [00:01<02:02, 30.44it/s][A
  2%|▏         | 62/3776 [00:02<02:02, 30.41it/s][A
  2%|▏         | 66/3776 [00:02<02:02, 30.29it/s][A
  2%|▏         | 70/3776 [00:02<02:00, 30.71it/s][A
  2

 16%|█▌        | 606/3776 [00:19<01:35, 33.03it/s][A
 16%|█▌        | 610/3776 [00:19<01:35, 33.07it/s][A
 16%|█▋        | 614/3776 [00:19<01:35, 33.23it/s][A
 16%|█▋        | 618/3776 [00:19<01:34, 33.32it/s][A
 16%|█▋        | 622/3776 [00:19<01:34, 33.34it/s][A
 17%|█▋        | 626/3776 [00:19<01:34, 33.41it/s][A
 17%|█▋        | 630/3776 [00:20<01:34, 33.47it/s][A
 17%|█▋        | 634/3776 [00:20<01:33, 33.49it/s][A
 17%|█▋        | 638/3776 [00:20<01:33, 33.50it/s][A
 17%|█▋        | 642/3776 [00:20<01:33, 33.52it/s][A
 17%|█▋        | 646/3776 [00:20<01:33, 33.55it/s][A
 17%|█▋        | 650/3776 [00:20<01:33, 33.48it/s][A
 17%|█▋        | 654/3776 [00:20<01:33, 33.51it/s][A
 17%|█▋        | 658/3776 [00:20<01:33, 33.52it/s][A
 18%|█▊        | 662/3776 [00:21<01:32, 33.54it/s][A
 18%|█▊        | 666/3776 [00:21<01:32, 33.54it/s][A
 18%|█▊        | 670/3776 [00:21<01:32, 33.55it/s][A
 18%|█▊        | 674/3776 [00:21<01:32, 33.49it/s][A
 18%|█▊        | 678/3776 [0

 32%|███▏      | 1206/3776 [00:37<01:22, 31.19it/s][A
 32%|███▏      | 1210/3776 [00:37<01:23, 30.88it/s][A
 32%|███▏      | 1214/3776 [00:37<01:23, 30.81it/s][A
 32%|███▏      | 1218/3776 [00:37<01:22, 30.91it/s][A
 32%|███▏      | 1222/3776 [00:37<01:21, 31.20it/s][A
 32%|███▏      | 1226/3776 [00:37<01:21, 31.46it/s][A
 33%|███▎      | 1230/3776 [00:38<01:20, 31.54it/s][A
 33%|███▎      | 1234/3776 [00:38<01:20, 31.68it/s][A
 33%|███▎      | 1238/3776 [00:38<01:20, 31.61it/s][A
 33%|███▎      | 1242/3776 [00:38<01:19, 31.92it/s][A
 33%|███▎      | 1246/3776 [00:38<01:18, 32.09it/s][A
 33%|███▎      | 1250/3776 [00:38<01:18, 32.25it/s][A
 33%|███▎      | 1254/3776 [00:38<01:17, 32.42it/s][A
 33%|███▎      | 1258/3776 [00:38<01:18, 32.03it/s][A
 33%|███▎      | 1262/3776 [00:39<01:18, 31.92it/s][A
 34%|███▎      | 1266/3776 [00:39<01:18, 31.88it/s][A
 34%|███▎      | 1270/3776 [00:39<01:18, 31.72it/s][A
 34%|███▎      | 1274/3776 [00:39<01:19, 31.54it/s][A
 34%|███▍ 

 47%|████▋     | 1767/3776 [00:55<01:03, 31.55it/s][A
 47%|████▋     | 1771/3776 [00:55<01:03, 31.63it/s][A
 47%|████▋     | 1775/3776 [00:55<01:03, 31.68it/s][A
 47%|████▋     | 1779/3776 [00:55<01:02, 32.02it/s][A
 47%|████▋     | 1783/3776 [00:56<01:01, 32.26it/s][A
 47%|████▋     | 1787/3776 [00:56<01:01, 32.36it/s][A
 47%|████▋     | 1791/3776 [00:56<01:01, 32.51it/s][A
 48%|████▊     | 1795/3776 [00:56<01:00, 32.62it/s][A
 48%|████▊     | 1799/3776 [00:56<01:00, 32.62it/s][A
 48%|████▊     | 1803/3776 [00:56<01:00, 32.69it/s][A
 48%|████▊     | 1807/3776 [00:56<01:00, 32.73it/s][A
 48%|████▊     | 1811/3776 [00:56<01:00, 32.70it/s][A
 48%|████▊     | 1815/3776 [00:56<01:00, 32.42it/s][A
 48%|████▊     | 1819/3776 [00:57<01:00, 32.49it/s][A
 48%|████▊     | 1823/3776 [00:57<00:59, 32.59it/s][A
 48%|████▊     | 1827/3776 [00:57<00:59, 32.67it/s][A
 48%|████▊     | 1831/3776 [00:57<00:59, 32.73it/s][A
 49%|████▊     | 1835/3776 [00:57<00:59, 32.69it/s][A
 49%|████▊

 62%|██████▏   | 2359/3776 [01:13<00:43, 32.79it/s][A
 63%|██████▎   | 2363/3776 [01:13<00:43, 32.80it/s][A
 63%|██████▎   | 2367/3776 [01:13<00:42, 32.81it/s][A
 63%|██████▎   | 2371/3776 [01:13<00:42, 32.82it/s][A
 63%|██████▎   | 2375/3776 [01:14<00:42, 32.82it/s][A
 63%|██████▎   | 2379/3776 [01:14<00:42, 32.76it/s][A
 63%|██████▎   | 2383/3776 [01:14<00:42, 32.77it/s][A
 63%|██████▎   | 2387/3776 [01:14<00:42, 32.79it/s][A
 63%|██████▎   | 2391/3776 [01:14<00:42, 32.73it/s][A
 63%|██████▎   | 2395/3776 [01:14<00:42, 32.76it/s][A
 64%|██████▎   | 2399/3776 [01:14<00:41, 32.79it/s][A
 64%|██████▎   | 2403/3776 [01:14<00:41, 32.74it/s][A
 64%|██████▎   | 2407/3776 [01:15<00:41, 32.76it/s][A
 64%|██████▍   | 2411/3776 [01:15<00:41, 32.77it/s][A
 64%|██████▍   | 2415/3776 [01:15<00:41, 32.59it/s][A
 64%|██████▍   | 2419/3776 [01:15<00:41, 32.66it/s][A
 64%|██████▍   | 2423/3776 [01:15<00:41, 32.70it/s][A
 64%|██████▍   | 2427/3776 [01:15<00:41, 32.74it/s][A
 64%|█████

 78%|███████▊  | 2951/3776 [01:31<00:25, 32.69it/s][A
 78%|███████▊  | 2955/3776 [01:31<00:25, 32.76it/s][A
 78%|███████▊  | 2959/3776 [01:31<00:25, 32.65it/s][A
 78%|███████▊  | 2963/3776 [01:32<00:24, 32.71it/s][A
 79%|███████▊  | 2967/3776 [01:32<00:24, 32.67it/s][A
 79%|███████▊  | 2971/3776 [01:32<00:24, 32.69it/s][A
 79%|███████▉  | 2975/3776 [01:32<00:24, 32.72it/s][A
 79%|███████▉  | 2979/3776 [01:32<00:24, 32.62it/s][A
 79%|███████▉  | 2983/3776 [01:32<00:24, 32.65it/s][A
 79%|███████▉  | 2987/3776 [01:32<00:24, 32.70it/s][A
 79%|███████▉  | 2991/3776 [01:32<00:24, 32.64it/s][A
 79%|███████▉  | 2995/3776 [01:33<00:23, 32.71it/s][A
 79%|███████▉  | 2999/3776 [01:33<00:23, 32.74it/s][A
 80%|███████▉  | 3003/3776 [01:33<00:23, 32.71it/s][A
 80%|███████▉  | 3007/3776 [01:33<00:23, 32.74it/s][A
 80%|███████▉  | 3011/3776 [01:33<00:23, 32.70it/s][A
 80%|███████▉  | 3015/3776 [01:33<00:23, 32.68it/s][A
 80%|███████▉  | 3019/3776 [01:33<00:23, 32.72it/s][A
 80%|█████

 94%|█████████▍| 3543/3776 [01:50<00:07, 31.34it/s][A
 94%|█████████▍| 3547/3776 [01:50<00:07, 31.45it/s][A
 94%|█████████▍| 3551/3776 [01:50<00:07, 31.52it/s][A
 94%|█████████▍| 3555/3776 [01:50<00:06, 32.05it/s][A
 94%|█████████▍| 3559/3776 [01:51<00:06, 32.17it/s][A
 94%|█████████▍| 3563/3776 [01:51<00:06, 32.41it/s][A
 94%|█████████▍| 3567/3776 [01:51<00:06, 32.32it/s][A
 95%|█████████▍| 3571/3776 [01:51<00:06, 31.39it/s][A
 95%|█████████▍| 3575/3776 [01:51<00:06, 30.80it/s][A
 95%|█████████▍| 3579/3776 [01:51<00:06, 30.44it/s][A
 95%|█████████▍| 3583/3776 [01:51<00:06, 30.38it/s][A
 95%|█████████▍| 3587/3776 [01:51<00:06, 30.07it/s][A
 95%|█████████▌| 3591/3776 [01:52<00:06, 29.86it/s][A
 95%|█████████▌| 3594/3776 [01:52<00:06, 29.74it/s][A
 95%|█████████▌| 3597/3776 [01:52<00:06, 29.67it/s][A
 95%|█████████▌| 3600/3776 [01:52<00:05, 29.63it/s][A
 95%|█████████▌| 3603/3776 [01:52<00:05, 29.68it/s][A
 95%|█████████▌| 3606/3776 [01:52<00:05, 29.65it/s][A
 96%|█████

Learning rate: 0.0001
Beta 1: 0.8
Beta 2: 0.9
Weight decay: 0.001
Accuracy: 96.1484872611465
