In [43]:
import torch # type: ignore
import torch.nn as nn # type: ignore
from torch.utils.data import DataLoader, TensorDataset # type: ignore
import numpy as np # type: ignore
import torchvision # type: ignore
import torchvision.transforms as transforms # type: ignore

In [44]:
#Transformation for the image
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

#load the CIFAR-10 dataset and apply the transformations
train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms
)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms
)

#create dataloaders to feed the model in batches
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False, num_workers=2)



In [45]:
# # Define the model
# # the classic LeNet-5 architecture
# class LeNet(nn.Module):
#     def __init__(self):
#         super(LeNet, self).__init__()
#         self.conv1 = nn.Conv2d(3, 6, 5)
#         self.pool = nn.AvgPool2d(2, 2)
#         self.conv2 = nn.Conv2d(6, 16, 5)
#         self.fc1 = nn.Linear(16 * 5 * 5, 120)
#         self.fc2 = nn.Linear(120, 84)
#         self.fc3 = nn.Linear(84, 10)
    
#     def forward(self,x):
#         x = self.pool(torch.tanh(self.conv1(x)))
#         x = self.pool(torch.tanh(self.conv2(x)))
#         x = x.view(-1, 16 * 5 * 5)
#         x = torch.tanh(self.fc1(x))
#         x = torch.tanh(self.fc2(x))
#         x = self.fc3(x)
#         return x


In [46]:
# 
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.AvgPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self,x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [53]:
#create the model instance and define the device, loss function and optimizer
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else 'cpu'
model = LeNet().to(device)
criteria = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)


In [54]:
#train the model
def train(dataloader, model, criterion, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        
        #forwardprop
        pred = model(x)
        loss = criterion(pred, y)

        #backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 128 == 0:
            loss, current = loss.item(), batch * len(x)
            print(f"Loss: {loss:>7f}   [{current:>5d}/{size:>5d}]")
        

In [55]:
#test the model
def test(dataloader, model, criterion):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            test_loss += criterion(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [57]:
#training the model
epochs = 20
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_loader, model, criteria, optimizer)
    test(test_loader, model, criteria)

Epoch 1
-------------------------------
Loss: 0.881621   [    0/50000]
Loss: 1.095250   [16384/50000]
Loss: 1.005521   [32768/50000]
Loss: 0.964553   [49152/50000]
Test Error: 
 Accuracy: 60.6%, Avg loss: 1.173707 

Epoch 2
-------------------------------
Loss: 0.724464   [    0/50000]
Loss: 0.948532   [16384/50000]
Loss: 0.777144   [32768/50000]
Loss: 0.804659   [49152/50000]
Test Error: 
 Accuracy: 60.8%, Avg loss: 1.165594 

Epoch 3
-------------------------------
Loss: 0.790396   [    0/50000]
Loss: 0.827974   [16384/50000]
Loss: 0.869041   [32768/50000]
Loss: 1.084978   [49152/50000]
Test Error: 
 Accuracy: 59.1%, Avg loss: 1.229746 

Epoch 4
-------------------------------
Loss: 0.882367   [    0/50000]
Loss: 0.789731   [16384/50000]
Loss: 0.903621   [32768/50000]
Loss: 0.939173   [49152/50000]
Test Error: 
 Accuracy: 59.5%, Avg loss: 1.202842 

Epoch 5
-------------------------------
Loss: 0.763737   [    0/50000]
Loss: 0.921361   [16384/50000]
Loss: 0.678659   [32768/50000]
Los

In [58]:
torch.save(model.state_dict(), "Lenet_with_relu_and_adam.pth")