<center>
    <tr>
    <td><img src="images/Quansight_Logo_Lockup_1.png" width="25%"></img></td>
    </tr>
</center>

---
# PyTorch Classification with One Hidden Layer
---

## Lesson plan

We will construct a neural network to solve a binary problem.  This neural network will consist of 1 hidden layer. 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pprint as pp
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn import datasets

## Generating test data

In [None]:
np.random.seed(0)

In [None]:
n_samples = 120
x, y = datasets.make_moons(n_samples=n_samples, random_state=0, noise=0.1)
x = x - np.mean(x,0) # 0 centered

plt.figure(figsize=(7,7))
plt.scatter(x[:,0], x[:,1], c=y, cmap=cm.bwr)
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.xlim(-3,3)
plt.ylim(-3,3)

## Data utilities

Deep learning models are data intensive.  In many cases a large fraction of time is spent organizing data to support training deep neural networks.  PyTorch provides `Dataset` class in its `torch.utils.data` module to construct data loaders appropriate for deep network training.

### Constructing a `DataSet`

In [None]:
from torch.utils.data import Dataset
class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        sample = {
            'feature': torch.tensor(self.x[idx], dtype=torch.float32), 
            'label': torch.tensor(np.array([self.y[idx]]), dtype=torch.float32)}
        return sample

### Constructing a `DataLoader`

We use dataloader class to construct batches needed during training.

In [None]:
dataset = MyDataset(x, y)
batch_size = 4
shuffle = True
num_workers = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
for i_batch, samples in enumerate(dataloader):
    print('\nbatch# = %s' % i_batch)
    print('samples: ')
    pp.pprint(samples)
    break # Otherwise it prints too much stuff

## Neural network model

In [None]:
class BinaryClassification(nn.Module):
    def __init__(self, input_size=2, hidden_size=10):
        super(BinaryClassification, self).__init__()
        
        num_classes = 2

        self.hidden = nn.Linear(in_features=input_size, out_features=hidden_size, bias=True)
        self.hidden_activation = nn.Tanh()
        
        self.output = nn.Linear(in_features=hidden_size, out_features=1, bias=True)
        self.output_activation = nn.Sigmoid()
    
    def forward(self, x):
        x1 = self.hidden(x)
        x2 = self.hidden_activation(x1)
        x3 = self.output(x2)
        x4 = self.output_activation(x3)        
        return x4

### Model summary

In [None]:
dummy = BinaryClassification()
print(dummy)

# Loss

Binary cross entropy loss.

In [None]:
class MyLoss(nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()
        
    def forward(self, predictions, targets):
        log_probs = torch.where(targets.byte(), torch.log(predictions), torch.log(1.-predictions))
        loss = - torch.sum(log_probs)
        return loss

## Accuracy

Counting how many predictions were correct.

In [None]:
def accuracy(predictions, targets):
    ones = torch.ones(predictions.shape)
    zeros = torch.zeros(predictions.shape)
    
    p = torch.where(predictions.cpu() > 0.5, ones, zeros)
    s = torch.sum(p == targets.cpu())
    return s.item()

## Training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))

In [None]:
model = BinaryClassification().to(device)
criterion = MyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 

In [None]:
dataset = MyDataset(x, y)
batch_size = 4
shuffle = True
num_workers = 4
training_sample_generator = DataLoader(dataset, 
                                       batch_size=batch_size, 
                                       shuffle=shuffle, 
                                       num_workers=num_workers)

In [None]:
num_epochs = 1000
for epoch in range(num_epochs):
    n = 0
    for batch_i, samples in enumerate(training_sample_generator):
        features = samples['feature'].to(device)
        targets = samples['label'].to(device)
        predictions = model(features)
        error = criterion(predictions, targets)
        n += accuracy(predictions, targets)
        optimizer.zero_grad()
        error.backward()        
        optimizer.step()

    if epoch % 100 == 0:
        print(f'epoch={epoch:03}. error={error.item():<7.4}. accuracy={n}')
    
    # If we have achieved 99% accuracy, then stop.
    if n > .99 * n_samples: 
        break
        
print(f'epoch={epoch:03}. error={error.item():<7.4}. accuracy={n}')

## Results

Colors represent whether or not points are classified correctly.

In [None]:
predicted_labels = np.zeros(n_samples)
prob_of_one = model(torch.Tensor(x).to(device)).detach().cpu().numpy().flatten()
predicted_labels[prob_of_one > 0.5] = 1

# Color 1 represent correct classification, 0 otherwise
colors = np.where(predicted_labels == y, 1, 0) 

accuracy = np.sum(colors)
print(f'accuracy = {accuracy}')

In [None]:
plt.figure(figsize=(7,7))
plt.title('Classification results')
plt.scatter(x[:,0], x[:,1], c=colors, cmap=cm.bwr)
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.xlim(-3,3)
plt.ylim(-3,3)

## Visualizing results

In [None]:
xcoord = np.linspace(-3, 3)
ycoord = np.linspace(-3, 3)
xx, yy = np.meshgrid(xcoord, ycoord)
xxt = torch.tensor(xx, dtype=torch.float32).view(-1).unsqueeze(0)
yyt = torch.tensor(yy, dtype=torch.float32).view(-1).unsqueeze(0)
# print(xxt.shape)
# print(yyt.shape)
v = torch.t(torch.cat([xxt,yyt]))
# print(v.shape)
m = model(v.to(device))
# print(m.shape)
mm = m.detach().cpu().numpy().reshape(50,50)
# print(mm.shape)

x_try = torch.tensor(x, dtype=torch.float32)
y_try = model(x_try.to(device))
yy_try = (y_try.squeeze() > 0.5).cpu().numpy()
# print(yy_try)

plt.figure(figsize=(7,7))
extent = -3, 3, -3, 3
plt.imshow(mm, cmap=cm.BuGn, interpolation='bilinear', extent=extent, alpha=.5, origin='lower')
plt.scatter(x[:,0], x[:,1], c=yy_try, cmap=cm.viridis)
plt.colorbar()
plt.xlabel('$x_1$')
plt.ylabel('$x_2$')
plt.xlim(-3,3)
plt.ylim(-3,3)
plt.title('Classification results')

<center>
    <tr>
    <td><img src="images/Quansight_Logo_Lockup_1.png" width="25%"></img></td>
    </tr>
</center>