In [None]:

import sys
stg_path = '/home/eli/Eli/Projects/stg/python'
if stg_path not in sys.path:
    sys.path.append(stg_path)

In [None]:
from stg import STG
import numpy as np
import scipy.stats # for creating a simple dataset 
import matplotlib.pyplot as plt 
from sklearn.model_selection import train_test_split
from dataset import create_twomoon_dataset
import torch


In [None]:
n_size = 1000 #Number of samples
p_size = 20   #Number of features
# X_data, y_data=create_twomoon_dataset(n_size,p_size)
# print(X_data.shape)
# print(y_data.shape)

In [None]:
# Create a simple dataset

x = np.random.normal(0,1,(11,2000))
# X_data[10] = np.ones(2000)
y_data = np.exp(x[0]*x[1]) * (x[10] < 0) + np.exp(x[2]+x[3]+x[4]+x[5]-4) * (x[10] >= 0)
y_data = 1 / (1 + np.log(y_data / (1 - y_data))) > 0.5

X_data = x.transpose()

print(X_data.shape)
print(y_data.shape)

In [None]:
f,ax = plt.subplots(1,2,figsize=(10,5))
        

ax[0].scatter(x=X_data[:,0], y=X_data[:,1], s=150, c=y_data.reshape(-1),alpha=0.4,cmap=plt.cm.get_cmap('RdYlBu'),)
ax[0].set_xlabel('$x_1$',fontsize=20)
ax[0].set_ylabel('$x_2$',fontsize=20)
ax[0].set_title('Target y')
ax[1].scatter(x=X_data[:,2], y=X_data[:,3], s=150, c=y_data.reshape(-1),alpha=0.4,cmap=plt.cm.get_cmap('RdYlBu'),)
ax[1].set_xlabel('$x_3$',fontsize=20)
ax[1].set_ylabel('$x_4$',fontsize=20)
ax[1].set_title('Target y')
plt.tick_params(labelsize=10)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, train_size=0.3)
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, train_size=0.8)

In [None]:
args_cuda = torch.cuda.is_available()
device = torch.device("cuda" if args_cuda else "cpu") 
feature_selection = True
model = STG(task_type='classification',input_dim=X_train.shape[1], output_dim=2, hidden_dims=[60, 20], activation='relu',
    optimizer='SGD', learning_rate=0.1, batch_size=X_train.shape[0], feature_selection=feature_selection, sigma=0.5, lam=0.5, random_state=1, device=device, extra_args={'gating_net_hidden_dims':[200,200]}) 


In [None]:
print(X_train.dtype, y_train.dtype)

In [None]:
model.fit(X_train, y_train, nr_epochs=6000, valid_X=X_valid, valid_y=y_valid, print_interval=1000)

In [None]:
model.get_gates(mode='prob')

In [None]:
model.get_gates(mode='raw') 

## Testing the model

In [None]:
y_pred=model.predict(X_data)

In [None]:
(y_data==0).sum(), (y_pred==0).sum()

In [None]:
y_pred[:10]

In [None]:
y_data[:10]

In [None]:
f,ax = plt.subplots(1,2,figsize=(10,5))
        

ax[0].scatter(x=X_data[:,0], y=X_data[:,1], s=150, c=y_data.reshape(-1),alpha=0.4,cmap=plt.cm.get_cmap('RdYlBu'),)
ax[0].set_xlabel('$x_1$',fontsize=20)
ax[0].set_ylabel('$x_2$',fontsize=20)
ax[0].set_title('Target y')
ax[1].scatter(x=X_data[:,0], y=X_data[:,1], s=150, c=y_pred.reshape(-1),alpha=0.4,cmap=plt.cm.get_cmap('RdYlBu'),)
ax[1].set_xlabel('$x_1$',fontsize=20)
ax[1].set_ylabel('$x_2$',fontsize=20)
ax[1].set_title('Classification output ')
plt.tick_params(labelsize=10)

X_data[:,0]## Model saving / loading 

In [None]:
model.save_checkpoint('trained_model.pt')

In [None]:
model_tmp = STG(task_type='classification',input_dim=X_train.shape[1], output_dim=2, hidden_dims=[60, 20], activation='tanh',
    optimizer='SGD', learning_rate=0.1, batch_size=X_train.shape[0], feature_selection=feature_selection, sigma=0.5, lam=0.5, random_state=1, device=device) 

In [None]:
model_tmp.load_checkpoint('trained_model.pt')