In [None]:
#default_exp tab_ae

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import fastai

In [None]:
#export
from fastai.tabular.all import *

In [None]:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')

In [None]:
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
y_names = 'salary'
y_block = CategoryBlock()
splits = RandomSplitter()(range_of(df))

In [None]:
to = TabularPandas(df, procs = [Categorify, FillMissing, Normalize], cat_names=cat_names, cont_names=cont_names, 
                   splits=splits, y_names=['salary'], y_block=CategoryBlock())

In [None]:
dls = to.dataloaders(bs=1024)

In [None]:
# train_dl = TabDataLoader(to.train, bs = 1280)
# valid_dl = TabDataLoader(to.valid, bs = 1280)
# dls = DataLoaders(train_dl, valid_dl)

In [None]:
learn = tabular_learner(dls, layers=[200,100], metrics=[accuracy])
#learn.fit(5, 1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,0.389453,0.420394,0.779484,00:01
1,0.368908,0.363753,0.836763,00:00
2,0.359075,0.35315,0.836149,00:00
3,0.352354,0.353868,0.838759,00:00
4,0.346758,0.358574,0.834306,00:00


# Custom Transforms and Dataloaders

In [None]:
# export
class ReadTabBatchIdentity(ItemTransform):
    "Read a batch of data and return the inputs as both `x` and `y`"
    def __init__(self, to): self.to = to

    def encodes(self, to):
        if not to.with_cont: res = (tensor(to.cats).long(),) + (tensor(to.cats).long(),)
        else: res = (tensor(to.cats).long(),tensor(to.conts).float()) + (tensor(to.cats).long(), tensor(to.conts).float())
        if to.device is not None: res = to_device(res, to.device)
        return res
    
class TabularPandasIdentity(TabularPandas): pass

In [None]:
# export
@delegates()
class TabDataLoaderIdentity(TabDataLoader):
    "A transformed `DataLoader` for AutoEncoder problems with Tabular data"
    do_item = noops
    def __init__(self, dataset, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):
        if after_batch is None: after_batch = L(TransformBlock().batch_tfms)+ReadTabBatchIdentity(dataset)
        super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)

    def create_batch(self, b): return self.dataset.iloc[b]

In [None]:
# export
TabularPandasIdentity._dl_type = TabDataLoaderIdentity

In [None]:
to = TabularPandasIdentity(df, [Categorify, FillMissing, Normalize], cat_names, cont_names, splits=RandomSplitter(seed=32)(df))
dls = to.dataloaders(bs=1024)

In [None]:
dls.n_inp = 2

In [None]:
import fastcore, fastai

In [None]:
total_cats = {k:len(v) for k,v in to.classes.items()}
total_cats

{'workclass': 10,
 'education': 17,
 'marital-status': 8,
 'occupation': 16,
 'relationship': 7,
 'race': 6,
 'education-num_na': 3}

In [None]:
sum([v for k,v in total_cats.items()])

67

In [None]:
to.means

{'age': 38.5793696495067,
 'fnlwgt': 190006.02011593536,
 'education-num': 10.079158508963875}

In [None]:
means = pd.DataFrame.from_dict({k:[v] for k,v in to.means.items()})
stds = pd.DataFrame.from_dict({k:[v] for k,v in to.stds.items()})

In [None]:
low = (df[cont_names].min().to_frame().T.values - means.values) / stds.values
high = (df[cont_names].max().to_frame().T.values - means.values) / stds.values

