In [1]:
import os
from config import *
from tqdm import tqdm
from utils.dataset import DrawingsDataset

import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def resnet50():
    model = models.resnet50(pretrained=False)
    
    conv1_out_channels = model.conv1.out_channels
    model.conv1 = nn.Conv2d(1, conv1_out_channels, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.MaxPool2d(kernel_size=2)
    fc_features = model.fc.in_features
    model.fc = nn.Linear(fc_features, len(CLASSES))
    
    return model

In [3]:
train_data = DrawingsDataset(mtype="train")
train_loader = torch.utils.data.DataLoader(train_data, batch_size=MODEL_CFG['batch_size'], shuffle=True)

test_data = DrawingsDataset(mtype="test")
test_loader = torch.utils.data.DataLoader(test_data, batch_size=MODEL_CFG['batch_size'], shuffle=True)

print("Train images: %d" % len(train_data))
print("Test images: %d" % len(test_data))

net = resnet50().cuda()
optimizer = torch.optim.SGD(net.parameters(), 0.1, momentum=MODEL_CFG['momentum'],
                            weight_decay=MODEL_CFG['weight_decay'])

train_accuracy = 0.0
test_accuracy = 0.0
best_accuracy = 0.0
    
def train():
    net.train()
    loss_avg = 0.0
    correct = 0
    
    data_loader = tqdm(train_loader, desc='Training')
    for batch_idx, (drawings, labels) in enumerate(data_loader):
        drawings, labels = torch.autograd.Variable(drawings.cuda()), torch.autograd.Variable(labels.cuda())
        drawings = drawings.view(-1, 1, IMAGE_SIZE, IMAGE_SIZE)
        drawings /= 255.0

        # forward
        output = net(drawings)

        # backward
        optimizer.zero_grad()
        loss = F.cross_entropy(output, labels)
        loss.backward()
        optimizer.step()

        # accuracy
        pred = output.data.max(1)[1]
        correct += float(pred.eq(labels.data).sum())

        # exp moving average
        loss_avg = loss_avg*0.2+float(loss)*0.8

    global train_accuracy
    train_accuracy = correct/len(train_loader.dataset)

def test():
    net.eval()
    loss_avg = 0.0
    correct = 0
    
    data_loader = tqdm(test_loader, desc='Testing')
    for batch_idx, (drawings, labels) in enumerate(data_loader):
        drawings, labels = torch.autograd.Variable(drawings.cuda()), torch.autograd.Variable(labels.cuda())
        
        drawings = drawings.view(-1, 1, IMAGE_SIZE, IMAGE_SIZE)
        drawings /= 255.0

        # forward
        output = net(drawings)
        loss = F.cross_entropy(output, labels)

        # accuracy
        pred = output.data.max(1)[1]
        correct += float(pred.eq(labels.data).sum())

        # test loss average
        loss_avg += float(loss)

    print(f'test loss: {loss_avg/len(test_loader)}')
    
    global test_accuracy
    test_accuracy = correct/len(test_loader.dataset)
    
for epoch in range(MODEL_CFG['epochs']):
    print("epoch: "+str(epoch+1))
    
    if epoch+1 in MODEL_CFG['lr_decay_step']:
        MODEL_CFG['learning_rate'] *= MODEL_CFG['gamma']
        for param_group in optimizer.param_groups:
            param_group['lr'] = MODEL_CFG['learning_rate']
            
    train()
    test()
    
    if test_accuracy > best_accuracy:
        best_accuracy = test_accuracy
        torch.save(net.state_dict(), os.path.join(MODELS_DIR, 'model.pth'))
        print("Best accuracy: %.4f" % best_accuracy)

Train images: 40000
Test images: 10000


Training:   0%|          | 0/313 [00:00<?, ?it/s]

epoch: 1


Training: 100%|██████████| 313/313 [01:40<00:00,  3.11it/s]
Testing: 100%|██████████| 79/79 [00:06<00:00, 11.63it/s]


test loss: 1.9616224795957156


Training:   0%|          | 0/313 [00:00<?, ?it/s]

Best accuracy: 0.2312
epoch: 2


Training: 100%|██████████| 313/313 [01:40<00:00,  3.11it/s]
Testing: 100%|██████████| 79/79 [00:07<00:00, 11.24it/s]


test loss: 0.8274026437650753


Training:   0%|          | 0/313 [00:00<?, ?it/s]

Best accuracy: 0.7391
epoch: 3


Training: 100%|██████████| 313/313 [01:40<00:00,  3.11it/s]
Testing: 100%|██████████| 79/79 [00:06<00:00, 11.73it/s]


test loss: 0.5046511835689786


Training:   0%|          | 0/313 [00:00<?, ?it/s]

Best accuracy: 0.8448
epoch: 4


Training: 100%|██████████| 313/313 [01:40<00:00,  3.12it/s]
Testing: 100%|██████████| 79/79 [00:06<00:00, 11.63it/s]


test loss: 0.3468429438675506


Training:   0%|          | 0/313 [00:00<?, ?it/s]

Best accuracy: 0.8924
epoch: 5


Training: 100%|██████████| 313/313 [01:40<00:00,  3.11it/s]
Testing: 100%|██████████| 79/79 [00:06<00:00, 11.60it/s]


test loss: 0.3169728123311755


Training:   0%|          | 0/313 [00:00<?, ?it/s]

Best accuracy: 0.9026
epoch: 6


Training: 100%|██████████| 313/313 [01:40<00:00,  3.11it/s]
Testing: 100%|██████████| 79/79 [00:06<00:00, 11.62it/s]


test loss: 0.26243886691105517


Training:   0%|          | 0/313 [00:00<?, ?it/s]

Best accuracy: 0.9229
epoch: 7


Training: 100%|██████████| 313/313 [01:41<00:00,  3.08it/s]
Testing: 100%|██████████| 79/79 [00:06<00:00, 11.64it/s]


test loss: 0.2190104786165153


Training:   0%|          | 0/313 [00:00<?, ?it/s]

Best accuracy: 0.9355
epoch: 8


Training: 100%|██████████| 313/313 [01:42<00:00,  3.06it/s]
Testing: 100%|██████████| 79/79 [00:07<00:00, 11.28it/s]
Training:   0%|          | 0/313 [00:00<?, ?it/s]

test loss: 0.3324043126423148
epoch: 9


Training: 100%|██████████| 313/313 [01:42<00:00,  3.06it/s]
Testing: 100%|██████████| 79/79 [00:06<00:00, 11.63it/s]
Training:   0%|          | 0/313 [00:00<?, ?it/s]

test loss: 0.24467892416670353
epoch: 10


Training: 100%|██████████| 313/313 [01:42<00:00,  3.05it/s]
Testing: 100%|██████████| 79/79 [00:06<00:00, 11.64it/s]

test loss: 0.2752262195454368



