<a href="https://colab.research.google.com/github/GoTudering/Final_Project_MTH3033/blob/main/SAlexNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils import data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm
import numpy as np
from glob import glob

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

NUM_EPOCHS = 30
BATCH_SIZE = 128

CHECKPOINTS_DIR = './models.1'
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)

In [3]:
train_set = datasets.FashionMNIST(root='FashionMNIST_data/',
                                  train=True,
                                  transform=transforms.ToTensor(),
                                  download=True)
test_set = datasets.FashionMNIST(root='FashionMNIST_data/',
                                 train=False,
                                 transform=transforms.ToTensor(),
                                 download=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to FashionMNIST_data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting FashionMNIST_data/FashionMNIST/raw/train-images-idx3-ubyte.gz to FashionMNIST_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to FashionMNIST_data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting FashionMNIST_data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to FashionMNIST_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to FashionMNIST_data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting FashionMNIST_data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to FashionMNIST_data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to FashionMNIST_data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting FashionMNIST_data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to FashionMNIST_data/FashionMNIST/raw



In [4]:
class SAlexNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        num_classes: int = 1000
        dropout: float = 0.5

        # Input: (N, 1, 28, 28)
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=6, stride=2, padding=0),  # (N, 12, 12, 16)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=5, stride=1, padding=0),  # (N, 8, 8, 16)
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),  # (N, 8, 8, 32)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=4, stride=1, padding=0),  # (N, 5, 5, 32)
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # (N, 5, 5, 64)
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),  # (N, 5, 5, 64)
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),  # (N, 5, 5, 32)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),  # (N, 2, 2, 32)
        ) # Output: (N, 32)

        # Input: (N, 32)
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(32 * 2 * 2, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes),
        ) # Output: (N, 1000)

    def _initialize(self):
        for layer in self.features:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weigh, mean=0, std=0.01)
                nn.init.constant_(layer.bias, 0)
        for i in (3, 8, 10):
            nn.init.constant_(self.features[i].bias, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = x.view(-1, 2 * 2 * 32)
        x = self.classifier(x)
        return x

In [8]:
seed = torch.initial_seed()
print(f"Seed: {seed}")

model = SAlexNet()
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
train_loader = data.DataLoader(train_set, shuffle=True,
                               pin_memory=True, num_workers=2,
                               drop_last=True, batch_size=BATCH_SIZE)

print("Start training...")
total_steps = 1
for epoch in range(NUM_EPOCHS):
    pbar = tqdm(train_loader, ascii=True, ncols=80)
    loss_records = []
    for imgs, classes in pbar:
        imgs, classes = imgs.to(device), classes.to(device)

        # Calculate the loss
        output = model(imgs)
        loss = F.cross_entropy(output, classes)

        # Update the parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_records.append(loss.item())

        # Log
#         if total_steps % 50 == 0:
#             with torch.no_grad():
#                 model.eval()
#                 _, preds = torch.max(output, 1)
#                 accuracy = torch.sum(preds == classes)

#                 print(f"Epoch: {epoch + 1} \tStep: {total_steps} \tLoss: {loss.item():.4f} \tAcc: {accuracy.item()}/{BATCH_SIZE}")
#                 model.train()
        total_steps += 1
    loss_mean = np.mean(loss_records)
    print(f"Epoch: {epoch + 1} \tStep: {total_steps} \tLoss: {loss_mean:.4f}")
    lr_scheduler.step()

    checkpoint_path = os.path.join(CHECKPOINTS_DIR, f"alexnet_states_e{epoch + 1:04d}.pkl")
    state = {
        'epoch': epoch,
        'total_steps': total_steps,
        'optimizer': optimizer.state_dict(),
        'model': model.state_dict(),
        'seed': seed,
    }
    torch.save(state, checkpoint_path)

Seed: 9722867967758336569
Start training...


100%|#########################################| 468/468 [00:53<00:00,  8.71it/s]


Epoch: 1 	Step: 469 	Loss: 1.6550


100%|#########################################| 468/468 [00:59<00:00,  7.85it/s]


Epoch: 2 	Step: 937 	Loss: 0.6087


100%|#########################################| 468/468 [01:01<00:00,  7.63it/s]


Epoch: 3 	Step: 1405 	Loss: 0.5055


100%|#########################################| 468/468 [01:04<00:00,  7.29it/s]


Epoch: 4 	Step: 1873 	Loss: 0.4588


100%|#########################################| 468/468 [01:10<00:00,  6.66it/s]


Epoch: 5 	Step: 2341 	Loss: 0.4235


100%|#########################################| 468/468 [01:13<00:00,  6.40it/s]


Epoch: 6 	Step: 2809 	Loss: 0.3979


100%|#########################################| 468/468 [01:14<00:00,  6.26it/s]


Epoch: 7 	Step: 3277 	Loss: 0.3761


100%|#########################################| 468/468 [01:19<00:00,  5.87it/s]


Epoch: 8 	Step: 3745 	Loss: 0.3609


100%|#########################################| 468/468 [01:21<00:00,  5.76it/s]


Epoch: 9 	Step: 4213 	Loss: 0.3455


100%|#########################################| 468/468 [01:19<00:00,  5.85it/s]


Epoch: 10 	Step: 4681 	Loss: 0.3335


100%|#########################################| 468/468 [01:21<00:00,  5.74it/s]


Epoch: 11 	Step: 5149 	Loss: 0.3276


100%|#########################################| 468/468 [01:22<00:00,  5.69it/s]


Epoch: 12 	Step: 5617 	Loss: 0.3159


100%|#########################################| 468/468 [01:22<00:00,  5.68it/s]


Epoch: 13 	Step: 6085 	Loss: 0.3127


100%|#########################################| 468/468 [01:22<00:00,  5.70it/s]


Epoch: 14 	Step: 6553 	Loss: 0.3024


100%|#########################################| 468/468 [01:22<00:00,  5.65it/s]


Epoch: 15 	Step: 7021 	Loss: 0.2941


100%|#########################################| 468/468 [01:21<00:00,  5.75it/s]


Epoch: 16 	Step: 7489 	Loss: 0.2930


100%|#########################################| 468/468 [01:24<00:00,  5.56it/s]


Epoch: 17 	Step: 7957 	Loss: 0.2811


100%|#########################################| 468/468 [01:19<00:00,  5.91it/s]


Epoch: 18 	Step: 8425 	Loss: 0.2729


100%|#########################################| 468/468 [01:19<00:00,  5.86it/s]


Epoch: 19 	Step: 8893 	Loss: 0.2711


100%|#########################################| 468/468 [01:19<00:00,  5.86it/s]


Epoch: 20 	Step: 9361 	Loss: 0.2667


100%|#########################################| 468/468 [01:17<00:00,  6.02it/s]


Epoch: 21 	Step: 9829 	Loss: 0.2635


100%|#########################################| 468/468 [01:17<00:00,  6.01it/s]


Epoch: 22 	Step: 10297 	Loss: 0.2622


100%|#########################################| 468/468 [01:16<00:00,  6.13it/s]


Epoch: 23 	Step: 10765 	Loss: 0.2555


100%|#########################################| 468/468 [01:15<00:00,  6.16it/s]


Epoch: 24 	Step: 11233 	Loss: 0.2474


100%|#########################################| 468/468 [01:18<00:00,  5.97it/s]


Epoch: 25 	Step: 11701 	Loss: 0.2453


100%|#########################################| 468/468 [01:16<00:00,  6.15it/s]


Epoch: 26 	Step: 12169 	Loss: 0.2440


100%|#########################################| 468/468 [01:16<00:00,  6.11it/s]


Epoch: 27 	Step: 12637 	Loss: 0.2390


100%|#########################################| 468/468 [01:17<00:00,  6.07it/s]


Epoch: 28 	Step: 13105 	Loss: 0.2303


100%|#########################################| 468/468 [01:16<00:00,  6.09it/s]


Epoch: 29 	Step: 13573 	Loss: 0.2292


100%|#########################################| 468/468 [01:15<00:00,  6.21it/s]


Epoch: 30 	Step: 14041 	Loss: 0.2267


In [9]:
test_loader = data.DataLoader(test_set, shuffle=True,
                              pin_memory=True, num_workers=2,
                              batch_size=BATCH_SIZE)
model_paths = glob(os.path.join(CHECKPOINTS_DIR, 'alexnet_states_e????.pkl'))
model_paths.sort()

accuracy_records = []

for model_path in model_paths:
    chk = torch.load(model_path)
    epoch = chk['epoch']
    
    model = SAlexNet()
    model.load_state_dict(chk['model'])  # Load parameters from the checkpoint
    model = model.to(device)
  
    correct_coutner = 0
    pbar = tqdm(test_loader, ascii=True, ncols=80, total=len(test_set))
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        
        with torch.no_grad():
            output = model(imgs)
            model.eval()
            _, preds = torch.max(output, 1)
            accuracy = torch.sum(preds == labels)
            correct_coutner += accuracy.item()
        pbar.update(BATCH_SIZE)

    print(f"Epoch: {epoch + 1} \tAcc: {correct_coutner}/{len(test_set)}")
    accuracy_records.append(dict(epoch=(epoch + 1), correct = correct_coutner, path=model_path))

accuracy_records.sort(reverse=True, key=lambda x: x['correct'])
print(accuracy_records[0])    

  1%|3                                       | 79/10000 [00:03<07:55, 20.85it/s]


Epoch: 1 	Acc: 7553/10000


  1%|3                                       | 79/10000 [00:03<07:50, 21.10it/s]


Epoch: 2 	Acc: 8032/10000


  1%|3                                       | 79/10000 [00:03<07:37, 21.66it/s]


Epoch: 3 	Acc: 8314/10000


  1%|3                                       | 79/10000 [00:03<07:47, 21.23it/s]


Epoch: 4 	Acc: 8421/10000


  1%|3                                       | 79/10000 [00:03<07:32, 21.92it/s]


Epoch: 5 	Acc: 8539/10000


  1%|3                                       | 79/10000 [00:03<07:39, 21.59it/s]


Epoch: 6 	Acc: 8567/10000


  1%|3                                       | 79/10000 [00:03<07:40, 21.53it/s]


Epoch: 7 	Acc: 8586/10000


  1%|3                                       | 79/10000 [00:04<09:04, 18.23it/s]


Epoch: 8 	Acc: 8651/10000


  1%|3                                       | 79/10000 [00:03<07:47, 21.24it/s]


Epoch: 9 	Acc: 8620/10000


  1%|3                                       | 79/10000 [00:03<07:44, 21.36it/s]


Epoch: 10 	Acc: 8645/10000


  1%|3                                       | 79/10000 [00:03<07:44, 21.35it/s]


Epoch: 11 	Acc: 8741/10000


  1%|3                                       | 79/10000 [00:03<07:49, 21.12it/s]


Epoch: 12 	Acc: 8750/10000


  1%|3                                       | 79/10000 [00:03<07:48, 21.16it/s]


Epoch: 13 	Acc: 8760/10000


  1%|3                                       | 79/10000 [00:03<07:44, 21.34it/s]


Epoch: 14 	Acc: 8783/10000


  1%|3                                       | 79/10000 [00:03<07:38, 21.62it/s]


Epoch: 15 	Acc: 8776/10000


  1%|3                                       | 79/10000 [00:03<07:43, 21.38it/s]


Epoch: 16 	Acc: 8809/10000


  1%|3                                       | 79/10000 [00:03<07:42, 21.47it/s]


Epoch: 17 	Acc: 8730/10000


  1%|3                                       | 79/10000 [00:03<07:46, 21.25it/s]


Epoch: 18 	Acc: 8790/10000


  1%|3                                       | 79/10000 [00:03<07:38, 21.65it/s]


Epoch: 19 	Acc: 8848/10000


  1%|3                                       | 79/10000 [00:03<07:39, 21.58it/s]


Epoch: 20 	Acc: 8832/10000


  1%|3                                       | 79/10000 [00:03<07:40, 21.56it/s]


Epoch: 21 	Acc: 8760/10000


  1%|3                                       | 79/10000 [00:03<07:40, 21.54it/s]


Epoch: 22 	Acc: 8784/10000


  1%|3                                       | 79/10000 [00:03<07:37, 21.69it/s]


Epoch: 23 	Acc: 8746/10000


  1%|3                                       | 79/10000 [00:03<07:40, 21.54it/s]


Epoch: 24 	Acc: 8829/10000


  1%|3                                       | 79/10000 [00:03<07:36, 21.73it/s]


Epoch: 25 	Acc: 8850/10000


  1%|3                                       | 79/10000 [00:03<07:42, 21.47it/s]


Epoch: 26 	Acc: 8794/10000


  1%|3                                       | 79/10000 [00:03<07:40, 21.56it/s]


Epoch: 27 	Acc: 8826/10000


  1%|3                                       | 79/10000 [00:03<07:37, 21.69it/s]


Epoch: 28 	Acc: 8794/10000


  1%|3                                       | 79/10000 [00:03<07:39, 21.60it/s]


Epoch: 29 	Acc: 8830/10000


  1%|3                                       | 79/10000 [00:03<07:43, 21.42it/s]

Epoch: 30 	Acc: 8870/10000
{'epoch': 30, 'correct': 8870, 'path': './models.1/alexnet_states_e0030.pkl'}





In [5]:
class SAlexNet2(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        num_classes: int = 1000
        dropout: float = 0.5

        # Input: (N, 1, 28, 28)
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=6, stride=2, padding=0),  # (N, 12, 12, 32)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=5, stride=1, padding=0),  # (N, 8, 8, 32)
            nn.Conv2d(32, 96, kernel_size=5, stride=1, padding=2),  # (N, 8, 8, 96)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=4, stride=1, padding=0),  # (N, 5, 5, 32)
            nn.Conv2d(96, 128, kernel_size=3, stride=1, padding=1),  # (N, 5, 5, 128)
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),  # (N, 5, 5, 128)
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1),  # (N, 5, 5, 96)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=0),  # (N, 2, 2, 96)
        ) # Output: (N, 2, 2, 96)

        # Input: (N, 2, 2, 96)
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(96 * 2 * 2, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes),
        ) # Output: (N, 1000)

    def _initialize(self):
        for layer in self.features:
            if isinstance(layer, nn.Conv2d):
                nn.init.normal_(layer.weigh, mean=0, std=0.01)
                nn.init.constant_(layer.bias, 0)
        for i in (3, 8, 10):
            nn.init.constant_(self.features[i].bias, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = x.view(-1, 2 * 2 * 96)
        x = self.classifier(x)
        return x

In [6]:
seed = torch.initial_seed()
print(f"Seed: {seed}")

model = SAlexNet2()
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
train_loader = data.DataLoader(train_set, shuffle=True,
                               pin_memory=True, num_workers=2,
                               drop_last=True, batch_size=BATCH_SIZE)

print("Start training...")
total_steps = 1
for epoch in range(NUM_EPOCHS):
    pbar = tqdm(train_loader, ascii=True, ncols=80)
    loss_records = []
    for imgs, classes in pbar:
        imgs, classes = imgs.to(device), classes.to(device)

        # Calculate the loss
        output = model(imgs)
        loss = F.cross_entropy(output, classes)

        # Update the parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_records.append(loss.item())

        # Log
#         if total_steps % 50 == 0:
#             with torch.no_grad():
#                 model.eval()
#                 _, preds = torch.max(output, 1)
#                 accuracy = torch.sum(preds == classes)

#                 print(f"Epoch: {epoch + 1} \tStep: {total_steps} \tLoss: {loss.item():.4f} \tAcc: {accuracy.item()}/{BATCH_SIZE}")
#                 model.train()
        total_steps += 1
    loss_mean = np.mean(loss_records)
    print(f"Epoch: {epoch + 1} \tStep: {total_steps} \tLoss: {loss_mean:.4f}")
    lr_scheduler.step()

    checkpoint_path = os.path.join(CHECKPOINTS_DIR, f"alexnet_states_e{epoch + 1:04d}.pkl")
    state = {
        'epoch': epoch,
        'total_steps': total_steps,
        'optimizer': optimizer.state_dict(),
        'model': model.state_dict(),
        'seed': seed,
    }
    torch.save(state, checkpoint_path)

Seed: 11538375912735639844
Start training...


100%|#########################################| 468/468 [02:27<00:00,  3.18it/s]


Epoch: 1 	Step: 469 	Loss: 1.5458


100%|#########################################| 468/468 [02:25<00:00,  3.21it/s]


Epoch: 2 	Step: 937 	Loss: 0.5713


100%|#########################################| 468/468 [02:22<00:00,  3.28it/s]


Epoch: 3 	Step: 1405 	Loss: 0.4451


100%|#########################################| 468/468 [02:26<00:00,  3.20it/s]


Epoch: 4 	Step: 1873 	Loss: 0.3892


100%|#########################################| 468/468 [02:27<00:00,  3.16it/s]


Epoch: 5 	Step: 2341 	Loss: 0.3545


100%|#########################################| 468/468 [02:40<00:00,  2.92it/s]


Epoch: 6 	Step: 2809 	Loss: 0.3357


100%|#########################################| 468/468 [02:48<00:00,  2.77it/s]


Epoch: 7 	Step: 3277 	Loss: 0.3182


100%|#########################################| 468/468 [02:54<00:00,  2.68it/s]


Epoch: 8 	Step: 3745 	Loss: 0.3047


100%|#########################################| 468/468 [02:46<00:00,  2.82it/s]


Epoch: 9 	Step: 4213 	Loss: 0.2897


100%|#########################################| 468/468 [02:29<00:00,  3.12it/s]


Epoch: 10 	Step: 4681 	Loss: 0.2772


100%|#########################################| 468/468 [02:28<00:00,  3.16it/s]


Epoch: 11 	Step: 5149 	Loss: 0.2688


100%|#########################################| 468/468 [02:31<00:00,  3.08it/s]


Epoch: 12 	Step: 5617 	Loss: 0.2637


100%|#########################################| 468/468 [02:37<00:00,  2.96it/s]


Epoch: 13 	Step: 6085 	Loss: 0.2514


100%|#########################################| 468/468 [02:40<00:00,  2.92it/s]


Epoch: 14 	Step: 6553 	Loss: 0.2416


100%|#########################################| 468/468 [02:36<00:00,  2.99it/s]


Epoch: 15 	Step: 7021 	Loss: 0.2377


100%|#########################################| 468/468 [02:26<00:00,  3.18it/s]


Epoch: 16 	Step: 7489 	Loss: 0.2293


100%|#########################################| 468/468 [02:28<00:00,  3.15it/s]


Epoch: 17 	Step: 7957 	Loss: 0.2207


100%|#########################################| 468/468 [02:28<00:00,  3.15it/s]


Epoch: 18 	Step: 8425 	Loss: 0.2116


100%|#########################################| 468/468 [02:33<00:00,  3.04it/s]


Epoch: 19 	Step: 8893 	Loss: 0.2077


100%|#########################################| 468/468 [02:41<00:00,  2.89it/s]


Epoch: 20 	Step: 9361 	Loss: 0.2044


100%|#########################################| 468/468 [02:33<00:00,  3.04it/s]


Epoch: 21 	Step: 9829 	Loss: 0.2015


100%|#########################################| 468/468 [02:29<00:00,  3.14it/s]


Epoch: 22 	Step: 10297 	Loss: 0.1910


100%|#########################################| 468/468 [02:30<00:00,  3.10it/s]


Epoch: 23 	Step: 10765 	Loss: 0.1851


100%|#########################################| 468/468 [02:29<00:00,  3.12it/s]


Epoch: 24 	Step: 11233 	Loss: 0.1772


100%|#########################################| 468/468 [02:29<00:00,  3.13it/s]


Epoch: 25 	Step: 11701 	Loss: 0.1754


100%|#########################################| 468/468 [02:32<00:00,  3.06it/s]


Epoch: 26 	Step: 12169 	Loss: 0.1689


100%|#########################################| 468/468 [02:40<00:00,  2.91it/s]


Epoch: 27 	Step: 12637 	Loss: 0.1661


100%|#########################################| 468/468 [02:43<00:00,  2.87it/s]


Epoch: 28 	Step: 13105 	Loss: 0.1635


100%|#########################################| 468/468 [02:29<00:00,  3.13it/s]


Epoch: 29 	Step: 13573 	Loss: 0.1603


100%|#########################################| 468/468 [02:33<00:00,  3.04it/s]

Epoch: 30 	Step: 14041 	Loss: 0.1521





In [7]:
test_loader = data.DataLoader(test_set, shuffle=True,
                              pin_memory=True, num_workers=2,
                              batch_size=BATCH_SIZE)
model_paths = glob(os.path.join(CHECKPOINTS_DIR, 'alexnet_states_e????.pkl'))
model_paths.sort()

accuracy_records = []

for model_path in model_paths:
    chk = torch.load(model_path)
    epoch = chk['epoch']
    
    model = SAlexNet2()
    model.load_state_dict(chk['model'])  # Load parameters from the checkpoint
    model = model.to(device)
  
    correct_coutner = 0
    pbar = tqdm(test_loader, ascii=True, ncols=80, total=len(test_set))
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        
        with torch.no_grad():
            output = model(imgs)
            model.eval()
            _, preds = torch.max(output, 1)
            accuracy = torch.sum(preds == labels)
            correct_coutner += accuracy.item()
        pbar.update(BATCH_SIZE)

    print(f"Epoch: {epoch + 1} \tAcc: {correct_coutner}/{len(test_set)}")
    accuracy_records.append(dict(epoch=(epoch + 1), correct = correct_coutner, path=model_path))

accuracy_records.sort(reverse=True, key=lambda x: x['correct'])
print(accuracy_records[0])    

  1%|3                                       | 79/10000 [00:09<20:44,  7.97it/s]


Epoch: 1 	Acc: 7542/10000


  1%|3                                       | 79/10000 [00:09<19:10,  8.62it/s]


Epoch: 2 	Acc: 8334/10000


  1%|3                                       | 79/10000 [00:09<18:54,  8.74it/s]


Epoch: 3 	Acc: 8523/10000


  1%|3                                       | 79/10000 [00:08<18:49,  8.78it/s]


Epoch: 4 	Acc: 8398/10000


  1%|3                                       | 79/10000 [00:09<19:00,  8.70it/s]


Epoch: 5 	Acc: 8575/10000


  1%|3                                       | 79/10000 [00:09<18:55,  8.73it/s]


Epoch: 6 	Acc: 8751/10000


  1%|3                                       | 79/10000 [00:08<18:30,  8.93it/s]


Epoch: 7 	Acc: 8739/10000


  1%|3                                       | 79/10000 [00:09<18:59,  8.70it/s]


Epoch: 8 	Acc: 8844/10000


  1%|3                                       | 79/10000 [00:08<18:36,  8.88it/s]


Epoch: 9 	Acc: 8824/10000


  1%|3                                       | 79/10000 [00:08<18:34,  8.90it/s]


Epoch: 10 	Acc: 8821/10000


  1%|3                                       | 79/10000 [00:08<18:35,  8.89it/s]


Epoch: 11 	Acc: 8859/10000


  1%|3                                       | 79/10000 [00:08<18:47,  8.80it/s]


Epoch: 12 	Acc: 8740/10000


  1%|3                                       | 79/10000 [00:09<20:42,  7.99it/s]


Epoch: 13 	Acc: 8885/10000


  1%|3                                       | 79/10000 [00:08<18:39,  8.86it/s]


Epoch: 14 	Acc: 8895/10000


  1%|3                                       | 79/10000 [00:08<18:18,  9.03it/s]


Epoch: 15 	Acc: 8903/10000


  1%|3                                       | 79/10000 [00:08<18:09,  9.11it/s]


Epoch: 16 	Acc: 8897/10000


  1%|3                                       | 79/10000 [00:08<18:15,  9.06it/s]


Epoch: 17 	Acc: 8920/10000


  1%|3                                       | 79/10000 [00:08<18:00,  9.18it/s]


Epoch: 18 	Acc: 8915/10000


  1%|3                                       | 79/10000 [00:08<18:13,  9.07it/s]


Epoch: 19 	Acc: 8930/10000


  1%|3                                       | 79/10000 [00:08<18:11,  9.09it/s]


Epoch: 20 	Acc: 8892/10000


  1%|3                                       | 79/10000 [00:08<18:30,  8.93it/s]


Epoch: 21 	Acc: 8876/10000


  1%|3                                       | 79/10000 [00:10<21:01,  7.86it/s]


Epoch: 22 	Acc: 8843/10000


  1%|3                                       | 79/10000 [00:08<18:26,  8.96it/s]


Epoch: 23 	Acc: 8897/10000


  1%|3                                       | 79/10000 [00:09<19:19,  8.56it/s]


Epoch: 24 	Acc: 8851/10000


  1%|3                                       | 79/10000 [00:08<18:35,  8.90it/s]


Epoch: 25 	Acc: 8944/10000


  1%|3                                       | 79/10000 [00:09<18:51,  8.77it/s]


Epoch: 26 	Acc: 8914/10000


  1%|3                                       | 79/10000 [00:08<18:40,  8.86it/s]


Epoch: 27 	Acc: 8846/10000


  1%|3                                       | 79/10000 [00:08<18:49,  8.78it/s]


Epoch: 28 	Acc: 8895/10000


  1%|3                                       | 79/10000 [00:09<19:01,  8.69it/s]


Epoch: 29 	Acc: 8908/10000


  1%|3                                       | 79/10000 [00:09<18:52,  8.76it/s]

Epoch: 30 	Acc: 8884/10000
{'epoch': 25, 'correct': 8944, 'path': './models.1/alexnet_states_e0025.pkl'}



