In [None]:
# Create MNIST data arrays
%run ./generate_mnist_dataset.ipynb

In [1]:
# Load and prepare the data

import torch
import os
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from training_data import DataCollection
from PIL import Image
from matplotlib import pyplot as plt

def print_data_infos(data_train, data_test):
    print("Train data length: {0}".format(len(data_train.data)))
    print("Test data length: {0}".format(len(data_test.data)))
    print("Img Shape: {0}".format(data_train.data[0].shape))
    print("Number of Labels: {0}".format(data_train.no_labels))
    
data_all_train = DataCollection()
data_all_test = DataCollection(train=False)

data_ops_train = DataCollection(use_hasy=False, use_mnist=False, own_path='plus-min-div')
data_ops_test = DataCollection(use_hasy=False, use_mnist=False, own_path='plus-min-div', train=False)

data_brckts_train = DataCollection(use_hasy=False, use_mnist=False, own_path='plus-brckts')
data_brckts_test = DataCollection(use_hasy=False, use_mnist=False, own_path='plus-brckts', train=False)

print_data_infos(data_all_train, data_all_test)
print_data_infos(data_ops_train, data_ops_test)
print_data_infos(data_brckts_train, data_brckts_test)


100%|██████████| 151241/151241 [00:00<00:00, 813804.36it/s]
100%|██████████| 60000/60000 [00:06<00:00, 9041.95it/s] 
100%|██████████| 60000/60000 [00:00<00:00, 368102.23it/s]
100%|██████████| 16992/16992 [00:00<00:00, 525432.13it/s]
 10%|▉         | 999/10000 [00:00<00:00, 9983.06it/s]

No training data for ). Skipping


100%|██████████| 10000/10000 [00:01<00:00, 8790.98it/s]
100%|██████████| 10000/10000 [00:00<00:00, 292638.79it/s]


No training data for ). Skipping
No training data for 0. Skipping
No training data for 1. Skipping
No training data for 2. Skipping
No training data for 3. Skipping
No training data for 4. Skipping
No training data for 5. Skipping
No training data for 6. Skipping
No training data for 7. Skipping
No training data for 8. Skipping
No training data for 9. Skipping
No training data for brckts. Skipping
No training data for ). Skipping
No training data for 0. Skipping
No training data for 1. Skipping
No training data for 2. Skipping
No training data for 3. Skipping
No training data for 4. Skipping
No training data for 5. Skipping
No training data for 6. Skipping
No training data for 7. Skipping
No training data for 8. Skipping
No training data for 9. Skipping
No training data for brckts. Skipping
No training data for ). Skipping
No training data for 0. Skipping
No training data for 1. Skipping
No training data for 2. Skipping
No training data for 3. Skipping
No training data for 4. Skipping


In [2]:
# Declare the network and some utilities

from torchvision import models
from torch.nn import Conv2d


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train(train_loader, test_loader, model_name, print_step, num_classes=15, epochs=5):
    model = models.alexnet(num_classes=num_classes)
    model.features[0] = Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    if torch.cuda.is_available():
        model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.8, 0.99), weight_decay=0.001)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        print("Epoch {0}".format(epoch))
        for step, [x_train, y_train] in enumerate(tqdm(train_loader)):
            if torch.cuda.is_available():
                 x_train, y_train = x_train.to(device), y_train.to(device)
            optimizer.zero_grad()
            train_pred = model(x_train)
            loss = criterion(train_pred, y_train)
            loss.backward()
            optimizer.step()
            if step % print_step == 0:
                print('Loss: {}'.format(loss))
        
        acc = calc_accuracy(model, test_loader)
        print("Accuracy: {0}".format(acc))
        if acc > 98:
            torch.save(model.state_dict(), '{0}-{1}.ckpt'.format(model_name,acc))
    print("Accuracy: {0}".format(acc))
    torch.save(model.state_dict(), '{0}.ckpt'.format(model_name))

def calc_accuracy(model, test_loader):
    accuracies = []
    for idx, [x_test, y_test] in enumerate(tqdm(test_loader)):
        if torch.cuda.is_available():
            x_test, y_test = x_test.to(device), y_test.to(device)
        test_pred = model(x_test)
        accuracy = 100 * torch.mean((torch.argmax(test_pred, dim=1) == y_test).float())
        accuracies.append(accuracy.item() if torch.cuda.is_available() else accuracy)
    return np.mean(accuracies)  

train_all_loader = DataLoader(data_all_train, batch_size=16, shuffle=True)
test_all_loader = DataLoader(data_all_test, batch_size=16, shuffle=False)

train_ops_loader = DataLoader(data_ops_train, batch_size=16, shuffle=True)
test_ops_loader = DataLoader(data_ops_test, batch_size=16, shuffle=False)

train_brckts_loader = DataLoader(data_brckts_train, batch_size=16, shuffle=True)
test_brckts_loader = DataLoader(data_brckts_test, batch_size=16, shuffle=False)



In [None]:
train(train_all_loader, test_all_loader, 'model-all-symbols', 500)
train(train_ops_loader, test_ops_loader, 'model-plus-minus-div', 60)
train(train_brckts_loader, test_brckts_loader, 'model-plus-brackets', 60)

In [None]:
train_no_strokes = DataCollection(own_path='digits-plus-brackets', no_strokes=True)
test_no_strokes = DataCollection(own_path='digits-plus-brackets', train=False, no_strokes=True)
print_data_infos(train_no_strokes, test_no_strokes)

train_no_strokes_loader = DataLoader(train_no_strokes, batch_size=16, shuffle=True)
test_no_strokes_loader = DataLoader(test_no_strokes, batch_size=16, shuffle=False)

train(train_no_strokes_loader, test_no_strokes_loader, 'model-no_strokes', 500, num_classes=13)

100%|██████████| 151241/151241 [00:00<00:00, 799802.72it/s]
100%|██████████| 60000/60000 [00:06<00:00, 9473.36it/s] 
100%|██████████| 60000/60000 [00:00<00:00, 342260.38it/s]


No training data for -. Skipping
No training data for div. Skipping


100%|██████████| 16992/16992 [00:00<00:00, 765812.92it/s]
 10%|▉         | 983/10000 [00:00<00:00, 9826.14it/s]

No training data for ). Skipping


100%|██████████| 10000/10000 [00:01<00:00, 9435.66it/s]
100%|██████████| 10000/10000 [00:00<00:00, 227711.22it/s]


No training data for -. Skipping
No training data for div. Skipping
No training data for ). Skipping
Train data length: 70388
Test data length: 13912
Img Shape: torch.Size([1, 32, 32])
Number of Labels: 13


  0%|          | 0/4400 [00:00<?, ?it/s]

Epoch 0


  0%|          | 1/4400 [00:00<52:47,  1.39it/s]

Loss: 2.561375379562378


  1%|          | 33/4400 [00:18<41:18,  1.76it/s]