In [1]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm as tqdm
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
from pytorch_widedeep.metrics import Accuracy, F1Score
from pytorch_widedeep.datasets import load_adult
import warnings
from torchmetrics import AveragePrecision, AUROC
warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*<zmq.*>")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
lukup = {'defaultCredit':'default.payment.next.month', 'bank':'y'}
name = 'defaultCredit'
label = lukup[name]
fold = 1
train_df = pd.read_csv('/home/vineeth/Documents/GitWorkSpace/PytorchRecipes/SimpleMLP/Dataset/{}/fold{}/train/data.csv'.format(name, fold))
valid_df = pd.read_csv('/home/vineeth/Documents/GitWorkSpace/PytorchRecipes/SimpleMLP/Dataset/{}/fold{}/valid/data.csv'.format(name, fold))
test_df = pd.read_csv('/home/vineeth/Documents/GitWorkSpace/PytorchRecipes/SimpleMLP/Dataset/{}/fold{}/test/data.csv'.format(name, fold))
train_df.head()

Unnamed: 0,LIMIT_BAL,SEX,EDUCATION,MARRIAGE,AGE,PAY_0,PAY_2,PAY_3,PAY_4,PAY_5,...,BILL_AMT4,BILL_AMT5,BILL_AMT6,PAY_AMT1,PAY_AMT2,PAY_AMT3,PAY_AMT4,PAY_AMT5,PAY_AMT6,default.payment.next.month
0,-0.134759,0,1,1,-1.029047,2,1,1,1,1,...,-0.428605,-0.369043,-0.239816,0.503316,-0.03998,-0.012818,0.194622,0.536758,-0.180878,0
1,1.483795,1,0,1,1.357652,2,1,1,1,1,...,0.579916,-0.525994,-0.512933,-0.070252,-0.126784,2.543032,0.228134,0.230763,0.436037,0
2,-0.674276,1,0,1,-0.378129,0,1,1,1,1,...,0.37445,0.561493,0.571863,-0.160815,-0.083165,-0.15481,0.330267,-0.314136,-0.012122,1
3,-0.75135,0,2,0,1.249166,2,1,1,1,1,...,0.425778,0.504203,0.545752,-0.14874,-0.109423,-0.132091,-0.132522,-0.107827,-0.141502,0
4,-0.365981,1,1,1,0.598248,2,1,1,1,1,...,1.126163,1.336657,1.378068,0.020312,-0.061681,-0.041216,0.196217,-0.117776,-0.067081,0


In [3]:
# Define the 'column set up'
wide_cols = [
    "SEX",
    "EDUCATION",
    "MARRIAGE",
    "PAY_0",
    "PAY_2",
    "PAY_3",
    "PAY_4",
    "PAY_5",
    "PAY_6"
]

cat_embed_cols = [
    "SEX",
    "EDUCATION",
    "MARRIAGE",
    "PAY_0",
    "PAY_2",
    "PAY_3",
    "PAY_4",
    "PAY_5",
    "PAY_6"
]
continuous_cols = ["LIMIT_BAL", "BILL_AMT1", "BILL_AMT1", "BILL_AMT2", \
     "BILL_AMT3", "BILL_AMT4", "BILL_AMT5", "BILL_AMT6", 'PAY_AMT1', 'PAY_AMT1',\
        'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6']
target = "default.payment.next.month"
target = train_df[target].values

In [4]:
# prepare the data
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols)
X_wide = wide_preprocessor.fit_transform(train_df)

tab_preprocessor = TabPreprocessor(
    cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols  # type: ignore[arg-type]
)
X_tab = tab_preprocessor.fit_transform(train_df)

In [5]:
# build the model
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)
tab_mlp = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    cat_embed_input=tab_preprocessor.cat_embed_input,
    continuous_cols=continuous_cols,
    mlp_hidden_dims=[400, 200],
    mlp_dropout=0.5,
    mlp_activation="leaky_relu"
)
model = WideDeep(wide=wide, deeptabular=tab_mlp)

In [6]:
# train and validate
trainer = Trainer(model, objective="binary", metrics=[AUROC(task='binary'), F1Score, AveragePrecision(task='binary')])
trainer.fit(
    X_wide=X_wide,
    X_tab=X_tab,
    target=target,
    n_epochs=100,
    batch_size=256,
)

epoch 1: 100%|███████████████████████████████████████████████████| 85/85 [00:01<00:00, 54.94it/s, loss=0.528, metrics={'BinaryAUROC': 0.6647, 'f1': 0.3431, 'BinaryAveragePrecision': 0.3797}]
epoch 2: 100%|█████████████████████████████████████████████████████| 85/85 [00:01<00:00, 76.10it/s, loss=0.47, metrics={'BinaryAUROC': 0.7231, 'f1': 0.3803, 'BinaryAveragePrecision': 0.468}]
epoch 3: 100%|███████████████████████████████████████████████████| 85/85 [00:01<00:00, 79.25it/s, loss=0.458, metrics={'BinaryAUROC': 0.7406, 'f1': 0.4051, 'BinaryAveragePrecision': 0.4924}]
epoch 4: 100%|███████████████████████████████████████████████████| 85/85 [00:01<00:00, 75.90it/s, loss=0.455, metrics={'BinaryAUROC': 0.7428, 'f1': 0.4227, 'BinaryAveragePrecision': 0.5023}]
epoch 5: 100%|████████████████████████████████████████████████████| 85/85 [00:01<00:00, 75.46it/s, loss=0.45, metrics={'BinaryAUROC': 0.7481, 'f1': 0.4256, 'BinaryAveragePrecision': 0.5125}]
epoch 6: 100%|███████████████████████████████

In [7]:
# predict on test
X_wide_te = wide_preprocessor.transform(test_df)
X_tab_te = tab_preprocessor.transform(test_df)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)

predict: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 39.26it/s]


In [8]:
pred_probs = trainer.predict_proba(X_wide=X_wide_te, X_tab=X_tab_te)

predict: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 40.85it/s]


In [9]:
from sklearn.metrics import average_precision_score, roc_auc_score
target = lukup[name]
y = test_df[target].values
print("ROC-AUC:{}".format(roc_auc_score(y, pred_probs[:, 1])))
print("PrecisionRecall-AUC:{}".format(average_precision_score(y, pred_probs[:, 1])))

ROC-AUC:0.7599003822481956
PrecisionRecall-AUC:0.5211694191896986


In [24]:
from pytorch_widedeep import Tab2Vec
t2v = Tab2Vec(model=model, tab_preprocessor=tab_preprocessor)
X_vec, y = t2v.transform(train_df, target_col=target)

In [27]:
X_vec

array([[-0.4756467 , -0.02291009,  0.56716114, ...,  0.21514589,
         0.54453516, -0.16239928],
       [ 0.17567964, -0.4121741 , -0.30569577, ...,  0.24836496,
         0.2460111 ,  0.47163504],
       [ 0.17567964, -0.4121741 , -0.30569577, ...,  0.34960404,
        -0.28558525,  0.01103947],
       ...,
       [-0.4756467 , -0.02291009, -0.30569577, ..., -0.22049856,
        -0.22402865, -0.22165753],
       [ 0.17567964, -0.4121741 ,  0.56716114, ..., -0.15659139,
        -0.15787444, -0.16239928],
       [ 0.17567964, -0.4121741 , -0.        , ..., -0.2325207 ,
        -0.22172983, -0.2317748 ]], dtype=float32)