In [1]:
import os
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

from gaminet import GAMINet
from gaminet.utils import feature_importance
from gaminet.utils import local_visualize
from gaminet.utils import global_visualize_density
from gaminet.utils import global_visualize_wo_density

## Load data

In [2]:
def metric_wrapper(metric, scaler):
    def wrapper(label, pred):
        return metric(label, pred, scaler=scaler)
    return wrapper

def rmse(label, pred, scaler):
    pred = scaler.inverse_transform(pred.reshape([-1, 1]))
    label = scaler.inverse_transform(label.reshape([-1, 1]))
    return np.sqrt(np.mean((pred - label)**2))

def data_generator1(datanum, random_state=0):
    
    np.random.seed(random_state)
    x = np.zeros((datanum, 10))
    for i in range(10):
        x[:, i:i+1] = np.random.uniform(0, 1,[datanum,1])
    x1, x2, x3, x4, x5, x6, x7, x8, x9, x10 = [x[:, [i]] for i in range(10)]

    def cliff(x1, x2):
        # x1: -20,20
        # x2: -10,5
        x1 = (2 * x1 - 1) * 20
        x2 = (2 * x2 - 1) * 7.5 - 2.5
        term1 = -0.5 * x1 ** 2 / 100
        term2 = -0.5 * (x2 + 0.03 * x1 ** 2 - 3) ** 2
        y = 10 * np.exp(term1 + term2)
        return  y

    y = (8 * (x1 - 0.5) ** 2
         + 0.1 * np.exp(-8 * x2 + 4)
         + 3 * np.sin(2 * np.pi * x3 * x4)
         + cliff(x5, x6)).reshape([-1,1]) + 1 * np.random.normal(0, 1, [datanum, 1])

    task_type = "Regression"
    meta_info = {"X1":{"type":'continuous'},
             'X2':{'type':'continuous'},
             'X3':{'type':'continuous'},
             'X4':{'type':'continuous'},
             'X5':{'type':'continuous'},
             'X6':{'type':'continuous'},
             'X7':{'type':'continuous'},
             'X8':{'type':'continuous'},
             'X9':{'type':'continuous'},
             'X10':{'type':'continuous'},
             'Y':{'type':'target'}}
    for i, (key, item) in enumerate(meta_info.items()):
        if item['type'] == 'target':
            sy = MinMaxScaler((0, 1))
            y = sy.fit_transform(y)
            meta_info[key]['scaler'] = sy
        else:
            sx = MinMaxScaler((0, 1))
            sx.fit([[0], [1]])
            x[:,[i]] = sx.transform(x[:,[i]])
            meta_info[key]['scaler'] = sx

    train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.2, random_state=random_state)
    return train_x, test_x, train_y, test_y, task_type, meta_info, metric_wrapper(rmse, sy)

train_x, test_x, train_y, test_y, task_type, meta_info, get_metric = data_generator1(datanum=5000, random_state=0)

## Train GAMI-Net 

In [None]:
model = GAMINet(meta_info=meta_info, interact_num=20, interact_arch=[20, 10],
            subnet_arch=[20, 10], task_type=task_type, activation_func=tf.tanh, main_grid_size=41, interact_grid_size=41,
            batch_size=min(500, int(0.2*train_x.shape[0])), lr_bp=0.001, main_effect_epochs=2000,
            interaction_epochs=2000, tuning_epochs=50, loss_threshold=0.01,
            verbose=True, val_ratio=0.2, early_stop_thres=100)
model.fit(train_x, train_y)

val_x = train_x[model.val_idx, :]
val_y = train_y[model.val_idx, :]
tr_x = train_x[model.tr_idx, :]
tr_y = train_y[model.tr_idx, :]
pred_train = model.predict(tr_x)
pred_val = model.predict(val_x)
pred_test = model.predict(test_x)
gaminet_stat = np.hstack([np.round(get_metric(tr_y, pred_train),5), 
                      np.round(get_metric(val_y, pred_val),5),
                      np.round(get_metric(test_y, pred_test),5)])
print(gaminet_stat)

