In [37]:
import torch
from torch import optim, nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

In [38]:
transformer = transforms.Compose([transforms.Resize(size=(224, 224)), transforms.ToTensor()])

In [39]:
train_data = datasets.MNIST(root='.', download=True, train=True, transform=transformer)
test_data = datasets.MNIST(root='.', download=True, train=False, transform=transformer)

In [40]:
train_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=16, shuffle=True)

In [41]:
pre_trained_model =  models.resnet18(pretrained = True)
pre_trained_model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [42]:
conv1_weights = pre_trained_model.conv1.weight
conv1_weights.shape

torch.Size([64, 3, 7, 7])

In [43]:
pre_trained_model.conv1.weight = torch.nn.Parameter(conv1_weights.sum(dim=1, keepdim=True))
pre_trained_model.conv1.weight.shape

torch.Size([64, 1, 7, 7])

In [44]:
pre_trained_model.conv1.in_channels = 1

In [45]:
pre_trained_model

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [46]:
num_in_features = pre_trained_model.fc.in_features

In [47]:
pre_trained_model.fc = nn.Linear(in_features=num_in_features, out_features=10)

In [48]:
torch.cuda.device_count()

1

In [49]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pre_trained_model.to(device)


ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [50]:
device

device(type='cuda', index=0)

In [51]:
criterion = nn.CrossEntropyLoss()

In [52]:
opmtimizer = optim.Adam(pre_trained_model.parameters(), lr = 0.001)

In [53]:
epochs = 10
num_valid_data = len(test_data)
num_valid_data

10000

In [54]:
for epoch in range(0, epochs):

    pre_trained_model.train()

    if epoch % 3 == 0:
        checkpoint = {
            'model_state' : pre_trained_model.state_dict(),
            'optimizer_state' : opmtimizer.state_dict(),
            'epoch' : epoch
        }
        torch.save(checkpoint, 'checkpoint.pth')

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        yhat = pre_trained_model(x)
        opmtimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            loss = criterion(yhat, y)
            loss.backward()
            opmtimizer.step()

    pre_trained_model.eval()
    correct = 0

    for x_val, y_val in test_loader:
        x_val = x_val.to(device)
        y_val = y_val.to(device)
        
        yhat_val = pre_trained_model(x_val)
        _, val_label = torch.max(yhat_val ,1)
        correct += (val_label == y_val).sum()

    val_accuracy = correct / num_valid_data
    print(f'val_accuracy epoch{epoch}: {val_accuracy}')

val_accuracy epoch0: 0.9767999649047852
val_accuracy epoch1: 0.9837999939918518
val_accuracy epoch2: 0.9922999739646912
val_accuracy epoch3: 0.9894999861717224
val_accuracy epoch4: 0.9935999512672424
val_accuracy epoch5: 0.995199978351593
val_accuracy epoch6: 0.9952999949455261
val_accuracy epoch7: 0.9939999580383301
val_accuracy epoch8: 0.9934999942779541
val_accuracy epoch9: 0.995199978351593
