In [None]:
# default_exp data

In [None]:
# hide
%load_ext autoreload
%autoreload 2

In [None]:
# hide
from nbdev import *

In [None]:
# export
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import gc
import typing
from pathlib import Path
from math import isclose
from typing import Sequence, Union, Tuple

In [None]:
# export
from nn4tab.test_utils import fake_data, test_normalized, test_categorical, test_nans, test_df_processed

# Data

## Processors

In [None]:
# export
class TabularProc():
    _order = 1
    isset = False
    def setup(self): pass
    def checkup(self):
        pass
    def encode(self, x):
        raise NotImplementedError
    def decode(self, x): pass

In [None]:
# export
def _readargs(**kwargs):
    ds = kwargs.get('ds', None)
    if ds is not None:
        return vars(ds)
    df = kwargs.get('df', None)
    if df is None:
        raise RuntimeError("Either dataset or dataframe should be in arguments")
    cont_names = kwargs.get('cont_names', None)
    cat_names = kwargs.get('cat_names', None)
    return {'data':df,
            'cont_names':cont_names,
            'cat_names':cat_names}

### Normalize proc

In [None]:
def print_stat(df):
    print('*'*15)
    for col in cont_names:
        print(f'{col}: mean= {df[col].mean():.4f}, std = {df[col].std():.4f}')

In [None]:
# export
class Normalize(TabularProc):
    """
    Normalizes continuous features to zero mean and unit variance.
    """
    def setup(self, data:Union[Dataset, pd.DataFrame], cont_names:Sequence=[]):
        """Store mean and std for columns in cont_names"""
        self.checkup()
        data, cont_names = self._argcheck(data, cont_names)
        self.mean = {col: data[col].mean() for col in cont_names}
        self.std = {col: data[col].std() for col in cont_names}
        self.isset = True
    
    def _argcheck(self, data, cont_names):
        if isinstance(data, Dataset):
            if not cont_names: cont_names = data.cont_names
            data = data.data
        else:
            if not cont_names:
                raise Warning("Given no columns to process")
        return data, cont_names
        
    def encode_one(self, df:pd.DataFrame, col:str):
        return (df[col] - self.mean[col])/self.std[col]

    def encode(self, data:Union[Dataset, pd.DataFrame], cont_names:Sequence=[]):
        data, cont_names = self._argcheck(data, cont_names)
        if not self.isset: self.setup(data, cont_names)
        for col in cont_names:
            data[col] = self.encode_one(data, col)
    
    def decode_one(self, df:pd.DataFrame, col:str):
        return df[col]*self.std[col] + self.mean[col]
    
    def decode(self, data:Union[Dataset, pd.DataFrame], cont_names:Sequence=[]):
        data, cont_names = self._argcheck(data, cont_names)
        for col in cont_names:
            data[col] = self.decode_one(data, col)

#### Test 1

In [None]:
df, cont_names, cat_names = fake_data(preproc=False)
test_df = df.copy()
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   cont_0  1000 non-null   float32
 1   cont_1  1000 non-null   float32
 2   cont_2  1000 non-null   float32
 3   cont_3  1000 non-null   float32
 4   cont_4  1000 non-null   float32
 5   cat_0   1000 non-null   object 
 6   cat_1   1000 non-null   object 
 7   targ    1000 non-null   float32
dtypes: float32(6), object(2)
memory usage: 39.2+ KB


In [None]:
norm = Normalize()
norm.setup(test_df, cont_names)

print_stat(test_df)

norm.encode(test_df, cont_names)

print_stat(test_df)

test_normalized(test_df, cont_names)

***************
cont_0: mean= -1.6134, std = 2.5976
cont_1: mean= 4.8120, std = 2.7086
cont_2: mean= -1.8440, std = 2.4517
cont_3: mean= 2.8474, std = 1.4774
cont_4: mean= 3.6936, std = 2.9721
***************
cont_0: mean= 0.0000, std = 1.0000
cont_1: mean= 0.0000, std = 1.0000
cont_2: mean= -0.0000, std = 1.0000
cont_3: mean= 0.0000, std = 1.0000
cont_4: mean= -0.0000, std = 1.0000


In [None]:
norm.decode(test_df, cont_names)

print_stat(test_df)

for x in (test_df[cont_names] - df[cont_names]).abs().sum():
    assert x < 1e-4

