In [1]:
import torch
from torch import optim, nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

In [2]:
transformer = transforms.Compose([transforms.Resize(size=(224, 224)), transforms.ToTensor()])

In [3]:
train_data = datasets.MNIST(root='.', download=True, train=True, transform=transformer)
test_data = datasets.MNIST(root='.', download=True, train=False, transform=transformer)

In [4]:
train_loader = DataLoader(dataset=train_data, batch_size=100, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=100, shuffle=True)

In [5]:
pre_trained_model =  models.resnet18(pretrained = True)



In [6]:
conv1_weights = pre_trained_model.conv1.weight
conv1_weights.shape

torch.Size([64, 3, 7, 7])

In [7]:
pre_trained_model.conv1.weight = torch.nn.Parameter(conv1_weights.sum(dim=1, keepdim=True))
pre_trained_model.conv1.weight.shape

torch.Size([64, 1, 7, 7])

In [8]:
pre_trained_model.conv1.in_channels = 1

In [9]:
for param in pre_trained_model.parameters():
    param.requires_grad = False

In [10]:
num_in_features = pre_trained_model.fc.in_features

In [11]:
pre_trained_model.fc = nn.Linear(in_features=num_in_features, out_features=10)

In [12]:
torch.cuda.device_count()

1

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pre_trained_model.to(device)


ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [14]:
device

device(type='cuda', index=0)

In [15]:
criterion = nn.CrossEntropyLoss()

In [16]:
opmtimizer = optim.Adam(pre_trained_model.parameters(), lr = 0.001)

In [17]:
epochs = 24
num_valid_data = len(test_data)
num_valid_data

10000

In [18]:
for epoch in range(0, epochs):

    pre_trained_model.train()

    if epoch % 3 == 0:
        checkpoint = {
            'model_state' : pre_trained_model.state_dict(),
            'optimizer_state' : opmtimizer.state_dict(),
            'epoch' : epoch
        }
        torch.save(checkpoint, 'checkpoint.pth')

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        yhat = pre_trained_model(x)
        opmtimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            loss = criterion(yhat, y)
            loss.backward()
            opmtimizer.step()

    pre_trained_model.eval()
    correct = 0

    for x_val, y_val in test_loader:
        x_val = x_val.to(device)
        y_val = y_val.to(device)
        
        yhat_val = pre_trained_model(x_val)
        _, val_label = torch.max(yhat_val ,1)
        correct += (val_label == y_val).sum()

    val_accuracy = correct / num_valid_data
    print(f'val_accuracy epoch{epoch}: {val_accuracy}')

val_accuracy epoch0: 0.950499951839447
val_accuracy epoch1: 0.9592999815940857
val_accuracy epoch2: 0.9614999890327454
val_accuracy epoch3: 0.9642999768257141
val_accuracy epoch4: 0.9666999578475952
val_accuracy epoch5: 0.9661999940872192
val_accuracy epoch6: 0.9664999842643738
val_accuracy epoch7: 0.9684000015258789
val_accuracy epoch8: 0.9682999849319458
val_accuracy epoch9: 0.9670999646186829
val_accuracy epoch10: 0.9691999554634094
val_accuracy epoch11: 0.9695000052452087
val_accuracy epoch12: 0.9684999585151672
val_accuracy epoch13: 0.9695000052452087
val_accuracy epoch14: 0.968999981880188
val_accuracy epoch15: 0.9684000015258789
val_accuracy epoch16: 0.9675999879837036
val_accuracy epoch17: 0.9704999923706055
val_accuracy epoch18: 0.9693999886512756
val_accuracy epoch19: 0.9692999720573425
val_accuracy epoch20: 0.9688999652862549
val_accuracy epoch21: 0.9690999984741211
val_accuracy epoch22: 0.9709999561309814
val_accuracy epoch23: 0.9698999524116516
val_accuracy epoch24: 0.9700