In [None]:
# Create MNIST data arrays
%run ./generate_mnist_dataset.ipynb

In [1]:
# Load and prepare the data

import torch
import os
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from training_data import DataCollection
from PIL import Image
from matplotlib import pyplot as plt

def print_data_infos(data_train, data_test):
    print("Train data length: {0}".format(len(data_train.data)))
    print("Test data length: {0}".format(len(data_test.data)))
    print("Img Shape: {0}".format(data_train.data[0].shape))
    print("Number of Labels: {0}".format(data_train.no_labels))
    
data_all_train = DataCollection()
data_all_test = DataCollection(train=False)

data_ops_train = DataCollection(use_hasy=False, use_mnist=False, own_path='plus-min-div')
data_ops_test = DataCollection(use_hasy=False, use_mnist=False, own_path='plus-min-div', train=False)

data_brckts_train = DataCollection(use_hasy=False, use_mnist=False, own_path='plus-brckts')
data_brckts_test = DataCollection(use_hasy=False, use_mnist=False, own_path='plus-brckts', train=False)

print_data_infos(data_all_train, data_all_test)
print_data_infos(data_ops_train, data_ops_test)
print_data_infos(data_brckts_train, data_brckts_test)


100%|██████████| 151241/151241 [00:00<00:00, 813804.36it/s]
100%|██████████| 60000/60000 [00:06<00:00, 9041.95it/s] 
100%|██████████| 60000/60000 [00:00<00:00, 368102.23it/s]
100%|██████████| 16992/16992 [00:00<00:00, 525432.13it/s]
 10%|▉         | 999/10000 [00:00<00:00, 9983.06it/s]

No training data for ). Skipping


100%|██████████| 10000/10000 [00:01<00:00, 8790.98it/s]
100%|██████████| 10000/10000 [00:00<00:00, 292638.79it/s]


No training data for ). Skipping
No training data for 0. Skipping
No training data for 1. Skipping
No training data for 2. Skipping
No training data for 3. Skipping
No training data for 4. Skipping
No training data for 5. Skipping
No training data for 6. Skipping
No training data for 7. Skipping
No training data for 8. Skipping
No training data for 9. Skipping
No training data for brckts. Skipping
No training data for ). Skipping
No training data for 0. Skipping
No training data for 1. Skipping
No training data for 2. Skipping
No training data for 3. Skipping
No training data for 4. Skipping
No training data for 5. Skipping
No training data for 6. Skipping
No training data for 7. Skipping
No training data for 8. Skipping
No training data for 9. Skipping
No training data for brckts. Skipping
No training data for ). Skipping
No training data for 0. Skipping
No training data for 1. Skipping
No training data for 2. Skipping
No training data for 3. Skipping
No training data for 4. Skipping


In [2]:
# Declare the network and some utilities

from torchvision import models
from torch.nn import Conv2d


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train(train_loader, test_loader, model_name, print_step, num_classes=15, epochs=5):
    model = models.alexnet(num_classes=num_classes)
    model.features[0] = Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    if torch.cuda.is_available():
        model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.8, 0.99), weight_decay=0.001)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        print("Epoch {0}".format(epoch))
        for step, [x_train, y_train] in enumerate(tqdm(train_loader)):
            if torch.cuda.is_available():
                 x_train, y_train = x_train.to(device), y_train.to(device)
            optimizer.zero_grad()
            train_pred = model(x_train)
            loss = criterion(train_pred, y_train)
            loss.backward()
            optimizer.step()
            if step % print_step == 0:
                print('Loss: {}'.format(loss))
        
        acc = calc_accuracy(model, test_loader)
        print("Accuracy: {0}".format(acc))
        if acc > 98:
            torch.save(model.state_dict(), '{0}-{1}.ckpt'.format(model_name,acc))
    print("Accuracy: {0}".format(acc))
    torch.save(model.state_dict(), '{0}.ckpt'.format(model_name))

def calc_accuracy(model, test_loader):
    accuracies = []
    for idx, [x_test, y_test] in enumerate(tqdm(test_loader)):
        if torch.cuda.is_available():
            x_test, y_test = x_test.to(device), y_test.to(device)
        test_pred = model(x_test)
        accuracy = 100 * torch.mean((torch.argmax(test_pred, dim=1) == y_test).float())
        accuracies.append(accuracy.item() if torch.cuda.is_available() else accuracy)
    return np.mean(accuracies)  

train_all_loader = DataLoader(data_all_train, batch_size=16, shuffle=True)
test_all_loader = DataLoader(data_all_test, batch_size=16, shuffle=False)

