# Training a classifer

In this notebook we train a simple classifier on the two-moosn dataset.

In [None]:
%load_ext autoreload
%autoreload 2
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

In [None]:
resample = False
if resample:
    X,Y = make_moons(noise=0.1, random_state=0, n_samples=1000,)
else:
    data = np.load('data-weights/two_moons.npz')
    X = data['X']
    Y = data['Y']
    
# visualize data
for i in [0,1]:
    plt.scatter(
        X[Y == i, 0],
        X[Y == i, 1],
        color=['Red', 'Blue'][i],
        alpha=0.2,
        label="Class "+str(i)
    )
plt.xlim(-1.5, 2.5)
plt.ylim(-1., 1.5)
plt.legend()
plt.tight_layout()
plt.savefig('results/two_moons.png')

In [None]:
save_data = False
if save_data:
    np.savez('two_moons.npz', X=X, Y=Y)

# Define torch dataset and loader

In [None]:
dataset = TensorDataset(torch.tensor(X, dtype=torch.float32),torch.tensor(Y, dtype=torch.float32))
loader  = DataLoader(dataset, batch_size=100)

# Define Device

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Operating on device: ' + str(device))

# Define neural model

We define a simple neural network. The structure is copied from ["A 'Hello World' for PyTorch"](https://seanhoward.me/blog/2022/hello_world_pytorch/) tutorial by Sean T. Howard.

In [None]:
from model import get_two_moons_model
act_fun = 'ReLU'
model = get_two_moons_model(act_fun=act_fun)
model.to(device)

# Train network

In [None]:
loss = nn.MSELoss()
opt = torch.optim.Adam(model.parameters())
epochs = 100

for i in range(epochs):
    L = 0
    for x,y in iter(loader):
        x,y = (x.to(device), y.to(device))
        opt.zero_grad()
        l = loss(model(x), y[:,None])
        l.backward()
        opt.step()
        L += l.item()
    print('Loss: ' + str(L))

# Check output

In [None]:
Xgrid,Ygrid = torch.meshgrid(torch.linspace(-1.5,2.5, 100), torch.linspace(-1,1.5, 100), indexing='ij')
inp = torch.stack([Xgrid.ravel(), Ygrid.ravel()], dim=1).to(device)
Z = model(inp).reshape(Xgrid.shape)

plt.contourf(Xgrid.numpy(), Ygrid.numpy(), Z.detach().cpu().numpy(), cmap='coolwarm_r', levels=100, alpha=1)
plt.xlim(-1.5, 2.5)
plt.ylim(-1., 1.5)
plt.tight_layout()
plt.savefig('results/netvis_' + act_fun)

# Save weights

In [None]:
save_weights = False
if save_weights:
    torch.save(model.state_dict(), 'data-weights/two_moons_' + act_fun + '.pt')

# Load weights

In [None]:
load_weights = True
if load_weights:
    model.load_state_dict(torch.load('data-weights/two_moons_' + act_fun + '.pt', map_location=device))