In [None]:
#| default_exp forktable

In [None]:
#| export
from folktables import ACSDataSource, ACSIncome
from folktables.load_acs import state_list, _STATE_CODES, initialize_and_download
from sklearn.linear_model import LogisticRegression
from sklearn.utils import shuffle
from relax.import_essentials import  *
from fastcore.parallel import parallel
from functools import partial

In [None]:
#| export
def download_data(
    states: List[str] = ["CA"],
    years: List[int] = [2018],
    data_dir: str = "assets/data/acs",
):
    for year in years:
        for state in states:
            data_source = ACSDataSource(
                survey_year=year, horizon='1-Year', survey='person'
            )
            data = data_source.get_data(states=[state], download=True)
            feats, labels, _ = ACSIncome.df_to_pandas(data)
            labels = labels.astype(float)
            feats, labels = shuffle(feats, labels)
            data = pd.concat([feats, labels], axis=1)
            data.to_csv(f"{data_dir}/{year}_{state}.csv", index=False)
            del data


In [None]:
download_data(states=state_list)

Downloading data for 2018 1-Year person survey for AL...
Downloading data for 2018 1-Year person survey for AK...
Downloading data for 2018 1-Year person survey for AZ...
Downloading data for 2018 1-Year person survey for AR...
Downloading data for 2018 1-Year person survey for CA...
Downloading data for 2018 1-Year person survey for CO...
Downloading data for 2018 1-Year person survey for CT...
Downloading data for 2018 1-Year person survey for DE...
Downloading data for 2018 1-Year person survey for FL...
Downloading data for 2018 1-Year person survey for GA...
Downloading data for 2018 1-Year person survey for HI...
Downloading data for 2018 1-Year person survey for ID...
Downloading data for 2018 1-Year person survey for IL...
Downloading data for 2018 1-Year person survey for IN...
Downloading data for 2018 1-Year person survey for IA...
Downloading data for 2018 1-Year person survey for KS...
Downloading data for 2018 1-Year person survey for KY...
Downloading data for 2018 1-Yea

In [None]:
for i, state in enumerate(state_list[1:]):
    data_cols = pd.read_csv(f"assets/data/acs/2018_{state}.csv").columns
    _data_cols = pd.read_csv(f"assets/data/acs/2018_{state_list[i-1]}.csv").columns
    assert np.array_equal(data_cols, _data_cols)

In [None]:
pd.read_csv(f"assets/data/acs/2018_CA.csv")

Unnamed: 0,AGEP,COW,SCHL,MAR,OCCP,POBP,RELP,WKHP,SEX,RAC1P,PINCP
0,58.0,5.0,19.0,2.0,5240.0,6.0,0.0,40.0,2.0,2.0,0.0
1,35.0,1.0,21.0,1.0,4850.0,217.0,0.0,40.0,1.0,6.0,1.0
2,23.0,1.0,21.0,1.0,4700.0,6.0,1.0,45.0,2.0,1.0,1.0
3,29.0,1.0,21.0,5.0,4700.0,6.0,2.0,40.0,1.0,1.0,1.0
4,64.0,2.0,21.0,1.0,440.0,15.0,1.0,40.0,1.0,9.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...
195660,48.0,7.0,21.0,1.0,4621.0,54.0,0.0,25.0,1.0,1.0,1.0
195661,57.0,6.0,21.0,1.0,2016.0,18.0,1.0,36.0,2.0,1.0,0.0
195662,31.0,1.0,19.0,5.0,3930.0,6.0,2.0,40.0,1.0,8.0,0.0
195663,54.0,8.0,16.0,1.0,4540.0,39.0,0.0,12.0,1.0,1.0,0.0


In [None]:
acs_configs = {
    "data_config": {
        "data_name": "census",
        "continous_cols": [
            "AGEP", "OCCP", "POBP", "RELP", "WKHP"
        ],
        "discret_cols": [
            "COW", "SCHL", "MAR", "SEX", "RAC1P"
        ],
        
    },
    "m_config": {
        # model structure
        "enc_sizes": [100,50],
        "dec_sizes": [20],
        "exp_sizes": [20],
        "dropout_rate": 0.3,
        "sizes": [50, 10, 50],
        # training module
        'lr': 0.003,
        "lambda_1": 1.0,
        "lambda_3": 0.1,
        "lambda_2": 1.0,
        # adv training
        "epsilon": 0.1,
        "n_steps": 10,
        "k": 2,
        "adv_lr": 0.03
    },
    "t_configs": {
        'n_epochs': 50,
        "batch_size": 256,
        # 'n_epochs': 10,
        'monitor_metrics': 'val/val_loss'
    }, 
    'data_dir_list': {
        f"assets/data/acs/2018_{state}.csv" for state in state_list
    }
}