train_ops_loader = DataLoader(data_ops_train, batch_size=16, shuffle=True)
test_ops_loader = DataLoader(data_ops_test, batch_size=16, shuffle=False)

train_brckts_loader = DataLoader(data_brckts_train, batch_size=16, shuffle=True)
test_brckts_loader = DataLoader(data_brckts_test, batch_size=16, shuffle=False)



In [None]:
train(train_all_loader, test_all_loader, 'model-all-symbols', 500)
train(train_ops_loader, test_ops_loader, 'model-plus-minus-div', 60)
train(train_brckts_loader, test_brckts_loader, 'model-plus-brackets', 60)

In [3]:
train_no_strokes = DataCollection(own_path='digits-plus-brackets', no_strokes=True)
test_no_strokes = DataCollection(own_path='digits-plus-brackets', train=False, no_strokes=True)
print_data_infos(train_no_strokes, test_no_strokes)

train_no_strokes_loader = DataLoader(train_no_strokes, batch_size=16, shuffle=True)
test_no_strokes_loader = DataLoader(test_no_strokes, batch_size=16, shuffle=False)

train(train_no_strokes_loader, test_no_strokes_loader, 'model-no_strokes', 500, num_classes=13)

100%|██████████| 151241/151241 [00:00<00:00, 799802.72it/s]
100%|██████████| 60000/60000 [00:06<00:00, 9473.36it/s] 
100%|██████████| 60000/60000 [00:00<00:00, 342260.38it/s]


No training data for -. Skipping
No training data for div. Skipping


100%|██████████| 16992/16992 [00:00<00:00, 765812.92it/s]
 10%|▉         | 983/10000 [00:00<00:00, 9826.14it/s]

No training data for ). Skipping


100%|██████████| 10000/10000 [00:01<00:00, 9435.66it/s]
100%|██████████| 10000/10000 [00:00<00:00, 227711.22it/s]


No training data for -. Skipping
No training data for div. Skipping
No training data for ). Skipping
Train data length: 70388
Test data length: 13912
Img Shape: torch.Size([1, 32, 32])
Number of Labels: 13


  0%|          | 0/4400 [00:00<?, ?it/s]

Epoch 0


  0%|          | 1/4400 [00:00<52:47,  1.39it/s]

Loss: 2.561375379562378


 11%|█▏        | 501/4400 [05:19<36:42,  1.77it/s]  

Loss: 0.7594608664512634


 23%|██▎       | 1001/4400 [14:59<1:10:39,  1.25s/it]

Loss: 0.16282838582992554


 34%|███▍      | 1501/4400 [23:01<42:38,  1.13it/s]  

Loss: 0.027018070220947266


 45%|████▌     | 2001/4400 [30:16<32:53,  1.22it/s]

Loss: 0.07095861434936523


 57%|█████▋    | 2501/4400 [38:38<30:09,  1.05it/s]

Loss: 0.05632273107767105


 68%|██████▊   | 3001/4400 [46:15<21:58,  1.06it/s]

Loss: 0.6697338819503784


 80%|███████▉  | 3501/4400 [54:03<12:42,  1.18it/s]

Loss: 0.060564074665308


 91%|█████████ | 4001/4400 [1:00:56<05:33,  1.20it/s]

Loss: 0.00879921019077301


100%|██████████| 4400/4400 [1:06:24<00:00,  1.26it/s]
100%|██████████| 870/870 [00:52<00:00, 16.81it/s]
  0%|          | 0/4400 [00:00<?, ?it/s]

Accuracy: 93.57040405273438
Epoch 1


  0%|          | 1/4400 [00:00<1:03:03,  1.16it/s]

Loss: 0.23687896132469177


 11%|█▏        | 501/4400 [06:51<53:32,  1.21it/s] 

Loss: 0.10531184822320938


 23%|██▎       | 1001/4400 [13:40<46:23,  1.22it/s]

Loss: 0.008107677102088928


 34%|███▍      | 1501/4400 [20:23<37:21,  1.29it/s]

Loss: 0.08047392219305038


 45%|████▌     | 2001/4400 [26:59<32:03,  1.25it/s]

Loss: 0.29267072677612305


 57%|█████▋    | 2501/4400 [33:46<25:08,  1.26it/s]

Loss: 0.09119364619255066


 68%|██████▊   | 3001/4400 [40:19<18:13,  1.28it/s]

Loss: 0.14240600168704987


 80%|███████▉  | 3501/4400 [46:49<11:28,  1.31it/s]

