In [1]:
%load_ext autoreload

import torch
import numpy as np

### This script examplarily how you can use the Kernel- and Neural-FGEL estimator directly on a given instrumental variable regression problem. 

## Generate some data

In [2]:
def generate_data(n_sample):
    e = np.random.normal(loc=0, scale=1.0, size=[n_sample, 1])
    gamma = np.random.normal(loc=0, scale=0.1, size=[n_sample, 1])
    delta = np.random.normal(loc=0, scale=0.1, size=[n_sample, 1])

    z = np.random.uniform(low=-3, high=3, size=[n_sample, 1])
    t = np.reshape(z[:, 0], [-1, 1]) + e + gamma
    y = np.abs(t) + e + delta
    return {'t': t, 'y': y, 'z': z}

train_data = generate_data(n_sample=100)
validation_data = generate_data(n_sample=100)

## Define a general non-linear IV model in Pytorch

In [3]:
model = torch.nn.Sequential(
            torch.nn.Linear(1, 20),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(20, 3),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(3, 1)
        )

## Train the model using Kernel/Neural-FGEL

In [4]:
from fgel.iv_estimator import fgel_iv_estimation

trained_model, other_stats = fgel_iv_estimation(model=model,           # Use any PyTorch model
                                                 train_data=train_data, 
                                                 version='kernel',     # 'kernel' or 'neural'
                                                 divergence=None,    # If 'None' optimize as hyperparam, otherise choose from ['chi2', 'kl', 'log']
                                                 reg_param=None,       # If 'None' optimize as hyperparam
                                                 validation_data=validation_data, 
                                                 val_loss_func=None,   # Hand over custom validation loss with signature (model, validation_data)
                                                 verbose=True)

Running:  divergence=chi2, reg_param=0.1
Running:  divergence=chi2, reg_param=0.01
Running:  divergence=chi2, reg_param=0.001
Running:  divergence=chi2, reg_param=0.0001
Running:  divergence=chi2, reg_param=1e-06
Running:  divergence=chi2, reg_param=1e-08
Best config:  {'divergence': 'chi2', 'reg_param': 0.1}
