In [None]:
import sys

import matplotlib.pyplot as plt 
%matplotlib inline  
import seaborn as sns
import numpy as np
import scipy.stats # for creating a simple dataset 
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
from torch.nn.functional import one_hot
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.transforms import ToTensor

from hungarian_algorithm import algorithm
from scipy.optimize import linear_sum_assignment

stg_path = '../'
if stg_path not in sys.path:
    sys.path.append(stg_path)

from dataset import create_twomoon_dataset
from stg import STG, train_net_to_output_float

In [None]:
traindt = MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)
testdt = MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)


tf = transforms.Compose(
    [transforms.Normalize((0.1307,), (0.3081,)), transforms.Pad(2)]
)


X_train = tf(torch.unsqueeze((traindt.data/255).float(), 1))
y_train = traindt.targets
X_test = tf(torch.unsqueeze((testdt.data/255).float(), 1))
y_test = testdt.targets

X_train.shape, y_train.shape, X_test.shape, y_test.shape

In [None]:
args_cuda = torch.cuda.is_available()
device = torch.device("cuda" if args_cuda else "cpu") 

In [None]:
feature_selection = True
model = STG(task_type='encoding_unet',input_dim=X_train.shape[1], output_dim=20, hidden_dims=32, activation='none',
    optimizer='SGD', learning_rate=0.01, batch_size=128, feature_selection=feature_selection, sigma=1, lam=0.001, random_state=1, device=device, extra_args={'gating_net_hidden_dims':[50], 'noise_sigma':0, 'lam_sim': 0.5}) 

In [None]:
gating_net = model._model.FeatureSelector


In [None]:
rand_gating_input = torch.randn((10,64), device = device)
res_gating = gating_net(rand_gating_input)[0]
print('mean output', res_gating.mean().item(), '>0 =0 <0 percentage', (res_gating>0).sum().item()/640, (res_gating==0).sum().item()/640, (res_gating<0).sum().item()/640)

In [None]:
train_net_to_output_float(gating_net, 0.9, device, (32,64), 10000, 0.1)
train_net_to_output_float(gating_net, 0.9, device, (32,64), 10000, 0.01)

In [None]:
res_gating = gating_net(rand_gating_input)[0]
print('mean output', res_gating.mean().item(), '>0 =0 <0 percentage', (res_gating>0).sum().item()/640, (res_gating==0).sum().item()/640, (res_gating<0).sum().item()/640)

In [None]:
model.fit(X_train, None, nr_epochs=1, valid_X=X_test, valid_y=y_test, print_interval=1, is_tensor_input=True)

In [None]:
sample_size=5
aaa = X_test[0:sample_size]
aaa_noise = torch.randn_like(aaa) * 0 + aaa
res = model.predict(aaa_noise)

fig, axs = plt.subplots(sample_size, 5);
fig.set_size_inches(20, 4*sample_size);

for i in range(sample_size):
    axs[i,0].imshow(res[i].squeeze());
    axs[i,1].imshow(aaa_noise.squeeze()[i]);
    axs[i,2].imshow(aaa.squeeze()[i]);
    axs[i,3].imshow(np.abs(res[i]-aaa[i].numpy())[0]);
    axs[i,4].imshow(torch.abs(aaa_noise[i]-aaa[i])[0]);


# np.abs(res[0]-aaa[0].numpy()).mean(), torch.abs(aaa_noise[0]-aaa[0]).mean().item()

In [None]:
r=50
test_res = model._model.get_gates("prob",X_test[:r].cuda())
vals = test_res>0
for k in range(10):
    print(k)
    for i in range(r):
        if y_test[i]==k:
            print (y_test[i], vals[i].nonzero())

In [None]:
plt.bar(np.arange(64), vals.astype(int).sum(0))

In [None]:
def similarity_loss(x):
    x_T = x.T

    x_norm = torch.linalg.norm(x, dim=1, keepdim=True)  # Size (n, 1).
    x_T_norm = torch.linalg.norm(x_T, dim=0, keepdim=True)  # Size (1, b).

    cosine_similarity = ((x @ x_T) / (x_norm @ x_T_norm)).T
    # cosine_similarity = cosine_similarity - torch.eye(x.size(0), device = x.device)
    return cosine_similarity

