In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as tt

import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm

In [34]:
#better change these parameters
N_NETWORKS = 100
BATCH_SIZE = 32
N_EPOCHS   = 200
STAT_STEP = 100

In [35]:
!rm ./notMNIST_large/D/VHJhbnNpdCBCb2xkLnR0Zg==.png
!rm ./notMNIST_large/B/TmlraXNFRi1TZW1pQm9sZEl0YWxpYy5vdGY=.png
!rm ./notMNIST_large/A/RnJlaWdodERpc3BCb29rSXRhbGljLnR0Zg==.png
!rm ./notMNIST_large/A/Um9tYW5hIEJvbGQucGZi.png
!rm ./notMNIST_large/A/SG90IE11c3RhcmQgQlROIFBvc3Rlci50dGY=.png

In [36]:
TRAIN_DATA_PATH = './notMNIST_large'
MODEL_STORE_PATH = './models_notMNIST/model_'
model_paths = []
for i in range(N_NETWORKS):
    model_paths.append(MODEL_STORE_PATH + str(i) + '.pt')
    file = open(model_paths[i], 'w+')

In [37]:
#transform data to tensor and normalize (values state for MNIST!)
trans = tt.Compose([tt.Grayscale(), tt.ToTensor()]) 

train_dataset = torchvision.datasets.ImageFolder(root=TRAIN_DATA_PATH, transform=trans) 

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) 

In [38]:
#preprocessing

# item = iter(train_loader)

# i = 0
# try:
#     while 1:
#         print(i)
#         i+= 1
#         try:
#             item.__next__()
#         except IOError as e:
#             print("found broken image", e)
# except StopIteration:
#     print("All data is allright")

In [39]:
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
        
    def forward(self, x):
        return x.view(BATCH_SIZE, -1)

In [40]:
class FConvMNIST(nn.Module):
    def __init__(self, seq, clf):
        super(FConvMNIST, self).__init__()
        self.seq = seq
        self.clf = clf
        
    def forward(self, x):
        return self.clf(self.seq(x)) 

In [41]:
#architecture of the network copied from the article
seq = nn.Sequential(
    nn.Conv2d(1, 256, kernel_size=(7, 7), stride=(1, 1)),
    nn.LeakyReLU(negative_slope=0.01),
    nn.MaxPool2d(kernel_size=2,stride=2, padding=0, dilation=1, ceil_mode=False),
    nn.Conv2d(256, 512, kernel_size=(5, 5), stride=(1, 1)),
    nn.LeakyReLU(negative_slope=0.01),
    nn.MaxPool2d(kernel_size=2,stride=2, padding=0, dilation=1, ceil_mode=False),
    Flatten()
)
clf = nn.Linear(in_features=4608, out_features=10, bias=True)

In [42]:
model = FConvMNIST(seq, clf)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)

In [None]:
loss_list = []
acc_list  = []
N_STEPS = len(train_loader)
print(N_STEPS)

#train network
for i in range(N_NETWORKS):
    for j in range(N_EPOCHS):
        for k, (images, labels) in enumerate(train_loader):
            if (images.shape[0] < BATCH_SIZE):
                break
            #forward
            pred = model(images)
            loss = criterion(pred, labels)
            loss_list.append(loss.item())
            
            #backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
            #statictics
            total = labels.size(0)
            _, predicted = torch.max(pred.data, 1)
            
            correct = (predicted == labels).sum().item()
            acc_list.append(correct / total)

            if k % STAT_STEP == 0:
                print('Network {}/{} Epoch {}/{}, Step {}/{}, Loss: {}, Accuracy: {}%'
                      .format(i, N_NETWORKS, j, N_EPOCHS, k, N_STEPS, loss.item(),
                              (sum(acc_list) / len(acc_list)) * 100))
                
    torch.save(model.state_dict(), model_paths[i])
    close(model_paths[i], "w")
    model = FConvMNIST(seq, clf)
    acc_list = []
    loss_list = []
    

16535
Network 0/100 Epoch 0/200, Step 0/16535, Loss: 2.330111503601074, Accuracy: 3.125%
Network 0/100 Epoch 0/200, Step 10/16535, Loss: 1.4021883010864258, Accuracy: 28.40909090909091%
Network 0/100 Epoch 0/200, Step 20/16535, Loss: 1.2587320804595947, Accuracy: 40.476190476190474%
Network 0/100 Epoch 0/200, Step 30/16535, Loss: 1.145156741142273, Accuracy: 47.07661290322581%
Network 0/100 Epoch 0/200, Step 40/16535, Loss: 1.0366491079330444, Accuracy: 51.0670731707317%
Network 0/100 Epoch 0/200, Step 50/16535, Loss: 1.0923550128936768, Accuracy: 54.595588235294116%
Network 0/100 Epoch 0/200, Step 60/16535, Loss: 0.8272674679756165, Accuracy: 57.120901639344254%
Network 0/100 Epoch 0/200, Step 70/16535, Loss: 0.8831173181533813, Accuracy: 59.198943661971825%
Network 0/100 Epoch 0/200, Step 80/16535, Loss: 0.8665003776550293, Accuracy: 60.60956790123457%
Network 0/100 Epoch 0/200, Step 90/16535, Loss: 0.7365061044692993, Accuracy: 62.019230769230774%
Network 0/100 Epoch 0/200, Step 100