# Library

In [1]:
#===========================================================
# Base code from https://www.kaggle.com/phoenix9032/pytorch-bert-plain
#===========================================================
import os
import sys
import gc
import time
import glob
import multiprocessing
import re
from urllib.parse import urlparse
from tqdm import tqdm
from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
from contextlib import contextmanager
from functools import partial

import numpy as np
import pandas as pd
import scipy as sp
from scipy.stats import spearmanr
import math
from math import floor, ceil
import random

#from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit, MultilabelStratifiedKFold
from sklearn.model_selection import GroupKFold
import category_encoders as ce
import re
from urllib.parse import urlparse

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils import data
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import transformers
from transformers import (
    BertTokenizer, BertModel, BertForSequenceClassification, BertConfig,
    WEIGHTS_NAME, CONFIG_NAME, AdamW, get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)
from transformers.modeling_bert import BertPreTrainedModel 


#===========================================================
# Utils
#===========================================================
def get_logger(filename='log'):
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

logger = get_logger()


@contextmanager
def timer(name):
    t0 = time.time()
    logger.info(f'[{name}] start')
    yield
    logger.info(f'[{name}] done in {time.time() - t0:.0f} s')


def seed_everything(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [2]:
import lightgbm as lgb
import pickle

# Models

In [3]:
#===========================================================
# Config
#===========================================================
class PipeLineConfig:
    def __init__(self, lr, warmup, accum_steps, epochs, seed, expname, 
                 head_tail, head, freeze, question_weight, answer_weight, fold, train, cv, test):
        self.lr = lr
        self.warmup = warmup
        self.accum_steps = accum_steps
        self.epochs = epochs
        self.seed = seed
        self.expname = expname
        self.head_tail = head_tail
        self.head = head
        self.freeze = freeze
        self.question_weight = question_weight
        self.answer_weight = answer_weight
        self.fold = fold
        self.train = train
        self.cv = cv
        self.test = test

config = PipeLineConfig(lr=1e-4, warmup=0.1, accum_steps=1, epochs=6,
                        seed=42, expname='uncased_6', head_tail=True, head=0.3, freeze=False,
                        question_weight=0., answer_weight=0., fold=5, train=False, cv=False, test=True)

DEBUG = False
ID = 'qa_id'
target_cols = ['question_asker_intent_understanding', 'question_body_critical', 
               'question_conversational', 'question_expect_short_answer', 
               'question_fact_seeking', 'question_has_commonly_accepted_answer', 
               'question_interestingness_others', 'question_interestingness_self', 
               'question_multi_intent', 'question_not_really_a_question', 
               'question_opinion_seeking', 'question_type_choice',
               'question_type_compare', 'question_type_consequence',
               'question_type_definition', 'question_type_entity', 
               'question_type_instructions', 'question_type_procedure', 
               'question_type_reason_explanation', 'question_type_spelling', 
               'question_well_written', 'answer_helpful',
               'answer_level_of_information', 'answer_plausible', 
               'answer_relevance', 'answer_satisfaction', 
               'answer_type_instructions', 'answer_type_procedure', 
               'answer_type_reason_explanation', 'answer_well_written']
NUM_FOLDS = config.fold
ROOT = '../input/google-quest-challenge/'
#ROOT = '../input/'
SEED = config.seed
seed_everything(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_DIR = '../input/googlequestchallenge-weights1/'
#MODEL_DIR = './'
COMBINE_INPUT = False
T_MAX_LEN = 30
Q_MAX_LEN = 479 # 382
A_MAX_LEN = 510 # 254 
MAX_SEQUENCE_LENGTH = T_MAX_LEN + Q_MAX_LEN + A_MAX_LEN + 4
q_max_sequence_length = T_MAX_LEN + Q_MAX_LEN + 3
a_max_sequence_length = A_MAX_LEN + 2

#===========================================================
# Model
#===========================================================
class OptimizedRounder5(object):
    def __init__(self):
        self.coef = [0.3333, 0.5, 0.6667, 1.]

    def _loss(self, X, y):
        X_p = np.digitize(X, self.coef)
        ll = spearmanr(y, X_p).correlation
        return -ll

    def fit(self, X: np.ndarray, y: np.ndarray):
        golden1 = 0.618
        golden2 = 1 - golden1
        ab_start = [(0., 0.3333), (0.3333, 0.5), (0.5, 0.6667), (0.6667, 1.)]
        for _ in range(100):
            search = iter(range(4))
            for idx in search:
                # golden section search
                a, b = ab_start[idx]
                # calc losses
                self.coef[idx] = a
                la = self._loss(X, y)
                self.coef[idx] = b
                lb = self._loss(X, y)
                for it in range(4):
                    # choose value
                    if la > lb:
                        a = b - (b - a) * golden1
                        self.coef[idx] = a
                        la = self._loss(X, y)
                    else:
                        b = b - (b - a) * golden2
                        self.coef[idx] = b
                        lb = self._loss(X, y)

    def predict(self, X, coef):
        X_p = np.digitize(X, coef)
        return X_p
    
    def coefficients(self):
        return self.coef


class OptimizedRounder9(object):
    def __init__(self):
        self.coef = [0.3333, 0.4444, 0.5, 0.5555, 0.6667, 0.7777, 0.8333, 0.8889]

    def _loss(self, X, y):
        X_p = np.digitize(X, self.coef)
        ll = spearmanr(y, X_p).correlation
        return -ll

    def fit(self, X: np.ndarray, y: np.ndarray):
        golden1 = 0.618
        golden2 = 1 - golden1
        ab_start = [(0, 0.3333), (0.3333, 0.4444), (0.4444, 0.5), (0.5, 0.5555), (0.5555, 0.6667), 
                    (0.6667, 0.7777), (0.7777, 0.8333), (0.8333, 1)]
        for _ in range(100):
            search = iter(range(8))
            for idx in search:
                # golden section search
                a, b = ab_start[idx]
                # calc losses
                self.coef[idx] = a
                la = self._loss(X, y)
                self.coef[idx] = b
                lb = self._loss(X, y)
                for it in range(8):
                    # choose value
                    if la > lb:
                        a = b - (b - a) * golden1
                        self.coef[idx] = a
                        la = self._loss(X, y)
                    else:
                        b = b - (b - a) * golden2
                        self.coef[idx] = b
                        lb = self._loss(X, y)

    def predict(self, X, coef):
        X_p = np.digitize(X, coef)
        return X_p
    
    def coefficients(self):
        return self.coef

    
class OptimizedRounder3(object):
    def __init__(self):
        self.coef = [0.01, 0.02]

    def _loss(self, X, y):
        X_p = np.digitize(X, self.coef)
        ll = spearmanr(y, X_p).correlation
        return -ll

    def fit(self, X: np.ndarray, y: np.ndarray):
        golden1 = 0.618
        golden2 = 1 - golden1
        ab_start = [(0., 0.1), (0.1, 0.3)]
        for _ in range(100):
            search = iter(range(2))
            for idx in search:
                # golden section search
                a, b = ab_start[idx]
                # calc losses
                self.coef[idx] = a
                la = self._loss(X, y)
                self.coef[idx] = b
                lb = self._loss(X, y)
                for it in range(2):
                    # choose value
                    if la > lb:
                        a = b - (b - a) * golden1
                        self.coef[idx] = a
                        la = self._loss(X, y)
                    else:
                        b = b - (b - a) * golden2
                        self.coef[idx] = b
                        lb = self._loss(X, y)

    def predict(self, X, coef):
        X_p = np.digitize(X, coef)
        return X_p
    
    def coefficients(self):
        return self.coef


def _get_masks(tokens, max_seq_length):
    """Mask for padding"""
    if len(tokens)>max_seq_length:
        print(f'len(tokens): {len(tokens)}')
        print(f'max_seq_length: {max_seq_length}')
        raise IndexError("Token length more than max seq length!")
    return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))


def _get_segments(tokens, max_seq_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    
    if len(tokens) > max_seq_length:
        raise IndexError("Token length more than max seq length!")
        
    segments = []
    first_sep = True
    current_segment_id = 0
    
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            if first_sep:
                first_sep = False 
            else:
                current_segment_id = 1
    return segments + [0] * (max_seq_length - len(tokens))


def _get_ids(tokens, tokenizer, max_seq_length):
    """Token ids from Tokenizer vocab"""
    
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
    return input_ids


def _trim_input(tokenizer, title, question, answer, max_sequence_length, t_max_len, q_max_len, a_max_len):
    
    # 350+128+30 = 508 +4 = 512
    
    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)
    a = tokenizer.tokenize(answer)
    
    t_len = len(t)
    q_len = len(q)
    a_len = len(a)

    if (t_len+q_len+a_len+4) > max_sequence_length:
        
        if t_max_len > t_len:
            t_new_len = t_len
            a_max_len = a_max_len + floor((t_max_len - t_len)/2)
            q_max_len = q_max_len + ceil((t_max_len - t_len)/2)
        else:
            t_new_len = t_max_len
      
        if a_max_len > a_len:
            a_new_len = a_len 
            q_new_len = q_max_len + (a_max_len - a_len)
        elif q_max_len > q_len:
            a_new_len = a_max_len + (q_max_len - q_len)
            q_new_len = q_len
        else:
            a_new_len = a_max_len
            q_new_len = q_max_len
            
            
        if t_new_len+a_new_len+q_new_len+4 != max_sequence_length:
            raise ValueError("New sequence length should be %d, but is %d"%(max_sequence_length, (t_new_len + a_new_len + q_new_len + 4)))
        # Head+Tail method 
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)  
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            a = a[:a_len_head]+a[a_len_tail:]
            #t = t[:t_len_head]+t[t_len_tail:]
            t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            a = a[:a_new_len]
            t = t[:t_new_len]
    
    return t, q, a


def q_trim_input(tokenizer, title, question, q_max_sequence_length, t_max_len, q_max_len):

    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)

    t_len = len(t)
    q_len = len(q)

    if (t_len+q_len+3) > q_max_sequence_length:

        if t_max_len > t_len:
            t_new_len = t_len
            q_max_len = q_max_len + (t_max_len - t_len)
        else:
            t_new_len = t_max_len

        if q_max_len > q_len:
            q_new_len = q_len
            t_new_len = t_max_len + (q_max_len - q_len)
        else:
            q_new_len = q_max_len

        # Head+Tail method
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            t = t[:t_len_head]+t[t_len_tail:]
            #t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            t = t[:t_new_len]

    return t, q


def a_trim_input(tokenizer, answer, a_max_sequence_length, a_max_len):

    a = tokenizer.tokenize(answer)

    a_len = len(a)

    if (a_len+2) > a_max_sequence_length:

        a_new_len = a_max_len

        # Head+Tail method
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        if config.head_tail :
            a = a[:a_len_head]+a[a_len_tail:]
        else:
            # No Head+Tail , usual processing
            a = a[:a_new_len]

    return a


