In [2]:
import datetime
import numpy as np
import pandas as pd
import joblib
import warnings
import logging
import os
import gc
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import collections
import re
import copy
import torch
import shap
import utils.utils as util

import utils_

from functools import reduce
from tqdm import tqdm
from dateutil.relativedelta import relativedelta
from joblib import Parallel, delayed
from scipy.stats import norm
from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
from sklearn.inspection import permutation_importance
from model.mmoe_condition_2 import MMOE
from torch.utils.data import DataLoader
from utils.dataset import DatasetLoader, DatasetLoader_w, DatasetLoader_www
from utils.warmup_lr import GradualWarmupScheduler
from loss.BCE_weighted_multi_task import BCEWL_weighted_multi_task


# pd.set_option('display.max_columns', None)
# pd.set_option('max_row', 500)
warnings.filterwarnings('ignore')
tqdm.pandas(desc='pandas bar')

In [None]:
torch.__version__

In [None]:
X_eval = utils_.load_pickle('../data/exp2suc/loss_pp_cb/df_X_pi_train_eval_transform_20230716_20231001.pickle')
y_eval = utils_.load_pickle('../data/exp2suc/loss_pp_cb/df_y_train_eval_20230716_20231001.pickle')

print(X_eval.shape)
print(y_eval.shape)

In [None]:
w_pre_link_eval = X_eval[['item_id_4', 'item_id_5']].max(axis=1)
w_pre_link_eval = w_pre_link_eval.apply(lambda x: 5 if x == 0 else 1)
print(w_pre_link_eval.shape)

w_gf_pp_eval = X_eval[['item_id_0', 'table_type_1']].sum(axis=1)
w_gf_pp_eval = w_gf_pp_eval.apply(lambda x: 1 if x == 2 else 0)
print(w_gf_pp_eval.shape)

w_after_link_eval= X_eval[['item_id_0', 'item_id_4', 'item_id_5', 'table_type_1']].\
    apply(lambda x: 0 if x['item_id_0']==1 and x['table_type_1']==1 
                      else 1 if x['item_id_4']==1 or x['item_id_5']==1 
                             else 5, 
          axis=1)
print(w_after_link_eval.shape)

test_loader = DataLoader(
    DatasetLoader_www(X_eval.values, y_eval.values, 
                      w_pre_link_eval.values, w_gf_pp_eval.values, w_after_link_eval.values), 
    1024, shuffle=False, num_workers=8)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
util.seed_everything(2023)

config = {
    'Model': {
        'num_experts': 7, 
        'expert_hidden_units': [512, 256, 128], 
        'units': 64, 
        'num_tasks': 5, 
        'tower_hidden_units': [64, 32, 16], 
        'dropout': 0.5, 
        'use_bn': False
    }
}

feats_columns = utils_.load_pickle('../data/exp2suc/loss_pp_cb/feats_columns.pickle')

model = MMOE(config, feats_columns).to(device)
model

In [None]:
for m in model.modules():
    if isinstance(m, (torch.nn.Conv2d, torch.nn.Linear)):
        torch.nn.init.xavier_uniform_(m.weight)
        # nn.init.kaiming_uniform_(m.weight)
    elif isinstance(m, torch.nn.BatchNorm1d):
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
criterion_train = (BCEWL_weighted_multi_task().to(device), BCEWL_weighted_multi_task().to(device), BCEWL_weighted_multi_task().to(device), BCEWL_weighted_multi_task().to(device), BCEWL_weighted_multi_task().to(device))

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30, 40], gamma=0.1)

In [None]:
model.train()

for i, (x, y, w1, w2, w3) in enumerate(test_loader):
    x, y = x.to(device).to(torch.float32), y.to(device).to(torch.float32)
    w1, w2, w3 = w1.to(device).to(torch.float32), w2.to(device).to(torch.float32), w3.to(device).to(torch.float32)
    optimizer.zero_grad()
    output = model(x)
    break

In [None]:
x.shape

In [None]:
y.shape

In [None]:
output.shape

In [None]:
import torch.nn.functional as F

In [None]:
loss_BCEWL = F.binary_cross_entropy_with_logits(output[:, 4], y[:, 4], reduction='none')
loss_BCEWL

In [None]:
w3

In [None]:
w3[:10]

In [None]:
w3_nozero_index = torch.nonzero(w3)
print(len(w3_nozero_index))
w3_nozero_index

In [None]:
print(len(w3[w3_nozero_index]))
w3_nozero = w3[w3_nozero_index]
w3_nozero

In [None]:
w3.shape

In [None]:
w3_nozero = w3_nozero.reshape(-1)
print(w3_nozero.shape)
w3_nozero

In [None]:
loss_BCEWL.shape

In [None]:
# w = w1
w = w3
w

In [None]:
w_nozero_index = torch.nonzero(w)
print(len(w_nozero_index))
w_nozero_index

In [None]:
loss_BCEWL_weighted = loss_BCEWL * w
print(loss_BCEWL_weighted.shape)
loss_BCEWL_weighted

In [None]:
loss = torch.mean(loss_BCEWL_weighted)
loss

In [None]:
loss_BCEWL_weighted = loss_BCEWL[w_nozero_index].reshape(-1) * w[w_nozero_index].reshape(-1)
print(loss_BCEWL_weighted.shape)
loss_BCEWL_weighted

In [None]:
loss = torch.mean(loss_BCEWL_weighted)
loss