In [None]:
from LANAM.models import  NAM, LaNAM
from LANAM.trainer import *
from LANAM.trainer.nam_trainer import train
from LANAM.trainer import test
from LANAM.config.default import * 
from LANAM.data import *
from LANAM.data.base import LANAMDataset, LANAMSyntheticDataset
from LANAM.utils.plotting import * 
from LANAM.utils.hsic import *

import matplotlib.pyplot as plt
import seaborn as sns
import copy 

In [None]:
%reload_ext autoreload
%autoreload 2

# Concurvity Regularization
## Preliminary
### NAM 
$$
h(y) = f_1(x_1) + \cdot + f_d(x_d) + \beta_0 
$$
where $\beta_0$ is the global bias.
### Concurvity
1. Why we don't like concurvity: </br>
fitted model becomes less interpretable as each feature's contribution to the target is not immediately apparently.
2. Target: pairwise uncorrelatedness, $\text{corr}(f_i, f_j) = 0$</br>
where $\text{corr}(\cdot)$ is the Pearson correlation coefficient: 
$$ 
r_{xy} = \frac{\sum_i(x-\overline x)(y - \overline y)}{\sqrt{\sum_i (x-\overline x)^2}\sqrt{\sum_i (y-\overline y)^2}}
$$

3. Method: concurvity regularization $\frac{1}{p(p-1)/2}\sum_{i=1}^p \sum_{j=i+1}^p \left|\text{corr}\left(f_i(X_i), f_j(X_j)\right)\right|$

4. Evaluation: three different strategies are used for evaluation, 
    - Pairwise correlation: $\text{corr}(f_i, f_j)$
    - Correlation between target and transformed features: $\text{corr}(f_i, y)$
    - Estimated feature importance (sensitivity): variance of shape function on training set. 
        - $\text{FI}_i[f_i(x_i)] = \frac{1}{N}\sum^N_{j}|f_i(x_{ij}) - \overline{f_i}|$ for transformed features, where $\overline{f}$ is the mean value of shape function on training data.  
        - $\text{FI}_i[x_i] = \frac{1}{N}\sum^N_{j}|x_{ij} - \overline{x_i}|$ for untransformed features. 
   
5. Performance and concurvity trade-off.

## Questions
1. Feature importance. 

## About toy examples
1. Experimental setup sharing between synthetic examples: 
    - $10000$ samples, dataset split: 7: 2: 1.
    - activation function: GELU
    - three hidden layers, each of which contains $128$ units.</br>
    - concurvity regularization parameter $\lambda \in [1e-6, 1]$. (fig. 3(b))
    
2. Questions: 
    - isn't the training dataset too large? => $10000 \rightarrow 1000$, no impact. 
    - different activation functions => 
    - different hidden layers => single layer: requires more training epochs.
    - behaviours: 
        - strongly correlated features are **all** muted. 
            - when strongly correlated features are important? performance-concurvity trade-off. 
        - mostly uncorrelated features remain.
    - training samples are shuffled; appoximate the global distribution; global correlation. 

3. Comments: 
    - sampling on subsets.

### Ex2
#### Ex2.0
$$
\begin{aligned}
&Y = 0 \cdot X_1 + 1 \cdot X_2, \\
&X_1 =  Z, \\
&X_2 = |Z|, \\
&Z \sim \mathcal{N}(0, 1), \quad \text{truncated by } (-1, 1)
\end{aligned}
$$

In [None]:
nonlinearly_dependent_data = load_nonlinearly_dependent_2D_examples(num_samples=1000) # uncorrelated features 
nd_train_dl, _, nd_val_dl, _ = nonlinearly_dependent_data.train_dataloaders()
nd_test_samples = nonlinearly_dependent_data.get_test_samples()

In [None]:
nonlinearly_dependent_data.plot_dataset()
untransformed_nd_feature_importance = feature_importance(nonlinearly_dependent_data.features)
untransformed_nd_feature_correlation = pairwise_correlation(nonlinearly_dependent_data.features)
print(f'[nonlinearly dependent dataset]: untransformed feature importance: {untransformed_nd_feature_importance}')
print(f'[nonlinearly dependent dataset]: corr(X1, X2): {untransformed_nd_feature_correlation[0][1]: .6f}')

In [None]:
cfg = toy_default()
cfg.log_loss_frequency = 200

In [None]:
cfg.concurvity_regularization = 0
nd_wo_model = train(config=cfg, train_loader=nd_train_dl, val_loader=nd_val_dl, test_samples=nd_test_samples, ensemble=True)

In [None]:
X, y, shape_functions, names = nd_test_samples
    
prediction_mean, feature_contribution_mean, prediction_mean, feature_contribution_var = get_prediction(nd_wo_model, nd_test_samples)