***************
cont_0: mean= -1.6134, std = 2.5976
cont_1: mean= 4.8120, std = 2.7086
cont_2: mean= -1.8440, std = 2.4517
cont_3: mean= 2.8474, std = 1.4774
cont_4: mean= 3.6936, std = 2.9721


#### Test 2

In [None]:
df, cont_names, cat_names = fake_data()
test_df = df.copy()

In [None]:
norm = Normalize()
norm.setup(test_df, cont_names)

print_stat(test_df)

norm.encode(test_df, cont_names)

print_stat(test_df)

test_normalized(test_df, cont_names)

***************
cont_0: mean= 0.0436, std = 1.0146
cont_1: mean= -0.0250, std = 0.9849
cont_2: mean= -0.0127, std = 0.9841
cont_3: mean= -0.0235, std = 1.0335
cont_4: mean= -0.0028, std = 1.0229
***************
cont_0: mean= 0.0000, std = 1.0000
cont_1: mean= 0.0000, std = 1.0000
cont_2: mean= -0.0000, std = 1.0000
cont_3: mean= -0.0000, std = 1.0000
cont_4: mean= 0.0000, std = 1.0000


In [None]:
norm.decode(test_df, cont_names)

print_stat(test_df)

for x in (test_df[cont_names] - df[cont_names]).abs().sum():
    assert x < 1e-4

***************
cont_0: mean= 0.0436, std = 1.0146
cont_1: mean= -0.0250, std = 0.9849
cont_2: mean= -0.0127, std = 0.9841
cont_3: mean= -0.0235, std = 1.0335
cont_4: mean= -0.0028, std = 1.0229


### FillMissing proc

In [None]:
# export
class FillMissing(TabularProc):
    """Fills missing values in continuous columns"""
    def __init__(self, add_bool=True, method='mean'):
        self.add_bool = add_bool
        self.method = method
        
    def setup(self, data:Union[Dataset, pd.DataFrame], cont_names:Sequence=[], cat_names:Sequence=[]):
        self.checkup()
        data, cont_names, cat_names = self._argcheck(data, cont_names, cat_names)
        if self.method == 'mean':
            self.values = {col:data[col].mean() for col in cont_names}
        self.cont_names = cont_names
        self.cat_names = cat_names
        self.isset = True
        
    def _argcheck(self, data, cont_names, cat_names):
        if isinstance(data, Dataset):
            if not cont_names: cont_names = data.cont_names
            if not cat_names: cat_names = data.cat_names
            data = data.data
        else:
            if not cont_names:
                raise Warning("Given no columns to process")
        return data, cont_names,cat_names
    
    def encode(self, data:Union[Dataset, pd.DataFrame], cont_names:Sequence=[], cat_names:Sequence=[]): 
        data, cont_names, cat_names = self._argcheck(data, cont_names, cat_names)
        if not self.isset: self.setup(data, cont_names, cat_names)
        for col in cont_names:
            if data[col].notna().all():
                continue
            if self.add_bool:
                data[f'{col}_na'] = data[col].isna().astype(np.int8)
                # add name to dataset.cat_names
                cat_names.append(f'{col}_na')
            data[col].fillna(value=self.values[col], inplace=True)
            
    def decode(self, *args, **kwargs):
        pass

#### Tests

In [None]:
df, cont_names, cat_names = fake_data(nans=True)
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   cont_0  883 non-null    float32
 1   cont_1  899 non-null    float32
 2   cont_2  912 non-null    float32
 3   cont_3  902 non-null    float32
 4   cont_4  894 non-null    float32
 5   cat_0   891 non-null    float64
 6   cat_1   902 non-null    float64
 7   targ    1000 non-null   float32
dtypes: float32(6), float64(2)
memory usage: 39.2 KB


In [None]:
fillproc = FillMissing()
fillproc.setup(test_df, cont_names, cat_names)

In [None]:
fillproc.encode(df, cont_names, cat_names)

In [None]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 13 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   cont_0     1000 non-null   float32
 1   cont_1     1000 non-null   float32
 2   cont_2     1000 non-null   float32
 3   cont_3     1000 non-null   float32
 4   cont_4     1000 non-null   float32
 5   cat_0      891 non-null    float64
 6   cat_1      902 non-null    float64
 7   targ       1000 non-null   float32
 8   cont_0_na  1000 non-null   int8   
 9   cont_1_na  1000 non-null   int8   
 10  cont_2_na  1000 non-null   int8   
 11  cont_3_na  1000 non-null   int8   
 12  cont_4_na  1000 non-null   int8   
