In [None]:
import sys

import matplotlib.pyplot as plt 
%matplotlib inline  
import numpy as np
import scipy.stats # for creating a simple dataset 
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
from torch.nn.functional import one_hot
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

stg_path = '../'

if stg_path not in sys.path:
    sys.path.append(stg_path)

from dataset import create_twomoon_dataset
from stg import STG

In [None]:
sub_train_size=50

traindt = MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)
testdt = MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)

X_train = traindt.data.reshape(traindt.data.size(0), -1)[:sub_train_size]
y_train = traindt.targets[:sub_train_size]
X_test = testdt.data.reshape(testdt.data.size(0), -1)
y_test = testdt.targets

X_train.shape, y_train.shape, X_test.shape, y_test.shape

In [None]:
train_indices = []
test_indices = []

for i in range(10):
    train_indices.append(torch.nonzero(y_train==i).reshape(-1))
    test_indices.append(torch.nonzero(y_test==i).reshape(-1))
min_train = min([len(a) for a in train_indices])
min_test = min([len(a) for a in test_indices])

new_train_ind = []
new_test_ind = []

for i in range(10):
    new_train_ind.append(train_indices[i][:min_train])
    new_test_ind.append(test_indices[i][:min_test])

In [None]:
number_sets = [(1,1,1), (2,2,2)]

data_sets = []
total_lengths = {'train':0, 'test':0}

for s in number_sets:
    train_ind_tmp = torch.cat([new_train_ind[i].reshape(1,-1) for i in s], axis=0).transpose(1,0).reshape(-1)
    test_ind_tmp = torch.cat([new_test_ind[i].reshape(1,-1) for i in s], axis=0).transpose(1,0).reshape(-1)
    data_sets.append({'X_train': X_train[train_ind_tmp], 'y_train': y_train[train_ind_tmp], 'X_test':X_test[test_ind_tmp], 'y_test':y_test[test_ind_tmp]})
    total_lengths['train']+= len(train_ind_tmp)
    total_lengths['test']+= len(test_ind_tmp)

print('Datasets train/test lengths')
for s in data_sets:
    print(len(s['y_train']), len(s['y_test']))

In [None]:
batch_size = 128
recurrent_split_dim = 8
lam = 0.0
dropout = 0.1

In [None]:
split_factor = 4

X_train_new = []
y_train_new =[]
X_test_new = []
y_test_new = []

for f in range(split_factor):
    for i, s in enumerate(data_sets):
        for t in ['train', 'test']:
            X_temp = data_sets[i]['X_' + t]
            y_temp = data_sets[i]['y_' + t]
            l = len(y_temp)
            jump = float(l) / split_factor
            start = int(f*jump)
            end = int((f+1)*jump + 1e-10)
            if t=='train':
                X_train_new.append(X_temp[start:end])
                y_train_new.append(y_temp[start:end])
            else:
                X_test_new.append(X_temp[start:end])
                y_test_new.append(y_temp[start:end])

X_train_new = torch.cat(X_train_new)
y_train_new = torch.cat(y_train_new)
X_test_new = torch.cat(X_test_new)
y_test_new = torch.cat(y_test_new)