plt.scatter(feature_contribution_mean[:, 0], feature_contribution_mean[:, 1])

In [None]:
cfg.concurvity_regularization = 0.1
nd_w_model = train(config=cfg, train_loader=nd_train_dl, val_loader=nd_val_dl, test_samples=nd_test_samples, ensemble=False)

#### Ex2.2
$$
\begin{aligned}
&Y = 0 \cdot X_1 + 1 \cdot X_2, \\
&X_1 =  Z, \\
&X_2 = \sin(4Z), \\
&Z \sim \mathcal{N}(0, 1), \quad \text{truncated by } (-1, 1)
\end{aligned}
$$

In [None]:
data = load_nonlinearly_dependent_2D_examples(num_samples=1000, sampling_type='normal', 
                                                                    dependent_functions=lambda x: torch.sin(3*x)) # uncorrelated features 
train_dl, _, val_dl, _ = data.train_dataloaders()
test_samples = data.get_test_samples()

data.plot_dataset()
fig, axs = plt.subplots(figsize=(4,3))
axs.set_title('Relation between untransformed features')
axs.set_xlabel('X2')
axs.set_ylabel('X1')
axs.scatter(data.features[:, 0], data.features[:, 1])
untransformed_nd_feature_importance = feature_importance(data.features)
untransformed_nd_feature_correlation = pairwise_correlation(data.features)
print(f'[nonlinearly dependent dataset]: untransformed feature importance: {untransformed_nd_feature_importance}')
print(f'[nonlinearly dependent dataset]: corr(X1, X2): {untransformed_nd_feature_correlation[0][1]: .6f}')

In [None]:
cfg = toy_default()
cfg.num_epochs = 400
cfg.log_loss_frequency = 100
print(cfg)

In [None]:
cfg.concurvity_regularization = 0
nd_wo_model = train(config=cfg, train_loader=train_dl, val_loader=val_dl, test_samples=test_samples, ensemble=False)

In [None]:
cfg.concurvity_regularization = 0.1
nd_wo_model = train(config=cfg, train_loader=train_dl, val_loader=val_dl, test_samples=test_samples, ensemble=False)

#### Ex2.1
$$
\begin{aligned}
&Y = 0 \cdot X_1 + 1 \cdot X_2, \\
&X_1 =  Z, \\
&X_2 = \sin(2Z), \\
&Z \sim \mathcal{N}(0, 1), \quad \text{truncated by } (-1, 1)
\end{aligned}
$$

In [None]:
nonlinearly_dependent_data = load_nonlinearly_dependent_2D_examples(num_samples=1000, dependent_functions=lambda x: torch.sin(2*x)) # uncorrelated features 
nd_train_dl, _, nd_val_dl, _ = nonlinearly_dependent_data.train_dataloaders()
nd_test_samples = nonlinearly_dependent_data.get_test_samples()

nonlinearly_dependent_data.plot_dataset()
fig, axs = plt.subplots(figsize=(4,3))
axs.set_title('Relation between untransformed features')
axs.set_xlabel('X2')
axs.set_ylabel('X1')
axs.scatter(nonlinearly_dependent_data.features[:, 0], nonlinearly_dependent_data.features[:, 1])
untransformed_nd_feature_importance = feature_importance(nonlinearly_dependent_data.features)
untransformed_nd_feature_correlation = pairwise_correlation(nonlinearly_dependent_data.features)
print(f'[nonlinearly dependent dataset]: untransformed feature importance: {untransformed_nd_feature_importance}')
print(f'[nonlinearly dependent dataset]: corr(X1, X2): {untransformed_nd_feature_correlation[0][1]: .6f}')

In [None]:
cfg = toy_default()
cfg.concurvity_regularization = 0.06
nd_wo_model = train(config=cfg, train_loader=nd_train_dl, val_loader=nd_val_dl, test_samples=nd_test_samples, ensemble=False)

### Ex3: concurvity examples
$$
X_1 \sim X_2 \sim X_3 \sim U(0,1)\\
X_4 = X_2^3 + X_3 ^ 2 + \sigma_1\\
X_5 = X_3^2+\sigma_1\\
X_6 = X_2^2 + X_4^3+\sigma_1 \\
X_7 = X_1 \times X_4 +\sigma_1\\
Y = 2X_1^2 + X_5^3 + 2\sin X_6+\sigma_2
$$
#### Ex3.1: different hidden sizes

In [None]:
concurvity_data = load_concurvity_data(sigma_1=0.05, sigma_2=0.5, num_samples=1000)
con_train_dl, con_train_dl_fnn, con_val_dl, _ = concurvity_data.train_dataloaders()
con_test_samples = concurvity_data.get_test_samples()
concurvity_data.plot_dataset()
# concurvity_data.plot_scatterplot_matrix()

