In [1]:
import os
import json
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from plotly.offline import init_notebook_mode, iplot
from sklearn.preprocessing import OrdinalEncoder, MinMaxScaler
init_notebook_mode(connected=True)

from gamixnn import GAMIxNN

In [2]:
def metric_wrappper(metric, scaler):
    def wrapper(label, pred):
        return metric(label, pred, scaler=scaler)
    return wrapper


def rmse(label, pred, scaler):
    pred = scaler.inverse_transform(pred.reshape([-1, 1]))
    label = scaler.inverse_transform(label.reshape([-1, 1]))
    return np.sqrt(np.mean((pred - label)**2))


def data_generator1(datanum, testnum=10000, noise_sigma=1, rand_seed=0):
    
    np.random.seed(rand_seed)
    x = np.zeros((datanum + testnum, 10))
    for i in range(10):
        x[:, i:i+1] = np.random.uniform(0, 1,[datanum + testnum,1])
    x1, x2, x3, x4, x5, x6, x7, x8, x9, x10 = [x[:, [i]] for i in range(10)]

    def cliff(x1, x2):
        # x1: -20,20
        # x2: -10,5
        x1 = (2 * x1 - 1) * 20
        x2 = (2 * x2 - 1) * 7.5 - 2.5
        term1 = -0.5 * x1 ** 2 / 100
        term2 = -0.5 * (x2 + 0.03 * x1 ** 2 - 3) ** 2
        y = 10 * np.exp(term1 + term2)
        return  y

    y = (8 * (x1 - 0.5) ** 2
         + 0.1 * np.exp(-8 * x2 + 4)
         + 3 * np.sin(2 * np.pi * x3 * x4)
         + cliff(x5, x6)).reshape([-1,1]) + noise_sigma*np.random.normal(0, 1, [datanum + testnum, 1])

    task_type = "Regression"
    meta_info = {"X1":{"type":"continuous"},
             "X2":{"type":"continuous"},
             "X3":{"type":"continuous"},
             "X4":{"type":"continuous"},
             "X5":{"type":"continuous"},
             "X6":{"type":"continuous"},
             "X7":{"type":"continuous"},
             "X8":{"type":"continuous"},
             "X9":{"type":"continuous"},
             "X10":{"type":"continuous"},
             "Y":{"type":"target"}}
    for i, (key, item) in enumerate(meta_info.items()):
        if item['type'] == "target":
            sy = MinMaxScaler((0, 1))
            y = sy.fit_transform(y)
            meta_info[key]["scaler"] = sy
        elif item['type'] == "categorical":
            enc = OrdinalEncoder()
            enc.fit(x[:,[i]])
            ordinal_feature = enc.transform(x[:,[i]])
            x[:,[i]] = ordinal_feature
            meta_info[key]["values"] = enc.categories_[0].tolist()
        else:
            sx = MinMaxScaler((0, 1))
            x[:,[i]] = sx.fit_transform(x[:,[i]])
            meta_info[key]["scaler"] = sx

    train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=testnum, random_state=rand_seed)
    return train_x, test_x, train_y, test_y, task_type, meta_info, metric_wrappper(rmse, sy)

In [3]:
simu_dir = "./results/test/"
if not os.path.exists(simu_dir):
    os.makedirs(simu_dir)

train_x, test_x, train_y, test_y, task_type, meta_info, get_metric = data_generator1(10000, 10000, noise_sigma=1, rand_seed=0)

In [None]:
model = GAMIxNN(meta_info=meta_info, interact_num=10, interact_arch=[40, 20],
               subnet_arch=[10, 6], task_type=task_type, activation_func=tf.tanh, grid_size=101, batch_size=1000,
               lr_bp=0.001, beta_threshold=0.05, gamma_threshold=0.1,
               init_training_epochs=2000, interact_training_epochs=2000, tuning_epochs=100,
               verbose=True, val_ratio=0.2, early_stop_thres=100)
