In [6]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.autograd import Variable

import numpy as np
from tqdm import tqdm

In [7]:
BATCH_SIZE=125
LEARNING_RATE=0.01
EPOCHS=10

number_trained_cnns=50

In [8]:
DATASET_TRAIN = torchvision.datasets.MNIST(root='data',train=True, transform=transforms.ToTensor(), download=True)
DATASET_TEST = torchvision.datasets.MNIST(root='data',train=False, transform=transforms.ToTensor(), download=True)

LOADER_TRAIN = DataLoader(dataset=DATASET_TRAIN, batch_size=BATCH_SIZE, shuffle=True)
LOADER_TEST = DataLoader(dataset=DATASET_TEST, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1,64,3,1,padding='same'),nn.ReLU(),nn.MaxPool2d(2))
        self.conv2 = nn.Sequential(nn.Conv2d(64,32,3,1,padding='same'),nn.ReLU(),nn.MaxPool2d(2))
        self.fc = nn.Sequential(nn.Linear(1568,32),nn.ReLU())
        self.out = nn.Sequential(nn.Linear(32, 10))

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.out(x)
        return x

In [10]:
progress_bar=tqdm(range(number_trained_cnns))
for n in progress_bar:
    model = ConvNet()
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
    model.train()
    for epoch in range(EPOCHS):
        if epoch in [1,4,9]:
            np.save(f'trained_cnns/model_{n+1}_epoch{epoch}_batch125_weights',model.conv1[0].weight.data.numpy())
        for i, (images, labels) in enumerate(LOADER_TRAIN):
            b_x = Variable(images) 
            b_y = Variable(labels)
            output=model(b_x)
            loss=criterion(output, b_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        progress_bar.set_description(f'CNN [{n+1}/{number_trained_cnns}], Epoch [{epoch+1}/{EPOCHS}]')
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in LOADER_TEST:
            test_output = model(images)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = (pred_y == labels).sum().item() / float(labels.size(0))
    f = open("test_data_cnns.txt", "a")
    f.write("CNN {}: Test Accuracy {}\n".format(n+1,accuracy))
    f.close()


CNN [10/50], Epoch [1/10]:  18%|█▊        | 9/50 [1:29:06<6:45:58, 594.10s/it]


KeyboardInterrupt: 