In [1]:
import torch
from torch import nn, optim
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms


In [2]:
writer = SummaryWriter('runs/first')

In [3]:
train_dataset = datasets.MNIST(root='.', download=True, train=True, transform=transforms.ToTensor())
valid_dataset = datasets.MNIST(root='.', download=True, train=False, transform=transforms.ToTensor())

In [4]:
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=32, shuffle=True)

In [5]:
model = models.resnet18(pretrained=True)
model



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [6]:
for param in model.parameters():
    param.requires_grad = False

In [7]:
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = nn.Linear(in_features=512, out_features=10, bias=True)

In [8]:
device = torch.device('cuda:0') if torch.cuda.is_available else 'cpu'

In [9]:
model.to(device)

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [10]:
images, labels = next(iter(train_loader))
images = images.to(device)
labels = labels.to(device)
grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)
writer.add_graph(model, images)
# writer.close()

In [11]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [12]:
epochs = 10
num_train_data = len(train_dataset)
num_valid_data = len(valid_dataset)

In [13]:
for epoch in range (0, epochs):
    model.train()

    if epoch % 2 == 0:
        checkpoint = {
            'model_state' : model.state_dict(),
            'optim_state' : optimizer.state_dict(),
            'epoch' : epoch
        }

    torch.save(checkpoint, 'chechpoint.pth')

    correct_train = 0

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        yhat = model(x)
        _, train_label = torch.max(yhat, 1)

        correct_train += (train_label == y).sum()

        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            loss = criterion(yhat, y)
            loss.backward()
            optimizer.step()

    train_acc = correct_train / num_train_data
    
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Accuracy/train', train_acc, epoch)
    
    print(f'train_acc epoch: {epoch}: {train_acc}')

    model.eval()
    correct_valid = 0

    for x_val, y_val in valid_loader:
        x_val = x_val.to(device)
        y_val = y_val.to(device)
            
        yhat_val = model(x_val)
        _, yhat_label = torch.max(yhat_val, 1)
        
        correct_valid += (yhat_label == y_val).sum()

    valid_acc = correct_valid / num_valid_data

    writer.add_scalar('Accuracy/valid', valid_acc, epoch)

    print(f'valid_acc epoch: {epoch}: {valid_acc}')




train_acc epoch: 0: 0.720466673374176
valid_acc epoch: 0: 0.8622999787330627
train_acc epoch: 1: 0.8399333357810974
valid_acc epoch: 1: 0.8930999636650085
train_acc epoch: 2: 0.8578667044639587
valid_acc epoch: 2: 0.9104999899864197
train_acc epoch: 3: 0.8698000311851501
valid_acc epoch: 3: 0.9062999486923218
train_acc epoch: 4: 0.8784833550453186
valid_acc epoch: 4: 0.917199969291687
train_acc epoch: 5: 0.8771666884422302
valid_acc epoch: 5: 0.9061999917030334
train_acc epoch: 6: 0.885033369064331
valid_acc epoch: 6: 0.9161999821662903
train_acc epoch: 7: 0.8849000334739685
valid_acc epoch: 7: 0.9103999733924866
train_acc epoch: 8: 0.8891000151634216
valid_acc epoch: 8: 0.9188999533653259
train_acc epoch: 9: 0.8909833431243896
valid_acc epoch: 9: 0.9185000061988831


In [14]:
writer.close()