model.fit(train_x, train_y)
pred_train = model.predict(model.tr_x)
pred_val = model.predict(model.val_x)
pred_test = model.predict(test_x)
model.global_explain(simu_dir, "demo_gaminet_simu1")
gaminet_stat = np.hstack([np.round(get_metric(model.tr_y, pred_train),5), 
                      np.round(get_metric(model.val_y, pred_val),5), 
                      np.round(get_metric(test_y, pred_test),5)])
print(gaminet_stat)

Main Effects Training.
Main effects training epoch: 1, train loss: 0.15108, val loss: 0.15592
Main effects training epoch: 2, train loss: 0.14088, val loss: 0.14685
Main effects training epoch: 3, train loss: 0.13277, val loss: 0.13953
Main effects training epoch: 4, train loss: 0.12640, val loss: 0.13364
Main effects training epoch: 5, train loss: 0.12099, val loss: 0.12838
Main effects training epoch: 6, train loss: 0.11604, val loss: 0.12332
Main effects training epoch: 7, train loss: 0.11118, val loss: 0.11820
Main effects training epoch: 8, train loss: 0.10648, val loss: 0.11319
Main effects training epoch: 9, train loss: 0.10189, val loss: 0.10835
Main effects training epoch: 10, train loss: 0.09749, val loss: 0.10377
Main effects training epoch: 11, train loss: 0.09327, val loss: 0.09944
Main effects training epoch: 12, train loss: 0.08922, val loss: 0.09524
Main effects training epoch: 13, train loss: 0.08521, val loss: 0.09110
Main effects training epoch: 14, train loss: 0.081

Main effects training epoch: 117, train loss: 0.01066, val loss: 0.01142
Main effects training epoch: 118, train loss: 0.01065, val loss: 0.01139
Main effects training epoch: 119, train loss: 0.01065, val loss: 0.01138
Main effects training epoch: 120, train loss: 0.01064, val loss: 0.01139
Main effects training epoch: 121, train loss: 0.01064, val loss: 0.01140
Main effects training epoch: 122, train loss: 0.01064, val loss: 0.01136
Main effects training epoch: 123, train loss: 0.01063, val loss: 0.01138
Main effects training epoch: 124, train loss: 0.01063, val loss: 0.01139
Main effects training epoch: 125, train loss: 0.01063, val loss: 0.01137
Main effects training epoch: 126, train loss: 0.01064, val loss: 0.01136
Main effects training epoch: 127, train loss: 0.01062, val loss: 0.01137
Main effects training epoch: 128, train loss: 0.01061, val loss: 0.01136
Main effects training epoch: 129, train loss: 0.01061, val loss: 0.01136
Main effects training epoch: 130, train loss: 0.010

Main effects training epoch: 232, train loss: 0.01043, val loss: 0.01115
Main effects training epoch: 233, train loss: 0.01043, val loss: 0.01119
Main effects training epoch: 234, train loss: 0.01043, val loss: 0.01118
Main effects training epoch: 235, train loss: 0.01042, val loss: 0.01116
Main effects training epoch: 236, train loss: 0.01042, val loss: 0.01121
Main effects training epoch: 237, train loss: 0.01042, val loss: 0.01117
Main effects training epoch: 238, train loss: 0.01042, val loss: 0.01116
Main effects training epoch: 239, train loss: 0.01042, val loss: 0.01117
Main effects training epoch: 240, train loss: 0.01042, val loss: 0.01114
Main effects training epoch: 241, train loss: 0.01042, val loss: 0.01117
Main effects training epoch: 242, train loss: 0.01042, val loss: 0.01118
Main effects training epoch: 243, train loss: 0.01041, val loss: 0.01115
Main effects training epoch: 244, train loss: 0.01041, val loss: 0.01117
Main effects training epoch: 245, train loss: 0.010

