In [3]:
import torch
import random
import numpy as np

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

In [4]:
import matplotlib.pyplot as plt
import torchvision.datasets

In [None]:
drive.mount('/content/drive')
sys.path.append('drive/My Drive/Colab Notebooks')

In [None]:
from LeNet5 import *
from FirstCustomNN import *
from SecondCustomNN import *

In [None]:
def trainig(neural_network, X_test, y_test):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    neural_network = neural_network.to(device)

    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(neural_network.parameters(), lr=1.0e-3)

    batch_size = 100

    test_accuracy_history = []
    test_loss_history = []

    X_test = X_test.to(device)
    y_test = y_test.to(device)

    for epoch in range(60):
        order = np.random.permutation(len(X_train))
        for start_index in range(0, len(X_train), batch_size):
            optimizer.zero_grad()
          
            batch_indexes = order[start_index:start_index+batch_size]
          
            X_batch = X_train[batch_indexes].to(device)
            y_batch = y_train[batch_indexes].to(device)
          
            preds = neural_network.forward(X_batch) 
          
            loss_value = loss(preds, y_batch)
            loss_value.backward()
          
            optimizer.step()
          
        test_preds = neural_network.forward(X_test)
        test_loss_history.append(loss(test_preds, y_test).data.cpu())
      
        accuracy = (test_preds.argmax(dim=1) == y_test).float().mean().data.cpu()
        test_accuracy_history.append(accuracy)
      
        print("error rate: " + str(1 - accuracy))

    return neural_network, test_accuracy_history

In [5]:
MNIST_train = torchvision.datasets.MNIST('./', download=True, train=True)
MNIST_test = torchvision.datasets.MNIST('./', download=True, train=False)

FASHION_train = torchvision.datasets.FashionMNIST('./', download=True, train=True)
FASHION_test = torchvision.datasets.FashionMNIST('./', download=True, train=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST\raw\train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./MNIST\raw\train-images-idx3-ubyte.gz to ./MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST\raw\train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./MNIST\raw\train-labels-idx1-ubyte.gz to ./MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST\raw\t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./MNIST\raw\t10k-images-idx3-ubyte.gz to ./MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST\raw\t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./MNIST\raw\t10k-labels-idx1-ubyte.gz to ./MNIST\raw
Processing...
Done!






In [None]:
X_train = MNIST_train.train_data
y_train = MNIST_train.train_labels
X_test = MNIST_test.test_data
y_test = MNIST_test.test_labels

In [6]:
X_train = X_train.unsqueeze(1).float()
X_test = X_test.unsqueeze(1).float()



In [7]:
lenet5 = LeNet5()
lenet5, lenet5_accuracy_history = trainig(lenet5, X_train, y_train)

<Figure size 640x480 with 1 Axes>

tensor(5)


In [8]:
first_custom = FirstCustomNN()
first_custom_accuracy_history = trainig(first_custom, X_train, y_train)

In [9]:
second_custom = SecondCustomNN()
second_custom_accuracy_history = trainig(second_custom, X_train, y_train)

torch.Size([60000, 1, 28, 28])

In [10]:
plt.plot(lenet5_accuracy_history, color ='blue');
plt.plot(first_custom_accuracy_history, color="darkorange");
plt.plot(second_custom_accuracy_history, color="green");

In [11]:
plt.imshow(MNIST_test.test_data[0, :, :])
plt.show()
print(MNIST_test.test_labels[0])