In [1]:
import sys
sys.path.append('../')
sys.path.append('../skimfa')  # TODO: replace once code is a python package

import math
from skimfa.kernels import BlockPairwiseSKIMFABasisKernel
from feature_maps import LinearFeatureMap
from fit import *

In [2]:
# Set seed for reproducibility
torch.manual_seed(32312)

<torch._C.Generator at 0x12bed3cf0>

# Generate synthetic data:

- p = 500 covariates
- N = 500 training datapoints 
- First 5 covariates have main / interaction effects w/ response; remaining 495 covariates have no influence on the response
- Linear main and interaction effects
- The signal variance / total variance (i.e., R^2) equals .8

In [3]:
### Generate Covariates ###
p = 500
N_train = 500
N_test = 100
N_valid = 100

X_train = torch.normal(mean=0., std=1., size=(N_train, p))
X_test = torch.normal(mean=0., std=1., size=(N_test, p))
X_valid = torch.normal(mean=0., std=1., size=(N_valid, p))

In [4]:
### Generate Main and Interaction Effects ###
K = 5 # First 5 covariates influence response
main_effects = dict()
interaction_effects = dict()

# Generate main effects
for cov_ix in range(5):
    main_effects[cov_ix] = torch.normal(mean=1., std=1., size=(1, )).item()

# Generate 4 pairwise interaction effects between covariates  
interaction_pairs = [(0, 1), (1, 2), (2, 3), (3, 4)]
for cov_ix1, cov_ix2 in interaction_pairs:
    interaction_effects[(cov_ix1, cov_ix2)] = torch.normal(mean=1., std=1., size=(1, )).item()    

In [5]:
### Generate Response ###
def generate_noiseless_response(X, main_effects, interaction_effects):
    Y_signal = torch.zeros(X.shape[0])
    for cov_ix, effect in main_effects.items():
        Y_signal += effect * X[:, cov_ix]
    
    for cov_pair, effect in interaction_effects.items():
        cov_ix1, cov_ix2 = cov_pair
        Y_signal += effect * X[:, cov_ix1] * X[:, cov_ix2]
    
    return Y_signal

Y_train_noiseless = generate_noiseless_response(X_train, main_effects, interaction_effects)
Y_test_noiseless = generate_noiseless_response(X_test, main_effects, interaction_effects)
Y_valid_noiseless = generate_noiseless_response(X_valid, main_effects, interaction_effects)

# Add noise so that R^2 = .8
R2 = .8
approx_signal_var = Y_train_noiseless.var().item()
noise_var = (1 - R2) * approx_signal_var / R2

Y_train = Y_train_noiseless + math.sqrt(noise_var)*torch.normal(mean=0., std=1., size=(N_train, ))
Y_test = Y_test_noiseless + math.sqrt(noise_var)*torch.normal(mean=0., std=1., size=(N_test, ))
Y_valid = Y_valid_noiseless + math.sqrt(noise_var)*torch.normal(mean=0., std=1., size=(N_valid, ))

# Fit SKIM-FA Model
- Includes all main and pairwise interaction effects (linear)
- Performs variable selection
- Estimtates effects (ANOVA decomposition)

In [6]:
# Step 1: Make feature map 
# for linear interaction case, this just standardizes the covariates to be 0 mean and unit variances
# the means variances and variances are estimates from training data, and then stored for future uses
# e.g., to standardize new test data

covariate_dims = list(range(p))
covariate_types = ['continuous'] * p # irrelevant for now (in the future the selected feature map will depend on the covariate type)
linfeatmap = LinearFeatureMap(covariate_dims, covariate_types)
linfeatmap.make_feature_map(X_train) 

# Step 2: Make kernel configuration
kernel_config = dict()
kernel_config['uncorrected'] = True
kernel_config['rescale'] = 1.
kernel_config['feat_map'] = linfeatmap
kernel_config['cache'] = True
kernel_config['Q'] = 2 # include up to pairwise interaction effects

kernel_config['pair_indcs'] = torch.arange(p)[:(p-10)] # Have all main and pairwise interaction effects for first p - 10 covariates
kernel_config['main_indcs'] = torch.arange(p)[(p-10):] # Only have main effects for last 10 covariates