def _convert_to_bert_inputs(title, question, answer, tokenizer, max_sequence_length):
    """Converts tokenized input to ids, masks and segments for BERT"""
    if COMBINE_INPUT:
        stoken = ["[CLS]"] + title + ["[QBODY]"] + question + ["[ANS]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title + ["[SEP]"] + question + ["[SEP]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title  + question  + answer + ["[SEP]"]
    
        input_ids = _get_ids(stoken, tokenizer, max_sequence_length)
        input_masks = _get_masks(stoken, max_sequence_length)
        input_segments = _get_segments(stoken, max_sequence_length)

        return [input_ids, input_masks, input_segments]
    else:
        q_token = ["[CLS]"] + title + ["[SEP]"] + question + ["[SEP"]
        q_input_ids = _get_ids(q_token, tokenizer, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_masks = _get_masks(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_segments = _get_segments(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        
        a_token = ["[CLS]"] + answer + ["[SEP]"]
        a_input_ids = _get_ids(a_token, tokenizer, A_MAX_LEN+2)
        a_input_masks = _get_masks(a_token, A_MAX_LEN+2)
        a_input_segments = _get_segments(a_token, A_MAX_LEN+2)

        return [q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments]


def compute_input_arays(df, columns, tokenizer, max_sequence_length, num_features, cat_features, 
                        t_max_len=T_MAX_LEN, q_max_len=Q_MAX_LEN, a_max_len=A_MAX_LEN):
    if COMBINE_INPUT:
        input_ids, input_masks, input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t, q, a = _trim_input(tokenizer, t, q, a, max_sequence_length, t_max_len, q_max_len, a_max_len)
            ids, masks, segments = _convert_to_bert_inputs(t, q, a, tokenizer, max_sequence_length)
            input_ids.append(ids)
            input_masks.append(masks)
            input_segments.append(segments)
        return [
                torch.from_numpy(np.asarray(input_ids, dtype=np.int32)).long(), 
                torch.from_numpy(np.asarray(input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]
    else:
        q_input_ids, q_input_masks, q_input_segments = [], [], []
        a_input_ids, a_input_masks, a_input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t, q = q_trim_input(tokenizer, t, q, q_max_sequence_length, t_max_len, q_max_len)
            a = a_trim_input(tokenizer, a, a_max_sequence_length, a_max_len)
            q_ids, q_masks, q_segments, a_ids, a_masks, a_segments = _convert_to_bert_inputs(t, q, a, tokenizer, max_sequence_length)
            q_input_ids.append(q_ids)
            q_input_masks.append(q_masks)
            q_input_segments.append(q_segments)
            a_input_ids.append(a_ids)
            a_input_masks.append(a_masks)
            a_input_segments.append(a_segments)
        return [
                torch.from_numpy(np.asarray(q_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]


def compute_output_arrays(df, columns):
    return np.asarray(df[columns])


if COMBINE_INPUT:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            input_ids       = self.inputs[0][idx]
            input_masks     = self.inputs[1][idx]
            input_segments  = self.inputs[2][idx]
            num_features    = self.inputs[3][idx]
            cat_features    = self.inputs[4][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return input_ids, input_masks, input_segments, num_features, cat_features, labels, lengths
            return input_ids, input_masks, input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.2)
            n_emb_out = sum([y for x, y in cat_dims])
            self.dropout = nn.Dropout(0.2)
            self.classifier_final = nn.Linear(config.hidden_size+n_emb_out+4, self.config.num_labels)  # num_features=4

            self.init_weights()

        def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            pooled_output = outputs[1]
            pooled_output = self.dropout(pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            pooled_output = torch.cat([pooled_output, num_features, emb], 1)
            logits = self.classifier_final(pooled_output)

            outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)

else:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            q_input_ids       = self.inputs[0][idx]
            q_input_masks     = self.inputs[1][idx]
            q_input_segments  = self.inputs[2][idx]
            a_input_ids       = self.inputs[3][idx]
            a_input_masks     = self.inputs[4][idx]
            a_input_segments  = self.inputs[5][idx]
            num_features    = self.inputs[6][idx]
            cat_features    = self.inputs[7][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, lengths
            return q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.2)
            n_emb_out = sum([y for x, y in cat_dims])
            self.q_dropout = nn.Dropout(0.2)
            self.a_dropout = nn.Dropout(0.2)
            self.classifier_final = nn.Linear(config.hidden_size*2+n_emb_out+4, self.config.num_labels)  # num_features=4

            self.init_weights()

        def forward(
            self,
            q_input_ids=None,
            q_attention_mask=None,
            q_token_type_ids=None,
            a_input_ids=None,
            a_attention_mask=None,
            a_token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            q_outputs = self.bert(
                q_input_ids,
                attention_mask=q_attention_mask,
                token_type_ids=q_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            q_pooled_output = q_outputs[1]
            q_pooled_output = self.q_dropout(q_pooled_output)

            a_outputs = self.bert(
                a_input_ids,
                attention_mask=a_attention_mask,
                token_type_ids=a_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            a_pooled_output = a_outputs[1]
            a_pooled_output = self.a_dropout(a_pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            pooled_output = torch.cat([q_pooled_output, a_pooled_output, num_features, emb], 1)
            logits = self.classifier_final(pooled_output)

            outputs = (logits,) + q_outputs[2:] + a_outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)


def train_model(model, train_loader, optimizer, criterion, scheduler, config):
    
    model.train()
    avg_loss = 0.
    avg_loss_1 = 0.
    avg_loss_2 = 0.
    avg_loss_3 = 0.
    avg_loss_4 = 0.
    avg_loss_5 = 0.
    #tk0 = tqdm(enumerate(train_loader),total =len(train_loader))
    optimizer.zero_grad()
    for idx, batch in enumerate(train_loader):
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
        
            output_train = model(input_ids = input_ids.long(),
                             labels = None,
                             attention_mask = input_masks,
                             token_type_ids = input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

            output_train = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        logits = output_train[0] #output preds
        """
        loss1 = criterion(logits[:,0:9], labels[:,0:9])
        loss2 = criterion(logits[:,9:10], labels[:,9:10])
        loss3 = criterion(logits[:,10:21], labels[:,10:21])
        loss4 = criterion(logits[:,21:26], labels[:,21:26])
        loss5 = criterion(logits[:,26:30], labels[:,26:30])
        loss = config.question_weight*loss1+config.answer_weight*loss2+config.question_weight*loss3+config.answer_weight*loss4+config.question_weight*loss5
            """
        loss = criterion(logits, labels)
        #loss =(config.question_weight*criterion(logits[:,0:21], labels[:,0:21]) + config.answer_weight*criterion(logits[:,21:30], labels[:,21:30]))/config.accum_steps
        loss.backward()
        if (idx + 1) % config.accum_steps == 0:    
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        avg_loss += loss.item() / (len(train_loader)*config.accum_steps)
        """
        avg_loss_1 += loss1.item() / (len(train_loader)*config.accum_steps)
        avg_loss_2 += loss2.item() / (len(train_loader)*config.accum_steps)
        avg_loss_3 += loss3.item() / (len(train_loader)*config.accum_steps)
        avg_loss_4 += loss4.item() / (len(train_loader)*config.accum_steps)
        avg_loss_5 += loss5.item() / (len(train_loader)*config.accum_steps)
        """
        if COMBINE_INPUT:
            del input_ids, input_masks, input_segments, labels
        else:
            del q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, labels

    torch.cuda.empty_cache()
    gc.collect()
    return avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5


def val_model(model, criterion, val_loader, val_shape, batch_size=8):

    avg_val_loss = 0.
    model.eval() # eval mode
    
    valid_preds = np.zeros((val_shape, len(target_cols)))
    original = np.zeros((val_shape, len(target_cols)))
    
    #tk0 = tqdm(enumerate(val_loader))
    with torch.no_grad():
        
        for idx, batch in enumerate(val_loader):
            if COMBINE_INPUT:
                input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
                input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            
                output_val = model(input_ids = input_ids.long(),
                               labels = None,
                               attention_mask = input_masks,
                               token_type_ids = input_segments,
                               num_features = num_features,
                               cat_features = cat_features,
                              )
            else:
                q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
                q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

                output_val = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
            logits = output_val[0] #output preds
            
            avg_val_loss += criterion(logits, labels).item() / len(val_loader)
            valid_preds[idx*batch_size : (idx+1)*batch_size] = logits.detach().cpu().squeeze().numpy()
            original[idx*batch_size : (idx+1)*batch_size]    = labels.detach().cpu().squeeze().numpy()
        
        score = 0
        preds = torch.sigmoid(torch.tensor(valid_preds)).numpy()
        
        # np.save("preds.npy", preds)
        # np.save("actuals.npy", original)
        
        rho_val = np.mean([spearmanr(original[:, i], preds[:,i]).correlation for i in range(preds.shape[1])])
        print('\r val_spearman-rho: %s' % (str(round(rho_val, 5))), end = 100*' '+'\n')
        
        for i in range(len(target_cols)):
            #print(i, spearmanr(original[:,i], preds[:,i]))
            score += np.nan_to_num(spearmanr(original[:, i], preds[:, i]).correlation)
        
    return avg_val_loss, score/len(target_cols)


def predict_valid_result(model, val_loader, val_length, batch_size=32):

    val_preds = np.zeros((val_length, len(target_cols)))
    original = np.zeros((val_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(val_loader))
    for idx, batch in tk0:
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            with torch.no_grad():
                outputs = model(input_ids = input_ids.long(),
                            labels = None,
                            attention_mask = input_masks,
                            token_type_ids = input_segments,
                            num_features = num_features,
                            cat_features = cat_features,
                            )
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)
            with torch.no_grad():
                outputs = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )

        predictions = outputs[0]
        val_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()
        original[idx*batch_size : (idx+1)*batch_size] = labels.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(val_preds)).numpy()
    return output, original


def predict_result(model, test_loader, test_length, batch_size=32):

    test_preds = np.zeros((test_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(test_loader))
    for idx, x_batch in tk0:
        if COMBINE_INPUT:
            with torch.no_grad():
                outputs = model(input_ids = x_batch[0].to(device),
                            labels = None,
                            attention_mask = x_batch[1].to(device),
                            token_type_ids = x_batch[2].to(device),
                            num_features = x_batch[3].to(device),
                            cat_features = x_batch[4].to(device),
                           )
        else:
            with torch.no_grad():
                outputs = model(q_input_ids = x_batch[0].to(device),
                            labels = None,
                            q_attention_mask = x_batch[1].to(device),
                            q_token_type_ids = x_batch[2].to(device),
                            a_input_ids = x_batch[3].to(device),
                            a_attention_mask = x_batch[4].to(device),
                            a_token_type_ids = x_batch[5].to(device),
                            num_features = x_batch[6].to(device),
                            cat_features = x_batch[7].to(device),
                           )
        predictions = outputs[0]
        test_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(test_preds)).numpy()
    return output


def add_features(df):
    find = re.compile(r"^[^.]*")
    df['netloc'] = df['url'].apply(lambda x: re.findall(find, urlparse(x).netloc)[0])
    df['qa_same_user_page_flag'] = (df['question_user_page']==df['answer_user_page'])*1
    df['question_title_num_words'] = df['question_title'].str.count('\S+')
    df['question_body_num_words'] = df['question_body'].str.count('\S+')
    df['answer_num_words'] = df['answer'].str.count('\S+')
    df['question_vs_answer_length'] = df['question_body_num_words']/df['answer_num_words']
    df['question_title_num_words'] = np.log1p(df['question_title_num_words'])
    df['question_body_num_words'] = np.log1p(df['question_body_num_words'])
    df['answer_num_words'] = np.log1p(df['answer_num_words'])
    df['question_vs_answer_length'] = np.log1p(df['question_vs_answer_length'])
    return df


def custom_loss(x, y):
    bce_loss = nn.BCEWithLogitsLoss()(x, y)
    return bce_loss


def get_bert_features(model, val_loader, val_length, batch_size=32):

    features = np.zeros((val_length, 768*2))

    model.eval()
    tk0 = tqdm(enumerate(val_loader))
    for idx, batch in tk0:
        if COMBINE_INPUT:
            None
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device)
            with torch.no_grad():
                q_output = model.bert(q_input_ids.long(),
                              attention_mask=q_input_masks,
                              token_type_ids=q_input_segments,
                              position_ids=None,
                              head_mask=None,
                              inputs_embeds=None,
                            )
                a_output = model.bert(a_input_ids.long(),
                              attention_mask=a_input_masks,
                              token_type_ids=a_input_segments,
                              position_ids=None,
                              head_mask=None,
                              inputs_embeds=None,
                            )
        q_feature = q_output[1].detach().cpu().squeeze().numpy()
        a_feature = a_output[1].detach().cpu().squeeze().numpy()
        features[idx*batch_size : (idx+1)*batch_size] = np.hstack([q_feature, a_feature])

    return features

#===========================================================
# main
#===========================================================
#def main():
if True:
    
    with timer('Data Loading'):
        train = pd.read_csv(f"{ROOT}train.csv").fillna("none")
        y_train = train[target_cols].values
        if config.test:
            test = pd.read_csv(f"{ROOT}test.csv").fillna("none")
            submission = pd.read_csv(f"{ROOT}sample_submission.csv")
    
    with timer('Num features'):
        train = add_features(train)
        if config.test:
            test = add_features(test)
        num_features = ['question_title_num_words', 'question_body_num_words', 'answer_num_words', 'question_vs_answer_length']
        train_num = train[num_features].values
        if config.test:
            test_num = test[num_features].values
                
    with timer('Cat features'):
        cat_features = ['netloc', 'category', 'qa_same_user_page_flag']
        ce_oe = ce.OrdinalEncoder(cols=cat_features, handle_unknown='return_nan')
        ce_oe.fit(train[cat_features])
        train_cat_df = ce_oe.transform(train[cat_features])
        test_cat_df = ce_oe.transform(test[cat_features]).fillna(0).astype(int)
        for c in cat_features:
            train[c] = train_cat_df[c]
            test[c] = test_cat_df[c]
        #cat_df = pd.concat([train_cat_df, test_cat_df])
        train_cat = train_cat_df.values
        test_cat = test_cat_df.values
        cat_dims = []
        for col in cat_features:
            #print(cat_df[col].unique())
            #dim = cat_df[col].nunique()
            dim = train[col].nunique()
            #cat_dims.append((dim, dim//2+1))
            cat_dims.append((dim+1, dim//2+1)) # for unknown=0
        print(cat_dims)

    if config.train:
        with timer('Create folds'):
            folds = train.copy()

            kf = MultilabelStratifiedKFold(n_splits=NUM_FOLDS, random_state=SEED)
            for fold, (train_index, val_index) in enumerate(kf.split(train.values, y_train)):
                folds.loc[val_index, 'fold'] = int(fold)
            """
            # less gap between CV vs LB with GroupKFold
            # https://www.kaggle.com/ratthachat/quest-cv-analysis-on-different-splitting-methods
            kf = GroupKFold(n_splits=NUM_FOLDS)
            for fold, (train_index, val_index) in enumerate(kf.split(X=train.question_body, groups=train.question_body)):
                folds.loc[val_index, 'fold'] = int(fold)
            """
            folds['fold'] = folds['fold'].astype(int)
            save_cols = [ID] + target_cols + ['fold']
            folds[save_cols].to_csv('folds.csv', index=None)

    with timer('Prepare Bert config'):
        tokenizer = BertTokenizer.from_pretrained("../input/pretrained-bert-models-for-pytorch/bert-base-uncased-vocab.txt", 
                                                  do_lower_case=True)
        input_categories = ['question_title', 'question_body', 'answer']
        bert_model_config = '../input/pretrained-bert-models-for-pytorch/bert-base-uncased/bert_config.json'
        bert_config = BertConfig.from_json_file(bert_model_config)
        bert_config.num_labels = len(target_cols)
        bert_model = 'bert-base-uncased'
        do_lower_case = 'uncased' in bert_model
        output_model_file = 'bert_pytorch.bin'
    
    if config.train:

        BATCH_SIZE = 8
        if DEBUG:
            epochs = 1
        else:
            epochs = config.epochs
        ACCUM_STEPS = config.accum_steps

        with timer('Train Bert'):
            
            for fold in range(NUM_FOLDS):

                logger.info(f"Current Fold: {fold}")
                train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index

                train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                logger.info(f"Train Shapes: {train_df.shape}")
                logger.info(f"Valid Shapes: {val_df.shape}")
            
                logger.info("Preparing train datasets....")
            
                inputs_train = compute_input_arays(train_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[train_index], cat_features=train_cat[train_index])
                outputs_train = compute_output_arrays(train_df, columns=target_cols)
                outputs_train = torch.tensor(outputs_train, dtype=torch.float32)
                lengths_train = np.argmax(inputs_train[0]==0, axis=1)
                lengths_train[lengths_train==0] = inputs_train[0].shape[1]
            
                logger.info("Preparing valid datasets....")
            
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
            
                logger.info("Preparing Dataloaders Datasets....")

                train_set = QuestDataset(inputs=inputs_train, lengths=lengths_train, labels=outputs_train)
                train_sampler = RandomSampler(train_set)
                train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,sampler=train_sampler)
            
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
            
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                torch.cuda.empty_cache()
                if config.freeze : ## This is basically using out of the box bert model while training only the classifier head with our data . 
                    for param in model.bert.parameters():
                        param.requires_grad = False
                model.train()
            
                i = 0
                best_avg_loss = 100.0
                best_score = -1.
                best_param_loss = None
                best_param_score = None
                param_optimizer = list(model.named_parameters())
                no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
                optimizer_grouped_parameters = [
                    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
                    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
                    ]        
                optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr, eps=4e-5)
                #optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, eps=4e-5)
                #criterion = nn.BCEWithLogitsLoss()
                criterion = custom_loss
                scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, num_training_steps=epochs*len(train_loader)//ACCUM_STEPS)
                logger.info("Training....")
            
                for epoch in tqdm(range(epochs)):

                    torch.cuda.empty_cache()
                
                    start_time   = time.time()
                    avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5 = train_model(model, train_loader, optimizer, criterion, scheduler, config)
                    avg_val_loss, score = val_model(model, criterion, valid_loader, val_shape=val_df.shape[0], batch_size=BATCH_SIZE)
                    elapsed_time = time.time() - start_time

                    logger.info('Epoch {}/{} \t loss={:.4f} \t val_loss={:.4f} \t train_loss={:.4f} \t train_loss_1={:.4f} \t train_loss_2={:.4f} \t train_loss_3={:.4f} \t train_loss_4={:.4f}  \t train_loss_5={:.4f} \t score={:.6f} \t time={:.2f}s'.format(
                        epoch+1, epochs, avg_loss, avg_val_loss, avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5, score, elapsed_time))

                    if best_avg_loss > avg_val_loss:
                        i = 0
                        best_avg_loss = avg_val_loss 
                        best_param_loss = model.state_dict()

                    if best_score < score:
                        best_score = score
                        best_param_score = model.state_dict()
                        logger.info('best_param_score_{}_{}.pt'.format(config.expname ,fold))
                        torch.save(best_param_score, 'best_param_score_{}_{}.pt'.format(config.expname, fold))
                    else:
                        i += 1

            del train_df, val_df, model, optimizer, criterion, scheduler
            del valid_loader, train_loader, valid_set, train_set
            torch.cuda.empty_cache()
            gc.collect()
    
    if config.cv:

        with timer('CV'):

            folds = pd.read_csv(f'{MODEL_DIR}folds.csv')
            results = np.zeros((len(train), len(target_cols)))
            logits = np.zeros((len(train), len(target_cols)))

            for fold in range(NUM_FOLDS):
                
                #train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index
                #train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                val_df = train.iloc[val_index]
                
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=32, shuffle=False, drop_last=False)
                
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result, logit = predict_valid_result(model, valid_loader, len(val_df))  
                results[val_index, :] = result
                logits[val_index, :] = logit 
            
            rho_val = np.mean([spearmanr(logits[:,i], results[:,i]).correlation for i in range(results.shape[1])])
            logger.info(f'CV spearman-rho: {round(rho_val, 5)}')
            
            oof = pd.DataFrame()
            for i, col in enumerate(target_cols):
                oof[col] = results[:,i]
            oof.to_csv('oof.csv', index=False)
            
    if config.cv:
        
        with timer('Get Bert features'):

            #folds = pd.read_csv(f"{ROOT}folds.csv")
            #features = np.zeros((len(train), 768*2))

            for fold in range(NUM_FOLDS):
                
                train_inputs = compute_input_arays(train, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num, cat_features=train_cat)
                lengths_train = np.argmax(train_inputs[0] == 0, axis=1)
                lengths_train[lengths_train == 0] = train_inputs[0].shape[1]
                train_set = QuestDataset(inputs=train_inputs, lengths=lengths_train, labels=None)
                train_loader  = DataLoader(train_set, batch_size=32, shuffle=False)
                features = np.zeros((len(train), 768*2))
                
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                features = get_bert_features(model, train_loader, len(train))
                pd.DataFrame(features).to_csv(f'train_bert_features_{fold}.csv', index=False)
                
    if config.test:

        with timer('Inference'):

            test_inputs = compute_input_arays(test, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                              num_features=test_num, cat_features=test_cat)
            lengths_test = np.argmax(test_inputs[0] == 0, axis=1)
            lengths_test[lengths_test == 0] = test_inputs[0].shape[1]
            test_set = QuestDataset(inputs=test_inputs, lengths=lengths_test, labels=None)
            test_loader  = DataLoader(test_set, batch_size=32, shuffle=False)
            result = np.zeros((len(test), len(target_cols)))

            for fold in range(NUM_FOLDS):
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result += predict_result(model, test_loader, len(test)) 
                if DEBUG:
                    break
                    
            result /= NUM_FOLDS

        with timer('Create submission.csv'):
            submission.loc[:, 'question_asker_intent_understanding':] = result
            submission.to_csv('submission1.csv', index=False)
                
    if config.test:
        
        with timer('Get Bert features'):
            
            test_inputs = compute_input_arays(test, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                              num_features=test_num, cat_features=test_cat)
            lengths_test = np.argmax(test_inputs[0] == 0, axis=1)
            lengths_test[lengths_test == 0] = test_inputs[0].shape[1]
            test_set = QuestDataset(inputs=test_inputs, lengths=lengths_test, labels=None)
            test_loader  = DataLoader(test_set, batch_size=32, shuffle=False)
            #features = np.zeros((len(test), 768*2))

            for fold in range(NUM_FOLDS):
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                features = get_bert_features(model, test_loader, len(test))
                pd.DataFrame(features).to_csv(f'test_bert_features_{fold}.csv', index=False)
        
        with timer('LGB'):
            
            LGB_MODEL_DICT = '../input/train-lgb-with-bert-features-xentropy/'

            folds = pd.read_csv(MODEL_DIR+'folds.csv')
            output = pd.read_csv('../input/google-quest-challenge/sample_submission.csv')

            for i in range(len(target_cols)):

                target_col = target_cols[i]
                #target = train[target_col]

                predictions = np.zeros(len(test))

                for fold in range(NUM_FOLDS):

                    train_cols = ['category', 'netloc', 'qa_same_user_page_flag', 'question_title_num_words',
                                  'question_body_num_words', 'answer_num_words', 'question_vs_answer_length']
                    bert_features = pd.read_csv(f'test_bert_features_{fold}.csv')
                    df = pd.concat([test[train_cols], bert_features], axis=1)
                    num_features = [c for c in df.columns if df.dtypes[c] != 'object']
                    cat_features = ['netloc', 'category', 'qa_same_user_page_flag']
                    features = num_features + cat_features
                    drop_features = ['qa_id']
                    features = [c for c in features if c not in drop_features]

                    with open(LGB_MODEL_DICT+f'{target_col}_lightgbm_fold{fold}.pkl', 'rb') as fin:
                        clf = pickle.load(fin)

                    predictions += clf.predict(df[features], num_iteration=clf.best_iteration) / NUM_FOLDS

                output[target_col] = predictions
            
            output.to_csv('submission_lgb.csv', index=False)

[Data Loading] start
[Data Loading] done in 0 s
[Num features] start
[Num features] done in 1 s
[Cat features] start
[Cat features] done in 0 s
[Prepare Bert config] start
[Prepare Bert config] done in 0 s
[Inference] start


[(60, 30), (6, 3), (3, 2)]


15it [00:17,  1.15s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.10s/it]
15it [00:16,  1.10s/it]
15it [00:16,  1.09s/it]
[Inference] done in 116 s
[Create submission.csv] start
[Create submission.csv] done in 0 s
[Get Bert features] start
15it [00:16,  1.10s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.10s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.10s/it]
[Get Bert features] done in 115 s
[LGB] start
[LGB] done in 170 s


In [4]:
#===========================================================
# Config
#===========================================================
class PipeLineConfig:
    def __init__(self, lr, warmup, accum_steps, epochs, seed, expname, 
                 head_tail, head, freeze, question_weight, answer_weight, fold, train, cv, test):
        self.lr = lr
        self.warmup = warmup
        self.accum_steps = accum_steps
        self.epochs = epochs
        self.seed = seed
        self.expname = expname
        self.head_tail = head_tail
        self.head = head
        self.freeze = freeze
        self.question_weight = question_weight
        self.answer_weight = answer_weight
        self.fold = fold
        self.train = train
        self.cv = cv
        self.test = test

config = PipeLineConfig(lr=1e-4, warmup=0.1, accum_steps=1, epochs=5,
                        seed=42, expname='uncased_7', head_tail=True, head=0.5, freeze=False,
                        question_weight=0., answer_weight=0., fold=5, train=False, cv=False, test=True)

DEBUG = False
ID = 'qa_id'
target_cols = ['question_asker_intent_understanding', 'question_body_critical', 
               'question_conversational', 'question_expect_short_answer', 
               'question_fact_seeking', 'question_has_commonly_accepted_answer', 
               'question_interestingness_others', 'question_interestingness_self', 
               'question_multi_intent', 'question_not_really_a_question', 
               'question_opinion_seeking', 'question_type_choice',
               'question_type_compare', 'question_type_consequence',
               'question_type_definition', 'question_type_entity', 
               'question_type_instructions', 'question_type_procedure', 
               'question_type_reason_explanation', 'question_type_spelling', 
               'question_well_written', 'answer_helpful',
               'answer_level_of_information', 'answer_plausible', 
               'answer_relevance', 'answer_satisfaction', 
               'answer_type_instructions', 'answer_type_procedure', 
               'answer_type_reason_explanation', 'answer_well_written']
NUM_FOLDS = config.fold
ROOT = '../input/google-quest-challenge/'
#ROOT = '../input/'
SEED = config.seed
seed_everything(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_DIR = '../input/googlequestchallenge-weights2/'
#MODEL_DIR = './'
COMBINE_INPUT = False
T_MAX_LEN = 30
Q_MAX_LEN = 479 # 382
A_MAX_LEN = 479 # 254 
MAX_SEQUENCE_LENGTH = T_MAX_LEN + Q_MAX_LEN + A_MAX_LEN + 4
q_max_sequence_length = T_MAX_LEN + Q_MAX_LEN + 3
a_max_sequence_length = T_MAX_LEN + A_MAX_LEN + 3

#===========================================================
# Model
#===========================================================
def _get_masks(tokens, max_seq_length):
    """Mask for padding"""
    if len(tokens)>max_seq_length:
        print(f'len(tokens): {len(tokens)}')
        print(f'max_seq_length: {max_seq_length}')
        raise IndexError("Token length more than max seq length!")
    return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))


def _get_segments(tokens, max_seq_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    
    if len(tokens) > max_seq_length:
        raise IndexError("Token length more than max seq length!")
        
    segments = []
    first_sep = True
    current_segment_id = 0
    
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            if first_sep:
                first_sep = False 
            else:
                current_segment_id = 1
    return segments + [0] * (max_seq_length - len(tokens))


def _get_ids(tokens, tokenizer, max_seq_length):
    """Token ids from Tokenizer vocab"""
    
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
    return input_ids


def _trim_input(tokenizer, title, question, answer, max_sequence_length, t_max_len, q_max_len, a_max_len):
    
    # 350+128+30 = 508 +4 = 512
    
    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)
    a = tokenizer.tokenize(answer)
    
    t_len = len(t)
    q_len = len(q)
    a_len = len(a)

    if (t_len+q_len+a_len+4) > max_sequence_length:
        
        if t_max_len > t_len:
            t_new_len = t_len
            a_max_len = a_max_len + floor((t_max_len - t_len)/2)
            q_max_len = q_max_len + ceil((t_max_len - t_len)/2)
        else:
            t_new_len = t_max_len
      
        if a_max_len > a_len:
            a_new_len = a_len 
            q_new_len = q_max_len + (a_max_len - a_len)
        elif q_max_len > q_len:
            a_new_len = a_max_len + (q_max_len - q_len)
            q_new_len = q_len
        else:
            a_new_len = a_max_len
            q_new_len = q_max_len
            
            
        if t_new_len+a_new_len+q_new_len+4 != max_sequence_length:
            raise ValueError("New sequence length should be %d, but is %d"%(max_sequence_length, (t_new_len + a_new_len + q_new_len + 4)))
        # Head+Tail method 
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)  
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            a = a[:a_len_head]+a[a_len_tail:]
            #t = t[:t_len_head]+t[t_len_tail:]
            t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            a = a[:a_new_len]
            t = t[:t_new_len]
    
    return t, q, a


def q_trim_input(tokenizer, title, question, q_max_sequence_length, t_max_len, q_max_len):

    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)

    t_len = len(t)
    q_len = len(q)

    if (t_len+q_len+3) > q_max_sequence_length:

        if t_max_len > t_len:
            t_new_len = t_len
            q_max_len = q_max_len + (t_max_len - t_len)
        else:
            t_new_len = t_max_len

        if q_max_len > q_len:
            q_new_len = q_len
            t_new_len = t_max_len + (q_max_len - q_len)
        else:
            q_new_len = q_max_len

        # Head+Tail method
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            t = t[:t_len_head]+t[t_len_tail:]
            #t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            t = t[:t_new_len]

    return t, q

"""
def a_trim_input(tokenizer, answer, a_max_sequence_length, a_max_len):

    a = tokenizer.tokenize(answer)

    a_len = len(a)

    if (a_len+2) > a_max_sequence_length:

        a_new_len = a_max_len

        # Head+Tail method
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        if config.head_tail :
            a = a[:a_len_head]+a[a_len_tail:]
        else:
            # No Head+Tail , usual processing
            a = a[:a_new_len]

    return a
"""

def a_trim_input(tokenizer, title, answer, a_max_sequence_length, t_max_len, a_max_len):

    t = tokenizer.tokenize(title)
    a = tokenizer.tokenize(answer)

    t_len = len(t)
    a_len = len(a)

    if (t_len+a_len+3) > a_max_sequence_length:

        if t_max_len > t_len:
            t_new_len = t_len
            a_max_len = a_max_len + (t_max_len - t_len)
        else:
            t_new_len = t_max_len

        if a_max_len > a_len:
            a_new_len = a_len
            t_new_len = t_max_len + (a_max_len - a_len)
        else:
            a_new_len = a_max_len

        # Head+Tail method
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)
        #t = t[:t_new_len]
        if config.head_tail :
            a = a[:a_len_head]+a[a_len_tail:]
            t = t[:t_len_head]+t[t_len_tail:]
            #t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            a = a[:a_new_len]
            t = t[:t_new_len]

    return t, a


def _convert_to_bert_inputs(title_q, title_a, question, answer, tokenizer, max_sequence_length):
    """Converts tokenized input to ids, masks and segments for BERT"""
    if COMBINE_INPUT:
        stoken = ["[CLS]"] + title + ["[QBODY]"] + question + ["[ANS]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title + ["[SEP]"] + question + ["[SEP]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title  + question  + answer + ["[SEP]"]
    
        input_ids = _get_ids(stoken, tokenizer, max_sequence_length)
        input_masks = _get_masks(stoken, max_sequence_length)
        input_segments = _get_segments(stoken, max_sequence_length)

        return [input_ids, input_masks, input_segments]
    else:
        q_token = ["[CLS]"] + title_q + ["[SEP]"] + question + ["[SEP"]
        q_input_ids = _get_ids(q_token, tokenizer, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_masks = _get_masks(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_segments = _get_segments(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        
        #a_token = ["[CLS]"] + answer + ["[SEP]"]
        #a_input_ids = _get_ids(a_token, tokenizer, A_MAX_LEN+2)
        #a_input_masks = _get_masks(a_token, A_MAX_LEN+2)
        #a_input_segments = _get_segments(a_token, A_MAX_LEN+2)
        a_token = ["[CLS]"] + title_a + ["[SEP]"] + answer + ["[SEP"]
        a_input_ids = _get_ids(a_token, tokenizer, T_MAX_LEN+A_MAX_LEN+3)
        a_input_masks = _get_masks(a_token, T_MAX_LEN+A_MAX_LEN+3)
        a_input_segments = _get_segments(a_token, T_MAX_LEN+A_MAX_LEN+3)
        
        return [q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments]


def compute_input_arays(df, columns, tokenizer, max_sequence_length, num_features, cat_features, 
                        t_max_len=T_MAX_LEN, q_max_len=Q_MAX_LEN, a_max_len=A_MAX_LEN):
    if COMBINE_INPUT:
        input_ids, input_masks, input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t, q, a = _trim_input(tokenizer, t, q, a, max_sequence_length, t_max_len, q_max_len, a_max_len)
            ids, masks, segments = _convert_to_bert_inputs(t, q, a, tokenizer, max_sequence_length)
            input_ids.append(ids)
            input_masks.append(masks)
            input_segments.append(segments)
        return [
                torch.from_numpy(np.asarray(input_ids, dtype=np.int32)).long(), 
                torch.from_numpy(np.asarray(input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]
    else:
        q_input_ids, q_input_masks, q_input_segments = [], [], []
        a_input_ids, a_input_masks, a_input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t_q, q = q_trim_input(tokenizer, t, q, q_max_sequence_length, t_max_len, q_max_len)
            #a = a_trim_input(tokenizer, a, a_max_sequence_length, a_max_len)
            t_a, a = a_trim_input(tokenizer, t, a, a_max_sequence_length, t_max_len, a_max_len)
            q_ids, q_masks, q_segments, a_ids, a_masks, a_segments = _convert_to_bert_inputs(t_q, t_a, q, a, tokenizer, max_sequence_length)
            q_input_ids.append(q_ids)
            q_input_masks.append(q_masks)
            q_input_segments.append(q_segments)
            a_input_ids.append(a_ids)
            a_input_masks.append(a_masks)
            a_input_segments.append(a_segments)
        return [
                torch.from_numpy(np.asarray(q_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]


def compute_output_arrays(df, columns):
    return np.asarray(df[columns])


if COMBINE_INPUT:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            input_ids       = self.inputs[0][idx]
            input_masks     = self.inputs[1][idx]
            input_segments  = self.inputs[2][idx]
            num_features    = self.inputs[3][idx]
            cat_features    = self.inputs[4][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return input_ids, input_masks, input_segments, num_features, cat_features, labels, lengths
            return input_ids, input_masks, input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.2)
            n_emb_out = sum([y for x, y in cat_dims])
            self.dropout = nn.Dropout(0.2)
            self.classifier_final = nn.Linear(config.hidden_size+n_emb_out+4, self.config.num_labels)  # num_features=4

            self.init_weights()

        def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            pooled_output = outputs[1]
            pooled_output = self.dropout(pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            pooled_output = torch.cat([pooled_output, num_features, emb], 1)
            logits = self.classifier_final(pooled_output)

            outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)

else:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            q_input_ids       = self.inputs[0][idx]
            q_input_masks     = self.inputs[1][idx]
            q_input_segments  = self.inputs[2][idx]
            a_input_ids       = self.inputs[3][idx]
            a_input_masks     = self.inputs[4][idx]
            a_input_segments  = self.inputs[5][idx]
            num_features    = self.inputs[6][idx]
            cat_features    = self.inputs[7][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, lengths
            return q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.1)
            n_emb_out = sum([y for x, y in cat_dims])
            self.num_drop = nn.Dropout(0.1)
            self.q_dropout = nn.Dropout(0.1)
            self.a_dropout = nn.Dropout(0.1)
            #self.dropout_all = nn.Dropout(0.2)
            #self.dropout_a = nn.Dropout(0.2)
            #self.dropout_q = nn.Dropout(0.2)
            #self.classifier_all = nn.Linear(config.hidden_size*2+n_emb_out+4, 64)  # num_features=4
            #self.classifier_all = nn.Sequential(
            #    nn.Linear(config.hidden_size*2+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            #self.classifier_a = nn.Linear(config.hidden_size+n_emb_out+4, 64)  # num_features=4
            #self.classifier_a = nn.Sequential(
            #    nn.Linear(config.hidden_size+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            #self.classifier_q = nn.Linear(config.hidden_size+n_emb_out+4, 64)  # num_features=4
            #self.classifier_q = nn.Sequential(
            #    nn.Linear(config.hidden_size+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            self.classifier_final = nn.Linear(config.hidden_size*2+n_emb_out+4, self.config.num_labels)
            #self.classifier_final = nn.Linear(64*3, self.config.num_labels)  # num_features=4
            #self.classifier_final = nn.Sequential(
            #    nn.BatchNorm1d(64*3),
            #    nn.Linear(64*3, self.config.num_labels),
            #)
            self.init_weights()

        def forward(
            self,
            q_input_ids=None,
            q_attention_mask=None,
            q_token_type_ids=None,
            a_input_ids=None,
            a_attention_mask=None,
            a_token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            q_outputs = self.bert(
                q_input_ids,
                attention_mask=q_attention_mask,
                token_type_ids=q_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            q_pooled_output = q_outputs[1]
            q_pooled_output = self.q_dropout(q_pooled_output)

            a_outputs = self.bert(
                a_input_ids,
                attention_mask=a_attention_mask,
                token_type_ids=a_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            a_pooled_output = a_outputs[1]
            a_pooled_output = self.a_dropout(a_pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            num_features = self.num_drop(num_features)

            pooled_output = torch.cat([q_pooled_output, a_pooled_output, num_features, emb], 1)
            #all_logits = self.classifier_all(pooled_output)
            #all_logits = self.dropout_all(all_logits)
            logits = self.classifier_final(pooled_output)
            
            #a_pooled_output = torch.cat([a_pooled_output, num_features, emb], 1)
            #a_logits = self.classifier_a(a_pooled_output)
            #a_logits = self.dropout_a(a_logits)

            #q_pooled_output = torch.cat([q_pooled_output, num_features, emb], 1)
            #q_logits = self.classifier_q(q_pooled_output)
            #q_logits = self.dropout_q(q_logits)

            #concat_logits = torch.cat([all_logits, q_logits, a_logits], 1)
            #logits = self.classifier_final(concat_logits)

            #logits = torch.cat([q_logits, a_logits], 1)

            outputs = (logits,) + q_outputs[2:] + a_outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)


def train_model(model, train_loader, optimizer, criterion, scheduler, config):
    
    model.train()
    avg_loss = 0.
    avg_loss_1 = 0.
    avg_loss_2 = 0.
    avg_loss_3 = 0.
    avg_loss_4 = 0.
    avg_loss_5 = 0.
    #tk0 = tqdm(enumerate(train_loader),total =len(train_loader))
    optimizer.zero_grad()
    for idx, batch in enumerate(train_loader):
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
        
            output_train = model(input_ids = input_ids.long(),
                             labels = None,
                             attention_mask = input_masks,
                             token_type_ids = input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

            output_train = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        logits = output_train[0] #output preds
        loss = criterion(logits, labels)
        loss.backward()
        if (idx + 1) % config.accum_steps == 0:    
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        avg_loss += loss.item() / (len(train_loader)*config.accum_steps)
        if COMBINE_INPUT:
            del input_ids, input_masks, input_segments, labels
        else:
            del q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, labels

    torch.cuda.empty_cache()
    gc.collect()
    return avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5


def val_model(model, criterion, val_loader, val_shape, batch_size=8):

    avg_val_loss = 0.
    model.eval() # eval mode
    
    valid_preds = np.zeros((val_shape, len(target_cols)))
    original = np.zeros((val_shape, len(target_cols)))
    
    #tk0 = tqdm(enumerate(val_loader))
    with torch.no_grad():
        
        for idx, batch in enumerate(val_loader):
            if COMBINE_INPUT:
                input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
                input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            
                output_val = model(input_ids = input_ids.long(),
                               labels = None,
                               attention_mask = input_masks,
                               token_type_ids = input_segments,
                               num_features = num_features,
                               cat_features = cat_features,
                              )
            else:
                q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
                q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

                output_val = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
            logits = output_val[0] #output preds
            
            avg_val_loss += criterion(logits, labels).item() / len(val_loader)
            valid_preds[idx*batch_size : (idx+1)*batch_size] = logits.detach().cpu().squeeze().numpy()
            original[idx*batch_size : (idx+1)*batch_size]    = labels.detach().cpu().squeeze().numpy()
        
        score = 0
        preds = torch.sigmoid(torch.tensor(valid_preds)).numpy()
        
        # np.save("preds.npy", preds)
        # np.save("actuals.npy", original)
        
        rho_val = np.mean([spearmanr(original[:, i], preds[:,i]).correlation for i in range(preds.shape[1])])
        print('\r val_spearman-rho: %s' % (str(round(rho_val, 5))), end = 100*' '+'\n')
        
        for i in range(len(target_cols)):
            logger.info(f"{i}, {spearmanr(original[:,i], preds[:,i])}")
            score += np.nan_to_num(spearmanr(original[:, i], preds[:, i]).correlation)
        
    return avg_val_loss, score/len(target_cols)


def predict_valid_result(model, val_loader, val_length, batch_size=32):

    val_preds = np.zeros((val_length, len(target_cols)))
    original = np.zeros((val_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(val_loader))
    for idx, batch in tk0:
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            with torch.no_grad():
                outputs = model(input_ids = input_ids.long(),
                            labels = None,
                            attention_mask = input_masks,
                            token_type_ids = input_segments,
                            num_features = num_features,
                            cat_features = cat_features,
                            )
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)
            with torch.no_grad():
                outputs = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )

        predictions = outputs[0]
        val_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()
        original[idx*batch_size : (idx+1)*batch_size] = labels.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(val_preds)).numpy()
    return output, original


def predict_result(model, test_loader, test_length, batch_size=32):

    test_preds = np.zeros((test_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(test_loader))
    for idx, x_batch in tk0:
        if COMBINE_INPUT:
            with torch.no_grad():
                outputs = model(input_ids = x_batch[0].to(device),
                            labels = None,
                            attention_mask = x_batch[1].to(device),
                            token_type_ids = x_batch[2].to(device),
                            num_features = x_batch[3].to(device),
                            cat_features = x_batch[4].to(device),
                           )
        else:
            with torch.no_grad():
                outputs = model(q_input_ids = x_batch[0].to(device),
                            labels = None,
                            q_attention_mask = x_batch[1].to(device),
                            q_token_type_ids = x_batch[2].to(device),
                            a_input_ids = x_batch[3].to(device),
                            a_attention_mask = x_batch[4].to(device),
                            a_token_type_ids = x_batch[5].to(device),
                            num_features = x_batch[6].to(device),
                            cat_features = x_batch[7].to(device),
                           )
        predictions = outputs[0]
        test_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(test_preds)).numpy()
    return output


def add_features(df):
    find = re.compile(r"^[^.]*")
    df['netloc'] = df['url'].apply(lambda x: re.findall(find, urlparse(x).netloc)[0])
    df['qa_same_user_page_flag'] = (df['question_user_page']==df['answer_user_page'])*1
    df['question_title_num_words'] = df['question_title'].str.count('\S+')
    df['question_body_num_words'] = df['question_body'].str.count('\S+')
    df['answer_num_words'] = df['answer'].str.count('\S+')
    df['question_vs_answer_length'] = df['question_body_num_words']/df['answer_num_words']
    df['question_title_num_words'] = np.log1p(df['question_title_num_words'])
    df['question_body_num_words'] = np.log1p(df['question_body_num_words'])
    df['answer_num_words'] = np.log1p(df['answer_num_words'])
    df['question_vs_answer_length'] = np.log1p(df['question_vs_answer_length'])
    return df


def custom_loss(logits, labels):
    #q_loss = nn.BCEWithLogitsLoss()(logits[:,:21], labels[:,:21])
    #a_loss = nn.BCEWithLogitsLoss()(logits[:,21:], labels[:,21:])
    #custom_loss = 0.5*q_loss + 0.5*a_loss
    custom_loss = nn.BCEWithLogitsLoss()(logits, labels)
    #loss1 = nn.BCEWithLogitsLoss()(logits[:,0:19], labels[:,0:19])
    #loss2 = nn.BCEWithLogitsLoss()(logits[:,20:], labels[:,20:]) # except index=19
    #custom_loss = loss1 + loss2
    #custom_loss = 0.
    #for i in range(len(loss_sample_weights)):
    #    custom_loss += loss_sample_weights[i] * nn.BCEWithLogitsLoss()(logits[:,i], labels[:,i])
    return custom_loss


#===========================================================
# main
#===========================================================
#def main():
if True:
    
    with timer('Data Loading'):
        train = pd.read_csv(f"{ROOT}train.csv").fillna("none")
        y_train = train[target_cols].values
        if config.test:
            test = pd.read_csv(f"{ROOT}test.csv").fillna("none")
            submission = pd.read_csv(f"{ROOT}sample_submission.csv")
    
    with timer('Num features'):
        train = add_features(train)
        if config.test:
            test = add_features(test)
        num_features = ['question_title_num_words', 'question_body_num_words', 'answer_num_words', 'question_vs_answer_length']
        train_num = train[num_features].values
        if config.test:
            test_num = test[num_features].values
                
    with timer('Cat features'):
        cat_features = ['netloc', 'category', 'qa_same_user_page_flag']
        ce_oe = ce.OrdinalEncoder(cols=cat_features, handle_unknown='return_nan')
        ce_oe.fit(train[cat_features])
        train_cat_df = ce_oe.transform(train[cat_features])
        test_cat_df = ce_oe.transform(test[cat_features]).fillna(0).astype(int)
        train_cat = train_cat_df.values
        test_cat = test_cat_df.values
        cat_dims = []
        for col in cat_features:
            dim = train[col].nunique()
            cat_dims.append((dim+1, dim//2+1)) # for unknown=0
        print(cat_dims)

    if config.train:
        with timer('Create folds'):
            folds = train.copy()

            kf = MultilabelStratifiedKFold(n_splits=NUM_FOLDS, random_state=SEED)
            for fold, (train_index, val_index) in enumerate(kf.split(train.values, y_train)):
                folds.loc[val_index, 'fold'] = int(fold)
            """
            # less gap between CV vs LB with GroupKFold
            # https://www.kaggle.com/ratthachat/quest-cv-analysis-on-different-splitting-methods
            kf = GroupKFold(n_splits=NUM_FOLDS)
            for fold, (train_index, val_index) in enumerate(kf.split(X=train.question_body, groups=train.question_body)):
                folds.loc[val_index, 'fold'] = int(fold)
            """
            folds['fold'] = folds['fold'].astype(int)
            save_cols = [ID] + target_cols + ['fold']
            folds[save_cols].to_csv('folds.csv', index=None)

    with timer('Prepare Bert config'):
        tokenizer = BertTokenizer.from_pretrained("../input/pretrained-bert-models-for-pytorch/bert-base-uncased-vocab.txt", 
                                                  do_lower_case=True)
        input_categories = ['question_title', 'question_body', 'answer']
        bert_model_config = '../input/pretrained-bert-models-for-pytorch/bert-base-uncased/bert_config.json'
        bert_config = BertConfig.from_json_file(bert_model_config)
        bert_config.num_labels = len(target_cols)
        bert_model = 'bert-base-uncased'
        do_lower_case = 'uncased' in bert_model
        output_model_file = 'bert_pytorch.bin'
    
    if config.train:

        BATCH_SIZE = 8
        if DEBUG:
            epochs = 1
        else:
            epochs = config.epochs
        ACCUM_STEPS = config.accum_steps

        with timer('Train Bert'):
            
            for fold in range(NUM_FOLDS):

                logger.info(f"Current Fold: {fold}")
                train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index

                train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                logger.info(f"Train Shapes: {train_df.shape}")
                logger.info(f"Valid Shapes: {val_df.shape}")
            
                logger.info("Preparing train datasets....")
            
                inputs_train = compute_input_arays(train_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[train_index], cat_features=train_cat[train_index])
                outputs_train = compute_output_arrays(train_df, columns=target_cols)
                outputs_train = torch.tensor(outputs_train, dtype=torch.float32)
                lengths_train = np.argmax(inputs_train[0]==0, axis=1)
                lengths_train[lengths_train==0] = inputs_train[0].shape[1]
            
                logger.info("Preparing valid datasets....")
            
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
            
                logger.info("Preparing Dataloaders Datasets....")

                train_set = QuestDataset(inputs=inputs_train, lengths=lengths_train, labels=outputs_train)
                train_sampler = RandomSampler(train_set)
                train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,sampler=train_sampler)
            
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
            
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                torch.cuda.empty_cache()
                if config.freeze : ## This is basically using out of the box bert model while training only the classifier head with our data . 
                    for param in model.bert.parameters():
                        param.requires_grad = False
                model.train()
            
                i = 0
                best_avg_loss = 100.0
                best_score = -1.
                best_param_loss = None
                best_param_score = None
                param_optimizer = list(model.named_parameters())
                no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
                optimizer_grouped_parameters = [
                    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
                    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
                    ]        
                optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr, eps=4e-5)
                #optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, eps=4e-5)
                #criterion = nn.BCEWithLogitsLoss()
                criterion = custom_loss
                scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, num_training_steps=epochs*len(train_loader)//ACCUM_STEPS)
                logger.info("Training....")
            
                for epoch in tqdm(range(epochs)):

                    torch.cuda.empty_cache()
                
                    start_time   = time.time()
                    avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5 = train_model(model, train_loader, optimizer, criterion, scheduler, config)
                    avg_val_loss, score = val_model(model, criterion, valid_loader, val_shape=val_df.shape[0], batch_size=BATCH_SIZE)
                    elapsed_time = time.time() - start_time

                    logger.info('Epoch {}/{} \t loss={:.4f} \t val_loss={:.4f} \t train_loss={:.4f} \t train_loss_1={:.4f} \t train_loss_2={:.4f} \t train_loss_3={:.4f} \t train_loss_4={:.4f}  \t train_loss_5={:.4f} \t score={:.6f} \t time={:.2f}s'.format(
                        epoch+1, epochs, avg_loss, avg_val_loss, avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5, score, elapsed_time))

                    if best_avg_loss > avg_val_loss:
                        i = 0
                        best_avg_loss = avg_val_loss 
                        best_param_loss = model.state_dict()

                    if best_score < score:
                        best_score = score
                        best_param_score = model.state_dict()
                        logger.info('best_param_score_{}_{}.pt'.format(config.expname ,fold))
                        torch.save(best_param_score, 'best_param_score_{}_{}.pt'.format(config.expname, fold))
                    else:
                        i += 1

            del train_df, val_df, model, optimizer, criterion, scheduler
            del valid_loader, train_loader, valid_set, train_set
            torch.cuda.empty_cache()
            gc.collect()
    
    if config.cv:

        with timer('CV'):

            folds = pd.read_csv(f'{MODEL_DIR}folds.csv')
            results = np.zeros((len(train), len(target_cols)))
            logits = np.zeros((len(train), len(target_cols)))

            for fold in range(NUM_FOLDS):
                
                #train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index
                #train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                val_df = train.iloc[val_index]
                
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=32, shuffle=False, drop_last=False)
                
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result, logit = predict_valid_result(model, valid_loader, len(val_df))  
                results[val_index, :] = result
                logits[val_index, :] = logit 
            
            rho_val = np.mean([spearmanr(logits[:,i], results[:,i]).correlation for i in range(results.shape[1])])
            logger.info(f'CV spearman-rho: {round(rho_val, 5)}')

            oof = pd.DataFrame()
            for i, col in enumerate(target_cols):
                oof[col] = results[:,i]
            oof.to_csv(f'oof_{config.expname}.csv', index=False)
    
    if config.test:

        with timer('Inference'):

            test_inputs = compute_input_arays(test, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                              num_features=test_num, cat_features=test_cat)
            lengths_test = np.argmax(test_inputs[0] == 0, axis=1)
            lengths_test[lengths_test == 0] = test_inputs[0].shape[1]
            test_set = QuestDataset(inputs=test_inputs, lengths=lengths_test, labels=None)
            test_loader  = DataLoader(test_set, batch_size=32, shuffle=False)
            result = np.zeros((len(test), len(target_cols)))

            for fold in range(NUM_FOLDS):
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result += predict_result(model, test_loader, len(test)) 
                if DEBUG:
                    break
                    
            result /= NUM_FOLDS

        with timer('Create submission.csv'):
            submission.loc[:, 'question_asker_intent_understanding':] = result
            submission.to_csv('submission2.csv', index=False)

[Data Loading] start
[Data Loading] done in 0 s
[Num features] start
[Num features] done in 1 s
[Cat features] start
[Cat features] done in 0 s
[Prepare Bert config] start
[Prepare Bert config] done in 0 s
[Inference] start


[(60, 30), (6, 3), (3, 2)]


15it [00:16,  1.10s/it]
15it [00:16,  1.10s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.10s/it]
[Inference] done in 106 s
[Create submission.csv] start
[Create submission.csv] done in 0 s


In [5]:
#===========================================================
# Config
#===========================================================
class PipeLineConfig:
    def __init__(self, lr, warmup, accum_steps, epochs, seed, expname, 
                 head_tail, head, freeze, question_weight, answer_weight, fold, train, cv, test):
        self.lr = lr
        self.warmup = warmup
        self.accum_steps = accum_steps
        self.epochs = epochs
        self.seed = seed
        self.expname = expname
        self.head_tail = head_tail
        self.head = head
        self.freeze = freeze
        self.question_weight = question_weight
        self.answer_weight = answer_weight
        self.fold = fold
        self.train = train
        self.cv = cv
        self.test = test

config = PipeLineConfig(lr=5e-5, warmup=0.1, accum_steps=1, epochs=6,
                        seed=42, expname='uncased_7', head_tail=True, head=0.5, freeze=False,
                        question_weight=0., answer_weight=0., fold=5, train=False, cv=False, test=True)

DEBUG = False
ID = 'qa_id'
target_cols = ['question_asker_intent_understanding', 'question_body_critical', 
               'question_conversational', 'question_expect_short_answer', 
               'question_fact_seeking', 'question_has_commonly_accepted_answer', 
               'question_interestingness_others', 'question_interestingness_self', 
               'question_multi_intent', 'question_not_really_a_question', 
               'question_opinion_seeking', 'question_type_choice',
               'question_type_compare', 'question_type_consequence',
               'question_type_definition', 'question_type_entity', 
               'question_type_instructions', 'question_type_procedure', 
               'question_type_reason_explanation', 'question_type_spelling', 
               'question_well_written', 'answer_helpful',
               'answer_level_of_information', 'answer_plausible', 
               'answer_relevance', 'answer_satisfaction', 
               'answer_type_instructions', 'answer_type_procedure', 
               'answer_type_reason_explanation', 'answer_well_written']
NUM_FOLDS = config.fold
ROOT = '../input/google-quest-challenge/'
#ROOT = '../input/'
SEED = config.seed
seed_everything(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_DIR = '../input/googlequestchallenge-weights3/'
#MODEL_DIR = './'
COMBINE_INPUT = False
T_MAX_LEN = 30
Q_MAX_LEN = 164
A_MAX_LEN = 254
MAX_SEQUENCE_LENGTH = T_MAX_LEN + Q_MAX_LEN + A_MAX_LEN + 4
q_max_sequence_length = T_MAX_LEN + Q_MAX_LEN + 3
a_max_sequence_length = A_MAX_LEN + 2

#===========================================================
# Model
#===========================================================
#===========================================================
# Model
#===========================================================
def _get_masks(tokens, max_seq_length):
    """Mask for padding"""
    if len(tokens)>max_seq_length:
        print(f'len(tokens): {len(tokens)}')
        print(f'max_seq_length: {max_seq_length}')
        raise IndexError("Token length more than max seq length!")
    return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))


def _get_segments(tokens, max_seq_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    
    if len(tokens) > max_seq_length:
        raise IndexError("Token length more than max seq length!")
        
    segments = []
    first_sep = True
    current_segment_id = 0
    
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            if first_sep:
                first_sep = False 
            else:
                current_segment_id = 1
    return segments + [0] * (max_seq_length - len(tokens))


def _get_ids(tokens, tokenizer, max_seq_length):
    """Token ids from Tokenizer vocab"""
    
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
    return input_ids


def _trim_input(tokenizer, title, question, answer, max_sequence_length, t_max_len, q_max_len, a_max_len):
    
    # 350+128+30 = 508 +4 = 512
    
    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)
    a = tokenizer.tokenize(answer)
    
    t_len = len(t)
    q_len = len(q)
    a_len = len(a)

    if (t_len+q_len+a_len+4) > max_sequence_length:
        
        if t_max_len > t_len:
            t_new_len = t_len
            a_max_len = a_max_len + floor((t_max_len - t_len)/2)
            q_max_len = q_max_len + ceil((t_max_len - t_len)/2)
        else:
            t_new_len = t_max_len
      
        if a_max_len > a_len:
            a_new_len = a_len 
            q_new_len = q_max_len + (a_max_len - a_len)
        elif q_max_len > q_len:
            a_new_len = a_max_len + (q_max_len - q_len)
            q_new_len = q_len
        else:
            a_new_len = a_max_len
            q_new_len = q_max_len
            
            
        if t_new_len+a_new_len+q_new_len+4 != max_sequence_length:
            raise ValueError("New sequence length should be %d, but is %d"%(max_sequence_length, (t_new_len + a_new_len + q_new_len + 4)))
        # Head+Tail method 
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)  
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            a = a[:a_len_head]+a[a_len_tail:]
            #t = t[:t_len_head]+t[t_len_tail:]
            t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            a = a[:a_new_len]
            t = t[:t_new_len]
    
    return t, q, a


def q_trim_input(tokenizer, title, question, q_max_sequence_length, t_max_len, q_max_len):

    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)

    t_len = len(t)
    q_len = len(q)

    if (t_len+q_len+3) > q_max_sequence_length:

        if t_max_len > t_len:
            t_new_len = t_len
            q_max_len = q_max_len + (t_max_len - t_len)
        else:
            t_new_len = t_max_len

        if q_max_len > q_len:
            q_new_len = q_len
            t_new_len = t_max_len + (q_max_len - q_len)
        else:
            q_new_len = q_max_len

        # Head+Tail method
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            t = t[:t_len_head]+t[t_len_tail:]
            #t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            t = t[:t_new_len]

    return t, q


def a_trim_input(tokenizer, answer, a_max_sequence_length, a_max_len):

    a = tokenizer.tokenize(answer)

    a_len = len(a)

    if (a_len+2) > a_max_sequence_length:

        a_new_len = a_max_len

        # Head+Tail method
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        if config.head_tail :
            a = a[:a_len_head]+a[a_len_tail:]
        else:
            # No Head+Tail , usual processing
            a = a[:a_new_len]

    return a


def _convert_to_bert_inputs(title, question, answer, tokenizer, max_sequence_length, combine=True):
    """Converts tokenized input to ids, masks and segments for BERT"""
    if combine:
        stoken = ["[CLS]"] + title + ["[QBODY]"] + question + ["[ANS]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title + ["[SEP]"] + question + ["[SEP]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title  + question  + answer + ["[SEP]"]
    
        input_ids = _get_ids(stoken, tokenizer, max_sequence_length)
        input_masks = _get_masks(stoken, max_sequence_length)
        input_segments = _get_segments(stoken, max_sequence_length)

        return [input_ids, input_masks, input_segments]
    else:
        q_token = ["[CLS]"] + title + ["[SEP]"] + question + ["[SEP"]
        q_input_ids = _get_ids(q_token, tokenizer, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_masks = _get_masks(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_segments = _get_segments(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        
        a_token = ["[CLS]"] + answer + ["[SEP]"]
        a_input_ids = _get_ids(a_token, tokenizer, A_MAX_LEN+2)
        a_input_masks = _get_masks(a_token, A_MAX_LEN+2)
        a_input_segments = _get_segments(a_token, A_MAX_LEN+2)

        return [q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments]


def compute_input_arays(df, columns, tokenizer, max_sequence_length, num_features, cat_features, 
                        t_max_len=T_MAX_LEN, q_max_len=Q_MAX_LEN, a_max_len=A_MAX_LEN):
    if COMBINE_INPUT:
        input_ids, input_masks, input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t, q, a = _trim_input(tokenizer, t, q, a, max_sequence_length, t_max_len, q_max_len, a_max_len)
            ids, masks, segments = _convert_to_bert_inputs(t, q, a, tokenizer, max_sequence_length)
            input_ids.append(ids)
            input_masks.append(masks)
            input_segments.append(segments)
        return [
                torch.from_numpy(np.asarray(input_ids, dtype=np.int32)).long(), 
                torch.from_numpy(np.asarray(input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]
    else:
        input_ids, input_masks, input_segments = [], [], []
        q_input_ids, q_input_masks, q_input_segments = [], [], []
        a_input_ids, a_input_masks, a_input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            # all
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t, q, a = _trim_input(tokenizer, t, q, a, max_sequence_length, t_max_len, q_max_len, a_max_len)
            ids, masks, segments = _convert_to_bert_inputs(t, q, a, tokenizer, max_sequence_length, combine=True)
            input_ids.append(ids)
            input_masks.append(masks)
            input_segments.append(segments)
            # q. a
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t, q = q_trim_input(tokenizer, t, q, q_max_sequence_length, t_max_len, q_max_len)
            a = a_trim_input(tokenizer, a, a_max_sequence_length, a_max_len)
            q_ids, q_masks, q_segments, a_ids, a_masks, a_segments = _convert_to_bert_inputs(t, q, a, tokenizer, max_sequence_length, combine=False)
            q_input_ids.append(q_ids)
            q_input_masks.append(q_masks)
            q_input_segments.append(q_segments)
            a_input_ids.append(a_ids)
            a_input_masks.append(a_masks)
            a_input_segments.append(a_segments)
        return [
                torch.from_numpy(np.asarray(input_ids, dtype=np.int32)).long(), 
                torch.from_numpy(np.asarray(input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]


def compute_output_arrays(df, columns):
    return np.asarray(df[columns])


if COMBINE_INPUT:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            input_ids       = self.inputs[0][idx]
            input_masks     = self.inputs[1][idx]
            input_segments  = self.inputs[2][idx]
            num_features    = self.inputs[3][idx]
            cat_features    = self.inputs[4][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return input_ids, input_masks, input_segments, num_features, cat_features, labels, lengths
            return input_ids, input_masks, input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.2)
            n_emb_out = sum([y for x, y in cat_dims])
            self.dropout = nn.Dropout(0.2)
            self.classifier_final = nn.Linear(config.hidden_size+n_emb_out+4, self.config.num_labels)  # num_features=4

            self.init_weights()

        def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            pooled_output = outputs[1]
            pooled_output = self.dropout(pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            pooled_output = torch.cat([pooled_output, num_features, emb], 1)
            logits = self.classifier_final(pooled_output)

            outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)

else:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            input_ids       = self.inputs[0][idx]
            input_masks     = self.inputs[1][idx]
            input_segments  = self.inputs[2][idx]
            q_input_ids       = self.inputs[3][idx]
            q_input_masks     = self.inputs[4][idx]
            q_input_segments  = self.inputs[5][idx]
            a_input_ids       = self.inputs[6][idx]
            a_input_masks     = self.inputs[7][idx]
            a_input_segments  = self.inputs[8][idx]
            num_features    = self.inputs[9][idx]
            cat_features    = self.inputs[10][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return input_ids, input_masks, input_segments, q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, lengths
            return input_ids, input_masks, input_segments, q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.2)
            n_emb_out = sum([y for x, y in cat_dims])
            self.dropout = nn.Dropout(0.2)
            self.q_dropout = nn.Dropout(0.2)
            self.a_dropout = nn.Dropout(0.2)
            self.classifier_final = nn.Linear(config.hidden_size*3+n_emb_out+4, self.config.num_labels)  # num_features=4

            self.init_weights()

        def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            q_input_ids=None,
            q_attention_mask=None,
            q_token_type_ids=None,
            a_input_ids=None,
            a_attention_mask=None,
            a_token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            pooled_output = outputs[1]
            pooled_output = self.dropout(pooled_output)

            q_outputs = self.bert(
                q_input_ids,
                attention_mask=q_attention_mask,
                token_type_ids=q_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            q_pooled_output = q_outputs[1]
            q_pooled_output = self.q_dropout(q_pooled_output)

            a_outputs = self.bert(
                a_input_ids,
                attention_mask=a_attention_mask,
                token_type_ids=a_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            a_pooled_output = a_outputs[1]
            a_pooled_output = self.a_dropout(a_pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            concat_output = torch.cat([pooled_output, q_pooled_output, a_pooled_output, num_features, emb], 1)
            logits = self.classifier_final(concat_output)

            outputs = (logits,) + q_outputs[2:] + a_outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)


def train_model(model, train_loader, optimizer, criterion, scheduler, config):
    
    model.train()
    avg_loss = 0.
    avg_loss_1 = 0.
    avg_loss_2 = 0.
    avg_loss_3 = 0.
    avg_loss_4 = 0.
    avg_loss_5 = 0.
    #tk0 = tqdm(enumerate(train_loader),total =len(train_loader))
    optimizer.zero_grad()
    for idx, batch in enumerate(train_loader):
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
        
            output_train = model(input_ids = input_ids.long(),
                             labels = None,
                             attention_mask = input_masks,
                             token_type_ids = input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        else:
            input_ids, input_masks, input_segments, q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = \
            input_ids.to(device), input_masks.to(device), input_segments.to(device), q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

            output_train = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             input_ids = input_ids.long(),
                             attention_mask = input_masks,
                             token_type_ids = input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        logits = output_train[0] #output preds
        loss = criterion(logits, labels)
        loss.backward()
        if (idx + 1) % config.accum_steps == 0:    
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        avg_loss += loss.item() / (len(train_loader)*config.accum_steps)
        if COMBINE_INPUT:
            del input_ids, input_masks, input_segments, labels
        else:
            del q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, labels

    torch.cuda.empty_cache()
    gc.collect()
    return avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5


def val_model(model, criterion, val_loader, val_shape, batch_size=8):

    avg_val_loss = 0.
    model.eval() # eval mode
    
    valid_preds = np.zeros((val_shape, len(target_cols)))
    original = np.zeros((val_shape, len(target_cols)))
    
    #tk0 = tqdm(enumerate(val_loader))
    with torch.no_grad():
        
        for idx, batch in enumerate(val_loader):
            if COMBINE_INPUT:
                input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
                input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            
                output_val = model(input_ids = input_ids.long(),
                               labels = None,
                               attention_mask = input_masks,
                               token_type_ids = input_segments,
                               num_features = num_features,
                               cat_features = cat_features,
                              )
            else:
                input_ids, input_masks, input_segments, q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
                input_ids, input_masks, input_segments, q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = \
                input_ids.to(device), input_masks.to(device), input_segments.to(device), q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

                output_val = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             input_ids = input_ids.long(),
                             attention_mask = input_masks,
                             token_type_ids = input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
            logits = output_val[0] #output preds
            
            avg_val_loss += criterion(logits, labels).item() / len(val_loader)
            valid_preds[idx*batch_size : (idx+1)*batch_size] = logits.detach().cpu().squeeze().numpy()
            original[idx*batch_size : (idx+1)*batch_size]    = labels.detach().cpu().squeeze().numpy()
        
        score = 0
        preds = torch.sigmoid(torch.tensor(valid_preds)).numpy()
        
        # np.save("preds.npy", preds)
        # np.save("actuals.npy", original)
        
        rho_val = np.mean([spearmanr(original[:, i], preds[:,i]).correlation for i in range(preds.shape[1])])
        print('\r val_spearman-rho: %s' % (str(round(rho_val, 5))), end = 100*' '+'\n')
        
        for i in range(len(target_cols)):
            #print(i, spearmanr(original[:,i], preds[:,i]))
            score += np.nan_to_num(spearmanr(original[:, i], preds[:, i]).correlation)
        
    return avg_val_loss, score/len(target_cols)


def predict_valid_result(model, val_loader, val_length, batch_size=32):

    val_preds = np.zeros((val_length, len(target_cols)))
    original = np.zeros((val_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(val_loader))
    for idx, batch in tk0:
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            with torch.no_grad():
                outputs = model(input_ids = input_ids.long(),
                            labels = None,
                            attention_mask = input_masks,
                            token_type_ids = input_segments,
                            num_features = num_features,
                            cat_features = cat_features,
                            )
        else:
            input_ids, input_masks, input_segments, q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = \
            input_ids.to(device), input_masks.to(device), input_segments.to(device), q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)
            with torch.no_grad():
                outputs = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             input_ids = input_ids.long(),
                             attention_mask = input_masks,
                             token_type_ids = input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )

        predictions = outputs[0]
        val_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()
        original[idx*batch_size : (idx+1)*batch_size] = labels.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(val_preds)).numpy()
    return output, original


def predict_result(model, test_loader, test_length, batch_size=32):

    test_preds = np.zeros((test_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(test_loader))
    for idx, x_batch in tk0:
        if COMBINE_INPUT:
            with torch.no_grad():
                outputs = model(input_ids = x_batch[0].to(device),
                            labels = None,
                            attention_mask = x_batch[1].to(device),
                            token_type_ids = x_batch[2].to(device),
                            num_features = x_batch[3].to(device),
                            cat_features = x_batch[4].to(device),
                           )
        else:
            with torch.no_grad():
                outputs = model(q_input_ids = x_batch[3].to(device),
                            labels = None,
                            q_attention_mask = x_batch[4].to(device),
                            q_token_type_ids = x_batch[5].to(device),
                            a_input_ids = x_batch[6].to(device),
                            a_attention_mask = x_batch[7].to(device),
                            a_token_type_ids = x_batch[8].to(device),
                            input_ids = x_batch[0].to(device),
                            attention_mask = x_batch[1].to(device),
                            token_type_ids = x_batch[2].to(device),
                            num_features = x_batch[9].to(device),
                            cat_features = x_batch[10].to(device),
                           )
        predictions = outputs[0]
        test_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(test_preds)).numpy()
    return output


def add_features(df):
    find = re.compile(r"^[^.]*")
    df['netloc'] = df['url'].apply(lambda x: re.findall(find, urlparse(x).netloc)[0])
    df['qa_same_user_page_flag'] = (df['question_user_page']==df['answer_user_page'])*1
    df['question_title_num_words'] = df['question_title'].str.count('\S+')
    df['question_body_num_words'] = df['question_body'].str.count('\S+')
    df['answer_num_words'] = df['answer'].str.count('\S+')
    df['question_vs_answer_length'] = df['question_body_num_words']/df['answer_num_words']
    df['question_title_num_words'] = np.log1p(df['question_title_num_words'])
    df['question_body_num_words'] = np.log1p(df['question_body_num_words'])
    df['answer_num_words'] = np.log1p(df['answer_num_words'])
    df['question_vs_answer_length'] = np.log1p(df['question_vs_answer_length'])
    return df


def custom_loss(x, y):
    #vx = x - torch.mean(x)
    #vy = y - torch.mean(y)
    #pearson_loss = 1 - (torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2))))
    bce_loss = nn.BCEWithLogitsLoss()(x, y)
    #loss = 0.1*pearson_loss + 0.9*bce_loss
    return bce_loss # loss

#===========================================================
# main
#===========================================================
#def main():
if True:
    
    with timer('Data Loading'):
        train = pd.read_csv(f"{ROOT}train.csv").fillna("none")
        y_train = train[target_cols].values
        if config.test:
            test = pd.read_csv(f"{ROOT}test.csv").fillna("none")
            submission = pd.read_csv(f"{ROOT}sample_submission.csv")
    
    with timer('Num features'):
        train = add_features(train)
        if config.test:
            test = add_features(test)
        num_features = ['question_title_num_words', 'question_body_num_words', 'answer_num_words', 'question_vs_answer_length']
        train_num = train[num_features].values
        if config.test:
            test_num = test[num_features].values
                
    with timer('Cat features'):
        cat_features = ['netloc', 'category', 'qa_same_user_page_flag']
        """
        ce_ohe = ce.OneHotEncoder(cols=features, handle_unknown='impute')
        ce_ohe.fit(train[features])
        #train_ohe = ce_ohe.transform(train[features]).values
        #test_ohe = ce_ohe.transform(test[features]).values
        train_ohe = pd.concat([_train_ohe, ce_ohe.transform(train[features])], axis=1).values
        test_ohe = pd.concat([_test_ohe, ce_ohe.transform(test[features])], axis=1).values
        """
        ce_oe = ce.OrdinalEncoder(cols=cat_features, handle_unknown='return_nan')
        ce_oe.fit(train[cat_features])
        train_cat_df = ce_oe.transform(train[cat_features])
        test_cat_df = ce_oe.transform(test[cat_features]).fillna(0).astype(int)
        #cat_df = pd.concat([train_cat_df, test_cat_df])
        train_cat = train_cat_df.values
        test_cat = test_cat_df.values
        cat_dims = []
        for col in cat_features:
            #print(cat_df[col].unique())
            #dim = cat_df[col].nunique()
            dim = train[col].nunique()
            #cat_dims.append((dim, dim//2+1))
            cat_dims.append((dim+1, dim//2+1)) # for unknown=0
        print(cat_dims)

    if config.train:
        with timer('Create folds'):
            folds = train.copy()

            kf = MultilabelStratifiedKFold(n_splits=NUM_FOLDS, random_state=SEED)
            for fold, (train_index, val_index) in enumerate(kf.split(train.values, y_train)):
                folds.loc[val_index, 'fold'] = int(fold)
            """
            # less gap between CV vs LB with GroupKFold
            # https://www.kaggle.com/ratthachat/quest-cv-analysis-on-different-splitting-methods
            kf = GroupKFold(n_splits=NUM_FOLDS)
            for fold, (train_index, val_index) in enumerate(kf.split(X=train.question_body, groups=train.question_body)):
                folds.loc[val_index, 'fold'] = int(fold)
            """
            folds['fold'] = folds['fold'].astype(int)
            save_cols = [ID] + target_cols + ['fold']
            folds[save_cols].to_csv('folds.csv', index=None)

    with timer('Prepare Bert config'):
        tokenizer = BertTokenizer.from_pretrained("../input/pretrained-bert-models-for-pytorch/bert-base-uncased-vocab.txt", 
                                                  do_lower_case=True)
        input_categories = ['question_title', 'question_body', 'answer']
        bert_model_config = '../input/pretrained-bert-models-for-pytorch/bert-base-uncased/bert_config.json'
        bert_config = BertConfig.from_json_file(bert_model_config)
        bert_config.num_labels = len(target_cols)
        bert_model = 'bert-base-uncased'
        do_lower_case = 'uncased' in bert_model
        output_model_file = 'bert_pytorch.bin'
    
    if config.train:

        BATCH_SIZE = 4
        if DEBUG:
            epochs = 1
        else:
            epochs = config.epochs
        ACCUM_STEPS = config.accum_steps

        with timer('Train Bert'):
            
            for fold in range(NUM_FOLDS):

                logger.info(f"Current Fold: {fold}")
                train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index

                train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                logger.info(f"Train Shapes: {train_df.shape}")
                logger.info(f"Valid Shapes: {val_df.shape}")
            
                logger.info("Preparing train datasets....")
            
                inputs_train = compute_input_arays(train_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[train_index], cat_features=train_cat[train_index])
                outputs_train = compute_output_arrays(train_df, columns=target_cols)
                outputs_train = torch.tensor(outputs_train, dtype=torch.float32)
                lengths_train = np.argmax(inputs_train[0]==0, axis=1)
                lengths_train[lengths_train==0] = inputs_train[0].shape[1]
            
                logger.info("Preparing valid datasets....")
            
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
            
                logger.info("Preparing Dataloaders Datasets....")

                train_set = QuestDataset(inputs=inputs_train, lengths=lengths_train, labels=outputs_train)
                train_sampler = RandomSampler(train_set)
                train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,sampler=train_sampler)
            
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
            
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                torch.cuda.empty_cache()
                if config.freeze : ## This is basically using out of the box bert model while training only the classifier head with our data . 
                    for param in model.bert.parameters():
                        param.requires_grad = False
                model.train()
            
                i = 0
                best_avg_loss = 100.0
                best_score = -1.
                best_param_loss = None
                best_param_score = None
                param_optimizer = list(model.named_parameters())
                no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
                optimizer_grouped_parameters = [
                    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
                    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
                    ]        
                optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr, eps=4e-5)
                #optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, eps=4e-5)
                #criterion = nn.BCEWithLogitsLoss()
                criterion = custom_loss
                scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, num_training_steps=epochs*len(train_loader)//ACCUM_STEPS)
                logger.info("Training....")
            
                for epoch in tqdm(range(epochs)):

                    torch.cuda.empty_cache()
                
                    start_time   = time.time()
                    avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5 = train_model(model, train_loader, optimizer, criterion, scheduler, config)
                    avg_val_loss, score = val_model(model, criterion, valid_loader, val_shape=val_df.shape[0], batch_size=BATCH_SIZE)
                    elapsed_time = time.time() - start_time

                    logger.info('Epoch {}/{} \t loss={:.4f} \t val_loss={:.4f} \t train_loss={:.4f} \t train_loss_1={:.4f} \t train_loss_2={:.4f} \t train_loss_3={:.4f} \t train_loss_4={:.4f}  \t train_loss_5={:.4f} \t score={:.6f} \t time={:.2f}s'.format(
                        epoch+1, epochs, avg_loss, avg_val_loss, avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5, score, elapsed_time))

                    if best_avg_loss > avg_val_loss:
                        i = 0
                        best_avg_loss = avg_val_loss 
                        best_param_loss = model.state_dict()

                    if best_score < score:
                        best_score = score
                        best_param_score = model.state_dict()
                        logger.info('best_param_score_{}_{}.pt'.format(config.expname ,fold))
                        torch.save(best_param_score, 'best_param_score_{}_{}.pt'.format(config.expname, fold))
                    else:
                        i += 1

            del train_df, val_df, model, optimizer, criterion, scheduler
            del valid_loader, train_loader, valid_set, train_set
            torch.cuda.empty_cache()
            gc.collect()
    
    if config.cv:

        with timer('CV'):

            folds = pd.read_csv(f'{MODEL_DIR}folds.csv')
            results = np.zeros((len(train), len(target_cols)))
            logits = np.zeros((len(train), len(target_cols)))

            for fold in range(NUM_FOLDS):
                
                #train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index
                #train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                val_df = train.iloc[val_index]
                
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=32, shuffle=False, drop_last=False)
                
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result, logit = predict_valid_result(model, valid_loader, len(val_df))  
                results[val_index, :] = result
                logits[val_index, :] = logit 
            
            rho_val = np.mean([spearmanr(logits[:,i], results[:,i]).correlation for i in range(results.shape[1])])
            logger.info(f'CV spearman-rho: {round(rho_val, 5)}')

            oof = pd.DataFrame()
            for i, col in enumerate(target_cols):
                oof[col] = results[:,i]
            oof.to_csv('oof_{config.expname}.csv', index=False)
    
    if config.test:

        with timer('Inference'):

            test_inputs = compute_input_arays(test, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                              num_features=test_num, cat_features=test_cat)
            lengths_test = np.argmax(test_inputs[0] == 0, axis=1)
            lengths_test[lengths_test == 0] = test_inputs[0].shape[1]
            test_set = QuestDataset(inputs=test_inputs, lengths=lengths_test, labels=None)
            test_loader  = DataLoader(test_set, batch_size=32, shuffle=False)
            result = np.zeros((len(test), len(target_cols)))

            for fold in range(NUM_FOLDS):
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result += predict_result(model, test_loader, len(test)) 
                if DEBUG:
                    break
                    
            result /= NUM_FOLDS

        with timer('Create submission.csv'):
            submission.loc[:, 'question_asker_intent_understanding':] = result
            submission.to_csv('submission3.csv', index=False)

[Data Loading] start
[Data Loading] done in 0 s
[Num features] start
[Num features] done in 1 s
[Cat features] start
[Cat features] done in 0 s
[Prepare Bert config] start
[Prepare Bert config] done in 0 s
[Inference] start


[(60, 30), (6, 3), (3, 2)]


15it [00:14,  1.07it/s]
15it [00:14,  1.07it/s]
15it [00:14,  1.07it/s]
15it [00:14,  1.07it/s]
15it [00:14,  1.07it/s]
[Inference] done in 100 s
[Create submission.csv] start
[Create submission.csv] done in 0 s


In [6]:
#===========================================================
# Config
#===========================================================
class PipeLineConfig:
    def __init__(self, lr, warmup, accum_steps, epochs, seed, expname, 
                 head_tail, head, freeze, question_weight, answer_weight, fold, train, cv, test):
        self.lr = lr
        self.warmup = warmup
        self.accum_steps = accum_steps
        self.epochs = epochs
        self.seed = seed
        self.expname = expname
        self.head_tail = head_tail
        self.head = head
        self.freeze = freeze
        self.question_weight = question_weight
        self.answer_weight = answer_weight
        self.fold = fold
        self.train = train
        self.cv = cv
        self.test = test

config = PipeLineConfig(lr=5e-5, warmup=0.1, accum_steps=1, epochs=6,
                        seed=42, expname='cased_1', head_tail=True, head=0.5, freeze=False,
                        question_weight=0., answer_weight=0., fold=5, train=False, cv=False, test=True)

DEBUG = False
ID = 'qa_id'
target_cols = ['question_asker_intent_understanding', 'question_body_critical', 
               'question_conversational', 'question_expect_short_answer', 
               'question_fact_seeking', 'question_has_commonly_accepted_answer', 
               'question_interestingness_others', 'question_interestingness_self', 
               'question_multi_intent', 'question_not_really_a_question', 
               'question_opinion_seeking', 'question_type_choice',
               'question_type_compare', 'question_type_consequence',
               'question_type_definition', 'question_type_entity', 
               'question_type_instructions', 'question_type_procedure', 
               'question_type_reason_explanation', 'question_type_spelling', 
               'question_well_written', 'answer_helpful',
               'answer_level_of_information', 'answer_plausible', 
               'answer_relevance', 'answer_satisfaction', 
               'answer_type_instructions', 'answer_type_procedure', 
               'answer_type_reason_explanation', 'answer_well_written']
NUM_FOLDS = config.fold
ROOT = '../input/google-quest-challenge/'
#ROOT = '../input/'
SEED = config.seed
seed_everything(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_DIR = '../input/googlequestchallenge-weights4/'
#MODEL_DIR = './'
COMBINE_INPUT = False
T_MAX_LEN = 30
Q_MAX_LEN = 479 # 382
A_MAX_LEN = 510 # 254 
MAX_SEQUENCE_LENGTH = T_MAX_LEN + Q_MAX_LEN + A_MAX_LEN + 4
q_max_sequence_length = T_MAX_LEN + Q_MAX_LEN + 3
a_max_sequence_length = A_MAX_LEN + 2

#===========================================================
# Model
#===========================================================
def _get_masks(tokens, max_seq_length):
    """Mask for padding"""
    if len(tokens)>max_seq_length:
        print(f'len(tokens): {len(tokens)}')
        print(f'max_seq_length: {max_seq_length}')
        raise IndexError("Token length more than max seq length!")
    return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))


def _get_segments(tokens, max_seq_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    
    if len(tokens) > max_seq_length:
        raise IndexError("Token length more than max seq length!")
        
    segments = []
    first_sep = True
    current_segment_id = 0
    
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            if first_sep:
                first_sep = False 
            else:
                current_segment_id = 1
    return segments + [0] * (max_seq_length - len(tokens))


def _get_ids(tokens, tokenizer, max_seq_length):
    """Token ids from Tokenizer vocab"""
    
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
    return input_ids


def _trim_input(tokenizer, title, question, answer, max_sequence_length, t_max_len, q_max_len, a_max_len):
    
    # 350+128+30 = 508 +4 = 512
    
    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)
    a = tokenizer.tokenize(answer)
    
    t_len = len(t)
    q_len = len(q)
    a_len = len(a)

    if (t_len+q_len+a_len+4) > max_sequence_length:
        
        if t_max_len > t_len:
            t_new_len = t_len
            a_max_len = a_max_len + floor((t_max_len - t_len)/2)
            q_max_len = q_max_len + ceil((t_max_len - t_len)/2)
        else:
            t_new_len = t_max_len
      
        if a_max_len > a_len:
            a_new_len = a_len 
            q_new_len = q_max_len + (a_max_len - a_len)
        elif q_max_len > q_len:
            a_new_len = a_max_len + (q_max_len - q_len)
            q_new_len = q_len
        else:
            a_new_len = a_max_len
            q_new_len = q_max_len
            
            
        if t_new_len+a_new_len+q_new_len+4 != max_sequence_length:
            raise ValueError("New sequence length should be %d, but is %d"%(max_sequence_length, (t_new_len + a_new_len + q_new_len + 4)))
        # Head+Tail method 
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)  
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            a = a[:a_len_head]+a[a_len_tail:]
            #t = t[:t_len_head]+t[t_len_tail:]
            t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            a = a[:a_new_len]
            t = t[:t_new_len]
    
    return t, q, a


def q_trim_input(tokenizer, title, question, q_max_sequence_length, t_max_len, q_max_len):

    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)

    t_len = len(t)
    q_len = len(q)

    if (t_len+q_len+3) > q_max_sequence_length:

        if t_max_len > t_len:
            t_new_len = t_len
            q_max_len = q_max_len + (t_max_len - t_len)
        else:
            t_new_len = t_max_len

        if q_max_len > q_len:
            q_new_len = q_len
            t_new_len = t_max_len + (q_max_len - q_len)
        else:
            q_new_len = q_max_len

        # Head+Tail method
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            t = t[:t_len_head]+t[t_len_tail:]
            #t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            t = t[:t_new_len]

    return t, q


def a_trim_input(tokenizer, answer, a_max_sequence_length, a_max_len):

    a = tokenizer.tokenize(answer)

    a_len = len(a)

    if (a_len+2) > a_max_sequence_length:

        a_new_len = a_max_len

        # Head+Tail method
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        if config.head_tail :
            a = a[:a_len_head]+a[a_len_tail:]
        else:
            # No Head+Tail , usual processing
            a = a[:a_new_len]

    return a


def _convert_to_bert_inputs(title, question, answer, tokenizer, max_sequence_length):
    """Converts tokenized input to ids, masks and segments for BERT"""
    if COMBINE_INPUT:
        stoken = ["[CLS]"] + title + ["[QBODY]"] + question + ["[ANS]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title + ["[SEP]"] + question + ["[SEP]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title  + question  + answer + ["[SEP]"]
    
        input_ids = _get_ids(stoken, tokenizer, max_sequence_length)
        input_masks = _get_masks(stoken, max_sequence_length)
        input_segments = _get_segments(stoken, max_sequence_length)

        return [input_ids, input_masks, input_segments]
    else:
        q_token = ["[CLS]"] + title + ["[SEP]"] + question + ["[SEP"]
        q_input_ids = _get_ids(q_token, tokenizer, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_masks = _get_masks(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_segments = _get_segments(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        
        a_token = ["[CLS]"] + answer + ["[SEP]"]
        a_input_ids = _get_ids(a_token, tokenizer, A_MAX_LEN+2)
        a_input_masks = _get_masks(a_token, A_MAX_LEN+2)
        a_input_segments = _get_segments(a_token, A_MAX_LEN+2)

        return [q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments]


def compute_input_arays(df, columns, tokenizer, max_sequence_length, num_features, cat_features, 
                        t_max_len=T_MAX_LEN, q_max_len=Q_MAX_LEN, a_max_len=A_MAX_LEN):
    if COMBINE_INPUT:
        input_ids, input_masks, input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t, q, a = _trim_input(tokenizer, t, q, a, max_sequence_length, t_max_len, q_max_len, a_max_len)
            ids, masks, segments = _convert_to_bert_inputs(t, q, a, tokenizer, max_sequence_length)
            input_ids.append(ids)
            input_masks.append(masks)
            input_segments.append(segments)
        return [
                torch.from_numpy(np.asarray(input_ids, dtype=np.int32)).long(), 
                torch.from_numpy(np.asarray(input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]
    else:
        q_input_ids, q_input_masks, q_input_segments = [], [], []
        a_input_ids, a_input_masks, a_input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t, q = q_trim_input(tokenizer, t, q, q_max_sequence_length, t_max_len, q_max_len)
            a = a_trim_input(tokenizer, a, a_max_sequence_length, a_max_len)
            q_ids, q_masks, q_segments, a_ids, a_masks, a_segments = _convert_to_bert_inputs(t, q, a, tokenizer, max_sequence_length)
            q_input_ids.append(q_ids)
            q_input_masks.append(q_masks)
            q_input_segments.append(q_segments)
            a_input_ids.append(a_ids)
            a_input_masks.append(a_masks)
            a_input_segments.append(a_segments)
        return [
                torch.from_numpy(np.asarray(q_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]


def compute_output_arrays(df, columns):
    return np.asarray(df[columns])


if COMBINE_INPUT:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            input_ids       = self.inputs[0][idx]
            input_masks     = self.inputs[1][idx]
            input_segments  = self.inputs[2][idx]
            num_features    = self.inputs[3][idx]
            cat_features    = self.inputs[4][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return input_ids, input_masks, input_segments, num_features, cat_features, labels, lengths
            return input_ids, input_masks, input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.2)
            n_emb_out = sum([y for x, y in cat_dims])
            self.dropout = nn.Dropout(0.2)
            self.classifier_final = nn.Linear(config.hidden_size+n_emb_out+4, self.config.num_labels)  # num_features=4

            self.init_weights()

        def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            pooled_output = outputs[1]
            pooled_output = self.dropout(pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            pooled_output = torch.cat([pooled_output, num_features, emb], 1)
            logits = self.classifier_final(pooled_output)

            outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)

else:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            q_input_ids       = self.inputs[0][idx]
            q_input_masks     = self.inputs[1][idx]
            q_input_segments  = self.inputs[2][idx]
            a_input_ids       = self.inputs[3][idx]
            a_input_masks     = self.inputs[4][idx]
            a_input_segments  = self.inputs[5][idx]
            num_features    = self.inputs[6][idx]
            cat_features    = self.inputs[7][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, lengths
            return q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.1)
            n_emb_out = sum([y for x, y in cat_dims])
            self.num_drop = nn.Dropout(0.1)
            self.q_dropout = nn.Dropout(0.1)
            self.a_dropout = nn.Dropout(0.1)
            #self.dropout_all = nn.Dropout(0.2)
            #self.dropout_a = nn.Dropout(0.2)
            #self.dropout_q = nn.Dropout(0.2)
            #self.classifier_all = nn.Linear(config.hidden_size*2+n_emb_out+4, 64)  # num_features=4
            #self.classifier_all = nn.Sequential(
            #    nn.Linear(config.hidden_size*2+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            #self.classifier_a = nn.Linear(config.hidden_size+n_emb_out+4, 64)  # num_features=4
            #self.classifier_a = nn.Sequential(
            #    nn.Linear(config.hidden_size+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            #self.classifier_q = nn.Linear(config.hidden_size+n_emb_out+4, 64)  # num_features=4
            #self.classifier_q = nn.Sequential(
            #    nn.Linear(config.hidden_size+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            self.classifier_final = nn.Linear(config.hidden_size*2+n_emb_out+4, self.config.num_labels)
            #self.classifier_final = nn.Linear(64*3, self.config.num_labels)  # num_features=4
            #self.classifier_final = nn.Sequential(
            #    nn.BatchNorm1d(64*3),
            #    nn.Linear(64*3, self.config.num_labels),
            #)
            self.init_weights()

        def forward(
            self,
            q_input_ids=None,
            q_attention_mask=None,
            q_token_type_ids=None,
            a_input_ids=None,
            a_attention_mask=None,
            a_token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            q_outputs = self.bert(
                q_input_ids,
                attention_mask=q_attention_mask,
                token_type_ids=q_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            q_pooled_output = q_outputs[1]
            q_pooled_output = self.q_dropout(q_pooled_output)

            a_outputs = self.bert(
                a_input_ids,
                attention_mask=a_attention_mask,
                token_type_ids=a_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            a_pooled_output = a_outputs[1]
            a_pooled_output = self.a_dropout(a_pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            num_features = self.num_drop(num_features)

            pooled_output = torch.cat([q_pooled_output, a_pooled_output, num_features, emb], 1)
            #all_logits = self.classifier_all(pooled_output)
            #all_logits = self.dropout_all(all_logits)
            logits = self.classifier_final(pooled_output)
            
            #a_pooled_output = torch.cat([a_pooled_output, num_features, emb], 1)
            #a_logits = self.classifier_a(a_pooled_output)
            #a_logits = self.dropout_a(a_logits)

            #q_pooled_output = torch.cat([q_pooled_output, num_features, emb], 1)
            #q_logits = self.classifier_q(q_pooled_output)
            #q_logits = self.dropout_q(q_logits)

            #concat_logits = torch.cat([all_logits, q_logits, a_logits], 1)
            #logits = self.classifier_final(concat_logits)

            #logits = torch.cat([q_logits, a_logits], 1)

            outputs = (logits,) + q_outputs[2:] + a_outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)


def train_model(model, train_loader, optimizer, criterion, scheduler, config):
    
    model.train()
    avg_loss = 0.
    avg_loss_1 = 0.
    avg_loss_2 = 0.
    avg_loss_3 = 0.
    avg_loss_4 = 0.
    avg_loss_5 = 0.
    #tk0 = tqdm(enumerate(train_loader),total =len(train_loader))
    optimizer.zero_grad()
    for idx, batch in enumerate(train_loader):
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
        
            output_train = model(input_ids = input_ids.long(),
                             labels = None,
                             attention_mask = input_masks,
                             token_type_ids = input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

            output_train = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        logits = output_train[0] #output preds
        loss = criterion(logits, labels)
        loss.backward()
        if (idx + 1) % config.accum_steps == 0:    
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        avg_loss += loss.item() / (len(train_loader)*config.accum_steps)
        if COMBINE_INPUT:
            del input_ids, input_masks, input_segments, labels
        else:
            del q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, labels

    torch.cuda.empty_cache()
    gc.collect()
    return avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5


def val_model(model, criterion, val_loader, val_shape, batch_size=8):

    avg_val_loss = 0.
    model.eval() # eval mode
    
    valid_preds = np.zeros((val_shape, len(target_cols)))
    original = np.zeros((val_shape, len(target_cols)))
    
    #tk0 = tqdm(enumerate(val_loader))
    with torch.no_grad():
        
        for idx, batch in enumerate(val_loader):
            if COMBINE_INPUT:
                input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
                input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            
                output_val = model(input_ids = input_ids.long(),
                               labels = None,
                               attention_mask = input_masks,
                               token_type_ids = input_segments,
                               num_features = num_features,
                               cat_features = cat_features,
                              )
            else:
                q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
                q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

                output_val = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
            logits = output_val[0] #output preds
            
            avg_val_loss += criterion(logits, labels).item() / len(val_loader)
            valid_preds[idx*batch_size : (idx+1)*batch_size] = logits.detach().cpu().squeeze().numpy()
            original[idx*batch_size : (idx+1)*batch_size]    = labels.detach().cpu().squeeze().numpy()
        
        score = 0
        preds = torch.sigmoid(torch.tensor(valid_preds)).numpy()
        
        # np.save("preds.npy", preds)
        # np.save("actuals.npy", original)
        
        rho_val = np.mean([spearmanr(original[:, i], preds[:,i]).correlation for i in range(preds.shape[1])])
        print('\r val_spearman-rho: %s' % (str(round(rho_val, 5))), end = 100*' '+'\n')
        
        for i in range(len(target_cols)):
            logger.info(f"{i}, {spearmanr(original[:,i], preds[:,i])}")
            score += np.nan_to_num(spearmanr(original[:, i], preds[:, i]).correlation)
        
    return avg_val_loss, score/len(target_cols)


def predict_valid_result(model, val_loader, val_length, batch_size=32):

    val_preds = np.zeros((val_length, len(target_cols)))
    original = np.zeros((val_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(val_loader))
    for idx, batch in tk0:
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            with torch.no_grad():
                outputs = model(input_ids = input_ids.long(),
                            labels = None,
                            attention_mask = input_masks,
                            token_type_ids = input_segments,
                            num_features = num_features,
                            cat_features = cat_features,
                            )
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)
            with torch.no_grad():
                outputs = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )

        predictions = outputs[0]
        val_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()
        original[idx*batch_size : (idx+1)*batch_size] = labels.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(val_preds)).numpy()
    return output, original


def predict_result(model, test_loader, test_length, batch_size=32):

    test_preds = np.zeros((test_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(test_loader))
    for idx, x_batch in tk0:
        if COMBINE_INPUT:
            with torch.no_grad():
                outputs = model(input_ids = x_batch[0].to(device),
                            labels = None,
                            attention_mask = x_batch[1].to(device),
                            token_type_ids = x_batch[2].to(device),
                            num_features = x_batch[3].to(device),
                            cat_features = x_batch[4].to(device),
                           )
        else:
            with torch.no_grad():
                outputs = model(q_input_ids = x_batch[0].to(device),
                            labels = None,
                            q_attention_mask = x_batch[1].to(device),
                            q_token_type_ids = x_batch[2].to(device),
                            a_input_ids = x_batch[3].to(device),
                            a_attention_mask = x_batch[4].to(device),
                            a_token_type_ids = x_batch[5].to(device),
                            num_features = x_batch[6].to(device),
                            cat_features = x_batch[7].to(device),
                           )
        predictions = outputs[0]
        test_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(test_preds)).numpy()
    return output


def add_features(df):
    find = re.compile(r"^[^.]*")
    df['netloc'] = df['url'].apply(lambda x: re.findall(find, urlparse(x).netloc)[0])
    df['qa_same_user_page_flag'] = (df['question_user_page']==df['answer_user_page'])*1
    df['question_title_num_words'] = df['question_title'].str.count('\S+')
    df['question_body_num_words'] = df['question_body'].str.count('\S+')
    df['answer_num_words'] = df['answer'].str.count('\S+')
    df['question_vs_answer_length'] = df['question_body_num_words']/df['answer_num_words']
    df['question_title_num_words'] = np.log1p(df['question_title_num_words'])
    df['question_body_num_words'] = np.log1p(df['question_body_num_words'])
    df['answer_num_words'] = np.log1p(df['answer_num_words'])
    df['question_vs_answer_length'] = np.log1p(df['question_vs_answer_length'])
    return df


def custom_loss(logits, labels):
    #q_loss = nn.BCEWithLogitsLoss()(logits[:,:21], labels[:,:21])
    #a_loss = nn.BCEWithLogitsLoss()(logits[:,21:], labels[:,21:])
    #custom_loss = 0.5*q_loss + 0.5*a_loss
    custom_loss = nn.BCEWithLogitsLoss()(logits, labels)
    #loss1 = nn.BCEWithLogitsLoss()(logits[:,0:19], labels[:,0:19])
    #loss2 = nn.BCEWithLogitsLoss()(logits[:,20:], labels[:,20:]) # except index=19
    #custom_loss = loss1 + loss2
    #custom_loss = 0.
    #for i in range(len(loss_sample_weights)):
    #    custom_loss += loss_sample_weights[i] * nn.BCEWithLogitsLoss()(logits[:,i], labels[:,i])
    return custom_loss


#===========================================================
# main
#===========================================================
#def main():
if True:
    
    with timer('Data Loading'):
        train = pd.read_csv(f"{ROOT}train.csv").fillna("none")
        y_train = train[target_cols].values
        if config.test:
            test = pd.read_csv(f"{ROOT}test.csv").fillna("none")
            submission = pd.read_csv(f"{ROOT}sample_submission.csv")
    
    with timer('Num features'):
        train = add_features(train)
        if config.test:
            test = add_features(test)
        num_features = ['question_title_num_words', 'question_body_num_words', 'answer_num_words', 'question_vs_answer_length']
        train_num = train[num_features].values
        if config.test:
            test_num = test[num_features].values
                
    with timer('Cat features'):
        cat_features = ['netloc', 'category', 'qa_same_user_page_flag']
        """
        ce_ohe = ce.OneHotEncoder(cols=features, handle_unknown='impute')
        ce_ohe.fit(train[features])
        #train_ohe = ce_ohe.transform(train[features]).values
        #test_ohe = ce_ohe.transform(test[features]).values
        train_ohe = pd.concat([_train_ohe, ce_ohe.transform(train[features])], axis=1).values
        test_ohe = pd.concat([_test_ohe, ce_ohe.transform(test[features])], axis=1).values
        """
        ce_oe = ce.OrdinalEncoder(cols=cat_features, handle_unknown='return_nan')
        ce_oe.fit(train[cat_features])
        train_cat_df = ce_oe.transform(train[cat_features])
        test_cat_df = ce_oe.transform(test[cat_features]).fillna(0).astype(int)
        #cat_df = pd.concat([train_cat_df, test_cat_df])
        train_cat = train_cat_df.values
        test_cat = test_cat_df.values
        cat_dims = []
        for col in cat_features:
            #print(cat_df[col].unique())
            #dim = cat_df[col].nunique()
            dim = train[col].nunique()
            #cat_dims.append((dim, dim//2+1))
            cat_dims.append((dim+1, dim//2+1)) # for unknown=0
        print(cat_dims)

    if config.train:
        with timer('Create folds'):
            folds = train.copy()

            kf = MultilabelStratifiedKFold(n_splits=NUM_FOLDS, random_state=SEED)
            for fold, (train_index, val_index) in enumerate(kf.split(train.values, y_train)):
                folds.loc[val_index, 'fold'] = int(fold)
            """
            # less gap between CV vs LB with GroupKFold
            # https://www.kaggle.com/ratthachat/quest-cv-analysis-on-different-splitting-methods
            kf = GroupKFold(n_splits=NUM_FOLDS)
            for fold, (train_index, val_index) in enumerate(kf.split(X=train.question_body, groups=train.question_body)):
                folds.loc[val_index, 'fold'] = int(fold)
            """
            folds['fold'] = folds['fold'].astype(int)
            save_cols = [ID] + target_cols + ['fold']
            folds[save_cols].to_csv('folds.csv', index=None)

    with timer('Prepare Bert config'):
        tokenizer = BertTokenizer.from_pretrained("../input/pretrained-bert-models-for-pytorch/bert-base-cased-vocab.txt", 
                                                  do_lower_case=True)
        input_categories = ['question_title', 'question_body', 'answer']
        bert_model_config = '../input/pretrained-bert-models-for-pytorch/bert-base-cased/bert_config.json'
        bert_config = BertConfig.from_json_file(bert_model_config)
        bert_config.num_labels = len(target_cols)
        bert_model = 'bert-base-cased'
        do_lower_case = 'cased' in bert_model
        output_model_file = 'bert_pytorch.bin'
    
    if config.train:

        BATCH_SIZE = 8
        if DEBUG:
            epochs = 1
        else:
            epochs = config.epochs
        ACCUM_STEPS = config.accum_steps

        with timer('Train Bert'):
            
            for fold in range(NUM_FOLDS):

                logger.info(f"Current Fold: {fold}")
                train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index

                train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                logger.info(f"Train Shapes: {train_df.shape}")
                logger.info(f"Valid Shapes: {val_df.shape}")
            
                logger.info("Preparing train datasets....")
            
                inputs_train = compute_input_arays(train_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[train_index], cat_features=train_cat[train_index])
                outputs_train = compute_output_arrays(train_df, columns=target_cols)
                outputs_train = torch.tensor(outputs_train, dtype=torch.float32)
                lengths_train = np.argmax(inputs_train[0]==0, axis=1)
                lengths_train[lengths_train==0] = inputs_train[0].shape[1]
            
                logger.info("Preparing valid datasets....")
            
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
            
                logger.info("Preparing Dataloaders Datasets....")

                train_set = QuestDataset(inputs=inputs_train, lengths=lengths_train, labels=outputs_train)
                train_sampler = RandomSampler(train_set)
                train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,sampler=train_sampler)
            
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
            
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-cased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                torch.cuda.empty_cache()
                if config.freeze : ## This is basically using out of the box bert model while training only the classifier head with our data . 
                    for param in model.bert.parameters():
                        param.requires_grad = False
                model.train()
            
                i = 0
                best_avg_loss = 100.0
                best_score = -1.
                best_param_loss = None
                best_param_score = None
                param_optimizer = list(model.named_parameters())
                no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
                optimizer_grouped_parameters = [
                    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
                    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
                    ]        
                optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr, eps=4e-5)
                #optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, eps=4e-5)
                #criterion = nn.BCEWithLogitsLoss()
                criterion = custom_loss
                scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, num_training_steps=epochs*len(train_loader)//ACCUM_STEPS)
                logger.info("Training....")
            
                for epoch in tqdm(range(epochs)):

                    torch.cuda.empty_cache()
                
                    start_time   = time.time()
                    avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5 = train_model(model, train_loader, optimizer, criterion, scheduler, config)
                    avg_val_loss, score = val_model(model, criterion, valid_loader, val_shape=val_df.shape[0], batch_size=BATCH_SIZE)
                    elapsed_time = time.time() - start_time

                    logger.info('Epoch {}/{} \t loss={:.4f} \t val_loss={:.4f} \t train_loss={:.4f} \t train_loss_1={:.4f} \t train_loss_2={:.4f} \t train_loss_3={:.4f} \t train_loss_4={:.4f}  \t train_loss_5={:.4f} \t score={:.6f} \t time={:.2f}s'.format(
                        epoch+1, epochs, avg_loss, avg_val_loss, avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5, score, elapsed_time))

                    if best_avg_loss > avg_val_loss:
                        i = 0
                        best_avg_loss = avg_val_loss 
                        best_param_loss = model.state_dict()

                    if best_score < score:
                        best_score = score
                        best_param_score = model.state_dict()
                        logger.info('best_param_score_{}_{}.pt'.format(config.expname ,fold))
                        torch.save(best_param_score, 'best_param_score_{}_{}.pt'.format(config.expname, fold))
                    else:
                        i += 1

            del train_df, val_df, model, optimizer, criterion, scheduler
            del valid_loader, train_loader, valid_set, train_set
            torch.cuda.empty_cache()
            gc.collect()
    
    if config.cv:

        with timer('CV'):

            folds = pd.read_csv(f'{MODEL_DIR}folds.csv')
            results = np.zeros((len(train), len(target_cols)))
            logits = np.zeros((len(train), len(target_cols)))

            for fold in range(NUM_FOLDS):
                
                #train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index
                #train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                val_df = train.iloc[val_index]
                
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=32, shuffle=False, drop_last=False)
                
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-cased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result, logit = predict_valid_result(model, valid_loader, len(val_df))  
                results[val_index, :] = result
                logits[val_index, :] = logit 
            
            rho_val = np.mean([spearmanr(logits[:,i], results[:,i]).correlation for i in range(results.shape[1])])
            logger.info(f'CV spearman-rho: {round(rho_val, 5)}')
            
            oof = pd.DataFrame()
            for i, col in enumerate(target_cols):
                oof[col] = results[:,i]
            oof.to_csv('oof.csv', index=False)
    
    if config.test:

        with timer('Inference'):

            test_inputs = compute_input_arays(test, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                              num_features=test_num, cat_features=test_cat)
            lengths_test = np.argmax(test_inputs[0] == 0, axis=1)
            lengths_test[lengths_test == 0] = test_inputs[0].shape[1]
            test_set = QuestDataset(inputs=test_inputs, lengths=lengths_test, labels=None)
            test_loader  = DataLoader(test_set, batch_size=32, shuffle=False)
            result = np.zeros((len(test), len(target_cols)))

            for fold in range(NUM_FOLDS):
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-cased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result += predict_result(model, test_loader, len(test)) 
                if DEBUG:
                    break
                    
            result /= NUM_FOLDS

        with timer('Create submission.csv'):
            submission.loc[:, 'question_asker_intent_understanding':] = result
            submission.to_csv('submission4.csv', index=False)

[Data Loading] start
[Data Loading] done in 0 s
[Num features] start
[Num features] done in 1 s
[Cat features] start
[Cat features] done in 0 s
[Prepare Bert config] start
[Prepare Bert config] done in 0 s
[Inference] start


[(60, 30), (6, 3), (3, 2)]


15it [00:16,  1.09s/it]
15it [00:16,  1.10s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.10s/it]
[Inference] done in 106 s
[Create submission.csv] start
[Create submission.csv] done in 0 s


In [7]:
#===========================================================
# Config
#===========================================================
class PipeLineConfig:
    def __init__(self, lr, warmup, accum_steps, epochs, seed, expname, 
                 head_tail, head, freeze, question_weight, answer_weight, fold, train, cv, test):
        self.lr = lr
        self.warmup = warmup
        self.accum_steps = accum_steps
        self.epochs = epochs
        self.seed = seed
        self.expname = expname
        self.head_tail = head_tail
        self.head = head
        self.freeze = freeze
        self.question_weight = question_weight
        self.answer_weight = answer_weight
        self.fold = fold
        self.train = train
        self.cv = cv
        self.test = test

config = PipeLineConfig(lr=1e-4, warmup=0.1, accum_steps=1, epochs=6,
                        seed=42, expname='uncased_8', head_tail=True, head=0.5, freeze=False,
                        question_weight=0., answer_weight=0., fold=5, train=False, cv=False, test=True)

DEBUG = False
ID = 'qa_id'
target_cols = ['question_asker_intent_understanding', 'question_body_critical', 
               'question_conversational', 'question_expect_short_answer', 
               'question_fact_seeking', 'question_has_commonly_accepted_answer', 
               'question_interestingness_others', 'question_interestingness_self', 
               'question_multi_intent', 'question_not_really_a_question', 
               'question_opinion_seeking', 'question_type_choice',
               'question_type_compare', 'question_type_consequence',
               'question_type_definition', 'question_type_entity', 
               'question_type_instructions', 'question_type_procedure', 
               'question_type_reason_explanation', 'question_type_spelling', 
               'question_well_written', 'answer_helpful',
               'answer_level_of_information', 'answer_plausible', 
               'answer_relevance', 'answer_satisfaction', 
               'answer_type_instructions', 'answer_type_procedure', 
               'answer_type_reason_explanation', 'answer_well_written']
NUM_FOLDS = config.fold
ROOT = '../input/google-quest-challenge/'
#ROOT = '../input/'
SEED = config.seed
seed_everything(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_DIR = '../input/googlequestchallenge-weights5/'
#MODEL_DIR = './'
COMBINE_INPUT = False
T_MAX_LEN = 30
Q_MAX_LEN = 479 # 382
A_MAX_LEN = 479 # 254 
MAX_SEQUENCE_LENGTH = T_MAX_LEN + Q_MAX_LEN + A_MAX_LEN + 4
q_max_sequence_length = T_MAX_LEN + Q_MAX_LEN + 3
a_max_sequence_length = T_MAX_LEN + A_MAX_LEN + 3

#===========================================================
# Model
#===========================================================
def _get_masks(tokens, max_seq_length):
    """Mask for padding"""
    if len(tokens)>max_seq_length:
        print(f'len(tokens): {len(tokens)}')
        print(f'max_seq_length: {max_seq_length}')
        raise IndexError("Token length more than max seq length!")
    return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))


def _get_segments(tokens, max_seq_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    
    if len(tokens) > max_seq_length:
        raise IndexError("Token length more than max seq length!")
        
    segments = []
    first_sep = True
    current_segment_id = 0
    
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            if first_sep:
                first_sep = False 
            else:
                current_segment_id = 1
    return segments + [0] * (max_seq_length - len(tokens))


def _get_ids(tokens, tokenizer, max_seq_length):
    """Token ids from Tokenizer vocab"""
    
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
    return input_ids


def _trim_input(tokenizer, title, question, answer, max_sequence_length, t_max_len, q_max_len, a_max_len):
    
    # 350+128+30 = 508 +4 = 512
    
    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)
    a = tokenizer.tokenize(answer)
    
    t_len = len(t)
    q_len = len(q)
    a_len = len(a)

    if (t_len+q_len+a_len+4) > max_sequence_length:
        
        if t_max_len > t_len:
            t_new_len = t_len
            a_max_len = a_max_len + floor((t_max_len - t_len)/2)
            q_max_len = q_max_len + ceil((t_max_len - t_len)/2)
        else:
            t_new_len = t_max_len
      
        if a_max_len > a_len:
            a_new_len = a_len 
            q_new_len = q_max_len + (a_max_len - a_len)
        elif q_max_len > q_len:
            a_new_len = a_max_len + (q_max_len - q_len)
            q_new_len = q_len
        else:
            a_new_len = a_max_len
            q_new_len = q_max_len
            
            
        if t_new_len+a_new_len+q_new_len+4 != max_sequence_length:
            raise ValueError("New sequence length should be %d, but is %d"%(max_sequence_length, (t_new_len + a_new_len + q_new_len + 4)))
        # Head+Tail method 
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)  
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            a = a[:a_len_head]+a[a_len_tail:]
            #t = t[:t_len_head]+t[t_len_tail:]
            t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            a = a[:a_new_len]
            t = t[:t_new_len]
    
    return t, q, a


def q_trim_input(tokenizer, title, question, q_max_sequence_length, t_max_len, q_max_len):

    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)

    t_len = len(t)
    q_len = len(q)

    if (t_len+q_len+3) > q_max_sequence_length:

        if t_max_len > t_len:
            t_new_len = t_len
            q_max_len = q_max_len + (t_max_len - t_len)
        else:
            t_new_len = t_max_len

        if q_max_len > q_len:
            q_new_len = q_len
            t_new_len = t_max_len + (q_max_len - q_len)
        else:
            q_new_len = q_max_len

        # Head+Tail method
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            t = t[:t_len_head]+t[t_len_tail:]
            #t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            t = t[:t_new_len]

    return t, q

"""
def a_trim_input(tokenizer, answer, a_max_sequence_length, a_max_len):

    a = tokenizer.tokenize(answer)

    a_len = len(a)

    if (a_len+2) > a_max_sequence_length:

        a_new_len = a_max_len

        # Head+Tail method
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        if config.head_tail :
            a = a[:a_len_head]+a[a_len_tail:]
        else:
            # No Head+Tail , usual processing
            a = a[:a_new_len]

    return a
"""

def a_trim_input(tokenizer, title, answer, a_max_sequence_length, t_max_len, a_max_len):

    t = tokenizer.tokenize(title)
    a = tokenizer.tokenize(answer)

    t_len = len(t)
    a_len = len(a)

    if (t_len+a_len+3) > a_max_sequence_length:

        if t_max_len > t_len:
            t_new_len = t_len
            a_max_len = a_max_len + (t_max_len - t_len)
        else:
            t_new_len = t_max_len

        if a_max_len > a_len:
            a_new_len = a_len
            t_new_len = t_max_len + (a_max_len - a_len)
        else:
            a_new_len = a_max_len

        # Head+Tail method
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)
        #t = t[:t_new_len]
        if config.head_tail :
            a = a[:a_len_head]+a[a_len_tail:]
            t = t[:t_len_head]+t[t_len_tail:]
            #t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            a = a[:a_new_len]
            t = t[:t_new_len]

    return t, a


def _convert_to_bert_inputs(title_q, title_a, question, answer, tokenizer, max_sequence_length):
    """Converts tokenized input to ids, masks and segments for BERT"""
    if COMBINE_INPUT:
        stoken = ["[CLS]"] + title + ["[QBODY]"] + question + ["[ANS]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title + ["[SEP]"] + question + ["[SEP]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title  + question  + answer + ["[SEP]"]
    
        input_ids = _get_ids(stoken, tokenizer, max_sequence_length)
        input_masks = _get_masks(stoken, max_sequence_length)
        input_segments = _get_segments(stoken, max_sequence_length)

        return [input_ids, input_masks, input_segments]
    else:
        q_token = ["[CLS]"] + title_q + ["[SEP]"] + question + ["[SEP"]
        q_input_ids = _get_ids(q_token, tokenizer, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_masks = _get_masks(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_segments = _get_segments(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        
        #a_token = ["[CLS]"] + answer + ["[SEP]"]
        #a_input_ids = _get_ids(a_token, tokenizer, A_MAX_LEN+2)
        #a_input_masks = _get_masks(a_token, A_MAX_LEN+2)
        #a_input_segments = _get_segments(a_token, A_MAX_LEN+2)
        a_token = ["[CLS]"] + title_a + ["[SEP]"] + answer + ["[SEP"]
        a_input_ids = _get_ids(a_token, tokenizer, T_MAX_LEN+A_MAX_LEN+3)
        a_input_masks = _get_masks(a_token, T_MAX_LEN+A_MAX_LEN+3)
        a_input_segments = _get_segments(a_token, T_MAX_LEN+A_MAX_LEN+3)
        
        return [q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments]


def compute_input_arays(df, columns, tokenizer, max_sequence_length, num_features, cat_features, 
                        t_max_len=T_MAX_LEN, q_max_len=Q_MAX_LEN, a_max_len=A_MAX_LEN):
    if COMBINE_INPUT:
        input_ids, input_masks, input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t, q, a = _trim_input(tokenizer, t, q, a, max_sequence_length, t_max_len, q_max_len, a_max_len)
            ids, masks, segments = _convert_to_bert_inputs(t, q, a, tokenizer, max_sequence_length)
            input_ids.append(ids)
            input_masks.append(masks)
            input_segments.append(segments)
        return [
                torch.from_numpy(np.asarray(input_ids, dtype=np.int32)).long(), 
                torch.from_numpy(np.asarray(input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]
    else:
        q_input_ids, q_input_masks, q_input_segments = [], [], []
        a_input_ids, a_input_masks, a_input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t_q, q = q_trim_input(tokenizer, t, q, q_max_sequence_length, t_max_len, q_max_len)
            #a = a_trim_input(tokenizer, a, a_max_sequence_length, a_max_len)
            t_a, a = a_trim_input(tokenizer, t, a, a_max_sequence_length, t_max_len, a_max_len)
            q_ids, q_masks, q_segments, a_ids, a_masks, a_segments = _convert_to_bert_inputs(t_q, t_a, q, a, tokenizer, max_sequence_length)
            q_input_ids.append(q_ids)
            q_input_masks.append(q_masks)
            q_input_segments.append(q_segments)
            a_input_ids.append(a_ids)
            a_input_masks.append(a_masks)
            a_input_segments.append(a_segments)
        return [
                torch.from_numpy(np.asarray(q_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]


def compute_output_arrays(df, columns):
    return np.asarray(df[columns])


if COMBINE_INPUT:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            input_ids       = self.inputs[0][idx]
            input_masks     = self.inputs[1][idx]
            input_segments  = self.inputs[2][idx]
            num_features    = self.inputs[3][idx]
            cat_features    = self.inputs[4][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return input_ids, input_masks, input_segments, num_features, cat_features, labels, lengths
            return input_ids, input_masks, input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.2)
            n_emb_out = sum([y for x, y in cat_dims])
            self.dropout = nn.Dropout(0.2)
            self.classifier_final = nn.Linear(config.hidden_size+n_emb_out+4, self.config.num_labels)  # num_features=4

            self.init_weights()

        def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            pooled_output = outputs[1]
            pooled_output = self.dropout(pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            pooled_output = torch.cat([pooled_output, num_features, emb], 1)
            logits = self.classifier_final(pooled_output)

            outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)

else:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            q_input_ids       = self.inputs[0][idx]
            q_input_masks     = self.inputs[1][idx]
            q_input_segments  = self.inputs[2][idx]
            a_input_ids       = self.inputs[3][idx]
            a_input_masks     = self.inputs[4][idx]
            a_input_segments  = self.inputs[5][idx]
            num_features    = self.inputs[6][idx]
            cat_features    = self.inputs[7][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, lengths
            return q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.1)
            n_emb_out = sum([y for x, y in cat_dims])
            self.num_drop = nn.Dropout(0.1)
            self.q_dropout = nn.Dropout(0.1)
            self.a_dropout = nn.Dropout(0.1)
            #self.dropout_all = nn.Dropout(0.2)
            #self.dropout_a = nn.Dropout(0.2)
            #self.dropout_q = nn.Dropout(0.2)
            #self.classifier_all = nn.Linear(config.hidden_size*2+n_emb_out+4, 64)  # num_features=4
            #self.classifier_all = nn.Sequential(
            #    nn.Linear(config.hidden_size*2+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            #self.classifier_a = nn.Linear(config.hidden_size+n_emb_out+4, 64)  # num_features=4
            #self.classifier_a = nn.Sequential(
            #    nn.Linear(config.hidden_size+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            #self.classifier_q = nn.Linear(config.hidden_size+n_emb_out+4, 64)  # num_features=4
            #self.classifier_q = nn.Sequential(
            #    nn.Linear(config.hidden_size+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            self.classifier_final = nn.Linear(config.hidden_size*2+n_emb_out+4, self.config.num_labels)
            #self.classifier_final = nn.Linear(64*3, self.config.num_labels)  # num_features=4
            #self.classifier_final = nn.Sequential(
            #    nn.BatchNorm1d(64*3),
            #    nn.Linear(64*3, self.config.num_labels),
            #)
            self.init_weights()

        def forward(
            self,
            q_input_ids=None,
            q_attention_mask=None,
            q_token_type_ids=None,
            a_input_ids=None,
            a_attention_mask=None,
            a_token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            q_outputs = self.bert(
                q_input_ids,
                attention_mask=q_attention_mask,
                token_type_ids=q_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            q_pooled_output = q_outputs[1]
            q_pooled_output = self.q_dropout(q_pooled_output)

            a_outputs = self.bert(
                a_input_ids,
                attention_mask=a_attention_mask,
                token_type_ids=a_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            a_pooled_output = a_outputs[1]
            a_pooled_output = self.a_dropout(a_pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            num_features = self.num_drop(num_features)

            pooled_output = torch.cat([q_pooled_output, a_pooled_output, num_features, emb], 1)
            #all_logits = self.classifier_all(pooled_output)
            #all_logits = self.dropout_all(all_logits)
            logits = self.classifier_final(pooled_output)
            
            #a_pooled_output = torch.cat([a_pooled_output, num_features, emb], 1)
            #a_logits = self.classifier_a(a_pooled_output)
            #a_logits = self.dropout_a(a_logits)

            #q_pooled_output = torch.cat([q_pooled_output, num_features, emb], 1)
            #q_logits = self.classifier_q(q_pooled_output)
            #q_logits = self.dropout_q(q_logits)

            #concat_logits = torch.cat([all_logits, q_logits, a_logits], 1)
            #logits = self.classifier_final(concat_logits)

            #logits = torch.cat([q_logits, a_logits], 1)

            outputs = (logits,) + q_outputs[2:] + a_outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)


def train_model(model, train_loader, optimizer, criterion, scheduler, config):
    
    model.train()
    avg_loss = 0.
    avg_loss_1 = 0.
    avg_loss_2 = 0.
    avg_loss_3 = 0.
    avg_loss_4 = 0.
    avg_loss_5 = 0.
    #tk0 = tqdm(enumerate(train_loader),total =len(train_loader))
    optimizer.zero_grad()
    for idx, batch in enumerate(train_loader):
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
        
            output_train = model(input_ids = input_ids.long(),
                             labels = None,
                             attention_mask = input_masks,
                             token_type_ids = input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

            output_train = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        logits = output_train[0] #output preds
        loss = criterion(logits, labels)
        loss.backward()
        if (idx + 1) % config.accum_steps == 0:    
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        avg_loss += loss.item() / (len(train_loader)*config.accum_steps)
        if COMBINE_INPUT:
            del input_ids, input_masks, input_segments, labels
        else:
            del q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, labels

    torch.cuda.empty_cache()
    gc.collect()
    return avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5


def val_model(model, criterion, val_loader, val_shape, batch_size=8):

    avg_val_loss = 0.
    model.eval() # eval mode
    
    valid_preds = np.zeros((val_shape, len(target_cols)))
    original = np.zeros((val_shape, len(target_cols)))
    
    #tk0 = tqdm(enumerate(val_loader))
    with torch.no_grad():
        
        for idx, batch in enumerate(val_loader):
            if COMBINE_INPUT:
                input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
                input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            
                output_val = model(input_ids = input_ids.long(),
                               labels = None,
                               attention_mask = input_masks,
                               token_type_ids = input_segments,
                               num_features = num_features,
                               cat_features = cat_features,
                              )
            else:
                q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
                q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

                output_val = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
            logits = output_val[0] #output preds
            
            avg_val_loss += criterion(logits, labels).item() / len(val_loader)
            valid_preds[idx*batch_size : (idx+1)*batch_size] = logits.detach().cpu().squeeze().numpy()
            original[idx*batch_size : (idx+1)*batch_size]    = labels.detach().cpu().squeeze().numpy()
        
        score = 0
        preds = torch.sigmoid(torch.tensor(valid_preds)).numpy()
        
        # np.save("preds.npy", preds)
        # np.save("actuals.npy", original)
        
        rho_val = np.mean([spearmanr(original[:, i], preds[:,i]).correlation for i in range(preds.shape[1])])
        print('\r val_spearman-rho: %s' % (str(round(rho_val, 5))), end = 100*' '+'\n')
        
        for i in range(len(target_cols)):
            logger.info(f"{i}, {spearmanr(original[:,i], preds[:,i])}")
            score += np.nan_to_num(spearmanr(original[:, i], preds[:, i]).correlation)
        
    return avg_val_loss, score/len(target_cols)


def predict_valid_result(model, val_loader, val_length, batch_size=32):

    val_preds = np.zeros((val_length, len(target_cols)))
    original = np.zeros((val_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(val_loader))
    for idx, batch in tk0:
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            with torch.no_grad():
                outputs = model(input_ids = input_ids.long(),
                            labels = None,
                            attention_mask = input_masks,
                            token_type_ids = input_segments,
                            num_features = num_features,
                            cat_features = cat_features,
                            )
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)
            with torch.no_grad():
                outputs = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )

        predictions = outputs[0]
        val_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()
        original[idx*batch_size : (idx+1)*batch_size] = labels.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(val_preds)).numpy()
    return output, original


def predict_result(model, test_loader, test_length, batch_size=32):

    test_preds = np.zeros((test_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(test_loader))
    for idx, x_batch in tk0:
        if COMBINE_INPUT:
            with torch.no_grad():
                outputs = model(input_ids = x_batch[0].to(device),
                            labels = None,
                            attention_mask = x_batch[1].to(device),
                            token_type_ids = x_batch[2].to(device),
                            num_features = x_batch[3].to(device),
                            cat_features = x_batch[4].to(device),
                           )
        else:
            with torch.no_grad():
                outputs = model(q_input_ids = x_batch[0].to(device),
                            labels = None,
                            q_attention_mask = x_batch[1].to(device),
                            q_token_type_ids = x_batch[2].to(device),
                            a_input_ids = x_batch[3].to(device),
                            a_attention_mask = x_batch[4].to(device),
                            a_token_type_ids = x_batch[5].to(device),
                            num_features = x_batch[6].to(device),
                            cat_features = x_batch[7].to(device),
                           )
        predictions = outputs[0]
        test_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(test_preds)).numpy()
    return output


def add_features(df):
    find = re.compile(r"^[^.]*")
    df['netloc'] = df['url'].apply(lambda x: re.findall(find, urlparse(x).netloc)[0])
    df['qa_same_user_page_flag'] = (df['question_user_page']==df['answer_user_page'])*1
    df['question_title_num_words'] = df['question_title'].str.count('\S+')
    df['question_body_num_words'] = df['question_body'].str.count('\S+')
    df['answer_num_words'] = df['answer'].str.count('\S+')
    df['question_vs_answer_length'] = df['question_body_num_words']/df['answer_num_words']
    df['question_title_num_words'] = np.log1p(df['question_title_num_words'])
    df['question_body_num_words'] = np.log1p(df['question_body_num_words'])
    df['answer_num_words'] = np.log1p(df['answer_num_words'])
    df['question_vs_answer_length'] = np.log1p(df['question_vs_answer_length'])
    return df


def custom_loss(logits, labels):
    #q_loss = nn.BCEWithLogitsLoss()(logits[:,:21], labels[:,:21])
    #a_loss = nn.BCEWithLogitsLoss()(logits[:,21:], labels[:,21:])
    #custom_loss = 0.5*q_loss + 0.5*a_loss
    custom_loss = nn.BCEWithLogitsLoss()(logits, labels)
    #loss1 = nn.BCEWithLogitsLoss()(logits[:,0:19], labels[:,0:19])
    #loss2 = nn.BCEWithLogitsLoss()(logits[:,20:], labels[:,20:]) # except index=19
    #custom_loss = loss1 + loss2
    #custom_loss = 0.
    #for i in range(len(loss_sample_weights)):
    #    custom_loss += loss_sample_weights[i] * nn.BCEWithLogitsLoss()(logits[:,i], labels[:,i])
    return custom_loss


#===========================================================
# main
#===========================================================
#def main():
if True:
    
    with timer('Data Loading'):
        train = pd.read_csv(f"{ROOT}train.csv").fillna("none")
        for c in ['question_not_really_a_question', 'question_type_consequence', 'question_type_spelling']:
            train[c] = np.log1p(train[c].values)
        y_train = train[target_cols].values
        if config.test:
            test = pd.read_csv(f"{ROOT}test.csv").fillna("none")
            submission = pd.read_csv(f"{ROOT}sample_submission.csv")
    
    with timer('Num features'):
        train = add_features(train)
        if config.test:
            test = add_features(test)
        num_features = ['question_title_num_words', 'question_body_num_words', 'answer_num_words', 'question_vs_answer_length']
        train_num = train[num_features].values
        if config.test:
            test_num = test[num_features].values
                
    with timer('Cat features'):
        cat_features = ['netloc', 'category', 'qa_same_user_page_flag']
        ce_oe = ce.OrdinalEncoder(cols=cat_features, handle_unknown='return_nan')
        ce_oe.fit(train[cat_features])
        train_cat_df = ce_oe.transform(train[cat_features])
        test_cat_df = ce_oe.transform(test[cat_features]).fillna(0).astype(int)
        train_cat = train_cat_df.values
        test_cat = test_cat_df.values
        cat_dims = []
        for col in cat_features:
            dim = train[col].nunique()
            cat_dims.append((dim+1, dim//2+1)) # for unknown=0
        print(cat_dims)

    if config.train:
        with timer('Create folds'):
            folds = train.copy()

            kf = MultilabelStratifiedKFold(n_splits=NUM_FOLDS, random_state=SEED)
            for fold, (train_index, val_index) in enumerate(kf.split(train.values, y_train)):
                folds.loc[val_index, 'fold'] = int(fold)
            """
            # less gap between CV vs LB with GroupKFold
            # https://www.kaggle.com/ratthachat/quest-cv-analysis-on-different-splitting-methods
            kf = GroupKFold(n_splits=NUM_FOLDS)
            for fold, (train_index, val_index) in enumerate(kf.split(X=train.question_body, groups=train.question_body)):
                folds.loc[val_index, 'fold'] = int(fold)
            """
            folds['fold'] = folds['fold'].astype(int)
            save_cols = [ID] + target_cols + ['fold']
            folds[save_cols].to_csv('folds.csv', index=None)

    with timer('Prepare Bert config'):
        tokenizer = BertTokenizer.from_pretrained("../input/pretrained-bert-models-for-pytorch/bert-base-uncased-vocab.txt", 
                                                  do_lower_case=True)
        input_categories = ['question_title', 'question_body', 'answer']
        bert_model_config = '../input/pretrained-bert-models-for-pytorch/bert-base-uncased/bert_config.json'
        bert_config = BertConfig.from_json_file(bert_model_config)
        bert_config.num_labels = len(target_cols)
        bert_model = 'bert-base-uncased'
        do_lower_case = 'uncased' in bert_model
        output_model_file = 'bert_pytorch.bin'
    
    if config.train:

        BATCH_SIZE = 8
        if DEBUG:
            epochs = 1
        else:
            epochs = config.epochs
        ACCUM_STEPS = config.accum_steps

        with timer('Train Bert'):
            
            for fold in range(NUM_FOLDS):

                logger.info(f"Current Fold: {fold}")
                train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index

                train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                logger.info(f"Train Shapes: {train_df.shape}")
                logger.info(f"Valid Shapes: {val_df.shape}")
            
                logger.info("Preparing train datasets....")
            
                inputs_train = compute_input_arays(train_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[train_index], cat_features=train_cat[train_index])
                outputs_train = compute_output_arrays(train_df, columns=target_cols)
                outputs_train = torch.tensor(outputs_train, dtype=torch.float32)
                lengths_train = np.argmax(inputs_train[0]==0, axis=1)
                lengths_train[lengths_train==0] = inputs_train[0].shape[1]
            
                logger.info("Preparing valid datasets....")
            
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
            
                logger.info("Preparing Dataloaders Datasets....")

                train_set = QuestDataset(inputs=inputs_train, lengths=lengths_train, labels=outputs_train)
                train_sampler = RandomSampler(train_set)
                train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,sampler=train_sampler)
            
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
            
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                torch.cuda.empty_cache()
                if config.freeze : ## This is basically using out of the box bert model while training only the classifier head with our data . 
                    for param in model.bert.parameters():
                        param.requires_grad = False
                model.train()
            
                i = 0
                best_avg_loss = 100.0
                best_score = -1.
                best_param_loss = None
                best_param_score = None
                param_optimizer = list(model.named_parameters())
                no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
                optimizer_grouped_parameters = [
                    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
                    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
                    ]        
                optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr, eps=4e-5)
                #optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, eps=4e-5)
                #criterion = nn.BCEWithLogitsLoss()
                criterion = custom_loss
                scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, num_training_steps=epochs*len(train_loader)//ACCUM_STEPS)
                logger.info("Training....")
            
                for epoch in tqdm(range(epochs)):

                    torch.cuda.empty_cache()
                
                    start_time   = time.time()
                    avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5 = train_model(model, train_loader, optimizer, criterion, scheduler, config)
                    avg_val_loss, score = val_model(model, criterion, valid_loader, val_shape=val_df.shape[0], batch_size=BATCH_SIZE)
                    elapsed_time = time.time() - start_time

                    logger.info('Epoch {}/{} \t loss={:.4f} \t val_loss={:.4f} \t train_loss={:.4f} \t train_loss_1={:.4f} \t train_loss_2={:.4f} \t train_loss_3={:.4f} \t train_loss_4={:.4f}  \t train_loss_5={:.4f} \t score={:.6f} \t time={:.2f}s'.format(
                        epoch+1, epochs, avg_loss, avg_val_loss, avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5, score, elapsed_time))

                    if best_avg_loss > avg_val_loss:
                        i = 0
                        best_avg_loss = avg_val_loss 
                        best_param_loss = model.state_dict()

                    if best_score < score:
                        best_score = score
                        best_param_score = model.state_dict()
                        logger.info('best_param_score_{}_{}.pt'.format(config.expname ,fold))
                        torch.save(best_param_score, 'best_param_score_{}_{}.pt'.format(config.expname, fold))
                    else:
                        i += 1

            del train_df, val_df, model, optimizer, criterion, scheduler
            del valid_loader, train_loader, valid_set, train_set
            torch.cuda.empty_cache()
            gc.collect()
    
    if config.cv:

        with timer('CV'):

            folds = pd.read_csv(f'{MODEL_DIR}folds.csv')
            results = np.zeros((len(train), len(target_cols)))
            logits = np.zeros((len(train), len(target_cols)))

            for fold in range(NUM_FOLDS):
                
                #train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index
                #train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                val_df = train.iloc[val_index]
                
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=32, shuffle=False, drop_last=False)
                
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result, logit = predict_valid_result(model, valid_loader, len(val_df))  
                results[val_index, :] = result
                logits[val_index, :] = logit 
            
            rho_val = np.mean([spearmanr(logits[:,i], results[:,i]).correlation for i in range(results.shape[1])])
            logger.info(f'CV spearman-rho: {round(rho_val, 5)}')

            oof = pd.DataFrame()
            for i, col in enumerate(target_cols):
                oof[col] = results[:,i]
            oof.to_csv(f'oof_{config.expname}.csv', index=False)
    
    if config.test:

        with timer('Inference'):

            test_inputs = compute_input_arays(test, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                              num_features=test_num, cat_features=test_cat)
            lengths_test = np.argmax(test_inputs[0] == 0, axis=1)
            lengths_test[lengths_test == 0] = test_inputs[0].shape[1]
            test_set = QuestDataset(inputs=test_inputs, lengths=lengths_test, labels=None)
            test_loader  = DataLoader(test_set, batch_size=32, shuffle=False)
            result = np.zeros((len(test), len(target_cols)))

            for fold in range(NUM_FOLDS):
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result += predict_result(model, test_loader, len(test)) 
                if DEBUG:
                    break
                    
            result /= NUM_FOLDS

        with timer('Create submission.csv'):
            submission.loc[:, 'question_asker_intent_understanding':] = result
            submission.to_csv('submission5.csv', index=False)

[Data Loading] start
[Data Loading] done in 0 s
[Num features] start
[Num features] done in 1 s
[Cat features] start
[Cat features] done in 0 s
[Prepare Bert config] start
[Prepare Bert config] done in 0 s
[Inference] start


[(60, 30), (6, 3), (3, 2)]


15it [00:16,  1.09s/it]
15it [00:16,  1.10s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.09s/it]
[Inference] done in 108 s
[Create submission.csv] start
[Create submission.csv] done in 0 s


In [8]:
#===========================================================
# Config
#===========================================================
class PipeLineConfig:
    def __init__(self, lr, warmup, accum_steps, epochs, seed, expname, 
                 head_tail, head, freeze, question_weight, answer_weight, fold, train, cv, test):
        self.lr = lr
        self.warmup = warmup
        self.accum_steps = accum_steps
        self.epochs = epochs
        self.seed = seed
        self.expname = expname
        self.head_tail = head_tail
        self.head = head
        self.freeze = freeze
        self.question_weight = question_weight
        self.answer_weight = answer_weight
        self.fold = fold
        self.train = train
        self.cv = cv
        self.test = test

config = PipeLineConfig(lr=1e-4, warmup=0.1, accum_steps=1, epochs=6,
                        seed=42, expname='uncased_8', head_tail=True, head=0.5, freeze=False,
                        question_weight=0., answer_weight=0., fold=5, train=False, cv=False, test=True)

DEBUG = False
ID = 'qa_id'
target_cols = ['question_asker_intent_understanding', 'question_body_critical', 
               'question_conversational', 'question_expect_short_answer', 
               'question_fact_seeking', 'question_has_commonly_accepted_answer', 
               'question_interestingness_others', 'question_interestingness_self', 
               'question_multi_intent', 'question_not_really_a_question', 
               'question_opinion_seeking', 'question_type_choice',
               'question_type_compare', 'question_type_consequence',
               'question_type_definition', 'question_type_entity', 
               'question_type_instructions', 'question_type_procedure', 
               'question_type_reason_explanation', 'question_type_spelling', 
               'question_well_written', 'answer_helpful',
               'answer_level_of_information', 'answer_plausible', 
               'answer_relevance', 'answer_satisfaction', 
               'answer_type_instructions', 'answer_type_procedure', 
               'answer_type_reason_explanation', 'answer_well_written']
NUM_FOLDS = config.fold
ROOT = '../input/google-quest-challenge/'
#ROOT = '../input/'
SEED = config.seed
seed_everything(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_DIR = '../input/googlequestchallenge-weights7/'
#MODEL_DIR = './'
COMBINE_INPUT = False
T_MAX_LEN = 30
Q_MAX_LEN = 479 # 382
A_MAX_LEN = 479 # 254 
MAX_SEQUENCE_LENGTH = T_MAX_LEN + Q_MAX_LEN + A_MAX_LEN + 4
q_max_sequence_length = T_MAX_LEN + Q_MAX_LEN + 3
a_max_sequence_length = T_MAX_LEN + A_MAX_LEN + 3

#===========================================================
# Model
#===========================================================
def _get_masks(tokens, max_seq_length):
    """Mask for padding"""
    if len(tokens)>max_seq_length:
        print(f'len(tokens): {len(tokens)}')
        print(f'max_seq_length: {max_seq_length}')
        raise IndexError("Token length more than max seq length!")
    return [1]*len(tokens) + [0] * (max_seq_length - len(tokens))


def _get_segments(tokens, max_seq_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    
    if len(tokens) > max_seq_length:
        raise IndexError("Token length more than max seq length!")
        
    segments = []
    first_sep = True
    current_segment_id = 0
    
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            if first_sep:
                first_sep = False 
            else:
                current_segment_id = 1
    return segments + [0] * (max_seq_length - len(tokens))


def _get_ids(tokens, tokenizer, max_seq_length):
    """Token ids from Tokenizer vocab"""
    
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
    return input_ids


def _trim_input(tokenizer, title, question, answer, max_sequence_length, t_max_len, q_max_len, a_max_len):
    
    # 350+128+30 = 508 +4 = 512
    
    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)
    a = tokenizer.tokenize(answer)
    
    t_len = len(t)
    q_len = len(q)
    a_len = len(a)

    if (t_len+q_len+a_len+4) > max_sequence_length:
        
        if t_max_len > t_len:
            t_new_len = t_len
            a_max_len = a_max_len + floor((t_max_len - t_len)/2)
            q_max_len = q_max_len + ceil((t_max_len - t_len)/2)
        else:
            t_new_len = t_max_len
      
        if a_max_len > a_len:
            a_new_len = a_len 
            q_new_len = q_max_len + (a_max_len - a_len)
        elif q_max_len > q_len:
            a_new_len = a_max_len + (q_max_len - q_len)
            q_new_len = q_len
        else:
            a_new_len = a_max_len
            q_new_len = q_max_len
            
            
        if t_new_len+a_new_len+q_new_len+4 != max_sequence_length:
            raise ValueError("New sequence length should be %d, but is %d"%(max_sequence_length, (t_new_len + a_new_len + q_new_len + 4)))
        # Head+Tail method 
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)  
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            a = a[:a_len_head]+a[a_len_tail:]
            #t = t[:t_len_head]+t[t_len_tail:]
            t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            a = a[:a_new_len]
            t = t[:t_new_len]
    
    return t, q, a


def q_trim_input(tokenizer, title, question, q_max_sequence_length, t_max_len, q_max_len):

    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)

    t_len = len(t)
    q_len = len(q)

    if (t_len+q_len+3) > q_max_sequence_length:

        if t_max_len > t_len:
            t_new_len = t_len
            q_max_len = q_max_len + (t_max_len - t_len)
        else:
            t_new_len = t_max_len

        if q_max_len > q_len:
            q_new_len = q_len
            t_new_len = t_max_len + (q_max_len - q_len)
        else:
            q_new_len = q_max_len

        # Head+Tail method
        q_len_head = round(q_new_len * config.head)
        q_len_tail = -1 * (q_new_len - q_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)
        #t = t[:t_new_len]
        if config.head_tail :
            q = q[:q_len_head]+q[q_len_tail:]
            t = t[:t_len_head]+t[t_len_tail:]
            #t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            q = q[:q_new_len]
            t = t[:t_new_len]

    return t, q

"""
def a_trim_input(tokenizer, answer, a_max_sequence_length, a_max_len):

    a = tokenizer.tokenize(answer)

    a_len = len(a)

    if (a_len+2) > a_max_sequence_length:

        a_new_len = a_max_len

        # Head+Tail method
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        if config.head_tail :
            a = a[:a_len_head]+a[a_len_tail:]
        else:
            # No Head+Tail , usual processing
            a = a[:a_new_len]

    return a
"""

def a_trim_input(tokenizer, title, answer, a_max_sequence_length, t_max_len, a_max_len):

    t = tokenizer.tokenize(title)
    a = tokenizer.tokenize(answer)

    t_len = len(t)
    a_len = len(a)

    if (t_len+a_len+3) > a_max_sequence_length:

        if t_max_len > t_len:
            t_new_len = t_len
            a_max_len = a_max_len + (t_max_len - t_len)
        else:
            t_new_len = t_max_len

        if a_max_len > a_len:
            a_new_len = a_len
            t_new_len = t_max_len + (a_max_len - a_len)
        else:
            a_new_len = a_max_len

        # Head+Tail method
        a_len_head = round(a_new_len * config.head)
        a_len_tail = -1 * (a_new_len - a_len_head)
        t_len_head = round(t_new_len * config.head)
        t_len_tail = -1 * (t_new_len - t_len_head)
        #t = t[:t_new_len]
        if config.head_tail :
            a = a[:a_len_head]+a[a_len_tail:]
            t = t[:t_len_head]+t[t_len_tail:]
            #t = t[:t_new_len]
        else:
            # No Head+Tail , usual processing
            a = a[:a_new_len]
            t = t[:t_new_len]

    return t, a


def _convert_to_bert_inputs(title_q, title_a, question, answer, tokenizer, max_sequence_length):
    """Converts tokenized input to ids, masks and segments for BERT"""
    if COMBINE_INPUT:
        stoken = ["[CLS]"] + title + ["[QBODY]"] + question + ["[ANS]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title + ["[SEP]"] + question + ["[SEP]"] + answer + ["[SEP]"]
        #stoken = ["[CLS]"] + title  + question  + answer + ["[SEP]"]
    
        input_ids = _get_ids(stoken, tokenizer, max_sequence_length)
        input_masks = _get_masks(stoken, max_sequence_length)
        input_segments = _get_segments(stoken, max_sequence_length)

        return [input_ids, input_masks, input_segments]
    else:
        q_token = ["[CLS]"] + title_q + ["[SEP]"] + question + ["[SEP"]
        q_input_ids = _get_ids(q_token, tokenizer, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_masks = _get_masks(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        q_input_segments = _get_segments(q_token, T_MAX_LEN+Q_MAX_LEN+3)
        
        #a_token = ["[CLS]"] + answer + ["[SEP]"]
        #a_input_ids = _get_ids(a_token, tokenizer, A_MAX_LEN+2)
        #a_input_masks = _get_masks(a_token, A_MAX_LEN+2)
        #a_input_segments = _get_segments(a_token, A_MAX_LEN+2)
        a_token = ["[CLS]"] + title_a + ["[SEP]"] + answer + ["[SEP"]
        a_input_ids = _get_ids(a_token, tokenizer, T_MAX_LEN+A_MAX_LEN+3)
        a_input_masks = _get_masks(a_token, T_MAX_LEN+A_MAX_LEN+3)
        a_input_segments = _get_segments(a_token, T_MAX_LEN+A_MAX_LEN+3)
        
        return [q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments]


def compute_input_arays(df, columns, tokenizer, max_sequence_length, num_features, cat_features, 
                        t_max_len=T_MAX_LEN, q_max_len=Q_MAX_LEN, a_max_len=A_MAX_LEN):
    if COMBINE_INPUT:
        input_ids, input_masks, input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t, q, a = _trim_input(tokenizer, t, q, a, max_sequence_length, t_max_len, q_max_len, a_max_len)
            ids, masks, segments = _convert_to_bert_inputs(t, q, a, tokenizer, max_sequence_length)
            input_ids.append(ids)
            input_masks.append(masks)
            input_segments.append(segments)
        return [
                torch.from_numpy(np.asarray(input_ids, dtype=np.int32)).long(), 
                torch.from_numpy(np.asarray(input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]
    else:
        q_input_ids, q_input_masks, q_input_segments = [], [], []
        a_input_ids, a_input_masks, a_input_segments = [], [], []
        for _, instance in df[columns].iterrows():
            t, q, a = instance.question_title, instance.question_body, instance.answer
            t_q, q = q_trim_input(tokenizer, t, q, q_max_sequence_length, t_max_len, q_max_len)
            #a = a_trim_input(tokenizer, a, a_max_sequence_length, a_max_len)
            t_a, a = a_trim_input(tokenizer, t, a, a_max_sequence_length, t_max_len, a_max_len)
            q_ids, q_masks, q_segments, a_ids, a_masks, a_segments = _convert_to_bert_inputs(t_q, t_a, q, a, tokenizer, max_sequence_length)
            q_input_ids.append(q_ids)
            q_input_masks.append(q_masks)
            q_input_segments.append(q_segments)
            a_input_ids.append(a_ids)
            a_input_masks.append(a_masks)
            a_input_segments.append(a_segments)
        return [
                torch.from_numpy(np.asarray(q_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(q_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_ids, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_masks, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(a_input_segments, dtype=np.int32)).long(),
                torch.from_numpy(np.asarray(num_features, dtype=np.float32)).float(),
                torch.from_numpy(np.asarray(cat_features, dtype=np.int32)).long(),
                ]


def compute_output_arrays(df, columns):
    return np.asarray(df[columns])


if COMBINE_INPUT:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            input_ids       = self.inputs[0][idx]
            input_masks     = self.inputs[1][idx]
            input_segments  = self.inputs[2][idx]
            num_features    = self.inputs[3][idx]
            cat_features    = self.inputs[4][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return input_ids, input_masks, input_segments, num_features, cat_features, labels, lengths
            return input_ids, input_masks, input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.2)
            n_emb_out = sum([y for x, y in cat_dims])
            self.dropout = nn.Dropout(0.2)
            self.classifier_final = nn.Linear(config.hidden_size+n_emb_out+4, self.config.num_labels)  # num_features=4

            self.init_weights()

        def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            pooled_output = outputs[1]
            pooled_output = self.dropout(pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            pooled_output = torch.cat([pooled_output, num_features, emb], 1)
            logits = self.classifier_final(pooled_output)

            outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)

else:

    class QuestDataset(torch.utils.data.Dataset):
        def __init__(self, inputs, lengths, labels = None):

            self.inputs = inputs
            if labels is not None:
                self.labels = labels
            else:
                self.labels = None
            self.lengths = lengths

        def __getitem__(self, idx):

            q_input_ids       = self.inputs[0][idx]
            q_input_masks     = self.inputs[1][idx]
            q_input_segments  = self.inputs[2][idx]
            a_input_ids       = self.inputs[3][idx]
            a_input_masks     = self.inputs[4][idx]
            a_input_segments  = self.inputs[5][idx]
            num_features    = self.inputs[6][idx]
            cat_features    = self.inputs[7][idx]
            lengths         = self.lengths[idx]
            if self.labels is not None: # targets
                labels = self.labels[idx]
                return q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, lengths
            return q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, lengths

        def __len__(self):
            return len(self.inputs[0])


    class CustomBert(BertPreTrainedModel):

        def __init__(self, config, cat_dims):
            super(CustomBert, self).__init__(config)
            self.num_labels = config.num_labels
            self.bert = BertModel(config)
            self.embeddings = nn.ModuleList([
                nn.Embedding(x, y) for x, y in cat_dims
            ])
            self.emb_drop = nn.Dropout(0.4)
            n_emb_out = sum([y for x, y in cat_dims])
            self.num_drop = nn.Dropout(0.4)
            self.q_dropout = nn.Dropout(0.4)
            self.a_dropout = nn.Dropout(0.4)
            #self.dropout_all = nn.Dropout(0.2)
            #self.dropout_a = nn.Dropout(0.2)
            #self.dropout_q = nn.Dropout(0.2)
            #self.classifier_all = nn.Linear(config.hidden_size*2+n_emb_out+4, 64)  # num_features=4
            #self.classifier_all = nn.Sequential(
            #    nn.Linear(config.hidden_size*2+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            #self.classifier_a = nn.Linear(config.hidden_size+n_emb_out+4, 64)  # num_features=4
            #self.classifier_a = nn.Sequential(
            #    nn.Linear(config.hidden_size+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            #self.classifier_q = nn.Linear(config.hidden_size+n_emb_out+4, 64)  # num_features=4
            #self.classifier_q = nn.Sequential(
            #    nn.Linear(config.hidden_size+n_emb_out+4, 64),
            #    nn.ReLU(inplace=True),
            #)
            self.classifier_final = nn.Linear(config.hidden_size*2+n_emb_out+4, self.config.num_labels)
            #self.classifier_final = nn.Linear(64*3, self.config.num_labels)  # num_features=4
            #self.classifier_final = nn.Sequential(
            #    nn.BatchNorm1d(64*3),
            #    nn.Linear(64*3, self.config.num_labels),
            #)
            self.init_weights()

        def forward(
            self,
            q_input_ids=None,
            q_attention_mask=None,
            q_token_type_ids=None,
            a_input_ids=None,
            a_attention_mask=None,
            a_token_type_ids=None,
            num_features=None,
            cat_features=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
        ):

            q_outputs = self.bert(
                q_input_ids,
                attention_mask=q_attention_mask,
                token_type_ids=q_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            q_pooled_output = q_outputs[1]
            q_pooled_output = self.q_dropout(q_pooled_output)

            a_outputs = self.bert(
                a_input_ids,
                attention_mask=a_attention_mask,
                token_type_ids=a_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
            )

            a_pooled_output = a_outputs[1]
            a_pooled_output = self.a_dropout(a_pooled_output)

            emb = [
                emb_layer(cat_features[:, j]) for j, emb_layer in enumerate(self.embeddings)
            ]
            emb = self.emb_drop(torch.cat(emb, 1))

            num_features = self.num_drop(num_features)

            pooled_output = torch.cat([q_pooled_output, a_pooled_output, num_features, emb], 1)
            #all_logits = self.classifier_all(pooled_output)
            #all_logits = self.dropout_all(all_logits)
            logits = self.classifier_final(pooled_output)
            
            #a_pooled_output = torch.cat([a_pooled_output, num_features, emb], 1)
            #a_logits = self.classifier_a(a_pooled_output)
            #a_logits = self.dropout_a(a_logits)

            #q_pooled_output = torch.cat([q_pooled_output, num_features, emb], 1)
            #q_logits = self.classifier_q(q_pooled_output)
            #q_logits = self.dropout_q(q_logits)

            #concat_logits = torch.cat([all_logits, q_logits, a_logits], 1)
            #logits = self.classifier_final(concat_logits)

            #logits = torch.cat([q_logits, a_logits], 1)

            outputs = (logits,) + q_outputs[2:] + a_outputs[2:]  # add hidden states and attention if they are here
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                outputs = (loss,) + outputs

            return outputs  # (loss), logits, (hidden_states), (attentions)


def train_model(model, train_loader, optimizer, criterion, scheduler, config):
    
    model.train()
    avg_loss = 0.
    avg_loss_1 = 0.
    avg_loss_2 = 0.
    avg_loss_3 = 0.
    avg_loss_4 = 0.
    avg_loss_5 = 0.
    #tk0 = tqdm(enumerate(train_loader),total =len(train_loader))
    optimizer.zero_grad()
    for idx, batch in enumerate(train_loader):
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
        
            output_train = model(input_ids = input_ids.long(),
                             labels = None,
                             attention_mask = input_masks,
                             token_type_ids = input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

            output_train = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
        logits = output_train[0] #output preds
        loss = criterion(logits, labels)
        loss.backward()
        if (idx + 1) % config.accum_steps == 0:    
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        avg_loss += loss.item() / (len(train_loader)*config.accum_steps)
        if COMBINE_INPUT:
            del input_ids, input_masks, input_segments, labels
        else:
            del q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, labels

    torch.cuda.empty_cache()
    gc.collect()
    return avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5


def val_model(model, criterion, val_loader, val_shape, batch_size=8):

    avg_val_loss = 0.
    model.eval() # eval mode
    
    valid_preds = np.zeros((val_shape, len(target_cols)))
    original = np.zeros((val_shape, len(target_cols)))
    
    #tk0 = tqdm(enumerate(val_loader))
    with torch.no_grad():
        
        for idx, batch in enumerate(val_loader):
            if COMBINE_INPUT:
                input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
                input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            
                output_val = model(input_ids = input_ids.long(),
                               labels = None,
                               attention_mask = input_masks,
                               token_type_ids = input_segments,
                               num_features = num_features,
                               cat_features = cat_features,
                              )
            else:
                q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
                q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)

                output_val = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )
            logits = output_val[0] #output preds
            
            avg_val_loss += criterion(logits, labels).item() / len(val_loader)
            valid_preds[idx*batch_size : (idx+1)*batch_size] = logits.detach().cpu().squeeze().numpy()
            original[idx*batch_size : (idx+1)*batch_size]    = labels.detach().cpu().squeeze().numpy()
        
        score = 0
        preds = torch.sigmoid(torch.tensor(valid_preds)).numpy()
        
        # np.save("preds.npy", preds)
        # np.save("actuals.npy", original)
        
        rho_val = np.mean([spearmanr(original[:, i], preds[:,i]).correlation for i in range(preds.shape[1])])
        print('\r val_spearman-rho: %s' % (str(round(rho_val, 5))), end = 100*' '+'\n')
        
        for i in range(len(target_cols)):
            logger.info(f"{i}, {spearmanr(original[:,i], preds[:,i])}")
            score += np.nan_to_num(spearmanr(original[:, i], preds[:, i]).correlation)
        
    return avg_val_loss, score/len(target_cols)


def predict_valid_result(model, val_loader, val_length, batch_size=32):

    val_preds = np.zeros((val_length, len(target_cols)))
    original = np.zeros((val_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(val_loader))
    for idx, batch in tk0:
        if COMBINE_INPUT:
            input_ids, input_masks, input_segments, num_features, cat_features, labels, _ = batch
            input_ids, input_masks, input_segments, num_features, cat_features, labels = input_ids.to(device), input_masks.to(device), input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)            
            with torch.no_grad():
                outputs = model(input_ids = input_ids.long(),
                            labels = None,
                            attention_mask = input_masks,
                            token_type_ids = input_segments,
                            num_features = num_features,
                            cat_features = cat_features,
                            )
        else:
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels, _ = batch
            q_input_ids, q_input_masks, q_input_segments, a_input_ids, a_input_masks, a_input_segments, num_features, cat_features, labels = q_input_ids.to(device), q_input_masks.to(device), q_input_segments.to(device), a_input_ids.to(device), a_input_masks.to(device), a_input_segments.to(device), num_features.to(device), cat_features.to(device), labels.to(device)
            with torch.no_grad():
                outputs = model(q_input_ids = q_input_ids.long(),
                             labels = None,
                             q_attention_mask = q_input_masks,
                             q_token_type_ids = q_input_segments,
                             a_input_ids = a_input_ids.long(),
                             a_attention_mask = a_input_masks,
                             a_token_type_ids = a_input_segments,
                             num_features = num_features,
                             cat_features = cat_features,
                            )

        predictions = outputs[0]
        val_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()
        original[idx*batch_size : (idx+1)*batch_size] = labels.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(val_preds)).numpy()
    return output, original


def predict_result(model, test_loader, test_length, batch_size=32):

    test_preds = np.zeros((test_length, len(target_cols)))

    model.eval()
    tk0 = tqdm(enumerate(test_loader))
    for idx, x_batch in tk0:
        if COMBINE_INPUT:
            with torch.no_grad():
                outputs = model(input_ids = x_batch[0].to(device),
                            labels = None,
                            attention_mask = x_batch[1].to(device),
                            token_type_ids = x_batch[2].to(device),
                            num_features = x_batch[3].to(device),
                            cat_features = x_batch[4].to(device),
                           )
        else:
            with torch.no_grad():
                outputs = model(q_input_ids = x_batch[0].to(device),
                            labels = None,
                            q_attention_mask = x_batch[1].to(device),
                            q_token_type_ids = x_batch[2].to(device),
                            a_input_ids = x_batch[3].to(device),
                            a_attention_mask = x_batch[4].to(device),
                            a_token_type_ids = x_batch[5].to(device),
                            num_features = x_batch[6].to(device),
                            cat_features = x_batch[7].to(device),
                           )
        predictions = outputs[0]
        test_preds[idx*batch_size : (idx+1)*batch_size] = predictions.detach().cpu().squeeze().numpy()

    output = torch.sigmoid(torch.tensor(test_preds)).numpy()
    return output


def add_features(df):
    find = re.compile(r"^[^.]*")
    df['netloc'] = df['url'].apply(lambda x: re.findall(find, urlparse(x).netloc)[0])
    df['qa_same_user_page_flag'] = (df['question_user_page']==df['answer_user_page'])*1
    df['question_title_num_words'] = df['question_title'].str.count('\S+')
    df['question_body_num_words'] = df['question_body'].str.count('\S+')
    df['answer_num_words'] = df['answer'].str.count('\S+')
    df['question_vs_answer_length'] = df['question_body_num_words']/df['answer_num_words']
    df['question_title_num_words'] = np.log1p(df['question_title_num_words'])
    df['question_body_num_words'] = np.log1p(df['question_body_num_words'])
    df['answer_num_words'] = np.log1p(df['answer_num_words'])
    df['question_vs_answer_length'] = np.log1p(df['question_vs_answer_length'])
    return df


def custom_loss(logits, labels):
    #vx = logits - torch.mean(logits)
    #vy = labels - torch.mean(labels)
    #p_loss = 1 - (torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2))))
    #q_loss = nn.BCEWithLogitsLoss()(logits[:,:21], labels[:,:21])
    #a_loss = nn.BCEWithLogitsLoss()(logits[:,21:], labels[:,21:])
    #custom_loss = 0.5*q_loss + 0.5*a_loss
    custom_loss = nn.BCEWithLogitsLoss()(logits, labels)
    #loss1 = nn.BCEWithLogitsLoss()(logits[:,0:19], labels[:,0:19])
    #loss2 = nn.BCEWithLogitsLoss()(logits[:,20:], labels[:,20:]) # except index=19
    #custom_loss = loss1 + loss2
    #custom_loss = 0.
    #for i in range(len(loss_sample_weights)):
    #    custom_loss += loss_sample_weights[i] * nn.BCEWithLogitsLoss()(logits[:,i], labels[:,i])
    #custom_loss = 0.5*p_loss + 0.5*b_loss
    return custom_loss


def min_max(x, axis=None):
    _min = x.min(axis=axis, keepdims=True)
    _max = x.max(axis=axis, keepdims=True)
    result = (x-_min)/(_max-_min)
    return result

#===========================================================
# main
#===========================================================
#def main():
if True:
    """
    with timer('Data Loading'):
        train = pd.read_csv(f"{ROOT}train.csv").fillna("none")
        for c in target_cols:
            #train.loc[train[c]==0, c] = 0.2
            train[c] = train[c].rank(method='dense').values
            train[c] = train[c]/train[c].max()
        y_train = train[target_cols].values
        if config.test:
            test = pd.read_csv(f"{ROOT}test.csv").fillna("none")
            submission = pd.read_csv(f"{ROOT}sample_submission.csv")
    
    with timer('Num features'):
        train = add_features(train)
        if config.test:
            test = add_features(test)
        num_features = ['question_title_num_words', 'question_body_num_words', 'answer_num_words', 'question_vs_answer_length']
        train_num = train[num_features].values
        if config.test:
            test_num = test[num_features].values
                
    with timer('Cat features'):
        cat_features = ['netloc', 'category', 'qa_same_user_page_flag']
        ce_oe = ce.OrdinalEncoder(cols=cat_features, handle_unknown='return_nan')
        ce_oe.fit(train[cat_features])
        train_cat_df = ce_oe.transform(train[cat_features])
        test_cat_df = ce_oe.transform(test[cat_features]).fillna(0).astype(int)
        train_cat = train_cat_df.values
        test_cat = test_cat_df.values
        cat_dims = []
        for col in cat_features:
            dim = train[col].nunique()
            cat_dims.append((dim+1, dim//2+1)) # for unknown=0
        print(cat_dims)

    if config.train:
        with timer('Create folds'):
            folds = train.copy()

            kf = MultilabelStratifiedKFold(n_splits=NUM_FOLDS, random_state=SEED)
            for fold, (train_index, val_index) in enumerate(kf.split(train.values, y_train)):
                folds.loc[val_index, 'fold'] = int(fold)
                
            folds['fold'] = folds['fold'].astype(int)
            save_cols = [ID] + target_cols + ['fold']
            folds[save_cols].to_csv('folds.csv', index=None)

    with timer('Prepare Bert config'):
        tokenizer = BertTokenizer.from_pretrained("../input/pretrained-bert-models-for-pytorch/bert-base-uncased-vocab.txt", 
                                                  do_lower_case=True)
        input_categories = ['question_title', 'question_body', 'answer']
        bert_model_config = '../input/pretrained-bert-models-for-pytorch/bert-base-uncased/bert_config.json'
        bert_config = BertConfig.from_json_file(bert_model_config)
        bert_config.num_labels = len(target_cols)
        bert_model = 'bert-base-uncased'
        do_lower_case = 'uncased' in bert_model
        output_model_file = 'bert_pytorch.bin'
    
    if config.train:

        BATCH_SIZE = 8
        if DEBUG:
            epochs = 1
        else:
            epochs = config.epochs
        ACCUM_STEPS = config.accum_steps

        with timer('Train Bert'):
            
            for fold in range(NUM_FOLDS):

                logger.info(f"Current Fold: {fold}")
                train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index

                train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                logger.info(f"Train Shapes: {train_df.shape}")
                logger.info(f"Valid Shapes: {val_df.shape}")
            
                logger.info("Preparing train datasets....")
            
                inputs_train = compute_input_arays(train_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[train_index], cat_features=train_cat[train_index])
                outputs_train = compute_output_arrays(train_df, columns=target_cols)
                outputs_train = torch.tensor(outputs_train, dtype=torch.float32)
                lengths_train = np.argmax(inputs_train[0]==0, axis=1)
                lengths_train[lengths_train==0] = inputs_train[0].shape[1]
            
                logger.info("Preparing valid datasets....")
            
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
            
                logger.info("Preparing Dataloaders Datasets....")

                train_set = QuestDataset(inputs=inputs_train, lengths=lengths_train, labels=outputs_train)
                train_sampler = RandomSampler(train_set)
                train_loader = DataLoader(train_set, batch_size=BATCH_SIZE,sampler=train_sampler)
            
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
            
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                torch.cuda.empty_cache()
                if config.freeze : ## This is basically using out of the box bert model while training only the classifier head with our data . 
                    for param in model.bert.parameters():
                        param.requires_grad = False
                model.train()
            
                i = 0
                best_avg_loss = 100.0
                best_score = -1.
                best_param_loss = None
                best_param_score = None
                param_optimizer = list(model.named_parameters())
                no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
                optimizer_grouped_parameters = [
                    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
                    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
                    ]        
                optimizer = AdamW(optimizer_grouped_parameters, lr=config.lr, eps=4e-5)
                #optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, eps=4e-5)
                #criterion = nn.BCEWithLogitsLoss()
                criterion = custom_loss
                scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup, num_training_steps=epochs*len(train_loader)//ACCUM_STEPS)
                logger.info("Training....")
            
                for epoch in tqdm(range(epochs)):

                    torch.cuda.empty_cache()
                
                    start_time   = time.time()
                    avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5 = train_model(model, train_loader, optimizer, criterion, scheduler, config)
                    avg_val_loss, score = val_model(model, criterion, valid_loader, val_shape=val_df.shape[0], batch_size=BATCH_SIZE)
                    elapsed_time = time.time() - start_time

                    logger.info('Epoch {}/{} \t loss={:.4f} \t val_loss={:.4f} \t train_loss={:.4f} \t train_loss_1={:.4f} \t train_loss_2={:.4f} \t train_loss_3={:.4f} \t train_loss_4={:.4f}  \t train_loss_5={:.4f} \t score={:.6f} \t time={:.2f}s'.format(
                        epoch+1, epochs, avg_loss, avg_val_loss, avg_loss, avg_loss_1, avg_loss_2, avg_loss_3, avg_loss_4, avg_loss_5, score, elapsed_time))

                    if best_avg_loss > avg_val_loss:
                        i = 0
                        best_avg_loss = avg_val_loss 
                        best_param_loss = model.state_dict()

                    if best_score < score:
                        best_score = score
                        best_param_score = model.state_dict()
                        logger.info('best_param_score_{}_{}.pt'.format(config.expname ,fold))
                        torch.save(best_param_score, 'best_param_score_{}_{}.pt'.format(config.expname, fold))
                    else:
                        i += 1

            del train_df, val_df, model, optimizer, criterion, scheduler
            del valid_loader, train_loader, valid_set, train_set
            torch.cuda.empty_cache()
            gc.collect()
    
    if config.cv:

        with timer('CV'):

            folds = pd.read_csv(f'{MODEL_DIR}folds.csv')
            results = np.zeros((len(train), len(target_cols)))
            logits = np.zeros((len(train), len(target_cols)))

            for fold in range(NUM_FOLDS):
                
                #train_index = folds[folds.fold != fold].index
                val_index = folds[folds.fold == fold].index
                #train_df, val_df = train.iloc[train_index], train.iloc[val_index]
                val_df = train.iloc[val_index]
                
                inputs_valid = compute_input_arays(val_df, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
                                                   num_features=train_num[val_index], cat_features=train_cat[val_index])
                outputs_valid = compute_output_arrays(val_df, columns = target_cols)
                outputs_valid = torch.tensor(outputs_valid, dtype=torch.float32)
                lengths_valid = np.argmax(inputs_valid[0] == 0, axis=1)
                lengths_valid[lengths_valid == 0] = inputs_valid[0].shape[1]
                valid_set = QuestDataset(inputs=inputs_valid, lengths=lengths_valid, labels=outputs_valid)
                valid_loader = DataLoader(valid_set, batch_size=32, shuffle=False, drop_last=False)
                
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result, logit = predict_valid_result(model, valid_loader, len(val_df))  
                results[val_index, :] = result
                logits[val_index, :] = logit 
            
            rho_val = np.mean([spearmanr(logits[:,i], results[:,i]).correlation for i in range(results.shape[1])])
            logger.info(f'CV spearman-rho: {round(rho_val, 5)}')

            oof = pd.DataFrame()
            for i, col in enumerate(target_cols):
                oof[col] = results[:,i]
            oof.to_csv(f'oof_{config.expname}.csv', index=False)
    """
    if config.test:

        with timer('Inference'):
            #===================================
            # same as before config
            #===================================
            #test_inputs = compute_input_arays(test, input_categories, tokenizer, max_sequence_length=MAX_SEQUENCE_LENGTH, 
            #                                  num_features=test_num, cat_features=test_cat)
            #lengths_test = np.argmax(test_inputs[0] == 0, axis=1)
            #lengths_test[lengths_test == 0] = test_inputs[0].shape[1]
            #test_set = QuestDataset(inputs=test_inputs, lengths=lengths_test, labels=None)
            #test_loader  = DataLoader(test_set, batch_size=32, shuffle=False)
            result = np.zeros((len(test), len(target_cols)))

            for fold in range(NUM_FOLDS):
                model = CustomBert.from_pretrained('../input/pretrained-bert-models-for-pytorch/bert-base-uncased/', config=bert_config, cat_dims=cat_dims)
                model.zero_grad()
                model.to(device)
                model.load_state_dict(torch.load(f'{MODEL_DIR}best_param_score_{config.expname}_{fold}.pt'))
                result += predict_result(model, test_loader, len(test)) 
                if DEBUG:
                    break
                    
            result /= NUM_FOLDS

        with timer('Create submission.csv'):
            submission.loc[:, 'question_asker_intent_understanding':] = result
            submission.to_csv('submission7.csv', index=False)

[Inference] start
15it [00:16,  1.09s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.10s/it]
15it [00:16,  1.09s/it]
15it [00:16,  1.09s/it]
[Inference] done in 99 s
[Create submission.csv] start
[Create submission.csv] done in 0 s


# Ensemble

In [9]:
if True:
    if True:
        
        with timer('Ensemble'):
            
            oof1 = pd.read_csv('../input/googlequestchallenge-weights1/oof.csv')
            oof2 = pd.read_csv('../input/googlequestchallenge-weights2/oof.csv')
            oof3 = pd.read_csv('../input/googlequestchallenge-weights3/oof.csv')
            oof4 = pd.read_csv('../input/googlequestchallenge-weights4/oof.csv')
            oof5 = pd.read_csv('../input/googlequestchallenge-weights5/oof.csv')
            #oof6 = pd.read_csv('../input/googlequestchallenge-weights6/oof.csv')
            oof7 = pd.read_csv('../input/googlequestchallenge-weights7/oof.csv')
            oof_lgb = pd.read_csv('../input/train-lgb-with-bert-features-xentropy/lgb_oof.csv')
            #oof_spell = pd.read_csv('../input/googlequestchallenge-spell/oof.csv')
            
            rho_val = np.mean([spearmanr(train.loc[:,i], oof1.loc[:,i]).correlation for i in target_cols])
            print(f'oof1 CV spearman-rho1: {round(rho_val, 5)}')
            rho_val = np.mean([spearmanr(train.loc[:,i], oof2.loc[:,i]).correlation for i in target_cols])
            print(f'oof2 CV spearman-rho2: {round(rho_val, 5)}')
            rho_val = np.mean([spearmanr(train.loc[:,i], oof3.loc[:,i]).correlation for i in target_cols])
            print(f'oof3 CV spearman-rho2: {round(rho_val, 5)}')
            rho_val = np.mean([spearmanr(train.loc[:,i], oof4.loc[:,i]).correlation for i in target_cols])
            print(f'oof4 CV spearman-rho2: {round(rho_val, 5)}')
            rho_val = np.mean([spearmanr(train.loc[:,i], oof5.loc[:,i]).correlation for i in target_cols])
            print(f'oof5 CV spearman-rho2: {round(rho_val, 5)}')
            #rho_val = np.mean([spearmanr(train.loc[:,i], oof6.loc[:,i]).correlation for i in target_cols])
            #print(f'oof6 CV spearman-rho2: {round(rho_val, 5)}')
            rho_val = np.mean([spearmanr(train.loc[:,i], oof7.loc[:,i]).correlation for i in target_cols])
            print(f'oof7 CV spearman-rho2: {round(rho_val, 5)}')
            rho_val = np.mean([spearmanr(train.loc[:,i], oof_lgb.loc[:,i]).correlation for i in target_cols])
            print(f'oof_lgb CV spearman-rho2: {round(rho_val, 5)}')
            #rho_val = spearmanr(train['question_type_spelling'], oof_spell['question_type_spelling']).correlation
            #print(f'oof_spell question_type_spelling spearman-rho2: {round(rho_val, 5)}')
            
            weight_dict = {#'question_asker_intent_understanding': [0.14423382227545373, 0.26403362581825757, 0.21342671281564143, 0.07873806029749497, 0.06163222126953327, 0.19133353793688565, 0.04660201958673338], 
                           'question_asker_intent_understanding': [0., 0.40328523, 0.15753677,
                                                                   0., 0.13640727, 0.22840674, 0.07436399],
                           'question_body_critical': [0., 0.06178358, 0.0292628, 
                                                      0., 0.06584192, 0.45798585, 0.38512585], 
                           #'question_conversational': [0.10164748898770576, 0.18405234477802096, 0.16808611052544156, 0.106469745260419, 0.35601581737694876, 0.07568642836573576, 0.008042064705728246], 
                           'question_conversational': [0.08811661, 0.24527189, 0.19405812,
                                                      0.05633958, 0.33125105, 0.08496275, 0],
                           'question_expect_short_answer': [0.11924842, 0.19755374, 0.,
                                                            0.2523071,  0.10030623, 0.33058451, 0], 
                           'question_fact_seeking': [0.013202065769539835, 0.26707698812113057, 0.0317382709215506, 0.21667859598041936, 0.004755702783864884, 0.32050538984849064, 0.14604298657500414], 
                           #'question_has_commonly_accepted_answer': [0.13854796335510416, 0.17812783621630807, 0.07317431478650147, 0.12849381596205142, 0.22461229668945618, 0.22545398495625127, 0.031589788034327416],
                           'question_has_commonly_accepted_answer': [0.14306743324595933, 0.18393841165175798, 0.07556127959242885,
                                                                     0.13268531700139297, 0.23193920707789695, 0.232808351430564, 0],
                           'question_interestingness_others': [0., 0.1564567, 0.07727025,
                                                               0.20987616, 0., 0.32132959, 0.2350673], 
                           'question_interestingness_self': [0.0764107430765866, 0.1250758634055338, 0.11271321495327034, 0.03549773316001633, 0.12661862801239315, 0.27569063228935353, 0.24799318510284624],
                           #'question_multi_intent': [0.07242022218179081, 0.07402376048693433, 0.05455913079557168, 0.2697231170434426, 0.0812918639251851, 0.23737074429530128, 0.21061116127177426],
                           'question_multi_intent': [0., 0.1564567, 0.07727025,
                                                     0.20987616, 0., 0.32132959, 0.2350673],
                           'question_not_really_a_question': [0.22173249960012137, 0.17070331144640122, 0.03800925013012738, 0.12092933866649534, 0.23763807386006638, 0.15542059707171976, 0.05556692922506842], 
                           #'question_opinion_seeking': [0.07579225676041464, 0.175130003847942, 0.15289957108380914, 0.31238211059467763, 0.06223106101140025, 0.1232044425977646, 0.09836055410399193],
                           'question_opinion_seeking': [0.12353181, 0.1277901, 0.18055271,
                                                       0.28815662, 0.02141769, 0.25855108, 0],
                           'question_type_choice': [0.13678921128070468, 0.2087929293477967, 0.20735486917872442, 0.1897499074181206, 0.008347306586780264, 0.16089103619963713, 0.08807473998823628],
                           #'question_type_compare': [0.026828916116782557, 0.2799801833284125, 0.13942302336501752, 0.06988229992716304, 0.1530419814053994, 0.08227884557595866, 0.24856475028126637],
                           'question_type_compare': [0.05881545, 0.45905453, 0.14891039,
                                                    0.08329721, 0.1998714, 0.05005102, 0],
                           'question_type_consequence': [5.55111512e-17, 2.72271947e-01, 5.92000967e-01,
                                                         5.55111512e-17, 5.55111512e-17, 1.35727085e-01, 0],
                           #'question_type_definition': [0.029721437749985875, 0.17635250438115008, 0.4397785287791846, 0.013185584358467851, 0.12560976132864476, 0.030970915950460667, 0.18438126745210612], 
                           'question_type_definition': [1.22679480e-02, 3.47201249e-01, 4.05141867e-01,
                                                       1.09169328e-01, 1.26092856e-01, 1.26751081e-04, 0],
                           'question_type_entity': [0.0418549535997241, 0.32667789354357957, 0.10596940155493283, 0.05927695959413755, 0.1422348839544507, 0.035617908903690404, 0.2883679988494848], 
                           'question_type_instructions': [0.13678921128070468, 0.2087929293477967, 0.20735486917872442, 0.1897499074181206, 0.008347306586780264, 0.16089103619963713, 0.08807473998823628],
                           #'question_type_procedure': [0.21076627571101972, 0.17070172694839036, 0.06696570361809562, 0.2038405617405412, 0.08418389425529142, 0.2520156600987512, 0.011526177627910493],
                           'question_type_procedure': [0.00000000e+00, 3.30383233e-01, 4.16333634e-17,
                                                       1.83175459e-01, 0.00000000e+00, 4.86441308e-01, 0.00000000e+00],
                           'question_type_reason_explanation': [0.22229041, 0.06843367, 0.21228687,
                                                                0.14428088, 0.06463155, 0.28807661, 0],
                           'question_type_spelling': [0.19286676590112609, 0.04242283968733379, 0.20523158390448215, 0.11685920202552116, 0.2257236270657778, 0.0008294038366749291, 0.21606657757908418], 
                           'question_well_written': [0.00000000e+00, 2.70205701e-01, 5.55111512e-17,
                                                     5.55111512e-17, 0.00000000e+00, 4.13528567e-01, 3.16265732e-01],
                           'answer_helpful': [0.013202065769539835, 0.26707698812113057, 0.0317382709215506, 0.21667859598041936, 0.004755702783864884, 0.32050538984849064, 0.14604298657500414], 
                           'answer_level_of_information': [0.12867320477143743, 0.11123532800986519, 0.05661504097529127, 0.1587144287756197, 0.002635259339544268, 0.29671865993536584, 0.24540807819287627], 
                           #'answer_plausible': [0.14423382227545373, 0.26403362581825757, 0.21342671281564143, 0.07873806029749497, 0.06163222126953327, 0.19133353793688565, 0.04660201958673338],
                           'answer_plausible': [2.07685008e-01, 1.17305096e-01, 3.25729756e-01,
                                                1.11022302e-16, 0.00000000e+00, 3.49280140e-01, 0],
                           'answer_relevance': [0.27761237, 0.151558, 0.13026388, 
                                                0., 0.11851019, 0.32205556, 0], 
                           'answer_satisfaction': [0.09088489, 0., 0.31911212,
                                                   0., 0.29274372, 0.19064694, 0.10661232], 
                           #'answer_type_instructions': [0.13678921128070468, 0.2087929293477967, 0.20735486917872442, 0.1897499074181206, 0.008347306586780264, 0.16089103619963713, 0.08807473998823628], 
                           'answer_type_instructions': [2.35330575e-01, 8.71562313e-02, 2.59208132e-01, 
                                                        2.19053323e-01, 2.22044605e-16, 1.99251739e-01, 0],
                           #'answer_type_procedure': [0.3358637794463721, 0.010657239458323304, 0.08600332206081915, 0.25234219821192166, 0.04010421246858745, 0.2530397629419472, 0.021989485412029098], 
                           'answer_type_procedure': [0.25387628, 0.12267789, 0.14188704, 
                                                    0.3124254, 0.0196214, 0.149512, 0],
                           #'answer_type_reason_explanation': [0.08347908484814268, 0.1597305268386166, 0.14420297234016713, 0.17020547001124728, 0.07517657152542827, 0.19206020951638134, 0.17514516492001675],
                           'answer_type_reason_explanation': [0.25138575, 0.14860905, 0.10729707,
                                                             0.22231021, 0.04555059, 0.22484732, 0],
                           'answer_well_written': [0.09088489, 0., 0.31911212,
                                                   0., 0.29274372, 0.19064694, 0.10661232]}
            
            oof = pd.DataFrame()
            for c in target_cols:
                weight = weight_dict[c]
                oof[c] = weight[0]*oof1[c] + weight[1]*oof2[c] + weight[2]*oof3[c] \
                            + weight[3]*oof4[c] + weight[4]*oof5[c] + weight[5]*oof7[c] + weight[6]*oof_lgb[c]
            #oof['question_type_spelling'] = oof_spell['question_type_spelling']
            rho_val = np.mean([spearmanr(train.loc[:,i], oof.loc[:,i]).correlation for i in target_cols])
            print(f'CV spearman-rho: {round(rho_val, 5)}')
            oof.to_csv('oof.csv', index=False)
            
            sub1 = pd.read_csv('submission1.csv')
            sub2 = pd.read_csv('submission2.csv')
            sub3 = pd.read_csv('submission3.csv')
            sub4 = pd.read_csv('submission4.csv')
            sub5 = pd.read_csv('submission5.csv')
            #sub6 = pd.read_csv('submission6.csv')
            sub7 = pd.read_csv('submission7.csv')
            sub_lgb = pd.read_csv('submission_lgb.csv')
            #sub_spell = pd.read_csv('submission_spell.csv')
            
            submission = pd.read_csv(f"{ROOT}sample_submission.csv")
            for c in target_cols:
                weight = weight_dict[c]
                submission[c] = weight[0]*sub1[c] + weight[1]*sub2[c] + weight[2]*sub3[c]\
                                    + weight[3]*sub4[c] + weight[4]*sub5[c] + weight[5]*sub7[c] + weight[6]*sub_lgb[c]
            #submission['question_type_spelling'] = sub_spell['question_type_spelling']

[Ensemble] start


oof1 CV spearman-rho1: 0.39754
oof2 CV spearman-rho2: 0.39892
oof3 CV spearman-rho2: 0.3944
oof4 CV spearman-rho2: 0.38585
oof5 CV spearman-rho2: 0.39704
oof7 CV spearman-rho2: 0.39365
oof_lgb CV spearman-rho2: 0.40056
CV spearman-rho: 0.42139


[Ensemble] done in 2 s


In [10]:
if True:
    if True:
        
        with timer('Post process'):
            
            #oof = pd.read_csv('oof.csv')
            sub = submission.copy()
            rho_val = np.mean([spearmanr(train[c], oof[c]).correlation for c in target_cols])
            print(f'rho_val: {rho_val}')
            
            # min_thredhols_samples = 100
            col_threshold1 = {#'question_asker_intent_understanding': [0], 
                              'question_asker_intent_understanding': [0.7798771982510944, 0.7920811109335039, 0.8557452293617155, 0.8782904388166711, 0.902854965507151, 0.9097216513183621, 0.9235148682909096, 0.9377637285504432, 0.9451765158821579, 0.9588248023972451, 0.9653175323909384, 0.9762548886633277, 0.9796716368024583, 0.9819606762586988],
                              'question_body_critical': [0.1495922097481511, 0.21278660744188876, 0.23628865235084048, 0.2800908098730057, 0.40498203159739254, 0.46787137110401567, 0.525492636979446, 0.6413243238866531, 0.6708672354281013, 0.7447139221830796, 0.7863341761178992, 0.8420947603145688, 0.8868824178173214, 0.9299641280918173, 0.9545921786495899], 
                              #'question_conversational': [0.15730855304042254, 0.20423813988542985, 0.2536424055146332, 0.3373688504302154, 0.5221058062019643, 0.740816428624697], 
                              'question_conversational': [0.1525069766760582, 0.21973333527958397,
                                                          0.2501086280984717, 0.29798695846565126,
                                                          0.38246143730815246, 0.5331792015729804,
                                                          0.7177321753573026, 0.7506112118301018],
                              'question_expect_short_answer': [0.2611144109126207, 0.38349116914923986, 0.4656468242600945, 0.5218841911294037, 0.6563936376568509, 0.799586370168953, 0.8348646742971397, 0.8443924042003939, 0.8855802026356525, 0.9345125552918279, 0.9497116097962152, 0.9683335385975709],
                              'question_fact_seeking': [0.3031777585182752, 0.38923299438492953, 0.5141984122575314, 0.5956082724766223, 0.6716870783315136, 0.8723047191802196, 0.9417597045808908, 0.9561960519695563, 0.9626713507691174, 0.9746346737387603], 
                              'question_interestingness_others': [0.44431287572719097, 0.4982307314319283, 0.5144904023708171, 0.5416725846068623, 0.6036645848631093, 0.6294690584877651, 0.7197414180193262, 0.7456757663769422, 0.7497264557912451, 0.7573632343682353, 0.7821305934341662],
                              'question_interestingness_self': [0.42498338810662023, 0.4703352513449347, 0.5266966945624868, 0.5614019215799894, 0.6286043708485128, 0.6696881575056536, 0.6974721297586757, 0.8208525195006722], 
                              #'question_multi_intent': [0.2002102301339863, 0.24164563322001426, 0.27837575301138995, 0.32605090736484393, 0.44463146451870367, 0.6347091999212645, 0.7574897330637935, 0.8343621386815206],
                              'question_multi_intent': [0.21574519340526252, 0.24455823906844346, 0.2882639741661787, 0.3175744065853631, 0.3728078616670225, 0.437773812330742, 0.45636612809188054, 0.6131775748570878, 0.6520070969720705, 0.7848995633232405, 0.8452294317089482, 0.9186552533350643, 0.9344598516526684, 0.9452438987246035],
                              'question_not_really_a_question': [0.03939258439095743, 0.04976097847575716, 0.09891222789537282],
                              #'question_opinion_seeking': [0.11192085372578475, 0.19644661036700625, 0.25217730232909336, 0.4722478013564555, 0.5385586884122188, 0.6146146571699068, 0.7088051861050262, 0.7601569084293851, 0.8710683264407086, 0.9234604611980788], 
                              'question_opinion_seeking': [0.13289964719018157, 0.20618607043554066,
                                                           0.2604280423363276, 0.3069630972895149, 
                                                           0.5348804402114498, 0.5987618913550097, 
                                                           0.7052586456484844, 0.7422339214461927, 
                                                           0.7846790466669848, 0.8743318962006583, 
                                                           0.9201453647137279, 0.939505738787939, 
                                                           0.9623826709445682],
                              'question_type_choice': [0.16189209863769405, 0.18041577490958538, 0.20375853233399585, 0.28264195949237053, 0.3174840123555569, 0.43230607635119594, 0.5614214858568691, 0.6778646218309891, 0.7760264291410569, 0.8869531382380472, 0.9136651272394836, 0.9545537136899828], 
                              #'question_type_compare': [0.20054487492113204, 0.3470358764633537], 
                              'question_type_compare': [0.18990925922981028, 0.40123795613313257, 
                                                        0.5589969999807299, 0.6779192540256538, 
                                                        0.742416114956296, 0.7936005539683534, 
                                                        0.8421920381249596], 
                              'question_type_consequence': [0.06038017543220174, 0.0702390460947302, 0.09607382454868457, 0.20571409772171417, 0.21407418656760632],
                              #'question_type_definition': [0.15763240556375235, 0.3526881294414419],
                              'question_type_definition': [0.11180949041294634, 0.24956385868209166, 
                                                           0.32807333040120734, 0.41919767202577446,
                                                           0.7795623708837238, 0.8149852210754654], 
                              'question_type_entity': [0.18025714657180678, 0.29494450802633304, 0.47341091797993157, 0.7023626874921312, 0.9073733560868775], 
                              'question_type_instructions': [0.11545000730019972, 0.19799267408587765, 0.23550907733286972, 0.3017076630685107, 0.40247920086274774, 0.4489292237364733, 0.5159060765432031, 0.6665511093010561, 0.9636617010717549], 
                              #'question_type_procedure': [0.06609997032258343, 0.09008359578030727, 0.11837768309646264, 0.14599606547596422, 0.16901712218515597, 0.3758457628027094, 0.4414868926563922], 
                              'question_type_procedure': [0.11004606575356937, 0.13441472465425375, 0.17708316018945558, 0.20675465774674012, 0.2309187642730335, 0.33345429362849754, 0.38374262465375975, 0.4168174490935639, 0.4389489980778297, 0.4846348516432625, 0.5280430388808683, 0.5770837436375043],
                              'question_type_reason_explanation': [0.18568368940670632, 0.24868573236528568, 0.2954520477866478, 0.34726564658051806, 0.46702325384626997, 0.5946945524603777, 0.6698542085890401, 0.753313352465967, 0.7811815673471862, 0.8526310272505137, 0.90231676523946, 0.9217352725526369, 0.9601912064297389, 0.9765686258376298, 0.9830687754020904],
                              #'question_type_spelling': [0.22019172620024594], 
                              'question_type_spelling': [0.0029089999],
                              'question_well_written': [0.4207639178763126, 0.5610948503866817, 0.6662563857158016, 0.737045736721135, 0.7998574481149376, 0.8442311848896967, 0.8976912762705032, 0.9206109761152341, 0.9551348125781118, 0.9578814586064871, 0.9623856082601339, 0.9712609208706265, 0.975746144221896], 
                              'answer_helpful': [0.6984034089413959, 0.7983687269189049, 0.8655643088812476, 0.9001466801353707, 0.935729499047471, 0.9579834480656675, 0.9755939591623526, 0.9777704009081776, 0.9810771630408655],
                              'answer_level_of_information': [0.3715275553789707, 0.396338749814461, 0.4587861594679614, 0.5148828192403583, 0.5315244715903833, 0.5594341091161936, 0.6084747296859907, 0.6336569576359905, 0.6595364221929101, 0.6846156176541431, 0.7047383742278317, 0.7167593572520277, 0.7455446587073238, 0.7817465842259476, 0.7983501095328426, 0.8119815931181447],
                              #'answer_plausible': [0.8970473473995253, 0.9319313682149294, 0.9607598067122134, 0.9720762804218548, 0.9774262965010801, 0.9794692016815529, 0.988353238268427], 
                              'answer_plausible': [0.9094755568746322, 0.928591183031122, 0.9473102584235416, 0.9628147848474388, 0.9784702155497013, 0.9843023378554436, 0.9896587797725087, 0.9900177500223624],
                              'answer_relevance': [0.8948966045153495, 0.9435168086799155, 0.9545568302156637, 0.9622565427863874, 0.9830454076528758, 0.9911405860892273, 0.9921045218942421, 0.9923499406262755],
                              'answer_satisfaction': [0.6293044939394735, 0.7121716066041376, 0.7561107046712046, 0.7914277468387129, 0.8298867650435837, 0.883850436430629, 0.9117851387565115, 0.9210891741477153, 0.9378713261848577, 0.9508567331490615, 0.9585898967280264],
                              #'answer_type_instructions': [0.10378756640878524, 0.12650578354436098, 0.1835660532428279, 0.2536069520137287, 0.30462117785830223, 0.444751757369504, 0.510977014118346, 0.5584100618542558, 0.6760724267021406, 0.9563513617530714], 
                              'answer_type_instructions': [0.12914539018869778, 0.2547374021123783, 0.3004133928631577, 0.3430771537748245, 0.47344102727934695, 0.5123180067788267, 0.6986051519696712, 0.9490846566348909, 0.9611641358159742, 0.9634507955866556, 0.9643617783194274],
                              #'answer_type_procedure': [0.07283263806783744, 0.08673782158466102, 0.10797326502772303, 0.14305042199760273, 0.18652053202933983, 0.2589035024488969, 0.3142348956297047, 0.3884851946263015], 
                              'answer_type_procedure': [0.05262500290391218, 0.068643509793028, 
                                                        0.0875365114214299, 0.12193730591853633, 
                                                        0.12859167991175083, 0.1749735190696107, 
                                                        0.23268544708591288, 0.2853473228116521, 
                                                        0.3081688498590655, 0.3726692764096011, 
                                                        0.43835658963783947],
                              #'answer_type_reason_explanation': [0.0719726353157682, 0.09497584034183688, 0.1922202264108104, 0.2927100843316804, 0.32682952727302356, 0.4546942172344034, 0.5492092734450725, 0.6414310974814582, 0.7712054310472717, 0.8032054448895015, 0.8963335925276644, 0.9275577809848018, 0.9604266579895407, 0.98558127380008], 
                              'answer_type_reason_explanation': [0.08911567475013671, 0.19586031922677827,
                                                                 0.3244037433600754, 0.40260246629932767, 
                                                                 0.49025563870497624, 0.5843397430488557, 
                                                                 0.6629220619382283, 0.7781651886175698, 
                                                                 0.8225671776234381, 0.8974398999014135, 
                                                                 0.9245137338400247, 0.9565124730882079, 
                                                                 0.9669367106467958],
                              'answer_well_written': [0.7790048186124875, 0.816189790201333, 0.8647582261050355, 0.8993758617055834, 0.9161059950859819, 0.9291888306671663, 0.9405892016200885, 0.9427343677958661, 0.9636116603983861]}

            for col, thresholds in col_threshold1.items():
                print('')
                print(col)
                score = spearmanr(train[col], oof[col]).correlation
                print(score)
                for i in range(len(thresholds)):
                    if i==0:
                        oof.loc[(oof[col]<=thresholds[0]), col] = 0
                        submission.loc[(submission[col]<=thresholds[0]), col] = 0
                    else:
                        oof.loc[(oof[col]>thresholds[i-1]) & (oof[col]<=thresholds[i]), col] = thresholds[i-1]
                        submission.loc[(submission[col]>thresholds[i-1]) & (submission[col]<=thresholds[i]), col] = thresholds[i-1]
                score = spearmanr(train[col], oof[col]).correlation
                print(score)
            
            col_threshold2 = {'question_has_commonly_accepted_answer': [0.7842707733624578],}
            
            for col, thresholds in col_threshold2.items():
                print('')
                print(col)
                score = spearmanr(train[col], oof[col]).correlation
                print(score)
                for i in range(len(thresholds)):
                    if i==0:
                        oof.loc[(oof[col]>=thresholds[0]), col] = 1
                        submission.loc[(submission[col]>=thresholds[0]), col] = 1
                score = spearmanr(train[col], oof[col]).correlation
                print(score)
                
            rho_val = np.mean([spearmanr(train[c], oof[c]).correlation for c in target_cols])
            print(rho_val)
            
            ### To Avoid Error
            if submission['question_type_spelling'].nunique()==1:
                v1 = sub['question_type_spelling'].max()
                v2 = sub[sub['question_type_spelling']!=v1]['question_type_spelling'].max()
                print(v1, v2)
                index1 = sub[sub['question_type_spelling']==v1].index
                index2 = sub[sub['question_type_spelling']==v2].index
                print(index1, index2)
                submission.loc[index1, 'question_type_spelling'] = 0.66666667
                submission.loc[index2, 'question_type_spelling'] = 0.33333333
            
            def min_max(x, axis=None):
                _min = x.min(axis=axis, keepdims=True)
                _max = x.max(axis=axis, keepdims=True)
                result = (x-_min)/(_max-_min)
                return result
    
            # 0 ~ 1
            for col in target_cols:
                submission[col] = min_max(submission[col].values, axis=None)
                submission.loc[submission[col]==0, col] = 0.0000000000000000001
                submission.loc[submission[col]==1, col] = 0.9999999999999999999
                
            submission.to_csv('submission.csv', index=False, float_format='%.20f')

[Post process] start


rho_val: 0.4213930504075034

question_asker_intent_understanding
0.39130295236866125
0.3960832308562305

question_body_critical
0.6790887363057522
0.682422715502982

question_conversational
0.4180530698616003
0.5173198563036796

question_expect_short_answer
0.3188913523811039
0.3229965222021559

question_fact_seeking
0.3820288083881838
0.3895198034976137

question_interestingness_others
0.3731809977675753
0.37909850585099436

question_interestingness_self
0.5080258622066897
0.5260439360987756

question_multi_intent
0.6088383936031573
0.620502359544521

question_not_really_a_question
0.08303998324509837
0.109441811844358

question_opinion_seeking
0.5005355833347301
0.5036638496515978

question_type_choice
0.7587243059572895
0.7823180189901875

question_type_compare
0.3600865812222029
0.542400043533202

question_type_consequence
0.17839223325338696
0.23579115389524735

question_type_definition
0.3570649987747841
0.6260806771190092

question_type_entity
0.4611885869484818
0.63016525290569

[Post process] done in 2 s
