In [1]:
import numpy as np
import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torchmetrics.classification import Accuracy
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm

import matplotlib.pyplot as plt

import os

In [2]:
# Setup device-agnocstic code
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = '1' #LocalResponseNorm is not available on Apple MPS / this line still does not fix it so i used google colab

In [3]:
# Define dataloaders for the CIFAR 10 dataset
def get_train_valid_loader(data_dir,
                           batch_size,
                           augment,
                           random_seed,
                           valid_size=0.1,
                           shuffle=True):
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    # define transforms
    valid_transform = transforms.Compose([
            transforms.Resize((227,227)),
            transforms.ToTensor(),
            normalize,
    ])
    #perform augmentation to strengthen model training
    if augment:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(), #horizontal flip
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize((227,227)),
            transforms.ToTensor(),
            normalize,
        ])

    train_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=train_transform,
    )

    valid_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=valid_transform,
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, sampler=valid_sampler)

    return (train_loader, valid_loader)

def get_test_loader(data_dir,
                    batch_size,
                    shuffle=True):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    # define transform
    transform = transforms.Compose([
        transforms.Resize((227,227)),
        transforms.ToTensor(),
        normalize,
    ])

    dataset = datasets.CIFAR10(
        root=data_dir, train=False,
        download=True, transform=transform,
    )

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle
    )

    return data_loader