# Step 3: Make optimization configuration
optimization_config = dict()
optimization_config['T'] = 2000 # 2000 total gradient steps
optimization_config['M'] = 100 # size of cross-validation random sample
optimization_config['param_save_freq'] = 100 # save model weights every 100 iterations
optimization_config['valid_report_freq'] = 100 # how often to report MSE on validation set 
optimization_config['lr'] = .1
optimization_config['train_noise'] = False
optimization_config['noise_var_init'] = Y_train.var().detach().item()
optimization_config['truncScheduler'] = adaptive_cutoff_scheduler

In [7]:
# Fit SKIM-FA
train_valid_data = dict()
train_valid_data['X_train'] = X_train
train_valid_data['Y_train'] = Y_train
train_valid_data['X_valid'] = X_valid
train_valid_data['Y_valid'] = Y_valid

# VERY STRANGE error on my computer where I need to invert a matrix to not get a segmentation 11 fault error...
import numpy as np
X_weird = np.random.normal(size=(500, 100))
np.linalg.inv(X_weird.T.dot(X_weird))

array([[ 2.47917909e-03, -9.89167493e-05, -2.36180769e-04, ...,
         1.78580883e-04, -5.60477038e-05,  4.67199751e-05],
       [-9.89167493e-05,  2.50588907e-03, -1.69032628e-05, ...,
         2.23976045e-04,  2.01503322e-05, -8.96354784e-05],
       [-2.36180769e-04, -1.69032628e-05,  2.58347688e-03, ...,
        -1.56860674e-04, -2.09240085e-05, -1.21383942e-04],
       ...,
       [ 1.78580883e-04,  2.23976045e-04, -1.56860674e-04, ...,
         2.37904044e-03,  2.71590032e-05, -6.73756756e-05],
       [-5.60477038e-05,  2.01503322e-05, -2.09240085e-05, ...,
         2.71590032e-05,  2.20055923e-03, -1.58906512e-07],
       [ 4.67199751e-05, -8.96354784e-05, -1.21383942e-04, ...,
        -6.73756756e-05, -1.58906512e-07,  2.24975443e-03]])

In [8]:
skimfit = SKIMFA()
skimfit.fit(train_valid_data, BlockPairwiseSKIMFABasisKernel, kernel_config, optimization_config)

  0%|          | 5/2000 [00:00<01:39, 20.02it/s]

Mean-Squared Prediction Error on Validation (Iteration=0): 10.061
Number Covariates Selected=500


  5%|▌         | 104/2000 [00:04<01:14, 25.53it/s]

Mean-Squared Prediction Error on Validation (Iteration=100): 10.229
Number Covariates Selected=500


 10%|█         | 206/2000 [00:08<01:10, 25.31it/s]

Mean-Squared Prediction Error on Validation (Iteration=200): 10.196
Number Covariates Selected=500


 15%|█▌        | 305/2000 [00:12<01:09, 24.46it/s]

Mean-Squared Prediction Error on Validation (Iteration=300): 10.091
Number Covariates Selected=500


 20%|██        | 404/2000 [00:16<01:04, 24.58it/s]

Mean-Squared Prediction Error on Validation (Iteration=400): 9.311
Number Covariates Selected=500


 25%|██▌       | 503/2000 [00:20<01:02, 23.77it/s]

Mean-Squared Prediction Error on Validation (Iteration=500): 6.0
Number Covariates Selected=500


 30%|███       | 605/2000 [00:24<00:59, 23.43it/s]

Mean-Squared Prediction Error on Validation (Iteration=600): 4.093
Number Covariates Selected=138


 35%|███▌      | 704/2000 [00:28<00:55, 23.38it/s]

Mean-Squared Prediction Error on Validation (Iteration=700): 2.479
Number Covariates Selected=25


 40%|████      | 803/2000 [00:32<00:51, 23.22it/s]

Mean-Squared Prediction Error on Validation (Iteration=800): 2.337
Number Covariates Selected=10


 45%|████▌     | 905/2000 [00:37<00:47, 22.86it/s]

Mean-Squared Prediction Error on Validation (Iteration=900): 2.26
Number Covariates Selected=10


 50%|█████     | 1004/2000 [00:41<00:42, 23.17it/s]

Mean-Squared Prediction Error on Validation (Iteration=1000): 2.263
Number Covariates Selected=10


 55%|█████▌    | 1103/2000 [00:45<00:39, 22.76it/s]

