## Setup

### Requirements and setup

##### Set up PyTorch and Apex

In [0]:
!pip install torch==1.4.0 &>> tmp.log

In [0]:
!pip install torchvision==0.5.0 &>> tmp.log

In [0]:
%%writefile apex_setup.sh

git clone https://github.com/NVIDIA/apex  &>> tmp.log
cd apex
pip install -v --no-cache-dir ./  &>> tmp.log

In [0]:
!sh apex_setup.sh &>> tmp.log

In [0]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))

try:
    from apex import amp
    APEX_AVAILABLE = True
    print("Apex enabled!")
except ModuleNotFoundError:
    APEX_AVAILABLE = False

##### Installing dependencies

In [0]:
!pip install category_encoders &>> tmp.log

In [0]:
!pip install transformers &>> tmp.log

In [0]:
!pip install torchcontrib &>> tmp.log

### Setting GDrive connection



In [0]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
import sys
import os

PATH = '/content/drive/My Drive/Colab Notebooks/google-quest-challenge/'
MODEL_NAME = 'pytorch-bert'
sys.path.append(PATH)

In [0]:
import numpy as np
import random

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def set_seeds(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

SEED = 21937
set_seeds(SEED)

## Dataset

In [0]:
import pandas as pd

train_dset = pd.read_csv(PATH+"data/train.csv")
test_dset = pd.read_csv(PATH+"data/test.csv")
submi_dset = pd.read_csv(PATH+'data/sample_submission.csv')

free_text_columns = ['question_title', 'question_body', 'answer']
category_columns = ['host', 'category']
discard_columns = ['question_user_name', 'question_user_page',  'answer_user_name', 'answer_user_page', 'url']

target_columns = ['question_asker_intent_understanding', 'question_body_critical', 'question_conversational',
                  'question_expect_short_answer', 'question_fact_seeking', 'question_has_commonly_accepted_answer',
                  'question_interestingness_others', 'question_interestingness_self', 'question_multi_intent',
                  'question_not_really_a_question', 'question_opinion_seeking', 'question_type_choice',
                  'question_type_compare', 'question_type_consequence', 'question_type_definition',
                  'question_type_entity', 'question_type_instructions', 'question_type_procedure',
                  'question_type_reason_explanation', 'question_type_spelling', 'question_well_written',
                  'answer_helpful', 'answer_level_of_information', 'answer_plausible', 'answer_relevance',
                  'answer_satisfaction', 'answer_type_instructions', 'answer_type_procedure',
                  'answer_type_reason_explanation', 'answer_well_written']


train_dset = train_dset.drop(discard_columns, axis=1)
# (train_dset['question_title'].apply(lambda x: word_tokenize(len(x)))).describe(percentile=[0.75, 0.9, 0.95, 0.95])
# [len(np.unique(train_dset[[col]])) for col in target_columns]

test_ids = test_dset.index
test_dset = test_dset.drop(discard_columns, axis=1)

In [0]:
test_run = False
if test_run:
    train_dset = train_dset[0:20]
    test_dset = test_dset[0:20]
    submi_dset = submi_dset[0:20]

### Feature engineering

In [0]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from category_encoders.one_hot import OneHotEncoder
from utils.nlp_tools import NlpUtils
from nltk.corpus import stopwords
from nltk import (sent_tokenize,
                  word_tokenize,
                  pos_tag)
from collections import defaultdict
import pandas as pd
import re


nlp_utils = NlpUtils()
STOPWORDS_SET = set(stopwords.words('english'))
PUNCTUATION_SET = {';', ':', ',', '.', '!', '?', '\n', '\r', '-', '\(', ')', '`', '$', '<', '>', '=', 
                   '+', '_', '&', '\'', '"', '\|', '#', '%', '*', '\[', ']', '\{', '}'}
QUESTION_WORDS = {'who', 'what', 'why', 'how', 'where', 'when', 'with', 'whose', 'whom', 'if', 'or'}


def preprocess_text(text):
    return text.apply(lambda x: pd.Series(nlp_utils.nlp_text(x)))


def oh_encoder() -> Pipeline:
    return Pipeline([('OHE', OneHotEncoder(drop_invariant=True))], verbose=True)


def custom_transformer(method) -> Pipeline:
    return Pipeline([
        ('Custom Function', FunctionTransformer(method, validate=False)),
        ], verbose=True)


@custom_transformer
def char_count(text: pd.Series) -> int:
    return pd.DataFrame(text.apply(lambda row: len(row)))


@custom_transformer
def word_count(text: pd.Series) -> int:
    """ Given a preprocessed text it returns the quantity of words"""
    return pd.DataFrame(text.apply(lambda row: len(word_tokenize(row))))


@custom_transformer
def unique_word_count(text: pd.Series) -> int:
    """ Given a preprocessed text it returns the quantity of unique words """
    return pd.DataFrame(text.apply(lambda row: len(set(word_tokenize(row)))))


@custom_transformer
def sentence_count(text: pd.Series) -> int:
    return pd.DataFrame(text.apply(lambda row: len(sent_tokenize(row))))


def stopword_ratio_calc(text: str) -> float:
    tokenized_text = word_tokenize(text)
    word_count = len(tokenized_text)
    stopword_count = sum([1 if word.lower() in STOPWORDS_SET else 0 for word in tokenized_text])
    try:
        return stopword_count/word_count
    except ZeroDivisionError:
        return 0.0


@custom_transformer
def stopword_ratio(text: pd.Series) -> pd.DataFrame:
    return pd.DataFrame(text.apply(lambda row: stopword_ratio_calc(row)))


def uppercase_ratio_calc(text: str) -> (float, float):
    tokens = word_tokenize(text)
    word_qty = len(tokens)
    word_uppercase_count = sum([1 if word[0].isupper() else 0 for word in tokens])
    char_qty = len(text)
    total_uppercase_count = len(re.findall(r'[A-Z]', text))
    try:
        word_ratio = word_uppercase_count / word_qty
        char_ratio = total_uppercase_count / char_qty
    except ZeroDivisionError:
        word_ratio = char_ratio = 0.0
    return word_ratio, char_ratio


@custom_transformer
def uppercase_ratio(text: pd.Series) -> pd.DataFrame:
    """ Given a word tokenized text it returns the ratio of words that begin with uppercase
        and the ratio of uppercase letters"""
    return text.to_frame().apply(lambda row: uppercase_ratio_calc(row[0]), result_type='expand', axis=1)


def punctuation_count_calc(text: str) -> pd.DataFrame:
    punct_counter = 0
    for i, punctuation in enumerate(PUNCTUATION_SET):
        punct_counter += len(re.findall('[{}]'.format(punctuation), text))
    return punct_counter


@custom_transformer
def punctuation_count(text: pd.Series) -> pd.DataFrame:
    return pd.DataFrame(text.apply(lambda row: punctuation_count_calc(row)))


def qwords_count_calc(text: str) -> pd.DataFrame:
    text = text.lower()
    qword_counter = 0
    for i, question in enumerate(QUESTION_WORDS):
        qword_counter += len(re.findall('[{}]'.format(question), text))
    return qword_counter


@custom_transformer
def qwords_count(text: pd.Series) -> pd.DataFrame:
    return pd.DataFrame(text.apply(lambda row: qwords_count_calc(row)))


@custom_transformer
def number_count(text: pd.Series) -> pd.DataFrame:
    return pd.DataFrame(text.apply(lambda row: len(re.findall(r'[0-9]', row))))

In [0]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import StandardScaler, MinMaxScaler


ohe = oh_encoder()

preprocess = ColumnTransformer([
    # ('qt_char_count', char_count, 'question_title'),
    # ('qb_char_count', char_count, 'question_body'),
    # ('a_char_count', char_count, 'answer'),
    
    # ('qt_word_count', word_count, 'question_title'),
    # ('qb_word_count', word_count, 'question_body'),
    # ('a_word_count', word_count, 'answer'),

    # ('qt_unique_word_count', unique_word_count, 'question_title'),
    # ('qb_unique_word_count', unique_word_count, 'question_body'),
    # ('a_unique_word_count', unique_word_count, 'answer'),

    # ('qt_sentence_count', sentence_count, 'question_title'),
    # ('qb_sentence_count', sentence_count, 'question_body'),
    # ('a_sentence_count', sentence_count, 'answer'),

    # ('qt_stopword_ratio', stopword_ratio, 'question_title'),
    # ('qb_stopword_ratio', stopword_ratio, 'question_body'),
    # ('a_stopword_ratio', stopword_ratio, 'answer'),

    # ('qt_uppercase_ratio', uppercase_ratio, 'question_title'),
    # ('qb_uppercase_ratio', uppercase_ratio, 'question_body'),
    # ('a_uppercase_ratio', uppercase_ratio, 'answer'),

    # ('qt_punctuation_count', punctuation_count, 'question_title'),
    # ('qb_punctuation_count', punctuation_count, 'question_body'),
    # ('a_punctuation_count', punctuation_count, 'answer'),

    # ('qt_qwords_count', qwords_count, 'question_title'),
    # ('qb_qwords_count', qwords_count, 'question_body'),

    # ('qt_number_count', number_count, 'question_title'),
    # ('qb_number_count', number_count, 'question_body'),
    # ('a_number_count', number_count, 'answer'),

    ('host_ohe', ohe, 'host'),
    ('category_ohe', ohe, 'category')
    ])

In [0]:
train_metadata_feat = preprocess.fit_transform(train_dset.drop(target_columns, axis=1))
# test_metadata_feat = preprocess.transform(test_dset)

In [0]:
whitening_preprocess = Pipeline([
                                ('Normalizer', MinMaxScaler(feature_range=(-1,1))),
                                ('Standarization', StandardScaler()),
                        ])

train_metadata_feat = whitening_preprocess.fit_transform(train_metadata_feat)
# test_metadata_feat = whitening_preprocess.transform(test_metadata_feat)

### BERT

In [0]:
from sklearn.model_selection import GroupKFold
from math import floor, ceil
from tqdm import tqdm_notebook as tqdm
import transformers


BERT_PATH = PATH+'bert/bert/'
# https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_bert.py#L34
tokenizer = transformers.BertTokenizer.from_pretrained(BERT_PATH+'bert-base-uncased-vocab.txt')

MAX_SEQUENCE_LENGTH = 512
FIRST_SENT_MAX_LENGTH = 50
SECOND_SENT_MIN_LENGTH = MAX_SEQUENCE_LENGTH - (FIRST_SENT_MAX_LENGTH + DUAL_SENT_SPECIAL_TOKENS)

Calling BertTokenizer.from_pretrained() with the path to a single file or url is deprecated



#### BERT preprocessing tools

In [0]:
def bert_dual_sentence_preprocess(first_sent, second_sent, tokenizer):
    first_sent_tokens = tokenizer.tokenize(first_sent)
    second_sent_tokens = tokenizer.tokenize(second_sent)

    first_sent_ids = tokenizer.convert_tokens_to_ids(first_sent_tokens)
    second_sent_ids = tokenizer.convert_tokens_to_ids(second_sent_tokens)

    first_sent_mask = list(np.ones(len(first_sent_ids)).astype(int))
    second_sent_mask = list(np.ones(len(second_sent_ids)).astype(int))

    first_sent_segm = list(np.zeros(len(first_sent_ids)).astype(int))
    second_sent_segm = list(np.ones(len(second_sent_ids)).astype(int))

    if len(first_sent_ids) < FIRST_SENT_MAX_LENGTH:
        second_sent_length = SECOND_SENT_MIN_LENGTH + (FIRST_SENT_MAX_LENGTH-len(first_sent_ids))
    else:
        second_sent_length = SECOND_SENT_MIN_LENGTH

    input_ids = [tokenizer.cls_token_id] + first_sent_ids[0:FIRST_SENT_MAX_LENGTH] + \
                [tokenizer.sep_token_id] + second_sent_ids[0:second_sent_length] + [tokenizer.sep_token_id]

    input_masks = [1] + first_sent_mask[0:FIRST_SENT_MAX_LENGTH] + [1] + second_sent_mask[0:second_sent_length] + [1]
    input_segments = [0] + first_sent_segm[0:FIRST_SENT_MAX_LENGTH] + [0] + second_sent_segm[0:second_sent_length] + [1]
  
    input_ids = input_ids + ([tokenizer.pad_token_id] * (MAX_SEQUENCE_LENGTH - len(input_ids)))
    input_masks = input_masks + ([0] * (MAX_SEQUENCE_LENGTH - len(input_masks)))
    input_segments = input_segments + ([0] * (MAX_SEQUENCE_LENGTH - len(input_segments)))
    
    return [input_ids, input_masks, input_segments]


def compute_input_arays(df, columns, tokenizer, max_sequence_length):
    input_ids, input_masks, input_segments = [], [], []
    for _, instance in tqdm(df[columns].iterrows()):
        first_sent, second_sent = instance[columns[0]], instance[columns[1]]
        ids, masks, segments = bert_dual_sentence_preprocess(first_sent, second_sent, tokenizer)
    
        input_ids.append(ids)
        input_masks.append(masks)
        input_segments.append(segments)

    return [torch.tensor(input_ids, dtype=torch.long),
            torch.tensor(input_masks, dtype=torch.long),
            torch.tensor(input_segments, dtype=torch.long)]


def compute_output_arrays(df, columns):
    return torch.tensor(df[columns].values, dtype=torch.float32)

#### Text preprocessing to BERT inputs

In [0]:
import gc 

FOLDS = 10
gkf = GroupKFold(n_splits=FOLDS).split(X=train_dset.question_body, groups=train_dset.question_body)

train_targets = compute_output_arrays(train_dset, target_columns)

train_body_inputs = compute_input_arays(train_dset, ['question_title', 'question_body'], tokenizer, MAX_SEQUENCE_LENGTH)
train_answer_inputs = compute_input_arays(train_dset, ['question_title', 'answer'], tokenizer, MAX_SEQUENCE_LENGTH)

metadata_feat_len = len(train_metadata_feat[0])
train_inputs = train_body_inputs + train_answer_inputs + list(torch.tensor([train_metadata_feat], dtype=torch.float16))

del tokenizer, train_dset, test_dset, train_body_inputs, train_answer_inputs
gc.collect()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




22

### BERT Dataset tools

In [0]:
def bert_dataset(inputs, idx=[], targets=torch.tensor([])):
    if len(targets) > 0 and len(idx) > 0:
        return data.TensorDataset(inputs[0][idx], # body_input_ids
                                  inputs[1][idx], # body_input_masks
                                  inputs[2][idx], # body_input_segments
                                  inputs[3][idx], # answer_input_ids
                                  inputs[4][idx], # answer_input_masks
                                  inputs[5][idx], # answer_input_segments
                                  inputs[6][idx], # metadata_features
                                  targets[idx] #targets
                                  )
    elif len(targets) > 0:
        return data.TensorDataset(inputs[0][:],
                                  inputs[1][:],
                                  inputs[2][:],
                                  inputs[3][:],
                                  inputs[4][:],
                                  inputs[5][:],
                                  inputs[6][:],
                                  targets[:]
                                  )
    elif len(idx) > 0:
        return data.TensorDataset(inputs[0][idx],
                                  inputs[1][idx],
                                  inputs[2][idx],
                                  inputs[3][idx],
                                  inputs[4][idx],
                                  inputs[5][idx],
                                  inputs[6][idx]
                                  )
    else: 
        return data.TensorDataset(inputs[0][:],
                                  inputs[1][:],
                                  inputs[2][:],
                                  inputs[3][:],
                                  inputs[4][:],
                                  inputs[5][:],
                                  inputs[6][:]
                                  )


def compute_spearmanr(preds, trues):
    rhos = []
    for col_trues, col_pred in zip(trues.T, preds.T):
        rhos.append(
            spearmanr(col_trues, col_pred + np.random.normal(0, 1e-7, col_pred.shape[0])).correlation)
    return np.mean(rhos), rhos

## Model

### Custom BERT

In [0]:
import multiprocessing, glob
import torch.nn.functional as F
import time

from scipy.stats import spearmanr
from torch import nn
from torch.utils import data
from torch.utils.data import DataLoader, Dataset,RandomSampler, SequentialSampler
from transformers import (
    BertTokenizer, BertModel, BertConfig,
    WEIGHTS_NAME, CONFIG_NAME, AdamW, get_linear_schedule_with_warmup, 
    get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
    )
from transformers.modeling_bert import BertPreTrainedModel
from utils.sodeep import sodeep


class CustomBERTBaseUncased(nn.Module):
    def __init__(self, bert_model_path: str, bert_cfg_path: str, dropout: float, output_len: int, metadata_feat_len: int):
        super(CustomBERTBaseUncased, self).__init__()
        
        hidden_bert_size = 768*4*2 # bert output * heads * models
        bert_metadata_size = hidden_bert_size + metadata_feat_len
        hidden_layer1_size = int(hidden_bert_size/4) + int(metadata_feat_len/2)
        hidden_layer2_size = int(hidden_layer1_size/4)
        
        self.body_bert = transformers.BertModel.from_pretrained(bert_model_path, config=bert_cfg_path)
        self.answer_bert = transformers.BertModel.from_pretrained(bert_model_path, config=bert_cfg_path)
        
        self.multi_sample_dropout = nn.Dropout(0.5) # SpatialDropout
        
        self.metadata_layer = nn.Linear(metadata_feat_len, metadata_feat_len)
        
        self.mlp_dense_1 = nn.Linear(bert_metadata_size, hidden_layer1_size)
        self.mlp_hidden_1 = nn.SELU()
        
        self.mlp_dense_2 = nn.Linear(hidden_layer1_size, hidden_layer2_size)
        self.mlp_hidden_2 = nn.SELU()
        
        self.mlp_drop = nn.Dropout(dropout)
        self.mlp_linear = nn.Linear(hidden_layer2_size, output_len)


    def forward(self, data):

        body_ids = data[0].to(device, dtype=torch.long)
        body_masks = data[1].to(device, dtype=torch.long)
        body_segments = data[2].to(device, dtype=torch.long)
        answer_ids = data[3].to(device, dtype=torch.long)
        answer_masks = data[4].to(device, dtype=torch.long)
        answer_segments = data[5].to(device, dtype=torch.long)
        metadata_inputs = data[6].to(device, dtype=torch.float16)

        body_bert_output = self.body_bert(body_ids, attention_mask=body_masks, token_type_ids=body_segments)
        answer_bert_output = self.answer_bert(answer_ids, attention_mask=answer_masks, token_type_ids=answer_segments)
        
        body_bert_output = body_bert_output[2][-4:]
        answer_bert_output = answer_bert_output[2][-4:]
        
        body_bert_heads = torch.cat(([self.multi_sample_dropout(layer) for layer in body_bert_output]), 2)
        answer_bert_heads = torch.cat(([self.multi_sample_dropout(layer) for layer in answer_bert_output]), 2)
        concat_bert_heads = torch.cat((body_bert_heads, answer_bert_heads), 2)
        
        bert_layer = torch.mean(concat_bert_heads, 1)

        metadata_layer = self.metadata_layer(metadata_inputs)
        concat_bert_metadata = torch.cat((bert_layer, metadata_layer), 1)

        mlp_layer_1 = self.mlp_dense_1(concat_bert_metadata)
        mlp_hidden_1 = self.mlp_hidden_1(mlp_layer_1)

        mlp_layer_2 = self.mlp_dense_2(mlp_hidden_1)
        mlp_hidden_2 = self.mlp_hidden_2(mlp_layer_2)

        mlp_drop = self.mlp_drop(mlp_hidden_2)
        
        mlp_linear = self.mlp_linear(mlp_hidden_2)
        return mlp_linear


def loss_fn(criterion, outputs, targets):
    most_diff_targets = ['question_not_really_a_question', 'question_type_spelling'] 
    diff_targets = ['answer_plausible', 'answer_well_written', 'answer_relevance', 'question_type_consequence', 
                    'answer_helpful', 'question_expect_short_answer', 'answer_satisfaction']
    mild_diff_targets = ['answer_type_procedure', 'question_fact_seeking', 'question_interestingness_others', 
                         'question_type_definition', 'question_type_compare', 'question_type_procedure', 
                         'question_conversational', 'question_asker_intent_understanding', 
                         'question_has_commonly_accepted_answer']

    weights = np.ones(len(target_columns))
    for unbalanced_col in most_diff_targets:
        weights[target_columns.index(unbalanced_col)] = 2
    for unbalanced_col in diff_targets:
        weights[target_columns.index(unbalanced_col)] = 1.25
    for unbalanced_col in mild_diff_targets:
        weights[target_columns.index(unbalanced_col)] = 1.1

    loss = [criterion(outputs[i], targets[i]) for i in range(len(outputs))]
    weighted_loss = [loss[i]*weights[i] for i in range(len(outputs))]
    return sum(weighted_loss)


def train_loop_fn(data_loader, model, criterion, optimizer, device, scheduler=None):
    model.train()
    total_batches = len(train_loader)
    train_progress = tqdm(enumerate(train_loader), total=total_batches)
    for batch_idx, data in train_progress:
        if len(data[0]) > 1:    # Discard last batch if size 1 to avoid breaking BatchNorm
            targets = data[10].to(device, dtype=torch.float32)
            
            optimizer.zero_grad()
            outputs = model(data)
            
            loss = loss_fn(criterion, outputs, targets)
            
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        train_progress.set_description(f'Batch: {batch_idx} - Loss: {loss}')
    torch.cuda.empty_cache()
    gc.collect()

def eval_loop_fn(data_loader, model, criterion, device):
    model.eval()
    fin_targets = []
    fin_outputs = []
    torch.cuda.empty_cache()
    gc.collect()
    with torch.no_grad():
        eval_progress = tqdm(enumerate(data_loader), total=len(data_loader))
        for batch_idx, data in eval_progress:
            targets = data[7].to(device, dtype=torch.float32)
            outputs = model(data)
            loss = loss_fn(criterion, outputs, targets)
            fin_targets.append(targets.cpu().detach().numpy())
            fin_outputs.append(outputs.cpu().detach().numpy())
            
            eval_progress.set_description(f'Predicting Eval Batch: {batch_idx}')
            torch.cuda.empty_cache()
            gc.collect()
        
    return np.vstack(fin_outputs), np.vstack(fin_targets)

### Model Training

In [0]:
import os 
EPOCHS = 4


def get_previous_train_info(fold):
    remaining_epochs = EPOCHS
    weights_path = None
    checkpoint_path = None
    epochs_run = sum(['.pt' in filename for filename in os.listdir(PATH+f'model/{MODEL_NAME}/fold{fold}/')])
    if epochs_run > 0: 
        last_epoch = epochs_run - 1
        remaining_epochs = EPOCHS - epochs_run
        weights_path = PATH+f'model/{MODEL_NAME}/fold{fold}/{MODEL_NAME}-{fold}-{last_epoch}.pt'
        checkpoint_path = PATH+f'model/{MODEL_NAME}/fold{fold}/{MODEL_NAME}-{fold}.chkpt'
    return weights_path, checkpoint_path, remaining_epochs, epochs_run


def save_model(model, optimizer, fold, epoch, previous_epoch=0):
    model = {'model': model.state_dict()}
    checkpoint = {
        'optimizer': optimizer.state_dict(),
        'amp': amp.state_dict()
    }
    real_epoch = epoch + previous_epoch
    torch.save(model, PATH+f'model/{MODEL_NAME}/fold{fold}/{MODEL_NAME}-{fold}-{real_epoch}.pt')
    torch.save(checkpoint, PATH+f'model/{MODEL_NAME}/fold{fold}/{MODEL_NAME}-{fold}.chkpt')

In [0]:
from torch.optim.lr_scheduler import CyclicLR


BATCH_SIZE = 5
ACCUM_STEPS = 1
DEVICE = 'cuda'
LR = 3e-5

max_folds = FOLDS -1

for fold, (train_idx, valid_idx) in enumerate(gkf):
    print(f"Current Fold: {fold}")

    if fold == max_folds:   # Full train
        train_set = bert_dataset(inputs=train_inputs, targets=train_targets)        
        valid_set = bert_dataset(inputs=train_inputs, targets=train_targets)
    elif 0 <= fold < max_folds:
        train_set = bert_dataset(inputs=train_inputs, idx=train_idx, targets=train_targets)        
        valid_set = bert_dataset(inputs=train_inputs, idx=valid_idx, targets=train_targets)

    if fold <= max_folds:
        
        weights_path, checkpoint_path, total_epochs, previous_epochs_run = get_previous_train_info(fold)
        if total_epochs > 0:
            train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
            valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

            set_seeds(SEED*fold)
            device = DEVICE
            lr = LR
            epoch_train_steps = int(len(train_idx) / BATCH_SIZE)
            num_train_steps = epoch_train_steps * EPOCHS
            
            model = CustomBERTBaseUncased(bert_model_path= BERT_PATH+'bert-base-uncased-pytorch_model.bin',
                                        bert_cfg_path= BERT_PATH+'bert-base-uncased-config.json',
                                        dropout= 0.2,
                                        output_len= len(target_columns),
                                        metadata_feat_len=metadata_feat_len)
            model.zero_grad();
            model.to(device);
            torch.cuda.empty_cache()

            param_optimizer = list(model.named_parameters())
            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [
                                            {'params': [p for n, p in param_optimizer 
                                                        if not any(nd in n for nd in no_decay)],
                                            'weight_decay': 0.8},
                                            {'params': [p for n, p in param_optimizer 
                                                        if any(nd in n for nd in no_decay)], 
                                            'weight_decay': 0.0}
                                            ]
            optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, eps=4e-5)
            
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", 
                                              keep_batchnorm_fp32=True, loss_scale="dynamic") 
            
            if weights_path:
                print(f'Training remaining {total_epochs} epochs with checkpoint {weights_path}')

                saved_model = torch.load(weights_path, map_location=lambda storage, loc: storage)
                model.load_state_dict(saved_model['model'])
                checkpoint = torch.load(checkpoint_path, 
                                        map_location=lambda storage, loc: storage)
                optimizer.load_state_dict(checkpoint['optimizer'])
                amp.load_state_dict(checkpoint['amp'])
                del saved_model, checkpoint
            
            scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, 
                                                        num_training_steps=num_train_steps)
            
            for epoch in range(total_epochs):
                torch.cuda.empty_cache()
                
                criterion = nn.BCEWithLogitsLoss()
                criterion.to(device)

                train_loop_fn(train_loader, model, criterion, optimizer, device, scheduler)
                
                save_model(model, optimizer, fold, epoch, previous_epoch=previous_epochs_run)

                outputs, targets = eval_loop_fn(valid_loader, model, criterion, device)

                spear, rho_cols = compute_spearmanr(outputs, targets)
                print(f'epoch = {epoch}, spearman = {spear}')
                rho_print = [print(target_columns[i] + " rho: " + str(rho_cols[i]) ) for 
                             i in range(0, len(target_columns))]
                
            del train_set, train_loader, valid_set, valid_loader, model, optimizer, scheduler
            torch.cuda.empty_cache()
            gc.collect()

