In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as tt
import matplotlib.pyplot as plt

In [2]:
N_NETWORKS = 500
BATCH_SIZE = 60
N_EPOCHS   = 1

In [3]:
DATA_PATH = './notMNIST'
MODEL_STORE_PATH = './models/model_'
model_paths = []
model_files = []
for i in range(N_NETWORKS):
    model_paths.append(MODEL_STORE_PATH + str(i) + '.pt')
    model_files.append(open(model_paths[i], 'w'))

In [4]:
#transform data to tensor and normalize (values state for MNIST!)
trans = tt.Compose([tt.ToTensor(), tt.Normalize((0.1307,), (0.3081,))]) 

train_dataset = torchvision.datasets.MNIST(root=DATA_PATH, train=True,  transform=trans, download=True) 
test_dataset  = torchvision.datasets.MNIST(root=DATA_PATH, train=False, transform=trans)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) 
test_loader  = DataLoader(dataset=test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

In [5]:
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
        
    def forward(self, x):
        return x.view(BATCH_SIZE, -1)

In [6]:
class FConvMNIST(nn.Module):
    def __init__(self, seq, clf):
        super(FConvMNIST, self).__init__()
        self.seq = seq
        self.clf = clf
        
    def forward(self, x):
        return self.clf(self.seq(x)) 

In [7]:
#architecture of the network copied from the article
seq = nn.Sequential(
    nn.Conv2d(1, 256, kernel_size=(7, 7), stride=(1, 1)),
    nn.LeakyReLU(negative_slope=0.01),
    nn.MaxPool2d(kernel_size=2,stride=2, padding=0, dilation=1, ceil_mode=False),
    nn.Conv2d(256, 512, kernel_size=(5, 5), stride=(1, 1)),
    nn.LeakyReLU(negative_slope=0.01),
    nn.MaxPool2d(kernel_size=2,stride=2, padding=0, dilation=1, ceil_mode=False),
    Flatten()
)
clf = nn.Linear(in_features=4608, out_features=10, bias=True)

In [8]:
device = 'cuda'

In [9]:
model = FConvMNIST(seq, clf).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)

In [10]:
loss_list = []
acc_list  = []
N_STEPS = len(train_loader)
print(N_STEPS)

#train network
for i in range(N_NETWORKS):
    for j in range(N_EPOCHS):
        for k, (images, labels) in enumerate(train_loader):
            #forward
            labels = labels.to(device)
            pred = model(images.to(device))
            loss = criterion(pred, labels)
            loss_list.append(loss.item())
            
            #backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
            #statictics
            total = labels.size(0)
            _, predicted = torch.max(pred.data, 1)
            
            correct = (predicted == labels).sum().item()
            acc_list.append(correct / total)

            if k % 500 == 0:
                print('Network {}/{} Epoch {}/{}, Step {}/{}, Loss: {}, Accuracy: {}%'
                      .format(i, N_NETWORKS, j, N_EPOCHS, k, N_STEPS, loss.item(),
                              (sum(acc_list) / len(acc_list)) * 100))
                
    torch.save(model.state_dict(), model_paths[i])
    model = FConvMNIST(seq, clf)
    acc_list = []
    loss_list = []
    

1000
Network 0/500 Epoch 0/1, Step 0/1000, Loss: 2.328864812850952, Accuracy: 10.0%
Network 0/500 Epoch 0/1, Step 500/1000, Loss: 0.15957410633563995, Accuracy: 94.66067864271436%
Network 1/500 Epoch 0/1, Step 0/1000, Loss: 0.08129095286130905, Accuracy: 98.33333333333333%
Network 1/500 Epoch 0/1, Step 500/1000, Loss: 0.07425698637962341, Accuracy: 98.24351297405174%
Network 2/500 Epoch 0/1, Step 0/1000, Loss: 0.029152972623705864, Accuracy: 100.0%
Network 2/500 Epoch 0/1, Step 500/1000, Loss: 0.01693369634449482, Accuracy: 98.60944777112437%
Network 3/500 Epoch 0/1, Step 0/1000, Loss: 0.011718817986547947, Accuracy: 100.0%
Network 3/500 Epoch 0/1, Step 500/1000, Loss: 0.008396422490477562, Accuracy: 98.70259481037927%
Network 4/500 Epoch 0/1, Step 0/1000, Loss: 0.003955245018005371, Accuracy: 100.0%
Network 4/500 Epoch 0/1, Step 500/1000, Loss: 0.14697568118572235, Accuracy: 98.88223552894206%
Network 5/500 Epoch 0/1, Step 0/1000, Loss: 0.008318042382597923, Accuracy: 100.0%
Network 5

