In [1]:
%matplotlib notebook
# calling it a second time may prevent some graphics errors
%matplotlib notebook 

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [3]:
from model import ResFFN
from data import TwoMoonsDataset
import train
import sklearn.mixture

In [4]:
torch.manual_seed(1337)
np.random.seed(1337)

# Train a model
First we train a FFN with residual connections to ensure smoothness and sensitivity. We use the specifications of [1]: https://arxiv.org/abs/2102.11582

In [5]:
epochs = 150
batch_size = 128
model = ResFFN(2, 2, 128, 4)
if torch.cuda.is_available(): 
    model = model.cuda()

data_train = TwoMoonsDataset(10000, noise=0.1)
data_val = TwoMoonsDataset(200, noise=0.1)
# data_test = TwoMoonsDataset(200, noise=0.1) Unused

loss_history, accuracy_history = train.train_model(model, data_train, data_val, epochs=epochs, batch_size=batch_size)

### Epoch 1 / 150
Validation loss 0.2708; Validation accuracy 0.8850
### Epoch 2 / 150
Validation loss 0.2072; Validation accuracy 0.9100
### Epoch 3 / 150
Validation loss 0.1968; Validation accuracy 0.9300
### Epoch 4 / 150
Validation loss 0.1534; Validation accuracy 0.9450
### Epoch 5 / 150
Validation loss 0.1255; Validation accuracy 0.9450
### Epoch 6 / 150
Validation loss 0.1201; Validation accuracy 0.9500
### Epoch 7 / 150
Validation loss 0.1341; Validation accuracy 0.9500
### Epoch 8 / 150
Validation loss 0.1088; Validation accuracy 0.9600
### Epoch 9 / 150
Validation loss 0.1037; Validation accuracy 0.9550
### Epoch 10 / 150
Validation loss 0.0964; Validation accuracy 0.9650
### Epoch 11 / 150
Validation loss 0.1002; Validation accuracy 0.9650
### Epoch 12 / 150
Validation loss 0.0798; Validation accuracy 0.9750
### Epoch 13 / 150
Validation loss 0.0863; Validation accuracy 0.9800
### Epoch 14 / 150
Validation loss 0.0917; Validation accuracy 0.9750
### Epoch 15 / 150
Validation

Validation loss 0.0570; Validation accuracy 0.9850
### Epoch 119 / 150
Validation loss 0.0615; Validation accuracy 0.9900
### Epoch 120 / 150
Validation loss 0.0543; Validation accuracy 0.9850
### Epoch 121 / 150
Validation loss 0.0523; Validation accuracy 0.9850
### Epoch 122 / 150
Validation loss 0.0545; Validation accuracy 0.9900
### Epoch 123 / 150
Validation loss 0.0533; Validation accuracy 0.9900
### Epoch 124 / 150
Validation loss 0.0598; Validation accuracy 0.9900
### Epoch 125 / 150
Validation loss 0.0455; Validation accuracy 0.9900
### Epoch 126 / 150
Validation loss 0.0520; Validation accuracy 0.9850
### Epoch 127 / 150
Validation loss 0.0492; Validation accuracy 0.9900
### Epoch 128 / 150
Validation loss 0.0687; Validation accuracy 0.9900
### Epoch 129 / 150
Validation loss 0.0519; Validation accuracy 0.9900
### Epoch 130 / 150
Validation loss 0.0528; Validation accuracy 0.9850
### Epoch 131 / 150
Validation loss 0.0526; Validation accuracy 0.9900
### Epoch 132 / 150
Valida

# Fit Gaussian Mixture Model
I found that estimating the covariance matrix per class from the feature representations directly leads to singular matrices. Investigating this issue I tried several things to circumcome this issue:
- Use a Leaky ReLU to prevent fixed "0"s in feature represenations, leading to zero variance components in the feature space which causes singularity
- Don't apply any non-linearity when extracting features to induce more variance

Finally, instead of estimating the mean and covariance directly from feature represenations per class I instead fitted a 1-component GMM to estimate a Gaussian distribution per class in the feature space.

In [6]:
zs, ys = [], []
model.eval()
data_loader_train = torch.utils.data.DataLoader(data_train, batch_size=128, drop_last=False)
with torch.no_grad():
    for (x, y) in tqdm(data_loader_train):
        x = x.float()
        if torch.cuda.is_available():
            x = x.cuda()
        zs.append(model(x, return_features=True))
        ys.append(y)
z = torch.cat(zs).detach().cpu().numpy()
y = torch.cat(ys).detach().cpu().numpy()

100%|██████████| 79/79 [00:00<00:00, 341.27it/s]


In [7]:
components, mixture_coefs = {}, {}
for label in np.unique(y):
    components[label] = sklearn.mixture.GaussianMixture(n_components=1, covariance_type='full')
    components[label].fit(z[y == label])
    mixture_coefs[label] = (y == label).sum() / y.shape[0]

In [8]:
# Create a mesh and get feature representations of all points within the mesh
X, Y = np.meshgrid(np.linspace(-3., 3., 100), np.linspace(-3., 3., 100))
XX = np.array([X.ravel(), Y.ravel()]).T
data_mesh = torch.utils.data.TensorDataset(torch.Tensor(XX))
data_loader_mesh = torch.utils.data.DataLoader(data_mesh, batch_size=128, drop_last=False)

In [9]:
zs_mesh = []
with torch.no_grad():
    for (x,) in tqdm(data_loader_mesh):
        x = x.float()
        if torch.cuda.is_available():
            x = x.cuda()
        zs_mesh.append(model(x, return_features=True))
z_mesh = torch.cat(zs_mesh).detach().cpu().numpy()

100%|██████████| 79/79 [00:00<00:00, 253.32it/s]


In [10]:
scores = np.zeros(z_mesh.shape[0])
for label in components:
    scores += mixture_coefs[label] * components[label].score_samples(z_mesh)
scores = scores.reshape(X.shape)

In [11]:
plt.pcolormesh(X, Y, np.exp(5e-3 * scores))

<IPython.core.display.Javascript object>

  """Entry point for launching an IPython kernel.


<matplotlib.collections.QuadMesh at 0x7fc0b8b35908>