In [None]:
## Importing Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm

In [None]:
## Model
class CNN(nn.Module):
  def __init__(self,in_channels, num_classes):
    super(CNN,self).__init__()
    self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=8, kernel_size=(3,3), padding=(1,1)) # Same Convolution
    self.pool = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))
    self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3,3), padding=(1,1)) # Same Convolution
    self.fc1 = nn.Linear(16*7*7,num_classes) # Will use 2 maxpool

  def forward(self,x):
    x = F.relu(self.conv1(x))
    x = self.pool(x)
    x = F.relu(self.conv2(x))
    x = self.pool(x)
    x = x.reshape(x.shape[0],-1)
    x = self.fc1(x)

    return x

In [None]:
def save_checkpoint(state,filename="my_checkpoint.pth.tar"):
  print(' => Saving Checkpoint')
  torch.save(state,filename)

In [None]:
def load_checkpoint(checkpoint):
  print(' => Loading Checkpoint')
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])

In [None]:
## Set Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
## Hyperparameters
in_channels=1
input_size = 784
num_classes = 10
learning_rate = 1e-4
batch_size = 1024
num_epochs = 10
load_model = True

In [None]:
## Load Data
train_dataset = datasets.MNIST(root='dataset/',train=True,transform=transforms.ToTensor(),download=True)
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)

test_dataset = datasets.MNIST(root='dataset/',train=False,transform=transforms.ToTensor(),download=True)
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

In [None]:
## Initialize Network
model = CNN(in_channels=in_channels,num_classes=num_classes).to(device)

In [None]:
## Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr=learning_rate)

In [None]:
if load_model:
  load_checkpoint(torch.load("my_checkpoint.pth.tar"))

 => Loading Checkpoint


In [None]:
## Train Network
for epoch in range(num_epochs):
  losses = []

  if epoch % 3 == 0:
    checkpoint = {'state_dict':model.state_dict(),'optimizer':optimizer.state_dict()}
    save_checkpoint(checkpoint)

  for batch_idx,(data,targets) in enumerate(tqdm(train_loader)):
    # Get data to cuda
    data = data.to(device=device)
    targets = targets.to(device=device)

    # Forward
    scores = model(data)
    loss = criterion(scores,targets)
    losses.append(loss.item())

    # Backward
    optimizer.zero_grad()
    loss.backward()

    # Gradient descent or Adam step size
    optimizer.step()

 => Saving Checkpoint


100%|██████████| 59/59 [00:04<00:00, 12.43it/s]
100%|██████████| 59/59 [00:04<00:00, 12.47it/s]
100%|██████████| 59/59 [00:04<00:00, 12.45it/s]


 => Saving Checkpoint


100%|██████████| 59/59 [00:04<00:00, 12.37it/s]
100%|██████████| 59/59 [00:04<00:00, 12.64it/s]
100%|██████████| 59/59 [00:04<00:00, 12.74it/s]


 => Saving Checkpoint


100%|██████████| 59/59 [00:04<00:00, 12.64it/s]
100%|██████████| 59/59 [00:04<00:00, 12.67it/s]
100%|██████████| 59/59 [00:04<00:00, 12.45it/s]


 => Saving Checkpoint


100%|██████████| 59/59 [00:04<00:00, 12.48it/s]


In [None]:
## Check Accuracy on training & test to see how good our model 
def check_accuracy(loader,model):
  num_correct = 0
  num_samples = 0
  model.eval()

  with torch.no_grad():
    for x,y  in loader:
      x = x.to(device=device)
      y = y.to(device=device)

      # Forward
      scores = model(x)
      #64x10
      _,predictions = scores.max(1) #(max values for 2nd dim)/will output indices
      
      num_correct += (predictions == y).sum()
      num_samples += predictions.size(0) #Size of 1st dimension

  model.train()
  return num_correct/num_samples

print(f'Accuracy on training data: {check_accuracy(train_loader, model)*100:.2f}')
print(f"Accuracy on test set: {check_accuracy(test_loader, model)*100:.2f}")

Accuracy on training data: 91.60
Accuracy on test set: 91.87
