In [None]:
import numpy as np
import pandas as pd
import random
from math import isclose
import pdb

import torch
from torch import nn

## Fake data generator

In [None]:
from nn4tab.test_utils import fake_data

In [None]:
df, cont_names, cat_names = fake_data(preproc=False)

In [None]:
df.head()

Unnamed: 0,cont_0,cont_1,cont_2,cont_3,cont_4,cat_0,cat_1,targ
0,0.237776,4.726856,0.328851,5.035037,1.793174,C,C,1.0
1,-0.16513,4.75535,1.774184,1.690559,2.819697,A,A,1.0
2,-0.801118,5.582966,-5.896749,3.399724,3.353442,C,B,0.0
3,-2.345072,4.445452,-1.267434,0.292013,8.225743,A,A,0.0
4,-3.939306,-1.294417,-2.01709,4.966501,2.186796,C,B,0.0


In [None]:
def print_stat(df):
    for col in cont_names:
        print(f'{col}: mean={df[col].mean():.4f}, std ={df[col].std():.4f}')

## Data preprocessing

In [None]:
from nn4tab.data import Normalize, FillMissing, Categorify

In [None]:
from nn4tab.data import cont_cat_split, TabularDataset, get_dsets, get_dl

In [None]:
dep_var = ['targ']

In [None]:
cont, cat = cont_cat_split(df, dep_var)

In [None]:
assert cont==cont_names

In [None]:
assert cat==cat_names

In [None]:
procs = [Normalize, FillMissing, Categorify]
train_ds, valid_ds = get_dsets(df, cont, cat, dep_var, procs)

In [None]:
train_ds[0]

(array([1, 2]),
 array([-1.1851418 ,  0.19633709, -1.4352983 , -0.07934862,  0.8162202 ],
       dtype=float32),
 array([0.], dtype=float32))

In [None]:
train_ds.data[dep_var].mean(), valid_ds.data[dep_var].mean()

(targ    0.5125
 dtype: float32,
 targ    0.51
 dtype: float32)

In [None]:
dataloaders = get_dl(train_ds, bs=16), get_dl(valid_ds, bs=16)

## Model and training

In [None]:
from nn4tab.model import get_tabular_model

In [None]:
from nn4tab.learner import LearnerV0, accuracy_binary

In [None]:
tabnn = get_tabular_model(train_ds, 1, layers=[100, 50])

In [None]:
learn = LearnerV0(tabnn, dataloaders, torch.optim.Adam, nn.BCEWithLogitsLoss(), accuracy_binary)

In [None]:
learn.fit(1)

epoch 1: train loss 0.3273: 100%|██████████████████████████████████████████████████████| 50/50 [00:03<00:00, 14.09it/s]
epoch 1: valid loss 0.1400, accuracy 0.9531: 100%|█████████████████████████████████████| 12/12 [00:00<00:00, 21.40it/s]
