In [28]:
%load_ext autoreload
%autoreload 2

%matplotlib inline
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(style='ticks') 

import warnings
warnings.filterwarnings('ignore')

from lazyvi import LazyVI
from data_generating_funcs import *
from networks import *
from utils import *
import numpy as np

from sklearn.metrics import mean_squared_error as mse

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
# generate data
X, y = generate_linear_data(beta=[1.5, 1.2, 1, 0, 0, 0], corr=.75)
n, p = X.shape
dm = FlexDataModule(X,y) # feed into data module

# extract train/test from data module
dm.setup()
X_train = dm.train.dataset.tensors[0]
y_train = dm.train.dataset.tensors[1]

X_test = dm.test.tensors[0]
y_test = dm.test.tensors[1]


# initialize network
full_nn = NN4vi(p, [50], 1)

# train full network
early_stopping = EarlyStopping('val_loss', min_delta=1e-3)
trainer = pl.Trainer(callbacks=[early_stopping], max_epochs=100)
trainer.fit(full_nn, dm)

tensor([[ 1.0278],
        [-1.7377],
        [ 1.8756],
        ...,
        [ 3.6448],
        [ 0.4229],
        [ 3.7112]])

In [26]:
# extract full model MSE
test_loss_full = nn.MSELoss()(y_test, full_nn(X_test))
test_loss_full

tensor(0.0137, grad_fn=<MseLossBackward0>)

In [30]:
# dropout first variable
X0_train = dropout(X_train, 0)
X0_test = dropout(X_test,0)

# initialize lazy object with parameters
lv = LazyVI(full_nn)

# fit LazyVI on modified data
lv.fit(X0_train, y_train)

# calculate mse of reduced dataset using lazyvi
mse(y_test.detach().numpy(), lv.predict(X0_test).detach().numpy())

1.090149

In [31]:
# can use the LazyVI object to extract gradients, look at regularization path, etc
lv.extract_grad(X0_train)
grad = lv.grads


In [32]:
grad

array([[-0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         3.3728093e-01,  0.0000000e+00,  1.0000000e+00],
       [-0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         2.5300747e-01,  0.0000000e+00,  1.0000000e+00],
       [ 4.7275194e-04,  4.1070670e-01,  4.0224460e-01, ...,
         0.0000000e+00,  0.0000000e+00,  1.0000000e+00],
       ...,
       [-0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  1.0000000e+00],
       [ 4.7275194e-04,  3.5979858e-01, -8.4723070e-02, ...,
         0.0000000e+00,  0.0000000e+00,  1.0000000e+00],
       [-0.0000000e+00,  0.0000000e+00, -0.0000000e+00, ...,
         0.0000000e+00,  0.0000000e+00,  1.0000000e+00]], dtype=float32)