In [23]:
import torch
from torch import nn, optim
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms


In [24]:
writer = SummaryWriter('runs/first')

In [25]:
transformer = transforms.Compose([transforms.Resize(size=(224, 224)), transforms.ToTensor()])

In [26]:
train_dataset = datasets.MNIST(root='.', download=True, train=True, transform=transformer)
valid_dataset = datasets.MNIST(root='.', download=True, train=False, transform=transformer)

In [27]:
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=32, shuffle=True)

In [28]:
model = models.resnet18(pretrained=True)
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [29]:
conv1_weights = model.conv1.weight
conv1_weights.shape

torch.Size([64, 3, 7, 7])

In [30]:
model.conv1.weight = torch.nn.Parameter(conv1_weights.sum(dim=1, keepdim=True))
model.conv1.weight.shape

torch.Size([64, 1, 7, 7])

In [31]:
model.conv1.in_channels = 1

In [32]:
for param in model.parameters():
    param.requires_grad = False

In [33]:
model.fc = nn.Linear(in_features=512, out_features=10, bias=True)

In [34]:
device = torch.device('cuda:0') if torch.cuda.is_available else 'cpu'

In [35]:
model.to(device)

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [36]:
images, labels = next(iter(train_loader))
images = images.to(device)
labels = labels.to(device)
grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)
writer.add_graph(model, images)
# writer.close()

In [37]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [38]:
epochs = 10
num_train_data = len(train_dataset)
num_valid_data = len(valid_dataset)

In [39]:
for epoch in range (0, epochs):
    model.train()

    if epoch % 2 == 0:
        checkpoint = {
            'model_state' : model.state_dict(),
            'optim_state' : optimizer.state_dict(),
            'epoch' : epoch
        }

    torch.save(checkpoint, 'chechpoint.pth')

    correct_train = 0

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        yhat = model(x)
        _, train_label = torch.max(yhat, 1)

        correct_train += (train_label == y).sum()

        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            loss = criterion(yhat, y)
            loss.backward()
            optimizer.step()

    train_acc = correct_train / num_train_data
    
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Accuracy/train', train_acc, epoch)
    
    print(f'train_acc epoch: {epoch}: {train_acc}')

    model.eval()
    correct_valid = 0

    for x_val, y_val in valid_loader:
        x_val = x_val.to(device)
        y_val = y_val.to(device)
            
        yhat_val = model(x_val)
        _, yhat_label = torch.max(yhat_val, 1)
        
        correct_valid += (yhat_label == y_val).sum()

    valid_acc = correct_valid / num_valid_data

    writer.add_scalar('Accuracy/valid', valid_acc, epoch)

    print(f'valid_acc epoch: {epoch}: {valid_acc}')




train_acc epoch: 0: 0.9147500395774841
valid_acc epoch: 0: 0.9563999772071838
train_acc epoch: 1: 0.949999988079071
valid_acc epoch: 1: 0.9569000005722046
train_acc epoch: 2: 0.9561499953269958
valid_acc epoch: 2: 0.960599958896637
train_acc epoch: 3: 0.9586499929428101
valid_acc epoch: 3: 0.9650999903678894
train_acc epoch: 4: 0.9608166813850403
valid_acc epoch: 4: 0.9693999886512756
train_acc epoch: 5: 0.9612833261489868
valid_acc epoch: 5: 0.9668999910354614
train_acc epoch: 6: 0.9639166593551636
valid_acc epoch: 6: 0.9661999940872192
train_acc epoch: 7: 0.9637666940689087
valid_acc epoch: 7: 0.9690999984741211
train_acc epoch: 8: 0.9651333689689636
valid_acc epoch: 8: 0.9688999652862549
train_acc epoch: 9: 0.9665166735649109
valid_acc epoch: 9: 0.9678999781608582


In [40]:
writer.close()