In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as tf
import numpy as np
import matplotlib.pyplot as plt


In [2]:
# print sanity check of versions and device
DEVICE = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

print(torch.__version__)
print(f"MPS Available : {torch.backends.mps.is_built()}")
print(f"CUDA Available : {torch.cuda.is_available()}")
print(f"DEVICE - {DEVICE}")
!python --version

2.0.0
MPS Available : True
CUDA Available : False
DEVICE - mps
Python 3.10.10


# Datasets

In [3]:
train_dataset = torchvision.datasets.CIFAR10(
    "CIFAR10_data",
    train=True,
    download=True,
    transform=tf.Compose(
        [
            tf.ToTensor(),
            tf.Normalize(
                mean=(0.49139968, 0.48215827, 0.44653124),
                std=(0.24703233, 0.24348505, 0.26158768),
            ),
        ]
    ),
)

test_dataset = torchvision.datasets.CIFAR10(
    "CIFAR10_data",
    train=False,
    download=True,
    transform=tf.Compose(
        [
            tf.ToTensor(),
            tf.Normalize(
                mean=(0.49139968, 0.48215827, 0.44653124),
                std=(0.24703233, 0.24348505, 0.26158768),
            ),
        ]
    ),
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=True)


Files already downloaded and verified
Files already downloaded and verified


# Model

In [4]:
class DenseOnlyCNN(nn.Module):
    def __init__(self):
        super(DenseOnlyCNN, self).__init__()
        self.model = nn.Sequential(
            # nn.Flatten(),
            nn.Linear(1850, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(32, 7),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)


# Hyperparams

In [5]:
LR = 1e-3
EPOCHS = 130
# wd = 0.01
loss_fn = nn.CrossEntropyLoss()
model = DenseOnlyCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)


# Train and Loss Functions

In [6]:
def get_accuracy():
    test_acc = 0
    model.eval()

    with torch.no_grad():
        # Iterating over the training dataset in batches
        for i, (images, labels) in enumerate(test_loader):
            images = images.to(DEVICE)
            y_true = labels.to(DEVICE)

            # Calculating outputs for the batch being iterated
            outputs = model(images)

            # Calculated prediction labels from models
            _, y_pred = torch.max(outputs.data, 1)

            # Comparing predicted and true labels
            test_acc += (y_pred == y_true).sum().item()

        print(f"Test set accuracy = {100 * test_acc / len(test_set)} %")
        return 100 * test_acc / len(test_set)


def get_loss():
    train_loss = 0

    # Iterating over the training dataset in batches
    model.train()

    for i, (images, labels) in enumerate(train_loader):
        # Extracting images and target labels for the batch being iterated
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        # Calculating the model output and the cross entropy loss
        outputs = model(images)
        loss = loss_fn(outputs, labels)

        # Updating weights according to calculated loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Printing loss for each epoch
    train_loss_list.append(train_loss / len(train_loader))
    print(f"Training loss = {train_loss_list[-1]}")


# Training Loop

In [None]:
train_loss_list = []
model = model.to(device=DEVICE)
best_acc = 0
best_epoch = -1
for epoch in range(EPOCHS):
    print(f"Epoch {epoch}/{EPOCHS} +++++++++++++++++")
    get_loss()
    this_acc = get_accuracy()
    if this_acc > best_acc:
        best_acc = this_acc
        best_epoch = epoch
print(f"Best accuracy occurred at {best_epoch} and was: {best_acc} %")