# fix length for recurrent neureal network splitting
X_train_new = X_train_new[: (X_train_new.shape[0] // recurrent_split_dim * recurrent_split_dim)]
y_train_new = y_train_new[: (y_train_new.shape[0] // recurrent_split_dim * recurrent_split_dim)]
X_test_new = X_test_new[: (X_test_new.shape[0] // recurrent_split_dim * recurrent_split_dim)]
y_test_new = y_test_new[: (y_test_new.shape[0] // recurrent_split_dim * recurrent_split_dim)]

In [None]:
args_cuda = torch.cuda.is_available()
device = torch.device("cuda" if args_cuda else "cpu") 
feature_selection = True
model = STG(task_type='classification',input_dim=X_train.shape[1], output_dim=10, hidden_dims=[10], activation='none',
    optimizer='SGD', learning_rate=0.0001, batch_size=batch_size, feature_selection=feature_selection, sigma=1, lam=lam, random_state=1, device=device, extra_args={'gating_net_hidden_dims':2000}, recurrent_split_dim=None, dropout=dropout) 

In [None]:
args_cuda = torch.cuda.is_available()
device = torch.device("cuda" if args_cuda else "cpu") 
feature_selection = True
model = STG(task_type='classification',input_dim=X_train.shape[1], output_dim=10, hidden_dims=[10], activation='none',
    optimizer='SGD', learning_rate=0.001, batch_size=batch_size, feature_selection=feature_selection, sigma=1, lam=lam, random_state=1, device=device, extra_args={'gating_net_hidden_dims':2000}, recurrent_split_dim=recurrent_split_dim, dropout=dropout) 

In [None]:
# X_train_run, y_train_run, X_test_run, y_test_run = X_train, y_train, X_test, y_test
X_train_run, y_train_run, X_test_run, y_test_run = X_train_new, y_train_new, X_test_new, y_test_new

In [None]:
epochs = 1000
print_interval = 100

model.fit(X_train_run, y_train_run, nr_epochs=epochs, valid_X=X_test_run, valid_y=y_test_run, print_interval=print_interval, is_tensor_input=True)

train_acc = 100 * (y_train_run.cpu().numpy()==model.predict(X_train_run)).sum()/len(y_train_run)
test_acc = 100 *(y_test_run.cpu().numpy()==model.predict(X_test_run)).sum()/len(y_test_run)

print(f'train accuracy: {train_acc:.2f}% test accuracy: {test_acc:.2f}%')

if model.has_feature_selection:
    prob = model._model.get_gates('prob', X_test_run.float().cuda()).reshape(-1,28,28)
    print(len(y_test_run), prob.mean(), prob.std())
    plt.imshow(prob.mean(0));

## Per digits example and distribution

In [None]:
for i in range(10):
    f = plt.figure(figsize=(25,4));
    f.suptitle(f'Data distribution for digit {i}',fontsize=16)
    ax1 = f.add_subplot(151)
    ax2 = f.add_subplot(152)
    ax3 = f.add_subplot(153)
    ax4 = f.add_subplot(154)
    ax5 = f.add_subplot(155)
    # fig, (ax1, ax2, ax3) = plt.subplots(1,3)
    filtered_prob = prob[(y_test==i)]
    counts = (filtered_prob > 0.001).sum(1).sum(1)
    num_prob = filtered_prob.mean(0)
    im = ax3.imshow(num_prob, interpolation='None')
    f.colorbar(im, ax=ax3)
    ax3.title.set_text('Mean prob')
    num_prob = num_prob[num_prob > 0.001]
    ax1.hist(num_prob.reshape(-1));
    ax1.title.set_text('Mean probability > 0.001 hist')
    ax2.hist(counts)
    ax2.title.set_text('Number of point > 0.001 hist')
    ax4.imshow(X_test[y_test==i][0].reshape(28,28))
    ax4.title.set_text('Sample digit')
    ax5.imshow(filtered_prob[0].reshape(28,28))
    ax5.title.set_text('Sample gates')

## Average gate probability

In [None]:
plt.imshow(prob.sum(0));

## Some experiments

In [None]:
aa = model._model.FeatureSelector.net.mlp[0](X_test[:124].float().cuda()).detach().cpu().numpy()

In [None]:
similar = aa[y_test[:124]==0]
different = aa[y_test[:124]!=0][:len(similar)]
len(similar), len(different)

In [None]:
from scipy.spatial import distance_matrix
d1 = distance_matrix(similar, similar, 1)
d2 = distance_matrix(similar, different, 1)

In [None]:
d1.mean(), d2.mean() * (len(d2) - 1)/ len(d2)