Main effects training epoch: 1, train loss: 0.13117, val loss: 0.13101
Main effects training epoch: 2, train loss: 0.12540, val loss: 0.12523
Main effects training epoch: 3, train loss: 0.12052, val loss: 0.12037
Main effects training epoch: 4, train loss: 0.11592, val loss: 0.11590
Main effects training epoch: 5, train loss: 0.11178, val loss: 0.11174
Main effects training epoch: 6, train loss: 0.10781, val loss: 0.10770
Main effects training epoch: 7, train loss: 0.10394, val loss: 0.10392
Main effects training epoch: 8, train loss: 0.10021, val loss: 0.10031
Main effects training epoch: 9, train loss: 0.09677, val loss: 0.09690
Main effects training epoch: 10, train loss: 0.09353, val loss: 0.09369
Main effects training epoch: 11, train loss: 0.09005, val loss: 0.09022
Main effects training epoch: 12, train loss: 0.08719, val loss: 0.08743
Main effects training epoch: 13, train loss: 0.08436, val loss: 0.08457
Main effects training epoch: 14, train loss: 0.08103, val loss: 0.08112
M

Main effects training epoch: 118, train loss: 0.01003, val loss: 0.00996
Main effects training epoch: 119, train loss: 0.01001, val loss: 0.00991
Main effects training epoch: 120, train loss: 0.00997, val loss: 0.00985
Main effects training epoch: 121, train loss: 0.01007, val loss: 0.00998
Main effects training epoch: 122, train loss: 0.00998, val loss: 0.00987
Main effects training epoch: 123, train loss: 0.00994, val loss: 0.00983
Main effects training epoch: 124, train loss: 0.01002, val loss: 0.00993
Main effects training epoch: 125, train loss: 0.00996, val loss: 0.00981
Main effects training epoch: 126, train loss: 0.00995, val loss: 0.00985
Main effects training epoch: 127, train loss: 0.00996, val loss: 0.00987
Main effects training epoch: 128, train loss: 0.00994, val loss: 0.00983
Main effects training epoch: 129, train loss: 0.00993, val loss: 0.00980
Main effects training epoch: 130, train loss: 0.00994, val loss: 0.00982
Main effects training epoch: 131, train loss: 0.009

Main effects training epoch: 232, train loss: 0.00973, val loss: 0.00963
Main effects training epoch: 233, train loss: 0.00972, val loss: 0.00970
Main effects training epoch: 234, train loss: 0.00971, val loss: 0.00971
Main effects training epoch: 235, train loss: 0.00971, val loss: 0.00963
Main effects training epoch: 236, train loss: 0.00970, val loss: 0.00964
Main effects training epoch: 237, train loss: 0.00971, val loss: 0.00966
Main effects training epoch: 238, train loss: 0.00970, val loss: 0.00968
Main effects training epoch: 239, train loss: 0.00969, val loss: 0.00963
Main effects training epoch: 240, train loss: 0.00969, val loss: 0.00963
Main effects training epoch: 241, train loss: 0.00968, val loss: 0.00967
Main effects training epoch: 242, train loss: 0.00968, val loss: 0.00962
Main effects training epoch: 243, train loss: 0.00966, val loss: 0.00961
Main effects training epoch: 244, train loss: 0.00966, val loss: 0.00965
Main effects training epoch: 245, train loss: 0.009

## Visualization

In [None]:
simu_dir = "./results/"
if not os.path.exists(simu_dir):
    os.makedirs(simu_dir)

In [None]:
feature_importance(data_dict, save_png=True, folder=simu_dir, name='s1_feature')

In [None]:
data_dict = model.global_explain(save_dict=False)
global_visualize_wo_density(data_dict, save_png=True, folder=simu_dir, name='s1_global')

In [None]:
data_dict_local = model.local_explain(train_x[[0]], train_y[[0]], save_dict=False)
local_visualize(data_dict_local, save_png=True, folder=simu_dir, name='s1_local')

In [None]:
data_dict_logs = model.summary_logs(save_dict=False, folder=folder, name="s1_logs")
plot_trajectory(data_dict_logs, folder=folder, name="s1_traj_plot", log_scale=True, save_png=True, save_eps=False)
plot_regularization(data_dict_logs, folder=folder, name="s1_regu_plot", log_scale=True, save_png=True, save_eps=False)