In [4]:
# Load CIFAR10 data
train_data, validation_data = get_train_valid_loader(data_dir='datasets', augment=False, batch_size=128, random_seed=42)
test_data = get_test_loader(data_dir='datasets', batch_size=128, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [5]:
#Create AlexNet architecture

class AlexNet(nn.Module):
    def __init__(self, in_channels, classes=10):
        super().__init__()
        self.layer_1 = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=96,
                kernel_size=11,
                stride=4,
                padding=0
                ),
            nn.ReLU(),
            nn.LocalResponseNorm(5, alpha=10**-4, beta=0.75, k=2), # 'brightness normalization' to reduce overfitting
            nn.MaxPool2d(kernel_size=3, stride=2) #overlapping pooling
        )
        self.layer_2 = nn.Sequential(
            nn.Conv2d(
                in_channels=96,
                out_channels=256,
                kernel_size=5,
                stride=1,
                padding=2
            ),
            nn.ReLU(),
            nn.LocalResponseNorm(5, alpha=10**-4, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.layer_3 = nn.Sequential(
            nn.Conv2d(
                in_channels=256,
                out_channels=384,
                kernel_size=3,
                stride=1,
                padding=1
                ),
            nn.ReLU()
        )
        self.layer_4 = nn.Sequential(
            nn.Conv2d(in_channels=384,
                      out_channels=384,
                      kernel_size=3,
                      stride=1,
                      padding=1
                      ),
            nn.ReLU()
        )
        self.layer_5 = nn.Sequential(
            nn.Conv2d(
                in_channels=384,
                out_channels=256,
                kernel_size=3,
                stride=1,
                padding=1
                ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.layer_6 = nn.Sequential(
            nn.Flatten(),
            #nn.Dropout(0.5),
            nn.Linear(
                in_features=256*6*6,
                out_features=4096,
                ),
            nn.ReLU()
        )
        self.layer_7 = nn.Sequential(
            #nn.Dropout(0.5),
            nn.Linear(
                in_features=4096,
                out_features=4096,
                ),
            nn.ReLU()
        )
        self.layer_8 = nn.Sequential(
            nn.Linear(
                in_features=4096,
                out_features=classes,
                ),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        x = self.layer_4(x)
        x = self.layer_5(x)
        x = self.layer_6(x)
        x = self.layer_7(x)
        x = self.layer_8(x)
        return x


    def _initialize_weights(self):
        biases = [1, 3, 4, 5, 6, 7]
        for i in self.modules():
            if isinstance(i, (nn.Conv2d, nn.Linear)):
                nn.init.normal_(i.weight, mean=0, std=0.01),
                if i in biases:
                    nn.init.constant_(i.bias, 1)
                else:
                    nn.init.constant_(i.bias ,0)




In [6]:
#initialize neural network
net = AlexNet(in_channels=3, classes=10).to(device=device)

In [7]:
#define loss function, optimizer and accuracy function
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    params=net.parameters(),
    momentum=0.9,
    weight_decay=0.0005,
    lr=0.01
    )
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10) #divide lr by 10 everytime validation loss plateaus
accuracy_fn = Accuracy(task='multiclass', num_classes=10).to(device)


In [8]:
# training step
epochs = 90

torch.manual_seed(42)
for epoch in tqdm(range(epochs), desc="Epochs"):
    train_loss = 0
    print(f"Epoch: {epoch}/{epochs}\n-----------------------")
    for i, (images, labels) in tqdm(enumerate(train_data), total=len(train_data), desc="Batches"):
        train_loss = 0
        images = images.to(device)
        labels = labels.to(device)

        #training
        net.train()

        #forward pass
        train_pred = net(images)

        #calculate loss
        train_loss += loss_fn(train_pred, labels)

        #backprop
        optimizer.zero_grad()
        train_loss.backward()

        #gradient descent
        optimizer.step()

    train_loss /= len(train_data)

    #evaluation
    eval_loss, eval_acc = 0, 0
    net.eval()
    with torch.inference_mode():
        for i, (images, labels) in enumerate(validation_data):
            images = images.to(device)
            labels = labels.to(device)
            #forward pass
            eval_logits = net(images)
            eval_pred = eval_logits.argmax(dim=1)

            #calculate loss and accuracy
            eval_loss += loss_fn(eval_logits, labels)
            eval_acc += accuracy_fn(eval_pred, labels)

        eval_loss /= len(validation_data)
        eval_acc /= len(validation_data)
        eval_acc *= 100

    print(f"\nTrain loss: {train_loss:.10f} | Val loss: {eval_loss:.10f}, Val acc: {eval_acc:.2f}%\n")
    scheduler.step(eval_loss)


Epochs:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch: 0/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0047697062 | Val loss: 1.8508930206, Val acc: 33.48%

Epoch: 1/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0044572721 | Val loss: 1.4921401739, Val acc: 47.50%

Epoch: 2/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0024846550 | Val loss: 1.1853249073, Val acc: 57.50%

Epoch: 3/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0027217683 | Val loss: 0.9653658271, Val acc: 66.05%

Epoch: 4/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0022593471 | Val loss: 0.7819209695, Val acc: 72.36%

Epoch: 5/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0023035763 | Val loss: 0.7249559164, Val acc: 74.51%

Epoch: 6/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0020786212 | Val loss: 0.7487882972, Val acc: 75.18%

Epoch: 7/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0014424960 | Val loss: 0.7242773175, Val acc: 76.17%

Epoch: 8/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0009701046 | Val loss: 0.7643483281, Val acc: 77.17%

Epoch: 9/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0006795788 | Val loss: 0.8795191050, Val acc: 75.82%

Epoch: 10/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0004886299 | Val loss: 0.8605499268, Val acc: 76.93%

Epoch: 11/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0005532595 | Val loss: 0.9232516289, Val acc: 77.91%

Epoch: 12/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0001344234 | Val loss: 1.0290602446, Val acc: 76.64%

Epoch: 13/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0002244466 | Val loss: 1.0588530302, Val acc: 77.05%

Epoch: 14/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000850531 | Val loss: 1.0331050158, Val acc: 77.19%

Epoch: 15/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0001917945 | Val loss: 1.1173080206, Val acc: 78.52%

Epoch: 16/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000713645 | Val loss: 1.1497893333, Val acc: 77.38%

Epoch: 17/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0001984637 | Val loss: 1.0374205112, Val acc: 78.07%

Epoch: 18/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000130976 | Val loss: 1.0603317022, Val acc: 79.36%

Epoch: 19/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000025104 | Val loss: 1.0541453362, Val acc: 80.37%

Epoch: 20/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000016115 | Val loss: 1.0412575006, Val acc: 80.88%

Epoch: 21/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000031642 | Val loss: 1.0623774529, Val acc: 80.70%

Epoch: 22/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000008278 | Val loss: 1.0738581419, Val acc: 80.92%

Epoch: 23/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000047149 | Val loss: 1.1085926294, Val acc: 80.72%

Epoch: 24/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000023472 | Val loss: 1.1223104000, Val acc: 80.21%

Epoch: 25/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000003486 | Val loss: 1.1366704702, Val acc: 79.92%

Epoch: 26/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000009762 | Val loss: 1.1282448769, Val acc: 80.51%

Epoch: 27/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000012168 | Val loss: 1.1495804787, Val acc: 80.21%

Epoch: 28/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000015618 | Val loss: 1.1570423841, Val acc: 80.51%

Epoch: 29/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000011762 | Val loss: 1.1330759525, Val acc: 80.00%

Epoch: 30/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000012648 | Val loss: 1.1071894169, Val acc: 81.21%

Epoch: 31/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000006653 | Val loss: 1.1075164080, Val acc: 80.92%

Epoch: 32/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000003758 | Val loss: 1.1249127388, Val acc: 80.04%

Epoch: 33/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000002950 | Val loss: 1.1381329298, Val acc: 80.33%

Epoch: 34/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000008244 | Val loss: 1.1344959736, Val acc: 80.62%

Epoch: 35/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000010770 | Val loss: 1.1386672258, Val acc: 80.33%

Epoch: 36/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000010371 | Val loss: 1.1103249788, Val acc: 80.92%

Epoch: 37/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000007484 | Val loss: 1.1193327904, Val acc: 80.62%

Epoch: 38/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000011971 | Val loss: 1.1754598618, Val acc: 80.62%

Epoch: 39/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000010590 | Val loss: 1.1450567245, Val acc: 80.33%

Epoch: 40/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000003625 | Val loss: 1.1527012587, Val acc: 80.33%

Epoch: 41/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000012397 | Val loss: 1.1073124409, Val acc: 81.21%

Epoch: 42/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000025478 | Val loss: 1.1209276915, Val acc: 80.64%

Epoch: 43/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000007024 | Val loss: 1.1088550091, Val acc: 80.94%

Epoch: 44/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000011177 | Val loss: 1.1275120974, Val acc: 80.64%

Epoch: 45/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000010786 | Val loss: 1.1506389380, Val acc: 80.94%

Epoch: 46/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000008132 | Val loss: 1.1222347021, Val acc: 80.94%

Epoch: 47/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000010088 | Val loss: 1.1629413366, Val acc: 80.64%

Epoch: 48/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000008678 | Val loss: 1.1153240204, Val acc: 80.64%

Epoch: 49/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000015219 | Val loss: 1.1055387259, Val acc: 81.23%

Epoch: 50/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000006978 | Val loss: 1.1478399038, Val acc: 80.94%

Epoch: 51/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000009531 | Val loss: 1.1249961853, Val acc: 80.94%

Epoch: 52/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000008914 | Val loss: 1.1075168848, Val acc: 81.23%

Epoch: 53/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000026508 | Val loss: 1.1069381237, Val acc: 81.23%

Epoch: 54/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000010927 | Val loss: 1.1056684256, Val acc: 81.23%

Epoch: 55/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000006847 | Val loss: 1.1055704355, Val acc: 81.23%

Epoch: 56/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000023933 | Val loss: 1.1455148458, Val acc: 80.64%

Epoch: 57/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000004389 | Val loss: 1.1663665771, Val acc: 80.35%

Epoch: 58/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000009118 | Val loss: 1.1487179995, Val acc: 80.94%

Epoch: 59/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000007552 | Val loss: 1.1457711458, Val acc: 80.64%

Epoch: 60/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000016405 | Val loss: 1.2144373655, Val acc: 79.77%

Epoch: 61/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000006104 | Val loss: 1.1092015505, Val acc: 80.94%

Epoch: 62/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000014908 | Val loss: 1.1331778765, Val acc: 80.94%

Epoch: 63/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000004155 | Val loss: 1.1626747847, Val acc: 80.06%

Epoch: 64/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000018462 | Val loss: 1.1060770750, Val acc: 81.23%

Epoch: 65/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000006103 | Val loss: 1.1675651073, Val acc: 80.35%

Epoch: 66/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000006099 | Val loss: 1.1116968393, Val acc: 80.94%

Epoch: 67/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000023509 | Val loss: 1.1465991735, Val acc: 80.64%

Epoch: 68/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000011131 | Val loss: 1.1785622835, Val acc: 80.35%

Epoch: 69/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000008941 | Val loss: 1.1067184210, Val acc: 81.23%

Epoch: 70/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000009918 | Val loss: 1.1389993429, Val acc: 80.64%

Epoch: 71/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000006511 | Val loss: 1.1248801947, Val acc: 80.64%

Epoch: 72/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000006997 | Val loss: 1.1057062149, Val acc: 81.23%

Epoch: 73/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000003326 | Val loss: 1.1164494753, Val acc: 80.64%

Epoch: 74/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000009225 | Val loss: 1.1677863598, Val acc: 80.64%

Epoch: 75/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000003990 | Val loss: 1.1801409721, Val acc: 80.35%

Epoch: 76/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000020258 | Val loss: 1.1707881689, Val acc: 80.35%

Epoch: 77/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000009809 | Val loss: 1.1072356701, Val acc: 81.23%

Epoch: 78/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000010529 | Val loss: 1.2136138678, Val acc: 80.06%

Epoch: 79/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000008941 | Val loss: 1.1582396030, Val acc: 80.35%

Epoch: 80/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000007413 | Val loss: 1.1430708170, Val acc: 80.64%

Epoch: 81/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000007628 | Val loss: 1.1224471331, Val acc: 80.94%

Epoch: 82/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000008316 | Val loss: 1.1194874048, Val acc: 80.64%

Epoch: 83/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000007424 | Val loss: 1.1214221716, Val acc: 80.94%

Epoch: 84/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000008169 | Val loss: 1.1077891588, Val acc: 80.94%

Epoch: 85/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000011540 | Val loss: 1.1058866978, Val acc: 81.23%

Epoch: 86/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000010018 | Val loss: 1.1699570417, Val acc: 80.64%

Epoch: 87/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000019617 | Val loss: 1.1159338951, Val acc: 80.94%

Epoch: 88/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000010176 | Val loss: 1.1588777304, Val acc: 80.64%

Epoch: 89/90
-----------------------


Batches:   0%|          | 0/352 [00:00<?, ?it/s]


Train loss: 0.0000010846 | Val loss: 1.1338763237, Val acc: 80.35%



In [10]:
net.eval()
test_acc = 0
with torch.inference_mode():
    for images, labels in test_data:
        images = images.to(device)
        labels = labels.to(device)
        
        predictions = net(images).argmax(dim=1)
        test_acc += accuracy_fn(predictions, labels)
    test_acc = (test_acc / len(test_data)) *100
    
    print(f"Accuracy of AlexNet on test images: {test_acc:.2f}%")

Accuracy of AlexNet on test images: 80.36%