In [None]:
yy=similarity_loss(torch.from_numpy(test_res))
plt.imshow(yy)

In [None]:
test_data_loader = model.get_dataloader(X_test,y_test, False, True)
all_res = []
for b in test_data_loader:
    one_res = (model._model.get_gates("prob",b['input'].cuda())>1e-5).astype(int)
    all_res.append(one_res)
all_res = np.vstack(all_res)

In [None]:
from sklearn.cluster import AgglomerativeClustering
cl = AgglomerativeClustering(10, affinity= "l1", linkage="complete")
cl.fit(all_res)
# for i in range(10):
#     print(np.where(cl.labels_==i))

In [None]:
plt.hist(cl.labels_,bins=cl.n_clusters);

In [None]:
def get_matches(a1, a2):
    a1_map = {i:np.where(a1==i) for i in range(cl.n_clusters)}
    a2_map = {i:np.where(a2==i) for i in range(10)}
    scores={}
    scores_m = np.zeros((cl.n_clusters, 10))
    # scores = np.zeros((len(a1_map), len(a2_map)))
    for i in range(cl.n_clusters):
        scores[str(i)] = {}
        for j in range(10):
            val = len(np.intersect1d(a1_map[i], a2_map[j]))
            scores[str(i)]["a"+str(j)] = val
            scores_m[i,j] = val
            # scores[i,j]=len(np.intersect1d(a1_map[i], a2_map[j]))
    
    # return scores_m
    
    # row_ind, col_ind = linear_sum_assignment(scores)
    # return row_ind, col_ind, scores

    res=algorithm.find_matching(scores, matching_type = 'max', return_type = 'list' )
    return res, sum([r[1] for r in res]), scores_m

In [None]:
match_res, sums, scores = get_matches(cl.labels_, y_test.numpy())
print(sums/len(y_test))

In [None]:
for label in range(10):
    ind = np.where(y_test[:50]==label)
    if len(ind[0])==0:
        continue
    print(label, '\t', (yy[ind][:,ind].sum(0).min()/len(ind[0])).item())

In [None]:
label = 9
ind = np.where(y_test[:50]==label)
sim_on_label = yy[ind][:,ind].squeeze()
sns.heatmap(sim_on_label, annot=True)

In [None]:
sim_on_label.sum(0)

In [None]:
rel_vals = vals[ind]

In [None]:
sns.heatmap(rel_vals)

In [None]:
sample_size=len(ind[0])
aaa = X_test[ind]
aaa_noise = torch.randn_like(aaa) * 1 + aaa
res = model.predict(aaa_noise)

fig, axs = plt.subplots(sample_size, 5);
fig.set_size_inches(20, 4*sample_size);

for i in range(sample_size):
    axs[i,0].imshow(res[i].squeeze());
    axs[i,1].imshow(aaa_noise.squeeze()[i]);
    axs[i,2].imshow(aaa.squeeze()[i]);
    axs[i,3].imshow(np.abs(res[i]-aaa[i].numpy())[0]);
    axs[i,4].imshow(torch.abs(aaa_noise[i]-aaa[i])[0]);

In [None]:
vals[ind].sum(1)

In [None]:
(rel_vals.sum(0)==1).sum(), rel_vals[0].sum()

In [None]:
indices_of_interest = np.where((rel_vals.sum(0)==1) & rel_vals[0])
indices_of_interest

In [None]:
second_batch = X_test[np.where(y_test==9)][5:200]
print(len(second_batch))

second_batch_res = model._model.get_gates("prob",second_batch.cuda())
second_batch_vals = second_batch_res>0

In [None]:
count = 0
for i in range(second_batch_vals.shape[0]):
    inter_score = len( np.intersect1d(np.where(second_batch_vals[i]), indices_of_interest))
    if inter_score>=5:
        count+=1
        print(i,  inter_score)
        plt.figure()
        plt.imshow(second_batch[i].squeeze())
        if count >= 10:
            break

In [None]:
count = 0
for i in range(second_batch_vals.shape[0]):
    inter_score = len( np.intersect1d(np.where(second_batch_vals[i]), indices_of_interest))
    if inter_score ==0:
        count += 1
        print(i,  inter_score)
        plt.figure()
        plt.imshow(second_batch[i].squeeze())
        if count >=5:
            break