Network 44/500 Epoch 0/1, Step 500/1000, Loss: 0.004357361700385809, Accuracy: 99.39454424484374%
Network 45/500 Epoch 0/1, Step 0/1000, Loss: 0.01381978951394558, Accuracy: 100.0%
Network 45/500 Epoch 0/1, Step 500/1000, Loss: 0.006002123933285475, Accuracy: 99.43446440452429%
Network 46/500 Epoch 0/1, Step 0/1000, Loss: 0.003321313764899969, Accuracy: 100.0%
Network 46/500 Epoch 0/1, Step 500/1000, Loss: 0.016072798520326614, Accuracy: 99.46440452428479%
Network 47/500 Epoch 0/1, Step 0/1000, Loss: 0.002705351449549198, Accuracy: 100.0%
Network 47/500 Epoch 0/1, Step 500/1000, Loss: 0.05705731734633446, Accuracy: 99.50432468396554%
Network 48/500 Epoch 0/1, Step 0/1000, Loss: 0.003729184390977025, Accuracy: 100.0%
Network 48/500 Epoch 0/1, Step 500/1000, Loss: 0.02994094230234623, Accuracy: 99.46773120425819%
Network 49/500 Epoch 0/1, Step 0/1000, Loss: 0.002916780998930335, Accuracy: 100.0%
Network 49/500 Epoch 0/1, Step 500/1000, Loss: 0.0018274069298058748, Accuracy: 99.4610778443

Network 89/500 Epoch 0/1, Step 0/1000, Loss: 0.0007407108787447214, Accuracy: 100.0%
Network 89/500 Epoch 0/1, Step 500/1000, Loss: 0.0067326705902814865, Accuracy: 99.48436460412506%
Network 90/500 Epoch 0/1, Step 0/1000, Loss: 0.010847274214029312, Accuracy: 100.0%
Network 90/500 Epoch 0/1, Step 500/1000, Loss: 0.03388407081365585, Accuracy: 99.49101796407192%
Network 91/500 Epoch 0/1, Step 0/1000, Loss: 0.020801614969968796, Accuracy: 100.0%
Network 91/500 Epoch 0/1, Step 500/1000, Loss: 0.07063005864620209, Accuracy: 99.48436460412512%
Network 92/500 Epoch 0/1, Step 0/1000, Loss: 0.002236660337075591, Accuracy: 100.0%
Network 92/500 Epoch 0/1, Step 500/1000, Loss: 0.004226319026201963, Accuracy: 99.52095808383241%
Network 93/500 Epoch 0/1, Step 0/1000, Loss: 0.020080693066120148, Accuracy: 98.33333333333333%
Network 93/500 Epoch 0/1, Step 500/1000, Loss: 0.016483156010508537, Accuracy: 99.48769128409857%
Network 94/500 Epoch 0/1, Step 0/1000, Loss: 0.0036732673179358244, Accuracy: 

