In [1]:
import torch
import os
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from training_data import CombinedData
from PIL import Image
from matplotlib import pyplot as plt

data_train = CombinedData('HASY')
data_test = CombinedData('HASY', train=False)

print("Train data length: {0}".format(len(data_train.data)))
print("Test data length: {0}".format(len(data_test.data)))
print("Img Shape: {0}".format(data_train.data[0].shape))
print("Number of Labels: {0}".format(data_train.no_labels))

100%|██████████| 151241/151241 [00:02<00:00, 74350.53it/s]
100%|██████████| 60000/60000 [00:14<00:00, 4162.49it/s]
100%|██████████| 60000/60000 [00:00<00:00, 210981.62it/s]
100%|██████████| 16992/16992 [00:00<00:00, 79507.73it/s] 
100%|██████████| 10000/10000 [00:02<00:00, 4005.66it/s]
100%|██████████| 10000/10000 [00:00<00:00, 233256.62it/s]

Train data length: 65690
Test data length: 10644
Img Shape: torch.Size([1, 28, 28])
Number of Labels: 15





In [None]:
# Train with a pretrained model from PyTorch
from torchvision import models
from torch.nn import Conv2d

pretrained = False
torch_model = models.densenet201(num_classes=15)
torch_model.features.conv0 = Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
if pretrained:
    torch_model.load_state_dict(torch.load('pretrained-model-01.ckpt'))
epochs = 10
optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001, betas=(0.8, 0.925), weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

def calc_accuracy(model):
    accuracies = []
    for idx, [x_test, y_test] in enumerate(tqdm(test_loader)):
        test_pred = model(x_test)
        accuracy = 100 * torch.mean((torch.argmax(test_pred, dim=1) == y_test).float())
        accuracies.append(accuracy)
    return np.mean(accuracies)  

train_loader = DataLoader(data_train, batch_size=16, shuffle=True)
test_loader = DataLoader(data_test, batch_size=16, shuffle=False)

for epoch in range(epochs):
    print("Epoch {0}".format(epoch))
    acc = calc_accuracy(torch_model)
    print("Accuracy: {0}".format(acc))
    if acc > 95:
        torch.save(torch_model.state_dict(), 'combined-model-{0}.ckpt'.format(acc))
    for step, [x_train, y_train] in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        train_pred = torch_model(x_train)
        loss = criterion(train_pred, y_train)
        loss.backward()
        optimizer.step()
        if step % 100 == 0:
            print('Loss: {}'.format(loss))
torch.save(torch_model.state_dict(), 'combined-model.ckpt')

  0%|          | 0/666 [00:00<?, ?it/s]

Epoch 0


100%|██████████| 666/666 [04:27<00:00,  2.72it/s]
  0%|          | 0/4106 [00:00<?, ?it/s]

Accuracy: 7.272897720336914


  0%|          | 1/4106 [00:01<1:40:32,  1.47s/it]

Loss: 2.765979528427124


  2%|▏         | 101/4106 [02:36<1:40:30,  1.51s/it]

Loss: 1.2769800424575806


  5%|▍         | 201/4106 [05:16<2:34:13,  2.37s/it]

Loss: 0.8478285074234009


  7%|▋         | 301/4106 [08:05<1:39:03,  1.56s/it]

Loss: 0.4498547613620758


 10%|▉         | 401/4106 [10:59<2:12:40,  2.15s/it]

Loss: 0.7750710248947144


 10%|▉         | 404/4106 [11:05<2:10:03,  2.11s/it]

In [None]:
torch.save(torch_model.state_dict(), 'combined-model.ckpt')
calc_accuracy(torch_model)