In [None]:
import sys

import matplotlib.pyplot as plt 
%matplotlib inline  
import numpy as np
import scipy.stats # for creating a simple dataset 
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
from torch.nn.functional import one_hot
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

stg_path = '../'
if stg_path not in sys.path:
    sys.path.append(stg_path)

from dataset import create_twomoon_dataset
from stg import STG

In [None]:
sub_train_size=600


traindt = MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)
testdt = MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)

X_train = traindt.data.reshape(traindt.data.size(0), -1)[:sub_train_size]
y_train = traindt.targets[:sub_train_size]
X_test = testdt.data.reshape(testdt.data.size(0), -1)
y_test = testdt.targets

X_train.shape, y_train.shape, X_test.shape, y_test.shape

In [None]:
args_cuda = torch.cuda.is_available()
device = torch.device("cuda" if args_cuda else "cpu") 
feature_selection = True
model = STG(task_type='classification',input_dim=X_train.shape[1], output_dim=10, hidden_dims=[50,50], activation='none',
    optimizer='SGD', learning_rate=0.001, batch_size=128, feature_selection=True, sigma=1, lam=5, random_state=1, device=device, extra_args={'gating_net_hidden_dims':2000}, dropout=0.2) 

In [None]:
epochs = 100
print_interval = 10

model.fit(X_train, y_train, nr_epochs=epochs, valid_X=X_test, valid_y=y_test, print_interval=print_interval, is_tensor_input=True)

train_acc = (y_train.cpu().numpy()==model.predict(X_train)).sum()/600
test_acc = (y_test.cpu().numpy()==model.predict(X_test)).sum()/100

print(f'train accuracy: {train_acc:.2f}% test accuracy: {test_acc:.2f}%')

In [None]:
out = model._model.mlp(X_train[:10].float().cuda())

In [None]:
model._model.loss(model._model.softmax(out), y_train[:10].cuda())

In [None]:
train_acc = (y_train.cpu().numpy()==model.predict(X_train)).sum()/600
test_acc = (y_test.cpu().numpy()==model.predict(X_test)).sum()/100

print(f'train accuracy: {train_acc:.2f}% test accuracy: {test_acc:.2f}%')

In [None]:
prob = model._model.get_gates('prob', X_test.float().cuda()).reshape(-1,28,28)

## Per digits example and distribution

In [None]:
for i in range(10):
    f = plt.figure(figsize=(25,4));
    f.suptitle(f'Data distribution for digit {i}',fontsize=16)
    ax1 = f.add_subplot(151)
    ax2 = f.add_subplot(152)
    ax3 = f.add_subplot(153)
    ax4 = f.add_subplot(154)
    ax5 = f.add_subplot(155)
    # fig, (ax1, ax2, ax3) = plt.subplots(1,3)
    filtered_prob = prob[(y_test==i)]
    counts = (filtered_prob > 0.001).sum(1).sum(1)
    num_prob = filtered_prob.mean(0)
    im = ax3.imshow(num_prob, interpolation='None')
    f.colorbar(im, ax=ax3)
    ax3.title.set_text('Mean prob')
    num_prob = num_prob[num_prob > 0.001]
    ax1.hist(num_prob.reshape(-1));
    ax1.title.set_text('Mean probability > 0.001 hist')
    ax2.hist(counts)
    ax2.title.set_text('Number of point > 0.001 hist')
    ax4.imshow(X_test[y_test==i][0].reshape(28,28))
    ax4.title.set_text('Sample digit')
    ax5.imshow(filtered_prob[0].reshape(28,28))
    ax5.title.set_text('Sample gates')

## Average gate probability

In [None]:
plt.imshow(prob.sum(0))

## Some experiments

In [None]:
aa = model._model.FeatureSelector.net.mlp[0](X_test[:124].float().cuda()).detach().cpu().numpy()

In [None]:
similar = aa[y_test[:124]==0]
different = aa[y_test[:124]!=0][:len(similar)]
len(similar), len(different)

In [None]:
from scipy.spatial import distance_matrix
d1 = distance_matrix(similar, similar, 1)
d2 = distance_matrix(similar, different, 1)

In [None]:
d1.mean(), d2.mean() * (len(d2) - 1)/ len(d2)