dtypes: float32(6), float64(2), int8(5)
memory usage: 44.1 KB


In [None]:
cat_names

['cat_0',
 'cat_1',
 'cont_0_na',
 'cont_1_na',
 'cont_2_na',
 'cont_3_na',
 'cont_4_na']

In [None]:
# currentlly NaN values in categorical columns are handled by Categorify proc
test_nans(df, cont_names, cat_names=[])

### Categorify proc

In [None]:
# export
def _catlist(s:pd.Series):
    c = set(s)
    c.discard('#na')
    return ['#na'] + list(c)

In [None]:
# export
class Categorify(TabularProc):
    """Numericalizes categorical columns."""
    def setup(self, data:Union[Dataset, pd.DataFrame], cat_names:Sequence=[]):
        self.checkup()
        data, cat_names = self._argcheck(data, cat_names)
        self.cat = {col: _catlist(data[col].dropna()) for col in cat_names}
        self.i2c = {c: i for i, c in enumerate(self.cat)}
        self.isset = True

    def _argcheck(self, data, cat_names):
        if isinstance(data, Dataset):
            if not cat_names: cat_names = data.cat_names
            data = data.data
        else:
            if not cat_names:
                raise Warning("Given no columns to process")
        return data, cat_names
    
    def encode_one(self, df:pd.DataFrame, col:str):
        return pd.Series(pd.Categorical(df[col].fillna('#na'), categories=self.cat[col])).cat.codes
    
    def encode(self, data:Union[Dataset, pd.DataFrame], cat_names:Sequence=[]):
        data, cat_names = self._argcheck(data, cat_names)
        if not self.isset: self.setup(data, cat_names)
        for col in cat_names:
            data[col] = self.encode_one(data, col)
    
    def decode_one(self, df:pd.DataFrame, col:str):
        return pd.Series(pd.Categorical.from_codes(df[col], categories=self.cat[col]))
    
    def decode(self, data:Union[Dataset, pd.DataFrame], cat_names:Sequence=[]):
        data, cat_names = self._argcheck(data, cat_names)
        for col in cat_names:
            data[col] = self.decode_one(data, col)

#### Test 1

In [None]:
df, cont_names, cat_names = fake_data(preproc=False, nans=True)
test_df = df.copy()
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   cont_0  883 non-null    float32
 1   cont_1  899 non-null    float32
 2   cont_2  912 non-null    float32
 3   cont_3  902 non-null    float32
 4   cont_4  894 non-null    float32
 5   cat_0   891 non-null    object 
 6   cat_1   901 non-null    object 
 7   targ    1000 non-null   float32
dtypes: float32(6), object(2)
memory usage: 39.2+ KB


In [None]:
cproc = Categorify()
cproc.setup(test_df, cat_names)

assert cproc.isset
print(cproc.cat)

{'cat_0': ['#na', 'A', 'B', 'C'], 'cat_1': ['#na', 'A', 'B', 'C']}


In [None]:
cproc.encode(test_df, cat_names)

test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   cont_0  883 non-null    float32
 1   cont_1  899 non-null    float32
 2   cont_2  912 non-null    float32
 3   cont_3  902 non-null    float32
 4   cont_4  894 non-null    float32
 5   cat_0   1000 non-null   int8   
 6   cat_1   1000 non-null   int8   
 7   targ    1000 non-null   float32
dtypes: float32(6), int8(2)
memory usage: 25.5 KB


In [None]:
for col in cat_names:
    assert sum(test_df.loc[df[col].isna(), col]) == 0, f'Error when handling nans in {col}'
    assert np.issubdtype(test_df[col].dtype, np.integer), f'{col} dtype is not int'

In [None]:
cproc.decode(test_df, cat_names)

In [None]:
for col in cat_names:
    assert (df.loc[df[col].notna(), col] == test_df.loc[df[col].notna(), col]).all()

#### Test 2

In [None]:
df, cont_names, cat_names = fake_data()
test_df = df.copy()
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   cont_0  1000 non-null   float32
 1   cont_1  1000 non-null   float32
 2   cont_2  1000 non-null   float32
 3   cont_3  1000 non-null   float32
 4   cont_4  1000 non-null   float32
 5   cat_0   1000 non-null   int64  
 6   cat_1   1000 non-null   int64  
 7   targ    1000 non-null   float32
