In [1]:
import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import models, transforms, datasets
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter

In [2]:
writer = SummaryWriter('runs/first')

In [3]:
train_data = datasets.FashionMNIST(root='', download=True, train=True, transform=transforms.ToTensor())
valid_data = datasets.FashionMNIST(root='', download=True, train=False, transform=transforms.ToTensor())

In [4]:
train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=32, shuffle=True)

In [5]:
model = models.resnet18(pretrained = True)

for param in model.parameters():
    param.requires_grad = False

model



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [6]:
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = nn.Linear(in_features=512, out_features=10, bias=True)

In [7]:
optimizer = optim.Adam(model.parameters(), lr=0.05)
scheduler = StepLR(optimizer, 
                   step_size = 10,
                   gamma = 0.1)
criterion = nn.CrossEntropyLoss()

In [8]:
epochs = 50
num_train_data = len(train_data)
num_valid_data = len(valid_data)

In [9]:
device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
model.to(device)

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [10]:
images, labels = next(iter(train_loader))
images = images.to(device)
labels = labels.to(device)
grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)
writer.add_graph(model, images)

In [11]:
for epoch in range (0, epochs):
    model.train()

    if epoch % 2 == 0:
        checkpoint = {
            'model_state' : model.state_dict(),
            'optim_state' : optimizer.state_dict(),
            'epoch' : epoch
        }

    torch.save(checkpoint, 'chechpoint.pth')

    correct_train = 0

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)

        yhat = model(x)
        _, train_label = torch.max(yhat, 1)

        correct_train += (train_label == y).sum()

        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            loss = criterion(yhat, y)
            loss.backward()
            optimizer.step()
           

    train_acc = correct_train / num_train_data
    
    writer.add_scalar('Loss/train', loss, epoch)
    writer.add_scalar('Accuracy/train', train_acc, epoch)
    
    print(f'train_acc epoch: {epoch}: {train_acc}')

    model.eval()
    correct_valid = 0

    for x_val, y_val in valid_loader:
        x_val = x_val.to(device)
        y_val = y_val.to(device)
            
        yhat_val = model(x_val)
        _, yhat_label = torch.max(yhat_val, 1)
        
        correct_valid += (yhat_label == y_val).sum()

    valid_acc = correct_valid / num_valid_data

    writer.add_scalar('Accuracy/valid', valid_acc, epoch)

    print(f'valid_acc epoch: {epoch}: {valid_acc}')

    writer.add_scalar('LR', optimizer.param_groups[0]["lr"], epoch)
    print(f'LR epoch: {epoch}: {optimizer.param_groups[0]["lr"]}')
    scheduler.step()


train_acc epoch: 0: 0.5878000259399414
valid_acc epoch: 0: 0.6797999739646912
LR epoch: 0: 0.05
train_acc epoch: 1: 0.633233368396759
valid_acc epoch: 1: 0.6897000074386597
LR epoch: 1: 0.05
train_acc epoch: 2: 0.6489666700363159
valid_acc epoch: 2: 0.6940000057220459
LR epoch: 2: 0.05
train_acc epoch: 3: 0.65420001745224
valid_acc epoch: 3: 0.6692000031471252
LR epoch: 3: 0.05
train_acc epoch: 4: 0.661300003528595
valid_acc epoch: 4: 0.7084999680519104
LR epoch: 4: 0.05
train_acc epoch: 5: 0.6640166640281677
valid_acc epoch: 5: 0.6879000067710876
LR epoch: 5: 0.05
train_acc epoch: 6: 0.6674166917800903
valid_acc epoch: 6: 0.6807000041007996
LR epoch: 6: 0.05
train_acc epoch: 7: 0.6715666651725769
valid_acc epoch: 7: 0.6847000122070312
LR epoch: 7: 0.05
train_acc epoch: 8: 0.6741666793823242
valid_acc epoch: 8: 0.7106999754905701
LR epoch: 8: 0.05
train_acc epoch: 9: 0.6789166927337646
valid_acc epoch: 9: 0.7128999829292297
LR epoch: 9: 0.05
train_acc epoch: 10: 0.7228666543960571
vali

In [12]:
writer.close()