In [None]:
second_batch_diff = X_test[np.where(y_test!=9)][5:200]
print(len(second_batch_diff))

second_batch_diff_res = model._model.get_gates("prob",second_batch_diff.cuda())
second_batch_diff_vals = second_batch_diff>0

In [None]:
count = 0
for i in range(second_batch_diff.shape[0]):
    inter_score = len( np.intersect1d(np.where(second_batch_diff_vals[i]), indices_of_interest))
    if inter_score >=2:
        count += 1
        print(i, inter_score)
        plt.figure()
        plt.imshow(second_batch_diff[i].squeeze())
        if count >=5:
            break

## Grid

In [None]:
from itertools import product
lams = [0.2,0.05,0.01,0.005,0.001]
lam_sims = [10,2,0.5,0.1,0.02,0.005]
for el in product(lams, lam_sims):
    print('lam ', el[0], ' lam_sim', el[1])
    print('------------------------------------')
    feature_selection = True
    model = STG(task_type='encoding_unet',input_dim=X_train.shape[1], output_dim=20, hidden_dims=32, activation='none', 
                optimizer='SGD', learning_rate=0.01, batch_size=128, feature_selection=feature_selection, sigma=1, lam=el[0], random_state=1, device=device, extra_args={'gating_net_hidden_dims':[200,200], 'noise_sigma':1, 'lam_sim': el[1]}) 
    model.fit(X_train, None, nr_epochs=15, valid_X=X_test, valid_y=y_test, print_interval=1, is_tensor_input=True)
    r=50
    test_res = model._model.get_gates("prob",X_test[:r].cuda())
    plt.figure()
    yy=similarity_loss(torch.from_numpy(test_res))
    plt.imshow(yy)
    plt.title('lam '+str(el[0])+' lam_sim ' + str(el[1]))

In [None]:
params = [
for i in range(10):


In [None]:
prob = model._model.get_gates('prob', X_test.float().cuda()).reshape(-1,28,28)
prob_means = prob.mean(0)
plt.imshow(prob_means)
prob_fixed = (prob_means > 0.99).astype(float)
plt.figure()
plt.imshow(prob_fixed);

In [None]:
for i in range(10):
    print(y_train[i].item())
    plt.figure()
    plt.imshow(prob[i]-prob_fixed)

In [None]:
for i in range(10):
    plt.figure()
    plt.imshow(prob[y_test==i].mean(0)-prob_fixed)

## Per digits example and distribution

In [None]:
for i in range(10):
    f = plt.figure(figsize=(25,4));
    f.suptitle(f'Data distribution for digit {i}',fontsize=16)5
    ax1 = f.add_subplot(151)
    ax2 = f.add_subplot(152)
    ax3 = f.add_subplot(153)
    ax4 = f.add_subplot(154)
    ax5 = f.add_subplot(155)
    # fig, (ax1, ax2, ax3) = plt.subplots(1,3)
    filtered_prob = prob[(y_test==i)]
    counts = (filtered_prob > 0.001).sum(1).sum(1)
    num_prob = filtered_prob.mean(0)
    im = ax3.imshow(num_prob, interpolation='None')
    f.colorbar(im, ax=ax3)
    ax3.title.set_text('Mean prob')
    num_prob = num_prob[num_prob > 0.001]
    ax1.hist(num_prob.reshape(-1));
    ax1.title.set_text('Mean probability > 0.001 hist')
    ax2.hist(counts)
    ax2.title.set_text('Number of point > 0.001 hist')
    ax4.imshow(X_test[y_test==i][0].reshape(28,28))
    ax4.title.set_text('Sample digit')
    ax5.imshow(filtered_prob[0].reshape(28,28))
    ax5.title.set_text('Sample gates')

## Average gate probability

In [None]:
plt.imshow(prob.mean(0))

## Some experiments

In [None]:
aa = model._model.FeatureSelector.net.mlp[0](X_test[:124].float().cuda()).detach().cpu().numpy()

In [None]:
similar = aa[y_test[:124]==0]
different = aa[y_test[:124]!=0][:len(similar)]
len(similar), len(different)

In [None]:
from scipy.spatial import distance_matrix
d1 = distance_matrix(similar, similar, 1)
d2 = distance_matrix(similar, different, 1)

In [None]:
d1.mean(), d2.mean() * (len(d2) - 1)/ len(d2)