In [6]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import transforms, datasets
from time import time

torch.set_printoptions(linewidth=120)
torch.set_grad_enabled(True)

from torch.utils.tensorboard import SummaryWriter

In [2]:
@torch.no_grad()
def get_num_corrects(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [3]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1,6,kernel_size=5) #1 channel image input and 6 channel image as output using 5x5 kernel
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv2 = nn.Conv2d(6,16,kernel_size=5)
        self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.fc1 = nn.Linear(400,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
        
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1,400)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x),dim=1)
        return x
        

In [4]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [5]:
apply_transform = transforms.Compose([transforms.Resize(32), transforms.ToTensor()])

train_set = datasets.MNIST(root='./data', train=True, download=False, transform=apply_transform)

In [10]:
net = LeNet()
net.load_state_dict(torch.load('./Saved_models/Lenet_params.pt', map_location=device),strict=True)
net.to(device)
net.train()
print(net)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [15]:
batch_size = 256
lr = 1e-4
optimizer = optim.Adam(net.parameters(), lr=lr)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True,num_workers=4)

In [13]:
images, labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)

comment = f'Batch_Size: {batch_size} lr: {lr}'
tb = SummaryWriter(comment=comment)
tb.add_image('images', grid)
tb.add_graph(net, images)

n = 1
for epoch in range(n):
    tot_loss = 0
    tot_correct = 0
    previous_epoch_timestamp = time()

    if epoch % 10 == 0:
      if epoch: lr *= 0.98
    optimizer = optim.Adam(net.parameters(), lr=lr)

    for data in train_loader:
        inputs,labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        preds = net(inputs)
        loss = F.cross_entropy(preds,labels)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        tot_loss += loss.item()*batch_size
        tot_correct += get_num_corrects(preds,labels)

    tb.add_scalar('Loss', tot_loss, epoch)
    tb.add_scalar('Number Correct', tot_correct, epoch)
    tb.add_scalar('Accuracy', tot_correct/len(train_set), epoch)

    for name, weight in net.named_parameters():
        tb.add_histogram(name, weight, epoch)
        tb.add_histogram(f'{name}.grad', weight.grad, epoch)

    print(f"epoch: {epoch+1}/{n}, train loss: {tot_loss:.6f}, train accuracy: {tot_correct/len(train_set):.6f}, \
time Used: {time()-previous_epoch_timestamp:.3f}s") 
tb.close()

epoch: 1/1, train loss: 343.936741, train accuracy: 0.997583, time Used: 108.991s
