In [5]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm as tqdm
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, TabNet, WideDeep
from pytorch_widedeep.metrics import Accuracy, F1Score
from pytorch_widedeep.datasets import load_adult
import warnings
from torchmetrics import AveragePrecision, AUROC
warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*<zmq.*>")

In [6]:
lukup = {'defaultCredit':'default.payment.next.month', 'bank':'y'}
name = 'defaultCredit'
label = lukup[name]
fold = 1
train_df = pd.read_csv('/home/vineeth/Documents/GitWorkSpace/PytorchRecipes/SimpleMLP/Dataset/{}/fold{}/train/data.csv'.format(name, fold))
valid_df = pd.read_csv('/home/vineeth/Documents/GitWorkSpace/PytorchRecipes/SimpleMLP/Dataset/{}/fold{}/valid/data.csv'.format(name, fold))
test_df = pd.read_csv('/home/vineeth/Documents/GitWorkSpace/PytorchRecipes/SimpleMLP/Dataset/{}/fold{}/test/data.csv'.format(name, fold))
train_df.head()

Unnamed: 0,LIMIT_BAL,SEX,EDUCATION,MARRIAGE,AGE,PAY_0,PAY_2,PAY_3,PAY_4,PAY_5,...,BILL_AMT4,BILL_AMT5,BILL_AMT6,PAY_AMT1,PAY_AMT2,PAY_AMT3,PAY_AMT4,PAY_AMT5,PAY_AMT6,default.payment.next.month
0,-0.134759,0,1,1,-1.029047,2,1,1,1,1,...,-0.428605,-0.369043,-0.239816,0.503316,-0.03998,-0.012818,0.194622,0.536758,-0.180878,0
1,1.483795,1,0,1,1.357652,2,1,1,1,1,...,0.579916,-0.525994,-0.512933,-0.070252,-0.126784,2.543032,0.228134,0.230763,0.436037,0
2,-0.674276,1,0,1,-0.378129,0,1,1,1,1,...,0.37445,0.561493,0.571863,-0.160815,-0.083165,-0.15481,0.330267,-0.314136,-0.012122,1
3,-0.75135,0,2,0,1.249166,2,1,1,1,1,...,0.425778,0.504203,0.545752,-0.14874,-0.109423,-0.132091,-0.132522,-0.107827,-0.141502,0
4,-0.365981,1,1,1,0.598248,2,1,1,1,1,...,1.126163,1.336657,1.378068,0.020312,-0.061681,-0.041216,0.196217,-0.117776,-0.067081,0


In [7]:
# Define the 'column set up'
wide_cols = [
    "SEX",
    "EDUCATION",
    "MARRIAGE",
    "PAY_0",
    "PAY_2",
    "PAY_3",
    "PAY_4",
    "PAY_5",
    "PAY_6"
]

cat_embed_cols = [
    "SEX",
    "EDUCATION",
    "MARRIAGE",
    "PAY_0",
    "PAY_2",
    "PAY_3",
    "PAY_4",
    "PAY_5",
    "PAY_6"
]
continuous_cols = ["LIMIT_BAL", "BILL_AMT1", "BILL_AMT1", "BILL_AMT2", \
     "BILL_AMT3", "BILL_AMT4", "BILL_AMT5", "BILL_AMT6", 'PAY_AMT1', 'PAY_AMT1',\
        'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6']
target = "default.payment.next.month"
target = train_df[target].values

In [8]:
# prepare the data
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols)
X_wide = wide_preprocessor.fit_transform(train_df)

tab_preprocessor = TabPreprocessor(
    cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols  # type: ignore[arg-type]
)
X_tab = tab_preprocessor.fit_transform(train_df)

In [11]:
# build the model
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)
tab_mlp = TabNet(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    continuous_cols=continuous_cols,
    dropout=0.5,
)
model = WideDeep(wide=wide, deeptabular=tab_mlp)



In [12]:
# train and validate
trainer = Trainer(model, objective="binary", accelerator="gpu",\
                  metrics=[AUROC(task='binary'), F1Score, AveragePrecision(task='binary')])
