In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as tt
import matplotlib.pyplot as plt

In [2]:
N_NETWORKS = 100
BATCH_SIZE = 60
N_EPOCHS   = 1

In [3]:
DATA_PATH = './notMNIST'
MODEL_STORE_PATH = './models/model_'
model_paths = []
model_files = []
for i in range(100):
    model_paths.append(MODEL_STORE_PATH + str(i) + '.pt')
    model_files.append(open(model_paths[i], 'w'))

In [4]:
#transform data to tensor and normalize (values state for MNIST!)
trans = tt.Compose([tt.ToTensor(), tt.Normalize((0.1307,), (0.3081,))]) 

train_dataset = torchvision.datasets.MNIST(root=DATA_PATH, train=True,  transform=trans, download=True) 
test_dataset  = torchvision.datasets.MNIST(root=DATA_PATH, train=False, transform=trans)

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) 
test_loader  = DataLoader(dataset=test_dataset,  batch_size=BATCH_SIZE, shuffle=False)

In [None]:
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
        
    def forward(self, x):
        return x.view(BATCH_SIZE, -1)

In [None]:
class FConvMNIST(nn.Module):
    def __init__(self, seq, clf):
        super(FConvMNIST, self).__init__()
        self.seq = seq
        self.clf = clf
        
    def forward(self, x):
        return self.clf(self.seq(x)) 

In [None]:
#architecture of the network copied from the article
seq = nn.Sequential(
    nn.Conv2d(1, 256, kernel_size=(7, 7), stride=(1, 1)),
    nn.LeakyReLU(negative_slope=0.01),
    nn.MaxPool2d(kernel_size=2,stride=2, padding=0, dilation=1, ceil_mode=False),
    nn.Conv2d(256, 512, kernel_size=(5, 5), stride=(1, 1)),
    nn.LeakyReLU(negative_slope=0.01),
    nn.MaxPool2d(kernel_size=2,stride=2, padding=0, dilation=1, ceil_mode=False),
    Flatten()
)
clf = nn.Linear(in_features=4608, out_features=10, bias=True)

In [None]:
model = FConvMNIST(seq, clf)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)

In [None]:
loss_list = []
acc_list  = []
N_STEPS = len(train_loader)
print(N_STEPS)

#train network
for i in range(N_NETWORKS):
    for j in range(N_EPOCHS):
        for k, (images, labels) in enumerate(train_loader):
            #forward
            pred = model(images)
            loss = criterion(pred, labels)
            loss_list.append(loss.item())
            
            #backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
            #statictics
            total = labels.size(0)
            _, predicted = torch.max(pred.data, 1)
            
            correct = (predicted == labels).sum().item()
            acc_list.append(correct / total)

            if k % 10 == 0:
                print('Network {}/{} Epoch {}/{}, Step {}/{}, Loss: {}, Accuracy: {}%'
                      .format(i, N_NETWORKS, j, N_EPOCHS, k, N_STEPS, loss.item(),
                              (sum(acc_list) / len(acc_list)) * 100))
                
    torch.save(model.state_dict(), model_paths[i])
    model_files[i].close()
    model = FConvMNIST(seq, clf)
    acc_list = []
    loss_list = []
    

1000
Network 0/100 Epoch 0/1, Step 0/1000, Loss: 2.3283114433288574, Accuracy: 6.666666666666667%
Network 0/100 Epoch 0/1, Step 10/1000, Loss: 1.1032981872558594, Accuracy: 38.18181818181819%
Network 0/100 Epoch 0/1, Step 20/1000, Loss: 0.8052517175674438, Accuracy: 56.746031746031754%
Network 0/100 Epoch 0/1, Step 30/1000, Loss: 0.4089963138103485, Accuracy: 66.50537634408602%
Network 0/100 Epoch 0/1, Step 40/1000, Loss: 0.2271326631307602, Accuracy: 72.39837398373983%
Network 0/100 Epoch 0/1, Step 50/1000, Loss: 0.17532040178775787, Accuracy: 75.75163398692808%
Network 0/100 Epoch 0/1, Step 60/1000, Loss: 0.3253006637096405, Accuracy: 78.63387978142073%
Network 0/100 Epoch 0/1, Step 70/1000, Loss: 0.24281950294971466, Accuracy: 80.75117370892015%
Network 0/100 Epoch 0/1, Step 80/1000, Loss: 0.15409746766090393, Accuracy: 82.3251028806584%
Network 0/100 Epoch 0/1, Step 90/1000, Loss: 0.11857176572084427, Accuracy: 83.77289377289378%
Network 0/100 Epoch 0/1, Step 100/1000, Loss: 0.2075