dtypes: float32(6), int64(2)
memory usage: 39.2 KB


In [None]:
cproc = Categorify()
cproc.setup(test_df, cat_names)

assert cproc.isset
print(cproc.cat)

{'cat_0': ['#na', 0, 1, 2], 'cat_1': ['#na', 0, 1, 2]}


In [None]:
cproc.encode(test_df, cat_names)

test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   cont_0  1000 non-null   float32
 1   cont_1  1000 non-null   float32
 2   cont_2  1000 non-null   float32
 3   cont_3  1000 non-null   float32
 4   cont_4  1000 non-null   float32
 5   cat_0   1000 non-null   int8   
 6   cat_1   1000 non-null   int8   
 7   targ    1000 non-null   float32
dtypes: float32(6), int8(2)
memory usage: 25.5 KB


In [None]:
cproc.decode(test_df, cat_names)

In [None]:
for col in cat_names:
    assert (df.loc[df[col].notna(), col] == test_df.loc[df[col].notna(), col]).all()

### Processing pipeline

In [None]:
# export
class ProcPipeline:
    """
    Combines data processors into pipeline
    """
    def __init__(self, procs:Sequence[TabularProc]):
        self._procs = procs
        self.reset()
        
    def setup(self, data):
        # todo
        return
        if not self.isset:
            for proc in self.procs:
                proc.setup(data)
        self.isset = True
        
    def encode(self, data):
        for proc in self.procs:
            proc.encode(data)
            
    def decode(self, data):
        for proc in self.procs:
            proc.decode(data)
    
    def __getitem__(self, i):
        return self.procs[i]
    
    def reset(self):
        procs = [p() for p in self._procs]
        self.procs = sorted(procs, key=lambda p: p._order)
        self.isset = False

## Dataset and dataloader

In [None]:
# export
def cont_cat_split(df, dep_var=None, max_card=np.inf, ignore=[]):
    """
    Sugests a split of columns of the dataframe to continuous and categorical ommiting dep_var and 
    ignore. Split is done based on column datatype: float columns and int with cardinality > max_card 
    are treated as continuous, all other - categorical.
    """
    cont, cat = [], []
    for col in df.columns:
        if (col == dep_var) or (col in dep_var) or (col in ignore): continue
        if np.issubdtype(df[col].dtype, np.floating) or (len(df[col].unique()) > max_card and np.issubdtype(df[col].dtype, np.integer)):
            cont.append(col)
        else: #?? any condition np.issubdtype(df[col].dtype, np.integer) 
            cat.append(col)
    return cont, cat

In [None]:
# export
class TabularDataset(Dataset):
    """
    Dataset for continious data.
    Produces tuple containing numpy arrays:
        x_cat, x_cont, y
    """
    def __init__(self, df:pd.DataFrame, cont_names:Sequence, cat_names:Sequence, dep_var:Sequence,
                 procs=[], copy=True):
        self.data = df.copy() if copy else df
        self.cat_names = cat_names
        self.cont_names = cont_names
        self.dep_var = dep_var
        self.procs = procs if isinstance(procs, ProcPipeline) else ProcPipeline(procs)
        self.procs.encode(self)

    def __getitem__(self, idx):
        return (self.data[self.cat_names].iloc[idx].to_numpy(dtype=np.long), 
                self.data[self.cont_names].iloc[idx].to_numpy(dtype=np.float32), 
                self.data[self.dep_var].iloc[idx].to_numpy(dtype=np.float32))

    def __len__(self):
        return len(self.data)
    
    def _decode(self):
        self.procs.decode(self)

#### Test pipe

In [None]:
df, cont_names, cat_names = fake_data(preproc=False, nans=True)
dep_var = ['targ']
test_df = df.copy()
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 8 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   cont_0  883 non-null    float32
 1   cont_1  899 non-null    float32
 2   cont_2  912 non-null    float32
 3   cont_3  902 non-null    float32
 4   cont_4  894 non-null    float32
 5   cat_0   891 non-null    object 
 6   cat_1   901 non-null    object 
 7   targ    1000 non-null   float32
dtypes: float32(6), object(2)
memory usage: 39.2+ KB


In [None]:
procs = [Normalize, FillMissing, Categorify]
pipe = ProcPipeline(procs)

