In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [2]:
class CNN(nn.Module):
  def __init__(self, in_channels=1, num_classes=10):
    super(CNN, self).__init__()
    self.conv1= nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(3,3), stride=(1,1), padding=(1,1))
    self.pool= nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))
    self.conv2=nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3,3), stride=(1,1), padding=(1,1))
    self.fc1=nn.Linear(16*7*7, num_classes)

  def forward(self,x):
    x=F.relu(self.conv1(x))
    x=self.pool(x)
    x=F.relu(self.conv2(x))
    x=self.pool(x)
    x=x.reshape(x.shape[0],-1)
    x=self.fc1(x)

    return x

In [3]:
def save_checkpoint(state, filename='my_checkpoint.pth.tar'):
  print("saving checkpoint")
  torch.save(state, filename )

In [12]:
def load_checkpoint(checkpoint):
  print("loading checkpoint")
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])

In [4]:
device= torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
in_channels=1
num_classes=10
learning_rate=1e-4
batch_size=1024
num_epochs=10
load_model=True

In [6]:
train_dataset=datasets.MNIST(root='dataaet/', train=True, transform=transforms.ToTensor(), download=True)
train_loader= DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset= datasets.MNIST(root='dataset/', train=False, transform=transforms.ToTensor(), download=True)
test_loader= DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataaet/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataaet/MNIST/raw/train-images-idx3-ubyte.gz to dataaet/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataaet/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataaet/MNIST/raw/train-labels-idx1-ubyte.gz to dataaet/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataaet/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataaet/MNIST/raw/t10k-images-idx3-ubyte.gz to dataaet/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataaet/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataaet/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataaet/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw



In [7]:
model=CNN().to(device)

In [8]:
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(), lr=learning_rate)

In [15]:
if load_model:
  load_checkpoint(torch.load("my_checkpoint.pth.tar"))

loading checkpoint


In [17]:
for epoch in range(num_epochs):
  losses=[]

  if epoch%3== 0:
    checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
    save_checkpoint(checkpoint)

  for batch_idx, (data, targets) in enumerate(train_loader):
    data=data.to(device=device)
    targets=targets.to(device=device)

    scores=model(data)
    loss=criterion(scores, targets)
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()

  mean_loss=sum(losses)/len(losses)
  print(f'loss at epoch (epoch) was {mean_loss:.5f}')

saving checkpoint
loss at epoch (epoch) was 0.19494
loss at epoch (epoch) was 0.19094
loss at epoch (epoch) was 0.18655
saving checkpoint
loss at epoch (epoch) was 0.18267
loss at epoch (epoch) was 0.17900
loss at epoch (epoch) was 0.17546
saving checkpoint
loss at epoch (epoch) was 0.17208
loss at epoch (epoch) was 0.16869
loss at epoch (epoch) was 0.16518
saving checkpoint
loss at epoch (epoch) was 0.16222


In [None]:
def check_accuracy(loader, model):
  num_correct=0
  num_samples=0
  model.eval()

  with torch.no_grad():
    for x,y in loader:
      x=x.to(device=device)
      y=y.to(device=device)

      scores=model(x)
      _, predictions= scores.max(1)
      num_correct+= (predictions==y).sum()
      num_samples+= predictions.size(0)

    print(f'Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}')

  model.train()

In [None]:
check_accuracy(train_loader, model)
check_accuracy(test_loader, model)