Network 0/100 Epoch 0/1, Step 860/1000, Loss: 0.02640523947775364, Accuracy: 95.76461478900544%
Network 0/100 Epoch 0/1, Step 870/1000, Loss: 0.030338048934936523, Accuracy: 95.79027937236934%
Network 0/100 Epoch 0/1, Step 880/1000, Loss: 0.01729539968073368, Accuracy: 95.81536133182034%
Network 0/100 Epoch 0/1, Step 890/1000, Loss: 0.050620436668395996, Accuracy: 95.83800972689906%
Network 0/100 Epoch 0/1, Step 900/1000, Loss: 0.030645830556750298, Accuracy: 95.86015538290835%
Network 0/100 Epoch 0/1, Step 910/1000, Loss: 0.0036620378959923983, Accuracy: 95.8909623124776%
Network 0/100 Epoch 0/1, Step 920/1000, Loss: 0.14892181754112244, Accuracy: 95.91024249004752%
Network 0/100 Epoch 0/1, Step 930/1000, Loss: 0.021513843908905983, Accuracy: 95.93626924454041%
Network 0/100 Epoch 0/1, Step 940/1000, Loss: 0.0049939872696995735, Accuracy: 95.96174282678051%
Network 0/100 Epoch 0/1, Step 950/1000, Loss: 0.054866451770067215, Accuracy: 95.97967052225778%
Network 0/100 Epoch 0/1, Step 96

Network 1/100 Epoch 0/1, Step 720/1000, Loss: 0.07162056863307953, Accuracy: 97.01820646910807%
Network 1/100 Epoch 0/1, Step 730/1000, Loss: 0.030954984948039055, Accuracy: 97.02580396687945%
Network 1/100 Epoch 0/1, Step 740/1000, Loss: 0.02566676214337349, Accuracy: 97.03522879571224%
Network 1/100 Epoch 0/1, Step 750/1000, Loss: 0.027263641357421875, Accuracy: 97.04549781077576%
Network 1/100 Epoch 0/1, Step 760/1000, Loss: 0.0936790481209755, Accuracy: 97.04523944728471%
Network 1/100 Epoch 0/1, Step 770/1000, Loss: 0.027397099882364273, Accuracy: 97.05533596838042%
Network 1/100 Epoch 0/1, Step 780/1000, Loss: 0.04913235455751419, Accuracy: 97.06251169754917%
Network 1/100 Epoch 0/1, Step 790/1000, Loss: 0.0036131381057202816, Accuracy: 97.07053787455895%
Network 1/100 Epoch 0/1, Step 800/1000, Loss: 0.04983728006482124, Accuracy: 97.0719970386832%
Network 1/100 Epoch 0/1, Step 810/1000, Loss: 0.02096511609852314, Accuracy: 97.08264310694005%
Network 1/100 Epoch 0/1, Step 820/100

Network 2/100 Epoch 0/1, Step 580/1000, Loss: 0.0034953991416841745, Accuracy: 97.52486116492261%


In [None]:
parameters = model.parameters()

#get 256 tensors with shape [1, 7, 7]
layer_weights = next(parameters)

#draw one weight
imgplot = plt.imshow(layer_weights[0].view(7, -1).data)