In [0]:
def compute_spearmanr_withnan(preds, trues):
    rhos = []
    for col_trues, col_pred in zip(trues.T, preds.T):
        rhos.append(
            spearmanr(col_trues, col_pred).correlation)
    return np.mean(rhos), rhos

def score_postprocess(outputs, targets, target_columns, p):
    relu_cols_idx = []
    for unbalanced_col in ['question_not_really_a_question', 'question_type_consequence', 'question_type_spelling']:
        relu_cols_idx.append(target_columns.index(unbalanced_col))
    
    relu_outputs = outputs.copy()
    for idx in relu_cols_idx:
        if relu_outputs[:,idx].max() < 0:
            center_k = abs(np.percentile(relu_outputs[:,idx], p))
        else:
            center_k = 0
        relu_outputs[:,idx] = np.array(torch.functional.F.relu(torch.Tensor(relu_outputs[:,idx] + center_k)))
    
    norm_out = (relu_outputs - relu_outputs.min(axis=0)) / (relu_outputs.max(axis=0) - relu_outputs.min(axis=0))
    spear, rho_cols = compute_spearmanr_withnan(norm_out, targets)
    print(f'epoch = {epoch}, spearman = {spear}')
    rho_print = [print(target_columns[i] + " rho: " + str(rho_cols[i]) ) for i in relu_cols_idx]
    return relu_outputs, norm_out