Loss: 0.09061107039451599


 91%|█████████ | 4001/4400 [53:23<05:19,  1.25it/s]

Loss: 0.002049356698989868


100%|██████████| 4400/4400 [58:46<00:00,  1.25it/s]
100%|██████████| 870/870 [00:56<00:00, 16.36it/s]
  0%|          | 0/4400 [00:00<?, ?it/s]

Accuracy: 89.712646484375
Epoch 2


  0%|          | 1/4400 [00:00<1:03:18,  1.16it/s]

Loss: 0.006799433380365372


 11%|█▏        | 501/4400 [06:40<51:29,  1.26it/s] 

Loss: 0.06629476696252823


 23%|██▎       | 1001/4400 [13:16<44:30,  1.27it/s]

Loss: 0.14119109511375427


 34%|███▍      | 1501/4400 [19:52<37:37,  1.28it/s]

Loss: 0.004509627819061279


 45%|████▌     | 2001/4400 [26:27<31:58,  1.25it/s]

Loss: 0.3568568527698517


 57%|█████▋    | 2501/4400 [32:58<24:14,  1.31it/s]

Loss: 0.11677229404449463


 68%|██████▊   | 3001/4400 [39:24<18:12,  1.28it/s]

Loss: 0.062406666576862335


 80%|███████▉  | 3501/4400 [46:00<12:07,  1.24it/s]

Loss: 0.2244892120361328


 91%|█████████ | 4001/4400 [52:28<05:03,  1.32it/s]

Loss: 0.0011208504438400269


100%|██████████| 4400/4400 [57:35<00:00,  1.33it/s]
100%|██████████| 870/870 [00:49<00:00, 17.46it/s]
  0%|          | 0/4400 [00:00<?, ?it/s]

Accuracy: 93.72126770019531
Epoch 3


  0%|          | 1/4400 [00:00<57:48,  1.27it/s]

Loss: 0.29021593928337097


 11%|█▏        | 501/4400 [06:27<50:49,  1.28it/s]

Loss: 0.03896717727184296


 23%|██▎       | 1001/4400 [13:06<45:16,  1.25it/s]

Loss: 0.3410634994506836


 34%|███▍      | 1501/4400 [19:42<38:07,  1.27it/s]

Loss: 0.14641141891479492


 45%|████▌     | 2001/4400 [26:08<30:38,  1.30it/s]

Loss: 0.10902576893568039


 57%|█████▋    | 2501/4400 [32:39<24:40,  1.28it/s]

Loss: 0.08757573366165161


 68%|██████▊   | 3001/4400 [39:17<18:31,  1.26it/s]

Loss: 0.09684185683727264


 80%|███████▉  | 3501/4400 [45:36<11:05,  1.35it/s]

Loss: 0.09042443335056305


 91%|█████████ | 4001/4400 [51:52<05:01,  1.32it/s]

Loss: 0.16747312247753143


100%|██████████| 4400/4400 [57:00<00:00,  1.32it/s]
100%|██████████| 870/870 [00:49<00:00, 17.50it/s]
  0%|          | 0/4400 [00:00<?, ?it/s]

Accuracy: 94.21695709228516
Epoch 4


  0%|          | 1/4400 [00:00<58:59,  1.24it/s]

Loss: 0.0007902234792709351


 11%|█▏        | 501/4400 [06:23<49:39,  1.31it/s]

Loss: 0.025169074535369873


 23%|██▎       | 1001/4400 [12:55<45:07,  1.26it/s]

Loss: 0.01474711298942566


 34%|███▍      | 1501/4400 [19:27<37:14,  1.30it/s]

Loss: 0.013735935091972351


 45%|████▌     | 2001/4400 [25:43<30:20,  1.32it/s]

Loss: 0.12396945804357529


 57%|█████▋    | 2501/4400 [32:07<24:12,  1.31it/s]

Loss: 0.25853127241134644


 68%|██████▊   | 3001/4400 [38:31<17:56,  1.30it/s]

Loss: 0.008723840117454529


 80%|███████▉  | 3501/4400 [45:11<12:15,  1.22it/s]

Loss: 0.18409977853298187


 91%|█████████ | 4001/4400 [51:45<05:13,  1.27it/s]

Loss: 0.05311565473675728


100%|██████████| 4400/4400 [56:52<00:00,  1.34it/s]
100%|██████████| 870/870 [00:48<00:00, 17.93it/s]


Accuracy: 92.67241668701172
Accuracy: 92.67241668701172