In [None]:
cfg = defaults()
lanam = LaNAM(config=cfg, name="LA-NAM", in_features=concurvity_data.in_features, hessian_structure='kron', subset_of_weights='last_layer')

lanam, margs, losses, perfs = marglik_training(lanam, 
                                               con_train_dl, 
                                               con_train_dl_fnn, 
                                               con_val_dl, 
                                               likelihood='regression', 
                                               test_samples=con_test_samples,
                                               n_epochs=500, 
                                               use_wandb=False, 
                                               optimizer_kwargs={'lr': 1e-2})

In [None]:
X_test, y_test, fnn_test, _ = con_test_samples
f_mu, f_var, f_mu_fnn, f_var_fnn = lanam.predict(X_test)

importance_fig = plot_feature_importance(lanam, con_test_samples)

recover_fig = plot_recovered_functions(X_test, y_test, fnn_test, f_mu_fnn, f_var_fnn.flatten(start_dim=1), center=False)       

In [None]:
cfg = toy_default()
# cfg.output_regularization = 0.05
cfg.log_loss_frequency = 100
cfg.concurvity_regularization = 0
cfg.num_ensemble = 5
cfg.num_epochs = 400
cfg.early_stopping_patience = 40
con_wo_model = train(config=cfg, train_loader=con_train_dl, val_loader=con_val_dl, test_samples=con_test_samples, ensemble=True)

In [None]:
cfg = toy_default()
# cfg.output_regularization = 0.05
cfg.log_loss_frequency = 100
cfg.concurvity_regularization = 0.5
cfg.num_ensemble = 5
cfg.num_epochs = 400
cfg.early_stopping_patience = 40
con_wo_model = train(config=cfg, train_loader=con_train_dl, val_loader=con_val_dl, test_samples=con_test_samples, ensemble=True)

In [None]:
cfg = toy_default()
# cfg.output_regularization = 0.05
cfg.log_loss_frequency = 100
cfg.concurvity_regularization = 0
cfg.num_ensemble = 5
cfg.num_epochs = 400
cfg.early_stopping_patience = 20
con_wo_model = train(config=cfg, train_loader=con_train_dl, val_loader=con_val_dl, test_samples=con_test_samples, ensemble=True)

### Ex1: multicollinearity 
given linear model 
$$
\begin{aligned}
Y = 1\cdot X_1+0\cdot X_2
\end{aligned}
$$
we generate feature $X_1$ and $X_2$ by sampling from a *uniform* distribution with two different settings: 
- independently sampled;
- fixed to identical samples (perfectly correlated).

Except for output penality, all the other regularization terms for the vanilla NAM are set as zeros. 

### Ex1.1: natural preference of NAM 
https://wandb.ai/xinyu-zhang/NAM_preference_multicolinearity?workspace=user-xinyu-zhang

### Ex1.0
#### build dataset

In [None]:
generate_funcs =[lambda x: x, lambda x: torch.zeros_like(x)]

uncorrelated_data = load_synthetic_data(generate_functions=generate_funcs, x_lims=(-1, 1), num_samples=1000, sigma=0, sampling_type='uniform') # uncorrelated features 
uc_train_dl, _, uc_val_dl, _ = uncorrelated_data.train_dataloaders()
uc_test_samples = uncorrelated_data.get_test_samples()

# generate perfectly correlated data
perfect_correlated_data = load_multicollinearity_data(generate_functions=generate_funcs, x_lims=(-1, 1), num_samples=1000, sigma=0, sampling_type='uniform') # perfectly correlated features 
pc_train_dl, _, pc_val_dl, _ = perfect_correlated_data.train_dataloaders()
pc_test_samples = perfect_correlated_data.get_test_samples()

#### NOTE: when measuring accuracy and concurvity trade-off...
DON'T use ensembling.

In [None]:
# for ensemble models
model = train(config=cfg, train_loader=pc_train_dl, val_loader=pc_val_dl, ensemble=True)

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(8, 8))
# testing
pred, fnn = get_ensemble_prediction(model, pc_test_samples[0], pc_test_samples[1])
f_mu, f_mu_fnn, f_var, f_var_fnn = pred.mean(dim=0), fnn.mean(dim=0), pred.var(dim=0), fnn.var(dim=0)
r = concurvity_loss(f_mu_fnn)
print(f'measured concurvity with ensembling: {r.item(): .4f}')
fig.supxlabel('X')
fig.supylabel('f(X)')