Main effects training epoch: 349, train loss: 0.00993, val loss: 0.01074
Main effects training epoch: 350, train loss: 0.00992, val loss: 0.01069
Main effects training epoch: 351, train loss: 0.00992, val loss: 0.01072
Main effects training epoch: 352, train loss: 0.00991, val loss: 0.01069
Main effects training epoch: 353, train loss: 0.00990, val loss: 0.01072
Main effects training epoch: 354, train loss: 0.00989, val loss: 0.01069
Main effects training epoch: 355, train loss: 0.00989, val loss: 0.01071
Main effects training epoch: 356, train loss: 0.00989, val loss: 0.01065
Main effects training epoch: 357, train loss: 0.00988, val loss: 0.01072
Main effects training epoch: 358, train loss: 0.00987, val loss: 0.01064
Main effects training epoch: 359, train loss: 0.00986, val loss: 0.01064
Main effects training epoch: 360, train loss: 0.00986, val loss: 0.01067
Main effects training epoch: 361, train loss: 0.00985, val loss: 0.01064
Main effects training epoch: 362, train loss: 0.009

Main effects training epoch: 462, train loss: 0.00971, val loss: 0.01056
Main effects training epoch: 463, train loss: 0.00972, val loss: 0.01056
Main effects training epoch: 464, train loss: 0.00971, val loss: 0.01057
Main effects training epoch: 465, train loss: 0.00971, val loss: 0.01053
Main effects training epoch: 466, train loss: 0.00971, val loss: 0.01053
Main effects training epoch: 467, train loss: 0.00972, val loss: 0.01054
Main effects training epoch: 468, train loss: 0.00971, val loss: 0.01052
Main effects training epoch: 469, train loss: 0.00971, val loss: 0.01059
Main effects training epoch: 470, train loss: 0.00971, val loss: 0.01054
Main effects training epoch: 471, train loss: 0.00971, val loss: 0.01055
Main effects training epoch: 472, train loss: 0.00971, val loss: 0.01052
Main effects training epoch: 473, train loss: 0.00971, val loss: 0.01054
Main effects training epoch: 474, train loss: 0.00971, val loss: 0.01050
Main effects training epoch: 475, train loss: 0.009

Main effects tunning epoch: 62, train loss: 0.00970, val loss: 0.01051
Main effects tunning epoch: 63, train loss: 0.00970, val loss: 0.01051
Main effects tunning epoch: 64, train loss: 0.00970, val loss: 0.01050
Main effects tunning epoch: 65, train loss: 0.00970, val loss: 0.01051
Main effects tunning epoch: 66, train loss: 0.00971, val loss: 0.01057
Main effects tunning epoch: 67, train loss: 0.00970, val loss: 0.01049
Main effects tunning epoch: 68, train loss: 0.00970, val loss: 0.01051
Main effects tunning epoch: 69, train loss: 0.00970, val loss: 0.01054
Main effects tunning epoch: 70, train loss: 0.00969, val loss: 0.01052
Main effects tunning epoch: 71, train loss: 0.00970, val loss: 0.01053
Main effects tunning epoch: 72, train loss: 0.00970, val loss: 0.01054
Main effects tunning epoch: 73, train loss: 0.00970, val loss: 0.01050
Main effects tunning epoch: 74, train loss: 0.00970, val loss: 0.01051
Main effects tunning epoch: 75, train loss: 0.00970, val loss: 0.01051
Main e

Interaction training epoch: 80, train loss: 0.00313, val loss: 0.00332
Interaction training epoch: 81, train loss: 0.00310, val loss: 0.00331
Interaction training epoch: 82, train loss: 0.00311, val loss: 0.00324
Interaction training epoch: 83, train loss: 0.00308, val loss: 0.00327
Interaction training epoch: 84, train loss: 0.00309, val loss: 0.00323
Interaction training epoch: 85, train loss: 0.00307, val loss: 0.00322
Interaction training epoch: 86, train loss: 0.00305, val loss: 0.00320
Interaction training epoch: 87, train loss: 0.00308, val loss: 0.00323
Interaction training epoch: 88, train loss: 0.00306, val loss: 0.00326
Interaction training epoch: 89, train loss: 0.00303, val loss: 0.00324
Interaction training epoch: 90, train loss: 0.00304, val loss: 0.00326
Interaction training epoch: 91, train loss: 0.00298, val loss: 0.00315
Interaction training epoch: 92, train loss: 0.00299, val loss: 0.00321
Interaction training epoch: 93, train loss: 0.00298, val loss: 0.00313
Intera

