In [1]:
import torch
import os
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from training_data import CombinedData
from PIL import Image
from matplotlib import pyplot as plt

data_train = CombinedData('HASY')
data_test = CombinedData('HASY', train=False)

print("Train data length: {0}".format(len(data_train.data)))
print("Test data length: {0}".format(len(data_test.data)))
print("Img Shape: {0}".format(data_train.data[0].shape))
print("Number of Labels: {0}".format(data_train.no_labels))

100%|██████████| 151241/151241 [00:00<00:00, 234716.00it/s]
100%|██████████| 60000/60000 [00:08<00:00, 6923.07it/s]
100%|██████████| 60000/60000 [00:00<00:00, 359589.34it/s]
100%|██████████| 16992/16992 [00:00<00:00, 224735.95it/s]
100%|██████████| 10000/10000 [00:01<00:00, 6830.53it/s]
100%|██████████| 10000/10000 [00:00<00:00, 367399.31it/s]

Train data length: 61958
Test data length: 10224
Img Shape: torch.Size([1, 28, 28])
Number of Labels: 15





In [None]:
# Train with a pretrained model from PyTorch
from torchvision import models
from torch.nn import Conv2d

pretrained = False
torch_model = models.densenet201(num_classes=15)
torch_model.features.conv0 = Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
if pretrained:
    torch_model.load_state_dict(torch.load('pretrained-model-01.ckpt'))
epochs = 10
optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001, betas=(0.8, 0.925), weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

def calc_accuracy(model):
    accuracies = []
    for idx, [x_test, y_test] in enumerate(tqdm(test_loader)):
        test_pred = model(x_test)
        accuracy = 100 * torch.mean((torch.argmax(test_pred, dim=1) == y_test).float())
        accuracies.append(accuracy)
    print("Accuracy: {0}".format(np.mean(accuracies)))  

train_loader = DataLoader(data_train, batch_size=16, shuffle=True)
test_loader = DataLoader(data_test, batch_size=16, shuffle=False)

for epoch in range(epochs):
        print("Epoch {0}".format(epoch))
        calc_accuracy(torch_model)
        for step, [x_train, y_train] in enumerate(tqdm(train_loader)):
            optimizer.zero_grad()
            train_pred = torch_model(x_train)
            loss = criterion(train_pred, y_train)
            loss.backward()
            optimizer.step()
            if step % 100 == 0:
                print('Loss: {}'.format(loss))
torch.save(torch_model.state_dict(), 'combined-model.ckpt')

  0%|          | 1/639 [00:00<02:00,  5.30it/s]

Epoch 0


100%|██████████| 639/639 [02:51<00:00,  3.98it/s]
  0%|          | 0/3873 [00:00<?, ?it/s]

Accuracy: 5.3012518882751465


  0%|          | 1/3873 [00:01<1:12:55,  1.13s/it]

Loss: 2.8040692806243896


  3%|▎         | 101/3873 [01:42<59:05,  1.06it/s] 

Loss: 2.139547348022461


  5%|▌         | 201/3873 [03:26<1:04:34,  1.06s/it]

Loss: 0.854712724685669


  8%|▊         | 301/3873 [05:08<1:02:50,  1.06s/it]

Loss: 1.0882377624511719


 10%|█         | 401/3873 [06:51<58:44,  1.02s/it]  

Loss: 0.5661975741386414


 13%|█▎        | 501/3873 [08:33<57:13,  1.02s/it]  

Loss: 0.26401156187057495


 16%|█▌        | 601/3873 [10:15<56:58,  1.04s/it]

Loss: 0.31051433086395264


 18%|█▊        | 701/3873 [11:58<53:51,  1.02s/it]

Loss: 0.5047825574874878


 21%|██        | 801/3873 [13:40<51:53,  1.01s/it]

Loss: 0.17832981050014496


 23%|██▎       | 901/3873 [15:22<51:09,  1.03s/it]

Loss: 0.15005695819854736


 26%|██▌       | 1001/3873 [17:04<49:06,  1.03s/it]

Loss: 0.4251668453216553


 28%|██▊       | 1101/3873 [18:55<52:39,  1.14s/it]

Loss: 0.1380361020565033


 31%|███       | 1201/3873 [20:57<44:31,  1.00it/s]  

Loss: 0.3740825355052948


 34%|███▎      | 1301/3873 [22:36<48:41,  1.14s/it]

Loss: 0.1535593420267105


 36%|███▌      | 1401/3873 [24:08<36:25,  1.13it/s]

Loss: 0.3227876126766205


 39%|███▉      | 1501/3873 [26:10<1:00:30,  1.53s/it]

Loss: 0.23552048206329346


 41%|████▏     | 1601/3873 [28:26<33:04,  1.15it/s]  

Loss: 0.3329312205314636


 44%|████▍     | 1701/3873 [29:52<29:20,  1.23it/s]

Loss: 0.23281103372573853


 47%|████▋     | 1801/3873 [31:20<30:31,  1.13it/s]

Loss: 0.5235540866851807


 49%|████▉     | 1901/3873 [33:02<42:20,  1.29s/it]

Loss: 0.18608656525611877


 52%|█████▏    | 2001/3873 [34:51<28:35,  1.09it/s]

Loss: 0.33666473627090454


 54%|█████▍    | 2101/3873 [36:21<26:38,  1.11it/s]

Loss: 0.0869893878698349


 57%|█████▋    | 2201/3873 [38:01<26:42,  1.04it/s]

Loss: 0.12622679769992828


 59%|█████▉    | 2301/3873 [39:47<28:06,  1.07s/it]

Loss: 0.458736777305603


 62%|██████▏   | 2383/3873 [41:08<24:45,  1.00it/s]

In [None]:
print(data_train.data)