In [None]:
ds = TabularDataset(df, cont_names, cat_names, dep_var)
ds.data.head()

Unnamed: 0,cont_0,cont_1,cont_2,cont_3,cont_4,cat_0,cat_1,targ
0,0.237776,4.726856,0.328851,5.035037,,,C,1.0
1,-0.16513,4.75535,1.774184,1.690559,2.819697,A,A,1.0
2,-0.801118,5.582966,-5.896749,3.399724,3.353442,C,,0.0
3,-2.345072,4.445452,-1.267434,0.292013,8.225743,A,A,0.0
4,-3.939306,-1.294417,-2.01709,4.966501,2.186796,C,B,0.0


In [None]:
pipe.encode(ds)
ds.data.head()

Unnamed: 0,cont_0,cont_1,cont_2,cont_3,cont_4,cat_0,cat_1,targ,cont_0_na,cont_1_na,cont_2_na,cont_3_na,cont_4_na
0,0.7157,-0.030405,0.886719,1.466304,2.773549e-08,0,3,1.0,1,1,1,1,2
1,0.55919,-0.019687,1.466657,-0.767848,-0.2901395,1,1,1.0,1,1,1,1,1
2,0.312139,0.291596,-1.611295,0.373895,-0.1118203,3,0,0.0,1,1,1,1,1
3,-0.287613,-0.136246,0.246211,-1.702093,1.515969,1,1,0.0,1,1,1,1,1
4,-0.906897,-2.295129,-0.054587,1.420521,-0.5015854,3,2,0.0,1,1,1,1,1


In [None]:
test_df_processed(ds.data, ds.cont_names, ds.cat_names, dep_var=dep_var)

In [None]:
df, cont_names, cat_names = fake_data(preproc=False)
dep_var = ['targ']
procs = [Normalize, FillMissing, Categorify]
pipe = ProcPipeline(procs)
ds = TabularDataset(df, cont_names, cat_names, dep_var)

In [None]:
pipe.encode(ds)

In [None]:
test_df_processed(ds.data, ds.cont_names, ds.cat_names, dep_var=dep_var)

In [None]:
# pipe.decode(ds)

#### Test dataset

In [None]:
df, cont_names, cat_names = fake_data()
dep_var = ['targ']
procs = [Normalize, FillMissing, Categorify]
pipe = ProcPipeline(procs)
ds = TabularDataset(df, cont_names, cat_names, dep_var)

In [None]:
ds[0]

(array([2, 1]),
 array([-1.7382663 , -1.3366427 , -1.3611068 , -0.35161713, -2.3125815 ],
       dtype=float32),
 array([0.], dtype=float32))

In [None]:
test_df_processed(ds.data, ds.cont_names, ds.cat_names)

In [None]:
df, cont_names, cat_names = fake_data(preproc=False)
dep_var = ['targ']
procs = [Normalize, FillMissing, Categorify]
pipe = ProcPipeline(procs)
ds = TabularDataset(df, cont_names, cat_names, dep_var, procs=pipe)

In [None]:
test_df_processed(ds.data, ds.cont_names, ds.cat_names)

In [None]:
ds.data.head()

Unnamed: 0,cont_0,cont_1,cont_2,cont_3,cont_4,cat_0,cat_1,targ
0,0.712654,-0.031417,0.886253,1.480729,-0.639405,3,3,1.0
1,0.557549,-0.020897,1.475774,-0.782973,-0.29402,1,1,1.0
2,0.312715,0.284651,-1.653041,0.373872,-0.114436,3,2,0.0
3,-0.281655,-0.135309,0.235161,-1.729575,1.524899,1,1,0.0
4,-0.89538,-2.254419,-0.070608,1.434341,-0.506966,3,2,0.0


In [None]:
df, cont_names, cat_names = fake_data(preproc=False, nans=True)
dep_var = ['targ']
procs = [Normalize, FillMissing, Categorify]
pipe = ProcPipeline(procs)
ds = TabularDataset(df, cont_names, cat_names, dep_var, procs=pipe)

In [None]:
ds[0]

(array([0, 3, 1, 1, 1, 1, 2]),
 array([ 7.1570003e-01, -3.0404553e-02,  8.8671887e-01,  1.4663039e+00,
         2.7735494e-08], dtype=float32),
 array([1.], dtype=float32))

In [None]:
test_df_processed(ds.data, ds.cont_names, ds.cat_names)

