In [5]:
import sys, os
root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(root_dir)
from src.data.synthetic_dataset import MarkovianHeteroDynamicDataset
import numpy as np
from omegaconf import DictConfig, OmegaConf
params = {
    'n_units':10,
    'n_periods':3,
    'sequence_length':6,
    'n_x':10,
    'n_treatments':2,
    's_x':40,
    's_t':1,
    'sigma_x':0.5,
    'sigma_t':0.5,
    'sigma_y':1.0,
    'gamma':0.2,
    'autoreg':0.25,
    'state_effect': 0.25,
    'hetero_strength':0.5,
    'hetero_inds':[8, 9],
    'conf_str': 5,
    'train_val_split':0.8,
    'seed':2024
}
dictconfig = OmegaConf.create(params)
dataset = MarkovianHeteroDynamicDataset(dictconfig)

In [6]:
Y_obs, T_obs, X_obs = dataset.generate_observational_data()
T_intv = np.ones((dataset.n_periods, params['n_treatments']))
T_base = np.zeros((dataset.n_periods, params['n_treatments']))

In [7]:
computed_effect = dataset.compute_treatment_effect(T_intv, T_base)

In [8]:
de = dataset.compute_individual_dynamic_effects(X_obs[:10,:, :])
def compute_TE_from_de(dynamic_effects, T_intv, T_base):
    #dynamic_effects (N, SL - m + 1, m, n_t)
    #T_intv (m, n_t)
    T_diff = (T_intv - T_base).reshape((1, 1, T_intv.shape[0], T_intv.shape[1]))
    return (dynamic_effects * T_diff).sum((-2, -1))
te = compute_TE_from_de(de, T_intv, T_base)
te

array([[1.19514041, 1.08064844, 1.54551722, 1.13485549],
       [1.57175071, 0.96373379, 1.28561419, 1.78461289],
       [1.53418676, 1.39184915, 0.98335063, 1.04426429],
       [1.31051117, 1.18474239, 0.78236685, 1.09016039],
       [1.52361413, 1.1400899 , 1.08616167, 2.02469077],
       [0.88906735, 1.32301276, 1.82041255, 1.48606231],
       [1.88490196, 1.04810296, 1.04737194, 1.26500802],
       [1.66046518, 1.36262972, 1.20551524, 1.6540534 ],
       [1.41547781, 1.3504944 , 1.03554458, 0.58633424],
       [1.44950104, 1.39179079, 2.0852993 , 1.13456156]])

In [9]:
computed_effect

array([[1.19514041, 1.08064844, 1.54551722, 1.13485549],
       [1.57175071, 0.96373379, 1.28561419, 1.78461289],
       [1.53418676, 1.39184915, 0.98335063, 1.04426429],
       [1.31051117, 1.18474239, 0.78236685, 1.09016039],
       [1.52361413, 1.1400899 , 1.08616167, 2.02469077],
       [0.88906735, 1.32301276, 1.82041255, 1.48606231],
       [1.88490196, 1.04810296, 1.04737194, 1.26500802],
       [1.66046518, 1.36262972, 1.20551524, 1.6540534 ],
       [1.41547781, 1.3504944 , 1.03554458, 0.58633424],
       [1.44950104, 1.39179079, 2.0852993 , 1.13456156]])

In [14]:
(T_intv - T_base).reshape((1, 1, params['n_periods'], params['n_treatments'])).shape

(1, 1, 3, 2)

In [10]:
de[:, 0, :, :].sum()

6.350037038353347