In [13]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.datasets as dataset
from torchvision import transforms
from torch.optim import Adam
import torch.nn.functional as F

In [9]:
class LeNet(nn.Module):
    def __init__(self,transform=None):
        super(LeNet,self).__init__()
        self.relu = nn.ReLU()
        self.pool = nn.AvgPool2d(kernel_size=(2,2),stride=(2,2))
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),stride=(1,1),padding=(0,0))
        self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=(5,5),stride=(1,1),padding=(0,0))
        self.conv3 = nn.Conv2d(in_channels=16,out_channels=120,kernel_size=(5,5),stride=(1,1),padding=(0,0))
        self.fc1 = nn.Linear(120,84)
        self.fc2 = nn.Linear(84,10)
        self.transform = transform
        
    def forward(self,x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = x.reshape(x.shape[0], -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x
        
        

In [29]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

In [14]:
batch_size = 32
input_size = (32,32)
epochs = 5
learning_rate = 0.001

In [7]:
train_dataset = dataset.MNIST(root='dataset/', train=True, transform=transforms.Compose([transforms.Resize(input_size), transforms.ToTensor()]), download=True)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = dataset.MNIST(root='dataset/', train=False, transform=transforms.Compose([transforms.Resize(input_size), transforms.ToTensor()]), download=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [30]:
model = LeNet().to(device)

In [31]:
criterion = nn.CrossEntropyLoss()
optimiser = Adam(model.parameters(),lr=learning_rate)

In [32]:
for epoch in range(epochs):
    for idx, (data, targets) in enumerate(train_dataloader):
        data = data.to(device)
        targets = targets.to(device)
        
        preds = model(data)
        
        loss = criterion(preds, targets)
        
        optimiser.zero_grad()
        loss.backward()
        
        optimiser.step()

In [37]:
def check_accuracy():
    model.eval()
    num_correct = 0
    num_samples = 0
    
    with torch.no_grad():
        for idx, (data, targets) in enumerate(test_dataloader):
            data = data.to(device)
            targets = targets.to(device)

            preds = model(data)
            _, predictions = preds.max(1)
            num_correct += (predictions==targets).sum()
            num_samples += len(preds)
    
    model.train()
    
    return num_correct/num_samples

In [38]:
accuracy = check_accuracy()
print(f"accuracy of test dataset is {accuracy}")

accuracy of test dataset is 0.9876999855041504