Network 133/500 Epoch 0/1, Step 0/1000, Loss: 0.0036721944343298674, Accuracy: 100.0%
Network 133/500 Epoch 0/1, Step 500/1000, Loss: 0.01414632797241211, Accuracy: 99.50099800399207%
Network 134/500 Epoch 0/1, Step 0/1000, Loss: 0.0112726129591465, Accuracy: 100.0%
Network 134/500 Epoch 0/1, Step 500/1000, Loss: 0.003055524779483676, Accuracy: 99.50432468396544%
Network 135/500 Epoch 0/1, Step 0/1000, Loss: 0.01270426157861948, Accuracy: 100.0%
Network 135/500 Epoch 0/1, Step 500/1000, Loss: 0.0011114120716229081, Accuracy: 99.45109780439128%
Network 136/500 Epoch 0/1, Step 0/1000, Loss: 0.06609079986810684, Accuracy: 98.33333333333333%
Network 136/500 Epoch 0/1, Step 500/1000, Loss: 0.008533875457942486, Accuracy: 99.46107784431142%
Network 137/500 Epoch 0/1, Step 0/1000, Loss: 0.0030562798492610455, Accuracy: 100.0%
Network 137/500 Epoch 0/1, Step 500/1000, Loss: 0.054593149572610855, Accuracy: 99.47438456420495%
Network 138/500 Epoch 0/1, Step 0/1000, Loss: 0.022218545898795128, Ac

Network 177/500 Epoch 0/1, Step 0/1000, Loss: 0.005124012473970652, Accuracy: 100.0%
Network 177/500 Epoch 0/1, Step 500/1000, Loss: 0.028403857722878456, Accuracy: 99.52095808383234%
Network 178/500 Epoch 0/1, Step 0/1000, Loss: 0.004187361337244511, Accuracy: 100.0%
Network 178/500 Epoch 0/1, Step 500/1000, Loss: 0.011871297843754292, Accuracy: 99.57418496340664%
Network 179/500 Epoch 0/1, Step 0/1000, Loss: 0.003486029338091612, Accuracy: 100.0%
Network 179/500 Epoch 0/1, Step 500/1000, Loss: 0.0006601810455322266, Accuracy: 99.53426480372595%
Network 180/500 Epoch 0/1, Step 0/1000, Loss: 0.0010008891113102436, Accuracy: 100.0%
Network 180/500 Epoch 0/1, Step 500/1000, Loss: 0.03508945181965828, Accuracy: 99.50765136393883%
Network 181/500 Epoch 0/1, Step 0/1000, Loss: 0.011312929913401604, Accuracy: 100.0%
Network 181/500 Epoch 0/1, Step 500/1000, Loss: 0.004680578131228685, Accuracy: 99.5675316034598%
Network 182/500 Epoch 0/1, Step 0/1000, Loss: 0.004065441899001598, Accuracy: 10

Network 221/500 Epoch 0/1, Step 500/1000, Loss: 0.0010622978443279862, Accuracy: 99.60745176314043%
Network 222/500 Epoch 0/1, Step 0/1000, Loss: 0.002546024275943637, Accuracy: 100.0%
Network 222/500 Epoch 0/1, Step 500/1000, Loss: 0.004663141444325447, Accuracy: 99.61410512308719%
Network 223/500 Epoch 0/1, Step 0/1000, Loss: 0.014244954101741314, Accuracy: 100.0%
Network 223/500 Epoch 0/1, Step 500/1000, Loss: 0.008935952559113503, Accuracy: 99.54757152361947%
Network 224/500 Epoch 0/1, Step 0/1000, Loss: 0.011333211325109005, Accuracy: 100.0%
Network 224/500 Epoch 0/1, Step 500/1000, Loss: 0.018925031647086143, Accuracy: 99.44777112441793%
Network 225/500 Epoch 0/1, Step 0/1000, Loss: 0.0015730857849121094, Accuracy: 100.0%
Network 225/500 Epoch 0/1, Step 500/1000, Loss: 0.012246577069163322, Accuracy: 99.6606786427146%
Network 226/500 Epoch 0/1, Step 0/1000, Loss: 0.009440501220524311, Accuracy: 100.0%
Network 226/500 Epoch 0/1, Step 500/1000, Loss: 0.006674551870673895, Accuracy:

Network 266/500 Epoch 0/1, Step 0/1000, Loss: 0.0030523459427058697, Accuracy: 100.0%
Network 266/500 Epoch 0/1, Step 500/1000, Loss: 0.06253591924905777, Accuracy: 99.54757152361945%
Network 267/500 Epoch 0/1, Step 0/1000, Loss: 0.06906847655773163, Accuracy: 96.66666666666667%
Network 267/500 Epoch 0/1, Step 500/1000, Loss: 0.002650483511388302, Accuracy: 99.55422488356625%
Network 268/500 Epoch 0/1, Step 0/1000, Loss: 0.003235801123082638, Accuracy: 100.0%
Network 268/500 Epoch 0/1, Step 500/1000, Loss: 0.0022755861282348633, Accuracy: 99.61077844311379%
Network 269/500 Epoch 0/1, Step 0/1000, Loss: 0.0027716001495718956, Accuracy: 100.0%
Network 269/500 Epoch 0/1, Step 500/1000, Loss: 0.005520860198885202, Accuracy: 99.53759148369937%
Network 270/500 Epoch 0/1, Step 0/1000, Loss: 0.005332032684236765, Accuracy: 100.0%
Network 270/500 Epoch 0/1, Step 500/1000, Loss: 0.008482447825372219, Accuracy: 99.61410512308723%
Network 271/500 Epoch 0/1, Step 0/1000, Loss: 0.035958193242549896,

Network 310/500 Epoch 0/1, Step 0/1000, Loss: 0.005864095874130726, Accuracy: 100.0%
Network 310/500 Epoch 0/1, Step 500/1000, Loss: 0.00263934931717813, Accuracy: 99.55422488356628%
Network 311/500 Epoch 0/1, Step 0/1000, Loss: 0.0016183058032765985, Accuracy: 100.0%
Network 311/500 Epoch 0/1, Step 500/1000, Loss: 0.0011404514079913497, Accuracy: 99.5375914836993%
Network 312/500 Epoch 0/1, Step 0/1000, Loss: 0.009950748644769192, Accuracy: 100.0%
Network 312/500 Epoch 0/1, Step 500/1000, Loss: 0.006562304683029652, Accuracy: 99.63406520292747%
Network 313/500 Epoch 0/1, Step 0/1000, Loss: 0.04430270940065384, Accuracy: 96.66666666666667%
Network 313/500 Epoch 0/1, Step 500/1000, Loss: 0.0487239845097065, Accuracy: 99.5276114437791%
Network 314/500 Epoch 0/1, Step 0/1000, Loss: 0.008259598165750504, Accuracy: 100.0%
Network 314/500 Epoch 0/1, Step 500/1000, Loss: 0.0039262929931283, Accuracy: 99.51097804391222%
Network 315/500 Epoch 0/1, Step 0/1000, Loss: 0.0016856353031471372, Accur

Network 355/500 Epoch 0/1, Step 500/1000, Loss: 0.015684008598327637, Accuracy: 99.50099800399211%
Network 356/500 Epoch 0/1, Step 0/1000, Loss: 0.02775915525853634, Accuracy: 98.33333333333333%
Network 356/500 Epoch 0/1, Step 500/1000, Loss: 0.007305463310331106, Accuracy: 99.56087824351299%
Network 357/500 Epoch 0/1, Step 0/1000, Loss: 0.010963177308440208, Accuracy: 100.0%
Network 357/500 Epoch 0/1, Step 500/1000, Loss: 0.0403122715651989, Accuracy: 99.55422488356628%
Network 358/500 Epoch 0/1, Step 0/1000, Loss: 0.016233762726187706, Accuracy: 98.33333333333333%
Network 358/500 Epoch 0/1, Step 500/1000, Loss: 0.004262630362063646, Accuracy: 99.52428476380582%
Network 359/500 Epoch 0/1, Step 0/1000, Loss: 0.030659206211566925, Accuracy: 98.33333333333333%
Network 359/500 Epoch 0/1, Step 500/1000, Loss: 0.003420718479901552, Accuracy: 99.5276114437792%
Network 360/500 Epoch 0/1, Step 0/1000, Loss: 0.007487567141652107, Accuracy: 100.0%
Network 360/500 Epoch 0/1, Step 500/1000, Loss: 