Interaction training epoch: 196, train loss: 0.00230, val loss: 0.00248
Interaction training epoch: 197, train loss: 0.00228, val loss: 0.00241
Interaction training epoch: 198, train loss: 0.00227, val loss: 0.00242
Interaction training epoch: 199, train loss: 0.00226, val loss: 0.00242
Interaction training epoch: 200, train loss: 0.00227, val loss: 0.00243
Interaction training epoch: 201, train loss: 0.00227, val loss: 0.00241
Interaction training epoch: 202, train loss: 0.00227, val loss: 0.00243
Interaction training epoch: 203, train loss: 0.00230, val loss: 0.00239
Interaction training epoch: 204, train loss: 0.00231, val loss: 0.00245
Interaction training epoch: 205, train loss: 0.00226, val loss: 0.00241
Interaction training epoch: 206, train loss: 0.00226, val loss: 0.00243
Interaction training epoch: 207, train loss: 0.00225, val loss: 0.00241
Interaction training epoch: 208, train loss: 0.00228, val loss: 0.00247
Interaction training epoch: 209, train loss: 0.00227, val loss: 

Interaction training epoch: 310, train loss: 0.00220, val loss: 0.00231
Interaction training epoch: 311, train loss: 0.00216, val loss: 0.00227
Interaction training epoch: 312, train loss: 0.00216, val loss: 0.00229
Interaction training epoch: 313, train loss: 0.00215, val loss: 0.00231
Interaction training epoch: 314, train loss: 0.00217, val loss: 0.00229
Interaction training epoch: 315, train loss: 0.00214, val loss: 0.00227
Interaction training epoch: 316, train loss: 0.00214, val loss: 0.00226
Interaction training epoch: 317, train loss: 0.00216, val loss: 0.00226
Interaction training epoch: 318, train loss: 0.00214, val loss: 0.00228
Interaction training epoch: 319, train loss: 0.00215, val loss: 0.00230
Interaction training epoch: 320, train loss: 0.00215, val loss: 0.00228
Interaction training epoch: 321, train loss: 0.00215, val loss: 0.00226
Interaction training epoch: 322, train loss: 0.00216, val loss: 0.00231
Interaction training epoch: 323, train loss: 0.00218, val loss: 

Interaction training epoch: 424, train loss: 0.00212, val loss: 0.00225
Interaction training epoch: 425, train loss: 0.00211, val loss: 0.00223
Interaction training epoch: 426, train loss: 0.00212, val loss: 0.00227
Interaction training epoch: 427, train loss: 0.00212, val loss: 0.00222
Interaction training epoch: 428, train loss: 0.00211, val loss: 0.00223
Interaction training epoch: 429, train loss: 0.00211, val loss: 0.00222
Interaction training epoch: 430, train loss: 0.00211, val loss: 0.00224
Interaction training epoch: 431, train loss: 0.00211, val loss: 0.00224
Interaction training epoch: 432, train loss: 0.00210, val loss: 0.00221
Interaction training epoch: 433, train loss: 0.00211, val loss: 0.00223
Interaction training epoch: 434, train loss: 0.00215, val loss: 0.00228
Interaction training epoch: 435, train loss: 0.00212, val loss: 0.00225
Interaction training epoch: 436, train loss: 0.00210, val loss: 0.00221
Interaction training epoch: 437, train loss: 0.00211, val loss: 