In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as tt
import matplotlib.pyplot as plt

In [2]:
N_NETWORKS = 100
BATCH_SIZE = 60
N_EPOCHS   = 1

In [3]:
DATA_PATH = './MNIST'
MODEL_STORE_PATH = './models_MNIST/model_'
model_paths = []
model_files = []
for i in range(N_NETWORKS):
    model_paths.append(MODEL_STORE_PATH + str(i) + '.pt')

In [4]:
#transform data to tensor and normalize (values state for MNIST!)
trans = tt.Compose([tt.ToTensor(), tt.Normalize((0.1307,), (0.3081,))]) 

train_dataset = torchvision.datasets.MNIST(root=DATA_PATH, train=True,  transform=trans, download=True) 
test_dataset  = torchvision.datasets.MNIST(root=DATA_PATH, train=False, transform=trans)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) 
test_loader  = DataLoader(dataset=test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

In [5]:
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
        
    def forward(self, x):
        return x.view(BATCH_SIZE, -1)

In [6]:
class FConvMNIST(nn.Module):
    def __init__(self, seq, clf):
        super(FConvMNIST, self).__init__()
        self.seq = seq
        self.clf = clf
        
    def forward(self, x):
        return self.clf(self.seq(x)) 

In [7]:
def build_new_net():
    #architecture of the network copied from the article
    seq = nn.Sequential(
        nn.Conv2d(1, 256, kernel_size=(7, 7), stride=(1, 1)),
        nn.LeakyReLU(negative_slope=0.01),
        nn.MaxPool2d(kernel_size=2,stride=2, padding=0, dilation=1, ceil_mode=False),
        nn.Conv2d(256, 512, kernel_size=(5, 5), stride=(1, 1)),
        nn.LeakyReLU(negative_slope=0.01),
        nn.MaxPool2d(kernel_size=2,stride=2, padding=0, dilation=1, ceil_mode=False),
        Flatten()
    )
    clf = nn.Linear(in_features=4608, out_features=10, bias=True)
    return seq, clf

In [8]:
device='cuda'
seq, clf = build_new_net()
model = FConvMNIST(seq, clf).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)

In [9]:
loss_list = []
acc_list  = []
N_STEPS = len(train_loader)
print(N_STEPS)

#train network
for i in range(N_NETWORKS):
    for j in range(N_EPOCHS):
        for k, (images, labels) in enumerate(train_loader):
            #forward
            labels = labels.to(device)
            pred = model(images.to(device))
            loss = criterion(pred, labels)
            loss_list.append(loss.item())
            
            #backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
            #statictics
            total = labels.size(0)
            _, predicted = torch.max(pred.data, 1)
            
            correct = (predicted == labels).sum().item()
            acc_list.append(correct / total)

            if k % 500 == 0:
                print('Network {}/{} Epoch {}/{}, Step {}/{}, Loss: {}, Accuracy: {}%'
                      .format(i, N_NETWORKS, j, N_EPOCHS, k, N_STEPS, loss.item(),
                              (sum(acc_list) / len(acc_list)) * 100))
                
    torch.save(model.state_dict(), model_paths[i])
    seq, clf = build_new_net()
    model = FConvMNIST(seq, clf).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)
    acc_list = []
    loss_list = []
    

1000
Network 0/100 Epoch 0/1, Step 0/1000, Loss: 2.2992560863494873, Accuracy: 8.333333333333332%
Network 0/100 Epoch 0/1, Step 500/1000, Loss: 0.04317813366651535, Accuracy: 94.79707252162324%
Network 1/100 Epoch 0/1, Step 0/1000, Loss: 2.3374359607696533, Accuracy: 11.666666666666666%
Network 1/100 Epoch 0/1, Step 500/1000, Loss: 0.03549010679125786, Accuracy: 94.4344644045241%
Network 2/100 Epoch 0/1, Step 0/1000, Loss: 2.3870561122894287, Accuracy: 11.666666666666666%
Network 2/100 Epoch 0/1, Step 500/1000, Loss: 0.09873221814632416, Accuracy: 94.84364604125061%
Network 3/100 Epoch 0/1, Step 0/1000, Loss: 2.285431385040283, Accuracy: 10.0%
Network 3/100 Epoch 0/1, Step 500/1000, Loss: 0.03419412299990654, Accuracy: 94.5874916832999%
Network 4/100 Epoch 0/1, Step 0/1000, Loss: 2.3226678371429443, Accuracy: 10.0%
Network 4/100 Epoch 0/1, Step 500/1000, Loss: 0.12091253697872162, Accuracy: 94.44111776447086%
Network 5/100 Epoch 0/1, Step 0/1000, Loss: 2.3212859630584717, Accuracy: 10.

Network 45/100 Epoch 0/1, Step 0/1000, Loss: 2.381880044937134, Accuracy: 6.666666666666667%
Network 45/100 Epoch 0/1, Step 500/1000, Loss: 0.023847250267863274, Accuracy: 94.25482368596116%
Network 46/100 Epoch 0/1, Step 0/1000, Loss: 2.3471853733062744, Accuracy: 8.333333333333332%
Network 46/100 Epoch 0/1, Step 500/1000, Loss: 0.20572826266288757, Accuracy: 94.00199600798383%
Network 47/100 Epoch 0/1, Step 0/1000, Loss: 2.3711800575256348, Accuracy: 13.333333333333334%
Network 47/100 Epoch 0/1, Step 500/1000, Loss: 0.1126551702618599, Accuracy: 93.65934797072491%
Network 48/100 Epoch 0/1, Step 0/1000, Loss: 2.359180212020874, Accuracy: 11.666666666666666%
Network 48/100 Epoch 0/1, Step 500/1000, Loss: 0.018173059448599815, Accuracy: 94.19494344644022%
Network 49/100 Epoch 0/1, Step 0/1000, Loss: 2.3427772521972656, Accuracy: 10.0%
Network 49/100 Epoch 0/1, Step 500/1000, Loss: 0.046115126460790634, Accuracy: 94.61743180306023%
Network 50/100 Epoch 0/1, Step 0/1000, Loss: 2.317239999

Network 89/100 Epoch 0/1, Step 0/1000, Loss: 2.3202693462371826, Accuracy: 11.666666666666666%
Network 89/100 Epoch 0/1, Step 500/1000, Loss: 0.059264425188302994, Accuracy: 94.92348636061195%
Network 90/100 Epoch 0/1, Step 0/1000, Loss: 2.3298983573913574, Accuracy: 15.0%
Network 90/100 Epoch 0/1, Step 500/1000, Loss: 0.042090900242328644, Accuracy: 94.26813040585469%
Network 91/100 Epoch 0/1, Step 0/1000, Loss: 2.3243649005889893, Accuracy: 20.0%
Network 91/100 Epoch 0/1, Step 500/1000, Loss: 0.10549089312553406, Accuracy: 94.58416500332639%
Network 92/100 Epoch 0/1, Step 0/1000, Loss: 2.3419034481048584, Accuracy: 15.0%
Network 92/100 Epoch 0/1, Step 500/1000, Loss: 0.014242657460272312, Accuracy: 94.17498336659992%
Network 93/100 Epoch 0/1, Step 0/1000, Loss: 2.3285727500915527, Accuracy: 13.333333333333334%
Network 93/100 Epoch 0/1, Step 500/1000, Loss: 0.23510035872459412, Accuracy: 94.494344644045%
Network 94/100 Epoch 0/1, Step 0/1000, Loss: 2.3605499267578125, Accuracy: 5.0%
N