Network 399/500 Epoch 0/1, Step 500/1000, Loss: 0.027194397523999214, Accuracy: 99.56753160345981%
Network 400/500 Epoch 0/1, Step 0/1000, Loss: 0.004998938180506229, Accuracy: 100.0%
Network 400/500 Epoch 0/1, Step 500/1000, Loss: 0.016370518133044243, Accuracy: 99.4278110445776%
Network 401/500 Epoch 0/1, Step 0/1000, Loss: 0.03808455541729927, Accuracy: 98.33333333333333%
Network 401/500 Epoch 0/1, Step 500/1000, Loss: 0.008259566500782967, Accuracy: 99.58749168330013%
Network 402/500 Epoch 0/1, Step 0/1000, Loss: 0.0054277339950203896, Accuracy: 100.0%
Network 402/500 Epoch 0/1, Step 500/1000, Loss: 0.01913630962371826, Accuracy: 99.51763140385897%
Network 403/500 Epoch 0/1, Step 0/1000, Loss: 0.0026629765052348375, Accuracy: 100.0%
Network 403/500 Epoch 0/1, Step 500/1000, Loss: 0.017422445118427277, Accuracy: 99.57751164338%
Network 404/500 Epoch 0/1, Step 0/1000, Loss: 0.009349417872726917, Accuracy: 100.0%
Network 404/500 Epoch 0/1, Step 500/1000, Loss: 0.006953525356948376, Ac

Network 444/500 Epoch 0/1, Step 0/1000, Loss: 0.02984170988202095, Accuracy: 98.33333333333333%
Network 444/500 Epoch 0/1, Step 500/1000, Loss: 0.02901042252779007, Accuracy: 99.48103792415172%
Network 445/500 Epoch 0/1, Step 0/1000, Loss: 0.006000836845487356, Accuracy: 100.0%
Network 445/500 Epoch 0/1, Step 500/1000, Loss: 0.004235283471643925, Accuracy: 99.57418496340654%
Network 446/500 Epoch 0/1, Step 0/1000, Loss: 0.02423095703125, Accuracy: 100.0%
Network 446/500 Epoch 0/1, Step 500/1000, Loss: 0.01429982203990221, Accuracy: 99.54757152361947%
Network 447/500 Epoch 0/1, Step 0/1000, Loss: 0.01876303367316723, Accuracy: 100.0%
Network 447/500 Epoch 0/1, Step 500/1000, Loss: 0.006832432933151722, Accuracy: 99.50432468396544%
Network 448/500 Epoch 0/1, Step 0/1000, Loss: 0.004079611971974373, Accuracy: 100.0%
Network 448/500 Epoch 0/1, Step 500/1000, Loss: 0.01923046074807644, Accuracy: 99.54091816367271%
Network 449/500 Epoch 0/1, Step 0/1000, Loss: 0.009791803546249866, Accuracy:

Network 488/500 Epoch 0/1, Step 500/1000, Loss: 0.0693100169301033, Accuracy: 99.53759148369939%
Network 489/500 Epoch 0/1, Step 0/1000, Loss: 0.0008728345273993909, Accuracy: 100.0%
Network 489/500 Epoch 0/1, Step 500/1000, Loss: 0.004082536790519953, Accuracy: 99.48103792415172%
Network 490/500 Epoch 0/1, Step 0/1000, Loss: 0.024096226319670677, Accuracy: 98.33333333333333%
Network 490/500 Epoch 0/1, Step 500/1000, Loss: 0.009505399502813816, Accuracy: 99.57418496340661%
Network 491/500 Epoch 0/1, Step 0/1000, Loss: 0.014897195622324944, Accuracy: 100.0%
Network 491/500 Epoch 0/1, Step 500/1000, Loss: 0.0018334706546738744, Accuracy: 99.46107784431146%
Network 492/500 Epoch 0/1, Step 0/1000, Loss: 0.005643145181238651, Accuracy: 100.0%
Network 492/500 Epoch 0/1, Step 500/1000, Loss: 0.0018876473186537623, Accuracy: 99.53426480372599%
Network 493/500 Epoch 0/1, Step 0/1000, Loss: 0.0035384814254939556, Accuracy: 100.0%
Network 493/500 Epoch 0/1, Step 500/1000, Loss: 0.0594335645437240