In [8]:
import numpy as np
import pandas as pd
from skillmodels.config import TEST_DIR
import yaml
from skillmodels.simulate_data import simulate_dataset
from skillmodels.likelihood_function import get_maximization_inputs

# How to simulate dataset



Below we show how to simulate dataset for a test model. 

## Getting inputs

For more details on this check out the introductory tutorial. 

In [3]:
with open(TEST_DIR / "model2.yaml") as y:
        model_dict = yaml.load(y, Loader=yaml.FullLoader)
        
data = pd.read_stata(TEST_DIR / "model2_simulated_data.dta")
data.set_index(["caseid", "period"], inplace=True)

params = pd.read_csv(TEST_DIR / "regression_vault" / f"one_stage_anchoring.csv")
params = params.set_index(["category", "period", "name1", "name2"])


##  Simulated data without policy

In [5]:
initial_data = simulate_dataset(
    model_dict=model_dict, 
    params=params,
    data=data,
)
initial_data["anchored_states"]["states"]

Unnamed: 0,fac1,fac2,fac3,period,id
0,-0.378972,-0.266733,-0.566692,0,0
0,-1.757949,1.557306,-0.566692,1,0
0,-0.485243,2.210941,-0.566692,2,0
0,-0.381358,1.450439,-0.566692,3,0
0,0.950597,3.544418,-0.566692,4,0
...,...,...,...,...,...
3999,-0.835985,-0.499080,0.291446,3,3999
3999,0.487642,-1.542402,0.291446,4,3999
3999,-2.614247,-3.657758,0.291446,5,3999
3999,-4.985279,-2.273178,0.291446,6,3999


## Why do I need data to simulate data?

The data you pass to simulate_data contains information on observed factors and control variables. Those are not part of the latent factor model and a standard model specification does not have enough information to generate them. 

If you have a model without control variables and observed factors, you can simply pass `n_obs` instead of `data`.

##  Simulated data with policy

In [5]:
policies = [
        {"period": 0, "factor": "fac1", "effect_size": 0.2, "standard_deviation": 0.0},
        {"period": 1, "factor": "fac2", "effect_size": 0.1, "standard_deviation": 0.0}]

In [7]:
data_after_policies = simulate_dataset(
    model_dict=model_dict, 
    params=params, 
    data=data,
)
data_after_policies["anchored_states"]["states"]

Unnamed: 0,fac1,fac2,fac3,period,id
0,-0.553380,0.127660,0.104046,0,0
0,0.616927,0.799906,0.104046,1,0
0,0.759767,-0.375969,0.104046,2,0
0,1.585738,-1.628875,0.104046,3,0
0,0.037964,-1.701462,0.104046,4,0
...,...,...,...,...,...
3999,-0.935151,-1.446857,-0.392653,3,3999
3999,0.189946,-0.029642,-0.392653,4,3999
3999,-0.280098,0.880144,-0.392653,5,3999
3999,0.943714,-1.078038,-0.392653,6,3999