axs[0][0].scatter(pc_test_samples[0][:, 0], f_mu_fnn[:, 0])
axs[0][1].scatter(pc_test_samples[0][:, 1], f_mu_fnn[:, 1])

individual_r = list()
for idx in range(cfg.num_ensemble): 
    # individual 
    f_mu, f_mu_fnn = pred[idx, :], fnn[idx, : ]
    axs[1][0].scatter(pc_test_samples[0][:, 0], f_mu_fnn[:, 0])
    axs[1][1].scatter(pc_test_samples[0][:, 1], f_mu_fnn[:, 1])

    r = concurvity_loss(f_mu_fnn)
    print(f'measured concurvity for individual model_{idx}: {r.item(): .4f}')
    individual_r.append(r.item())

print(max(individual_r), min(individual_r))

In [None]:
uncorrelated_data.plot_dataset()
untransformed_uc_feature_correlation = pairwise_correlation(torch.concatenate([uncorrelated_data.features, uncorrelated_data.targets], dim=1))
untransformed_uc_feature_importance = feature_importance(uncorrelated_data.features)
print(f'[uncorrelated dataset]: untransformed feature importance: {untransformed_uc_feature_importance}')
print(f'[uncorrelated dataset]: corr(X1, X2): {untransformed_uc_feature_correlation[0][1]: .6f}')

In [None]:
perfect_correlated_data.plot_dataset()
untransformed_pc_feature_importance = feature_importance(perfect_correlated_data.features)
untransformed_pc_feature_correlation = pairwise_correlation(perfect_correlated_data.features)
print(f'[perfectly correlated dataset]: untransformed feature importance: {untransformed_pc_feature_importance}')
print(f'[perfectly correlated dataset]: corr(X1, X2): {untransformed_pc_feature_correlation[0][1]: .6f}')

In [None]:
cfg = toy_default() # configuration

#### uncorrelated data, without and with concurvity regularization
**Claim**: Page 13, 'concurvity regularizer R does not automaticalkly affect the predictive performance of a GAM'. 

**Experimental result**: 
1. the correlation of untransformed $X_1$ and $X_2$ ($corr(X_1, X_2)$): within $\pm 0.01$. 
2. <span style='color: red'>with concurvity regularization parameter $\lambda = 1$, Val. RMSE increase from $1e-5$ to $1e-3$. </span>

In [None]:
cfg.concurvity_regularization = 0 
uc_wo_model = train(config=cfg, train_loader=uc_train_dl, val_loader=uc_val_dl, test_samples=uc_test_samples, ensemble=False)

In [None]:
cfg.concurvity_regularization = 1
uc_w_model = train(config=cfg, train_loader=uc_train_dl, val_loader=uc_val_dl, test_samples=uc_test_samples, ensemble=False)

#### perfectly correlated data, with and without concurvity regularization
number of ensemble members: $40$. 

**Claim**: decrease $corr(f_1(X_1), f_2(X_2)$.

**Experimental result**:
- overall correlation $corr(f_1(X_1), f_2(X_2)$ decrease from $1$ to $1e-3$. 
- <span style='color:red'>the model fails to recover individual functions.</span> Piecewise correlations are generated whose additive impacts counteract. 
    - identical features, no natural bias?

In [None]:
cfg.num_ensemble = 40
cfg.concurvity_regularization = 0 
pc_wo_model = train(config=cfg, train_loader=pc_train_dl, val_loader=pc_val_dl, test_samples=pc_test_samples, ensemble=True)

In [None]:
cfg.concurvity_regularization = 1
pc_w_model = train(config=cfg, train_loader=pc_train_dl, val_loader=pc_val_dl, test_samples=pc_test_samples, ensemble=True)

## Tabular data 
### California housing

In [None]:
california_housing_data = load_sklearn_housing_data()
cal_train_dl, _, cal_val_dl, _ = california_housing_data.train_dataloaders()
cal_test_dl, _ = california_housing_data.test_dataloaders()
print(f'number of features: {california_housing_data.in_features}, dataset size: {len(california_housing_data.features)}')

In [None]:
chcfg = toy_default()
chcfg.decay_rate=3.73e-3  # necessary 
chcfg.activation_cls = 'relu'
chcfg.hidden_sizes=[72, 72, 72, 72, 72]
chcfg.num_epochs = 40
chcfg.batch_size = 512
chcfg.lr = 9.46e-3
chcfg.log_loss_frequency = 1
chcfg.concurvity_regularization = 0
print(chcfg)

In [None]:
cal_model = train(config=chcfg, train_loader=cal_train_dl, val_loader=cal_val_dl, ensemble=False)

In [None]:
test_rmse = torch.sqrt(test('regression', cal_model[0], cal_test_dl))
print(test_rmse)

In [None]:
cal_model[0]