Mean-Squared Prediction Error on Validation (Iteration=1100): 2.199
Number Covariates Selected=10


 60%|██████    | 1205/2000 [00:50<00:34, 22.76it/s]

Mean-Squared Prediction Error on Validation (Iteration=1200): 2.38
Number Covariates Selected=9


 65%|██████▌   | 1304/2000 [00:54<00:30, 22.57it/s]

Mean-Squared Prediction Error on Validation (Iteration=1300): 2.284
Number Covariates Selected=9


 70%|███████   | 1403/2000 [00:58<00:26, 22.66it/s]

Mean-Squared Prediction Error on Validation (Iteration=1400): 2.183
Number Covariates Selected=9


 75%|███████▌  | 1505/2000 [01:03<00:22, 21.67it/s]

Mean-Squared Prediction Error on Validation (Iteration=1500): 2.253
Number Covariates Selected=8


 80%|████████  | 1604/2000 [01:07<00:17, 22.40it/s]

Mean-Squared Prediction Error on Validation (Iteration=1600): 2.208
Number Covariates Selected=7


 85%|████████▌ | 1703/2000 [01:12<00:13, 22.26it/s]

Mean-Squared Prediction Error on Validation (Iteration=1700): 2.388
Number Covariates Selected=7


 90%|█████████ | 1805/2000 [01:16<00:09, 21.29it/s]

Mean-Squared Prediction Error on Validation (Iteration=1800): 2.141
Number Covariates Selected=7


 95%|█████████▌| 1904/2000 [01:21<00:04, 21.79it/s]

Mean-Squared Prediction Error on Validation (Iteration=1900): 2.153
Number Covariates Selected=7


100%|██████████| 2000/2000 [01:25<00:00, 23.36it/s]

Mean-Squared Prediction Error on Validation (Iteration=1999): 2.254
Number Covariates Selected=7





# See how well SKIM-FA does in terms of variable selection, estimation, and prediction

In [11]:
# Variable selection
selected_covs = set([cov_ix.item() for cov_ix in skimfit.get_selected_covariates()])
correct_covs = set(range(K))

print(f'Correct Selected: {selected_covs & correct_covs}')
print(f'Correct Not Selected: {correct_covs - selected_covs}')
print(f'Wrong Selected: {selected_covs - correct_covs}')

Correct Selected: {0, 1, 2, 3, 4}
Correct Not Selected: set()
Wrong Selected: {362, 171}


In [12]:
# Prediction
Y_test_pred = skimfit.predict(X_test)
print(f'Mean-Squared Prediction Error on Test: {round(torch.mean((Y_test - Y_test_pred) ** 2).item(), 2)}')
print(f'True Noise Variance: {round(noise_var, 2)}')

Mean-Squared Prediction Error on Test: 2.65
True Noise Variance: 2.43


In [15]:
# Estimation - look at first main effect as an example

# Estimate vs. Truth: Main Effects
for i in range(5):
    print(f'Main Effect {i} (Estimate, Truth): ({round(get_linear_iteraction_effect(skimfit, [i]), 2)}, {round(main_effects[i], 2)})')

print('\n')
    
# Estimate vs. Truth: Pairwise Effects
for cov_ix1, cov_ix2 in interaction_pairs:
    print(f'Interaction Effect {(cov_ix1, cov_ix2)} (Estimate, Truth): ({round(get_linear_iteraction_effect(skimfit, [cov_ix1, cov_ix2]), 2)}, {round(interaction_effects[(cov_ix1, cov_ix2)], 2)})')

Main Effect 0 (Estimate, Truth): (0.77, 0.68)
Main Effect 1 (Estimate, Truth): (1.98, 1.91)
Main Effect 2 (Estimate, Truth): (0.22, 0.09)
Main Effect 3 (Estimate, Truth): (-1.1, -1.09)
Main Effect 4 (Estimate, Truth): (0.56, 0.53)


Interaction Effect (0, 1) (Estimate, Truth): (1.45, 1.37)
Interaction Effect (1, 2) (Estimate, Truth): (0.71, 0.73)
Interaction Effect (2, 3) (Estimate, Truth): (1.05, 1.01)
Interaction Effect (3, 4) (Estimate, Truth): (1.13, 1.19)