trainer.fit(
    X_wide=X_wide,
    X_tab=X_tab,
    target=target,
    n_epochs=100,
    batch_size=256,
)

epoch 1: 100%|██████████| 85/85 [00:02<00:00, 32.23it/s, loss=0.839, metrics={'BinaryAUROC': 0.48, 'f1': 0.1944, 'BinaryAveragePrecision': 0.2196}]  
epoch 2: 100%|██████████| 85/85 [00:02<00:00, 41.15it/s, loss=0.743, metrics={'BinaryAUROC': 0.5047, 'f1': 0.1838, 'BinaryAveragePrecision': 0.2328}]
epoch 3: 100%|██████████| 85/85 [00:02<00:00, 41.26it/s, loss=0.664, metrics={'BinaryAUROC': 0.5469, 'f1': 0.2109, 'BinaryAveragePrecision': 0.2681}]
epoch 4: 100%|██████████| 85/85 [00:02<00:00, 42.25it/s, loss=0.639, metrics={'BinaryAUROC': 0.562, 'f1': 0.2407, 'BinaryAveragePrecision': 0.2838}] 
epoch 5: 100%|██████████| 85/85 [00:02<00:00, 42.31it/s, loss=0.595, metrics={'BinaryAUROC': 0.5905, 'f1': 0.2642, 'BinaryAveragePrecision': 0.31}]  
epoch 6: 100%|██████████| 85/85 [00:02<00:00, 41.10it/s, loss=0.571, metrics={'BinaryAUROC': 0.6126, 'f1': 0.2706, 'BinaryAveragePrecision': 0.3287}]
epoch 7: 100%|██████████| 85/85 [00:02<00:00, 40.16it/s, loss=0.546, metrics={'BinaryAUROC': 0.6284,

In [13]:
# predict on test
X_wide_te = wide_preprocessor.transform(test_df)
X_tab_te = tab_preprocessor.transform(test_df)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)
pred_probs = trainer.predict_proba(X_wide=X_wide_te, X_tab=X_tab_te)

predict: 100%|██████████| 24/24 [00:00<00:00, 43.22it/s]
predict: 100%|██████████| 24/24 [00:00<00:00, 44.99it/s]


In [14]:
from sklearn.metrics import average_precision_score, roc_auc_score
target = lukup[name]
y = test_df[target].values
print("ROC-AUC:{}".format(roc_auc_score(y, pred_probs[:, 1])))
print("PrecisionRecall-AUC:{}".format(average_precision_score(y, pred_probs[:, 1])))

ROC-AUC:0.7512364801214417
PrecisionRecall-AUC:0.5330399295236273


In [11]:
from pytorch_widedeep import Tab2Vec
t2v = Tab2Vec(model=model, tab_preprocessor=tab_preprocessor)
X_vec, y = t2v.transform(train_df, target_col=target)

In [12]:
X_vec

array([[-1.4172393e-01, -2.9816014e-01,  2.1386049e+00, ...,
         5.1322991e-01,  8.6423665e-01, -3.0246025e-03],
       [-1.8721935e+00,  8.3828169e-01, -4.5036829e-01, ...,
         5.5625874e-01,  4.9738172e-01,  8.4056890e-01],
       [-1.8721935e+00,  8.3828169e-01, -4.5036829e-01, ...,
         6.8739414e-01, -1.5589465e-01,  2.2773863e-01],
       ...,
       [-1.4172393e-01, -2.9816014e-01, -4.5036829e-01, ...,
        -5.1062021e-02, -8.0248013e-02, -8.1868723e-02],
       [-1.8721935e+00,  8.3828169e-01,  2.1386049e+00, ...,
         3.1717189e-02,  1.0486166e-03, -3.0246025e-03],
       [-1.8721935e+00,  8.3828169e-01, -4.5036829e-01, ...,
        -6.6634342e-02, -7.7423006e-02, -9.5329911e-02]], dtype=float32)