In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
n_per_clust = 300
blur = 1

A = [1, 1]
B = [5, 1]
C = [4, 3]

a = [ A[0]+np.random.randn(n_per_clust)*blur, A[1]+np.random.randn(n_per_clust)*blur ]
b = [ B[0]+np.random.randn(n_per_clust)*blur, B[1]+np.random.randn(n_per_clust)*blur ]
c = [ C[0]+np.random.randn(n_per_clust)*blur, C[1]+np.random.randn(n_per_clust)*blur ]

labels_np = np.hstack(( np.zeros((n_per_clust)),
                        np.ones((n_per_clust)),
                      1+np.ones((n_per_clust))))

data_np = np.hstack((a,b,c)).T

data = torch.tensor(data_np).float()
labels = torch.tensor(labels_np).long()

fig = plt.figure(figsize=(5,5))
plt.plot(data[np.where(labels==0)[0],0], data[np.where(labels==0)[0],1],'bs',alpha=.5)
plt.plot(data[np.where(labels==1)[0],0], data[np.where(labels==1)[0],1],'ko',alpha=.5)
plt.plot(data[np.where(labels==2)[0],0], data[np.where(labels==2)[0],1],'r^',alpha=.5)
plt.title('The qwerties!')
plt.xlabel('qwerty dimension 1')
plt.ylabel('qwerty dimension 2')
plt.show()

In [None]:
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, test_size=.1)

train_data = TensorDataset(train_data, train_labels)
test_data = TensorDataset(test_data, test_labels)

batch_size = 16
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size = test_data.tensors[0].shape[0])

In [None]:
def create_model(momentum):
    class ANN(nn.Module):
        def __init__(self):
            super().__init__()
            
            self.input = nn.Linear(2,8)
            
            self.fc1 = nn.Linear(8,8)
            
            self.output = nn.Linear(8,3)
            
        def forward(self, x):
            x = F.relu(self.input(x))
            x = F.relu(self.fc1(x))
            return self.output(x)
        
    model = ANN()
    
    loss_fun = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),lr=.01, momentum=momentum)
    
    return model, loss_fun, optimizer

In [None]:
optim = create_model(.9)[2]
optim

In [None]:
num_of_epochs = 50
def train_model(momentum):
 
    
    model, loss_fun, optimizer = create_model(momentum)
    
    losses = torch.zeros(num_of_epochs)
    train_acc = []
    test_acc = []
    
    for epoch in range(num_of_epochs):
        model.train()
        
        batch_acc = []
        batch_loss = []
        for X,y in train_loader:
            yHat = model(X)
            loss = loss_fun(yHat, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            batch_loss.append(loss.item())
            
            matches = torch.argmax(yHat, axis=1) == y
            matches_numeric = matches.float()
            accuracy = 100*torch.mean(matches_numeric)
            batch_acc.append(accuracy)
        train_acc.append(np.mean(batch_acc))
        
        losses[epoch] = np.mean(batch_loss)
        
        model.eval()
        X,y = next(iter(test_loader))
        with torch.no_grad():
            yHat = model(X)
            
        test_acc.append(100*torch.mean((torch.argmax(yHat,axis=1)==y).float()))    
        
    return train_acc, test_acc, losses, model

In [None]:
momenta = [0, .5, .9, .95, .99]

results = np.zeros((num_of_epochs, len(momenta), 3))

for idx, mom in enumerate(momenta):
    train_acc, test_acc, losses, net = train_model(mom)
    results[:, idx, 0] = losses
    results[:, idx, 1] = train_acc
    results[:, idx, 2] = test_acc

In [None]:
fig, ax = plt.subplots(1,3, figsize=(12, 5))

for i in range(3):
    ax[i].plot(results[:, :, i])
    ax[i].legend(momenta)
    ax[i].set_xlabel('Epochs')
    if i == 0:
        ax[i].set_ylabel('Loss')
    else:
        ax[i].set_ylabel('Accuracy (%)')
        ax[i].set_ylim([20, 100])
    
ax[0].set_title('Losses')
ax[1].set_title('Train')
ax[2].set_title('Test')

plt.show()