In [None]:
# export
class RecreatedLoss(Module):
    "Measures how well we have created the original tabular inputs"
    def __init__(self, cat_dict):
        ce = CrossEntropyLossFlat(reduction='sum')
        mse = MSELossFlat(reduction='sum')
        #store_attr('cat_dict,ce,mse')
        self.cat_dict = cat_dict
        self.ce = ce
        self.mse = mse

    def forward(self, preds, cat_targs, cont_targs):
        cats, conts = preds
        tot_ce, pos = cats.new([0]), 0
        for i, (k,v) in enumerate(self.cat_dict.items()):
            tot_ce += self.ce(cats[:, pos:pos+v], cat_targs[:,i])
            pos += v
        
        norm_cats = cats.new([len(self.cat_dict)])
        norm_conts = conts.new([conts.size(1)])
        cat_loss = tot_ce/norm_cats
        cont_loss = self.mse(conts, cont_targs)/norm_conts
        total = cat_loss+cont_loss

        return total / cats.size(0)

In [None]:
loss_func = RecreatedLoss(total_cats)

## The model

In [None]:
# export
class BatchSwapNoise(Module):
    "Swap Noise Module"
    def __init__(self, p): self.p = p

    def forward(self, x):
        if self.training:
            mask = torch.rand(x.size()) > (1 - self.p)
            l1 = torch.floor(torch.rand(x.size()) * x.size(0)).type(torch.LongTensor)
            l2 = (mask.type(torch.LongTensor) * x.size(1))
            res = (l1 * l2).view(-1)
            idx = torch.arange(x.nelement()) + res
            idx[idx>=x.nelement()] = idx[idx>=x.nelement()]-x.nelement()
            return x.flatten()[idx].view(x.size())
        else:
            return x

In [None]:
# export
class TabularAE(TabularModel):
    "A simple AutoEncoder model"
    def __init__(self, emb_szs, n_cont, hidden_size, cats, low, high, ps=0.2, embed_p=0.01, bswap=None):
        super().__init__(emb_szs, n_cont, layers=[1024, 512, 256], out_sz=hidden_size, embed_p=embed_p)
        
        self.bswap = bswap
        self.cats = cats
        self.activation_cats = sum([v for k,v in cats.items()])
        
        self.layers = nn.Sequential(*L(self.layers.children())[:-1] + nn.Sequential(LinBnDrop(256, hidden_size, p=ps, act=Mish())))
        
        if(bswap != None): self.noise = BatchSwapNoise(bswap)
        self.decoder = nn.Sequential(
            LinBnDrop(hidden_size, 256, p=ps, act=Mish()),
            LinBnDrop(256, 512, p=ps, act=Mish()),
            LinBnDrop(512, 1024, p=ps, act=Mish())
        )
        
        self.decoder_cont = nn.Sequential(
            LinBnDrop(1024, n_cont, p=ps, bn=False, act=None),
            SigmoidRange(low=low, high=high)
        )
        
        self.decoder_cat = LinBnDrop(1024, self.activation_cats, p=ps, bn=False, act=None)
        
    def forward(self, x_cat, x_cont=None, encode=False):
        if(self.bswap != None):
            x_cat = self.noise(x_cat)
            x_cont = self.noise(x_cont)
        encoded = super().forward(x_cat, x_cont)
        if encode: return encoded # return the representation
        decoded_trunk = self.decoder(encoded)
        decoded_cats = self.decoder_cat(decoded_trunk)
        decoded_conts = self.decoder_cont(decoded_trunk)
        return decoded_cats, decoded_conts

In [None]:

loss_func = RecreatedLoss(total_cats)
emb_szs = get_emb_sz(to.train)

model = TabularAE(emb_szs, len(cont_names), 128, ps=0.1, cats=total_cats, embed_p=0.01,
              bswap=.1, low=tensor(low), high=tensor(high))
learn = Learner(dls, model, loss_func=loss_func, wd=0.01, opt_func=ranger)

learn.fit_one_cycle(n_epoch = 5, lr_max = 1e-3, wd=0.1,  cbs=[EarlyStoppingCallback(min_delta=0.05, patience = 2)])

epoch,train_loss,valid_loss,time
0,10.602939,9.226448,00:07
1,5.434954,1.711648,00:06
2,3.295231,1.030882,00:07
3,2.273335,0.725697,00:07
4,1.731029,0.694927,00:07