In [None]:
ds.data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1000 entries, 0 to 999
Data columns (total 13 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   cont_0     1000 non-null   float32
 1   cont_1     1000 non-null   float32
 2   cont_2     1000 non-null   float32
 3   cont_3     1000 non-null   float32
 4   cont_4     1000 non-null   float32
 5   cat_0      1000 non-null   int8   
 6   cat_1      1000 non-null   int8   
 7   targ       1000 non-null   float32
 8   cont_0_na  1000 non-null   int8   
 9   cont_1_na  1000 non-null   int8   
 10  cont_2_na  1000 non-null   int8   
 11  cont_3_na  1000 non-null   int8   
 12  cont_4_na  1000 non-null   int8   
dtypes: float32(6), int8(7)
memory usage: 30.4 KB


#### Test adult ds

In [None]:
df = pd.read_csv(Path('./datasets/adult.csv'))

In [None]:
df['salary'] = (df['salary'].apply(lambda x: x=='>=50k')).astype(np.int8)

In [None]:
dep_var = ['salary']
cont_names, cat_names = cont_cat_split(df, dep_var, max_card=10)
cont_names, cat_names

(['age',
  'fnlwgt',
  'education-num',
  'capital-gain',
  'capital-loss',
  'hours-per-week'],
 ['workclass',
  'education',
  'marital-status',
  'occupation',
  'relationship',
  'race',
  'sex',
  'native-country'])

In [None]:
procs = [Normalize, FillMissing, Categorify]
ds = TabularDataset(df, cont_names, cat_names, dep_var, procs=procs)

In [None]:
test_df_processed(ds.data, ds.cont_names, ds.cat_names, ds.dep_var)

In [None]:
ds[0]

(array([7, 5, 4, 0, 2, 2, 1, 7, 1]),
 array([ 0.76378465, -0.8380709 ,  0.74628264, -0.14591825,  4.5034127 ,
        -0.0354289 ], dtype=float32),
 array([1.], dtype=float32))

In [None]:
# hide
class Datasets:
    
    def __init__(self, *dfs, dsclass:Dataset=TabularDataset, 
                 cat_names:Sequence, cont_names:Sequence, dep_var:Sequence, procs=None):
        pass

### Utils

In [None]:
# export
def get_dsets(df:pd.DataFrame, cont_names:Sequence, cat_names:Sequence, dep_var:Sequence,
              procs=[], splits=None, stratify=True):
    if splits:
        train_df, valid_df = df[splits[0]].copy(), df[splits[1]].copy()
    else:
        s = df[dep_var[0]] if stratify else None
        train_df, valid_df = train_test_split(df, test_size=0.2, stratify=s)
    train_df.reset_index(drop=True, inplace=True)
    valid_df.reset_index(drop=True, inplace=True)
    train_ds = TabularDataset(train_df, cont_names, cat_names, dep_var, procs=procs)
    valid_ds = TabularDataset(valid_df, cont_names, cat_names, dep_var, procs=train_ds.procs)
    return (train_ds, valid_ds)

In [None]:
procs = [Normalize, FillMissing, Categorify]
train_ds, valid_ds = get_dsets(df, cont_names, cat_names, dep_var, procs=procs)

In [None]:
test_df_processed(train_ds.data, train_ds.cont_names, train_ds.cat_names, train_ds.dep_var)

In [None]:
test_df_processed(valid_ds.data, valid_ds.cont_names, valid_ds.cat_names, valid_ds.dep_var)

In [None]:
train_ds.data[train_ds.cont_names].mean()

age              -1.690311e-16
fnlwgt            8.387196e-17
education-num     1.673339e-15
capital-gain      1.216058e-15
capital-loss     -6.312987e-17
hours-per-week   -1.571971e-16
dtype: float64

In [None]:
valid_ds.data[valid_ds.cont_names].mean()

age               0.018169
fnlwgt           -0.002425
education-num    -0.020961
capital-gain     -0.014558
capital-loss      0.000794
hours-per-week   -0.018492
dtype: float64

In [None]:
# export
def get_dl(ds, bs=512, train=True, drop_last=True):
    return DataLoader(ds, batch_size=bs, shuffle=train, drop_last=drop_last)

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 00a_test_utils.ipynb.
Converted 01_data.ipynb.
Converted 02_model.ipynb.
Converted 03_learner.ipynb.
Converted index.ipynb.
