In [1]:
import torch
from torch import optim, nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

In [2]:
train_data = datasets.MNIST(root='.', download=True, train=True, transform=transforms.ToTensor())
test_data = datasets.MNIST(root='.', download=True, train=False, transform=transforms.ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Using downloaded and verified file: .\MNIST\raw\train-images-idx3-ubyte.gz
Extracting .\MNIST\raw\train-images-idx3-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Using downloaded and verified file: .\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting .\MNIST\raw\train-labels-idx1-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Using downloaded and verified file: .\MNIST\raw\t10k-images-idx3-ubyte.gz
Extracting .\MNIST\raw\t10k-images-idx3-ubyte.gz to .\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to .\MNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%

Extracting .\MNIST\raw\t10k-labels-idx1-ubyte.gz to .\MNIST\raw






In [3]:
train_loader = DataLoader(dataset=train_data, batch_size=100, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=100, shuffle=True)

In [4]:
pre_trained_model =  models.resnet18(pretrained = True)




In [5]:
for param in pre_trained_model.parameters():
    param.required_grad = False

In [6]:
num_in_features = pre_trained_model.fc.in_features

In [7]:
pre_trained_model.fc = nn.Linear(in_features=num_in_features, out_features=10)
pre_trained_model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [8]:
torch.cuda.device_count()

1

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pre_trained_model.to(device)


ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [10]:
device

device(type='cuda', index=0)

In [11]:
criterion = nn.CrossEntropyLoss()

In [12]:
opmtimizer = optim.Adam(pre_trained_model.parameters(), lr = 0.001)

In [13]:
epochs = 25
num_valid_data = len(test_data)
num_valid_data

10000

In [14]:
for epoch in range(0, epochs):

    pre_trained_model.train()

    if epoch % 3 == 0:
        checkpoint = {
            'model_state' : pre_trained_model.state_dict(),
            'optimizer_state' : opmtimizer.state_dict(),
            'epoch' : epoch
        }
        torch.save(checkpoint, 'checkpoint.pth')

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        yhat = pre_trained_model(x)
        opmtimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            loss = criterion(yhat, y)
            loss.backward()
            opmtimizer.step()

    pre_trained_model.eval()
    correct = 0

    for x_val, y_val in test_loader:
        x_val = x_val.to(device)
        y_val = y_val.to(device)
        
        yhat_val = pre_trained_model(x_val)
        _, val_label = torch.max(yhat_val ,1)
        correct += (val_label == y_val).sum()

    val_accuracy = correct / num_valid_data
    print(f'val_accuracy epoch{epoch}: {val_accuracy}')

val_accuracy epoch0: 0.9840999841690063
val_accuracy epoch1: 0.9809999465942383
val_accuracy epoch2: 0.9894999861717224
val_accuracy epoch3: 0.9888999462127686
val_accuracy epoch4: 0.9899999499320984
val_accuracy epoch5: 0.9901999831199646
val_accuracy epoch6: 0.9916999936103821
val_accuracy epoch7: 0.9921999573707581
val_accuracy epoch8: 0.9937999844551086
val_accuracy epoch9: 0.9914999604225159
val_accuracy epoch10: 0.9911999702453613
val_accuracy epoch11: 0.9923999905586243
val_accuracy epoch12: 0.9932000041007996
val_accuracy epoch13: 0.9767999649047852
val_accuracy epoch14: 0.9939000010490417
val_accuracy epoch15: 0.9912999868392944
val_accuracy epoch16: 0.9890999794006348
val_accuracy epoch17: 0.991599977016449
val_accuracy epoch18: 0.9936999678611755
val_accuracy epoch19: 0.993399977684021
val_accuracy epoch20: 0.9914000034332275
val_accuracy epoch21: 0.9908999800682068
val_accuracy epoch22: 0.9936999678611755
val_accuracy epoch23: 0.9936999678611755
val_accuracy epoch24: 0.9932