In [1]:
import numpy as np
import pandas as pd

from tqdm import tqdm, tqdm_notebook

from scipy import stats
from sklearn.model_selection import GroupKFold

import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.data
from transformers import *

import os
import re
import math
import random
from matplotlib import pyplot as plt
import warnings
from math import floor, ceil

warnings.filterwarnings('ignore')
device = torch.device('cuda')
#device = torch.device('cpu')
torch.backends.cudnn.benchmark=True

%matplotlib inline



In [2]:
output_dir="./token_model_config/"

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def spearman_corr(y_true, y_pred):
    if np.ndim(y_pred) == 2:
        corr = np.nan_to_num([stats.spearmanr(y_true[:, i], y_pred[:, i])[0] for i in range(y_true.shape[1])]).mean()
    else:
        corr = stats.spearmanr(y_true, y_pred)[0]
    return corr
  
def calc_each_spearman(valid_y, valid_pred):
    lst = []
    for idx in range(30):
        spearman = spearman_corr(valid_y[:,idx], valid_pred[:,idx])
        lst.append(spearman)
    df = pd.DataFrame(lst).T
    df.columns = class_names
    return df

In [4]:
def _get_segments(tokens, max_seq_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    if len(tokens)>max_seq_length:
        raise IndexError("Token length more than max seq length!")
    segments = []
    first_sep = True
    current_segment_id = 0
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            if first_sep:
                first_sep = False
                current_segment_id = 1#新增 
            else:
                current_segment_id = 1
    return segments + [0] * (max_seq_length - len(tokens))

def _get_ids(tokens, tokenizer, max_seq_length):
    """Token ids from Tokenizer vocab"""
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = token_ids + [0] * (max_seq_length-len(token_ids))
    return input_ids

def _trim_input(title, question, answer, max_sequence_length=512-1, 
                t_max_len=70-1, q_max_len=219, a_max_len=219):#???

    t = tokenizer.tokenize(title)
    q = tokenizer.tokenize(question)
    a = tokenizer.tokenize(answer)
    
    t_len = len(t)
    q_len = len(q)
    a_len = len(a)

    if (t_len+q_len+a_len+4) > max_sequence_length:
        
        if t_max_len > t_len:
            t_new_len = t_len
            a_max_len = a_max_len + floor((t_max_len - t_len)/2)
            q_max_len = q_max_len + ceil((t_max_len - t_len)/2)
        else:
            t_new_len = t_max_len
      
        if a_max_len > a_len:
            a_new_len = a_len 
            q_new_len = q_max_len + (a_max_len - a_len)
        elif q_max_len > q_len:
            a_new_len = a_max_len + (q_max_len - q_len)
            q_new_len = q_len
        else:
            a_new_len = a_max_len
            q_new_len = q_max_len
            
            
        if t_new_len+a_new_len+q_new_len+4 != max_sequence_length:
            raise ValueError("New sequence length should be %d, but is %d" 
                             % (max_sequence_length, (t_new_len+a_new_len+q_new_len+4)))

        t = t[:t_new_len]
        q = norm_token_length(q, q_new_len)
        a = norm_token_length(a, a_new_len)
    
    return t, q, a

def norm_token_length(tokens, l):
    if len(tokens) > l:
        head = l//2
        tail = l - head
        return tokens[:head] + tokens[-tail:]
    else:
        return tokens[:l]

def _convert_to_bert_inputs(title, question, answer, cate, max_sequence_length=512):
    """Converts tokenized input to ids, masks and segments for BERT"""
    #stoken = ["[CLS]"] + [cate] + title + ["[SEP]"] + question + ["[SEP]"] + answer + ["[SEP]"]
    stoken_1 = ["[CLS]"] + [cate] + title + ["[SEP]"] + question + ["[SEP]"]
    stoken_2 = ["[CLS]"] + title + ["[SEP]"] + answer + ["[SEP]"]
    input_ids = _get_ids(stoken_1, tokenizer, max_sequence_length)
    input_ids_2 = _get_ids(stoken_2, tokenizer, max_sequence_length)
    input_segments = _get_segments(stoken_1, max_sequence_length)
    input_segments_2 = _get_segments(stoken_2, max_sequence_length)
    
    #return [input_ids, input_segments]
    return input_ids, input_segments,input_ids_2,input_segments_2

def convert_row(row):
    #c = f"[{row['category'].lower()}]"
    c = f"[{row['category']}]"
    t, q, a = row["question_title"], row["question_body"], row["answer"]
    t, q, a = _trim_input(t, q, a)
    #ids, segments = _convert_to_bert_inputs(t, q, a, c)
    ids, segments, ids2, segments2 = _convert_to_bert_inputs(t, q, a, c)
    #total_input=[np.array([[ids, segments]]),np.array([[ids2, segments2]])]
    # total_input=[]
    # print(np.array([[ids, segments]]).shape)
    # total_input.append(np.array([[ids, segments]]))
    # total_input.append(np.array([[ids2, segments2]]))
    # return total_input
    return np.array([[ids, segments, ids2, segments2]])

In [5]:


train = pd.read_csv('../input/google-quest-challenge/train.csv').fillna(' ')
sub = pd.read_csv('../input/google-quest-challenge/sample_submission.csv').fillna(' ')

#model_class, tokenizer_class = transformer_models_dict[pretrained_weights]
#tokenizer = tokenizer_class.from_pretrained(pretrained_weights)

categories = train["category"].unique().tolist()
categories = [f"[{c}]" for c in categories]
#tokenizer.add_tokens(categories)#??

#tokenizer.added_tokens_encoder#??

In [6]:
model="bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer.add_tokens(categories)


loading configuration file config.json from cache at C:\Users\Lab000/.cache\huggingface\hub\models--bert-base-cased\snapshots\5532cc56f74641d4bb33641f5c76a55d11f846e0\config.json
Model config BertConfig {
  "_name_or_path": "bert-base-cased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.26.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 28996
}

loading file vocab.txt from cache at C:\Users\Lab000/.cache\huggingface\hub\models--bert-base-cased\snapshots\5532cc56f74641d4bb33641f5c76a55d11f846e0\vocab.txt
l

5

In [7]:
out_path="./bert_based_tokenizer"
tokenizer.save_pretrained(output_dir+"tokenizer")

tokenizer config file saved in ./token_model_config/tokenizer\tokenizer_config.json
Special tokens file saved in ./token_model_config/tokenizer\special_tokens_map.json


('./token_model_config/tokenizer\\tokenizer_config.json',
 './token_model_config/tokenizer\\special_tokens_map.json',
 './token_model_config/tokenizer\\vocab.txt',
 './token_model_config/tokenizer\\added_tokens.json',
 './token_model_config/tokenizer\\tokenizer.json')

In [8]:
%%time
X = train.apply(convert_row, axis=1).values
X = np.vstack(X).reshape((len(X), 2048))
assert X.shape == (6079, 2048)

Token indices sequence length is longer than the specified maximum sequence length for this model (649 > 512). Running this sequence through the model will result in indexing errors


Wall time: 7.26 s


In [9]:
X.shape

(6079, 2048)

In [10]:
X[0][:512]

array([  101, 28996,  1327,  1821,   146,  3196,  1165,  1606,  4973,
       11182,  1939,  1104,   170, 23639,  2180, 11039,   136,   102,
        1258,  1773,  1213,  1114, 23639,  2180,  6427,  1113,   118,
        1103,   118, 10928,   113,  2373,   131, 11802, 11039,   117,
        1231,  1964,   119, 11039,  5378,  1113,   170,  2632, 11039,
         117, 14403,  4973, 11182,   114,   117,   146,  1156,  1176,
        1106,  1243,  1748,  1114,  1142,   119,  1109,  2645,  1114,
        1103,  4884,   146,  1215,  1110,  1115,  2817,  1110,  9506,
        1105, 22769,  1654,  1110, 20405,  1120,  1436,   119,  1188,
        2609,  1139, 18011,  1106,  1253,  5174,   113,  2373,   131,
        2044,  9895,   114,  1986,   117,  1112,  3450,  1110,  8320,
         117,   146,  1328,  1106,  1129,  1682,  1106,  5211,  1686,
        9895,   119,   146,  2059,  1115,  1111,  1142,   117, 12365,
       14467,  6697,  1105,  1383,  8637, 22769,  1209,  1129,  1104,
        1632,  1494,

In [11]:
X[0][512:1024]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [12]:
X[0][1024:1536]

array([  101,  1327,  1821,   146,  3196,  1165,  1606,  4973, 11182,
        1939,  1104,   170, 23639,  2180, 11039,   136,   102,   146,
        1198,  1400,  4973, 11182,   117,  1177,  1303,   112,   188,
        1103, 19244,   119,   119,   119,   119,  1184,  1821,   146,
        3196,  1165,  1606, 11182,   119,   119,   119,   136,   138,
        1304,  5602,  2971,  1104,  1609,   106,  3561, 11811,  4253,
        1115,  2462,  1121,  1103,  1322,  1104,  1103, 11039,  1106,
        1103, 15228,  1169,  2195,  1240,  1609,  1317,  6260,   119,
       16544,  1114,  1103,  1864,  1115,  1128,   112,  1325,  1932,
        5211,  2141,  1205,   118,  5363,  1106,  1444,  1106,  2773,
        1240, 11533,  9627,   119,  1109,  1864,  1103, 23639,  2180,
         112,   188,  1132,  1932,  1737,  1304,  1304,  4295,   117,
        1780,   146,  2059,  1115,  3102,   118,  2363,  6262,   123,
         119,   129,  1110,  3155,  1106,  1129,  2385,  4295,   119,
        1109, 18737,

In [13]:
X[0][1536:]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [14]:
class_names = list(sub.columns[1:])
y = train[class_names].values#(6079,30)

lst = []
for idx in range(30):
    t = pd.DataFrame(y[:,idx])[0]
    # print(len(t))
    #print(1-t.value_counts())
    #print(1-t.value_counts()/len(t))
    w_df = (1-t.value_counts()/len(t)).reset_index()
    #print(w_df)
    w_dic = {row["index"]: row[0] for _, row in w_df.iterrows()}
    # print("=======================")
    # print(w_dic)
    w = t.map(w_dic).values#(6079,)
    #print(w.shape)
    lst.append(w)
# print(lst)
# print("==================================")
weights = np.vstack(lst).T#轉置, shape:(6079,30)

import copy
y_true = copy.deepcopy(y)
y = np.hstack([y, weights])#(6079,60)


In [15]:
X[:, 512:].min()

0

In [16]:
def custom_loss(data, targets):
    
    mse = nn.MSELoss(reduction="none")(data[:,:30].sigmoid(), targets[:,:30])
    bce = nn.BCEWithLogitsLoss(reduction='none')(data[:,:30], targets[:,:30])#??
    w =  targets[:,30:]
    #loss = (mse*w).sum() + bce.sum()
    loss = (mse).sum()+ bce.sum()
    return loss

class CustomBert(nn.Module):
    def __init__(self, model,config_path=None):
        super(CustomBert, self).__init__()
        self.config = AutoConfig.from_pretrained(model) 
        
        self.config.num_labels = 30
        self.config.output_hidden_states = True
        #self.n_use_layer = 4 原本
        self.n_use_layer = 2
        self.double_bert= 2
        self.n_labels = self.config.num_labels
        #self.config.save_pretrained("bert_based_config")
        self.config.save_pretrained(output_dir+"config")
        #self.bert = BertModel(config)
        self.bert=AutoModel.from_pretrained(model, config=self.config)
        #self.bert.save_pretrained('bert_based_model')
        self.bert.save_pretrained(output_dir+"model")
        # self.dense1 = nn.Linear(self.config.hidden_size*self.n_use_layer, self.config.hidden_size*self.n_use_layer)
        # self.dense2 = nn.Linear(self.config.hidden_size*self.n_use_layer, self.config.hidden_size*self.n_use_layer)
        # self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        # self.classifier = nn.Linear(self.config.hidden_size*self.n_use_layer, self.config.num_labels)

        self.dense1 = nn.Linear(self.double_bert*self.config.hidden_size*self.n_use_layer, self.double_bert*self.config.hidden_size*self.n_use_layer)
        self.dense2 = nn.Linear(self.double_bert*self.config.hidden_size*self.n_use_layer, self.double_bert*self.config.hidden_size*self.n_use_layer)
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.double_bert*self.config.hidden_size*self.n_use_layer, self.config.num_labels)

        #self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None
                ,input_ids2=None, attention_mask2=None, token_type_ids2=None,position_ids=None, head_mask=None, inputs_embeds=None, labels=None):

        # outputs = self.bert(input_ids,
        #                     attention_mask=attention_mask,
        #                     token_type_ids=token_type_ids,
        #                     position_ids=position_ids,
        #                     head_mask=head_mask,
        #                     inputs_embeds=inputs_embeds)
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids
                            )
        outputs2 = self.bert(input_ids2,
                            attention_mask=attention_mask2,
                            token_type_ids=token_type_ids2
                            )
        
        #print(outputs[2][-1].shape)
        #pooled_output = torch.cat([outputs[2][-1*(i+1)][:,0] for i in range(self.n_use_layer)], dim=1)#把倒數最後4個layer的cls output concat在一起，把4個(8,768) concat，變成(8,3072) #原本
        pooled_output = torch.cat([outputs[2][-1*(i+1)][:,0] for i in range(self.n_use_layer)], dim=1)#把倒數最後2個layer的cls output concat在一起,把2個(8,768) concat，變成(8,1536)
        pooled_output2 = torch.cat([outputs2[2][-1*(i+1)][:,0] for i in range(self.n_use_layer)], dim=1)#同上
        double_pooled_output=torch.cat([pooled_output,pooled_output2],dim=1)#(8,3072)
        
        # pooled_output = self.dense1(pooled_output)
        # pooled_output = self.dense2(pooled_output)
        # pooled_output = self.dropout(pooled_output)
        # logits = self.classifier(pooled_output)

        double_pooled_output = self.dense1(double_pooled_output)
        double_pooled_output = self.dense2(double_pooled_output)
        double_pooled_output = self.dropout(double_pooled_output)
        logits = self.classifier(double_pooled_output)

        outputs = (logits,) + outputs[2:]
        return outputs

In [17]:
model=CustomBert(model)

loading configuration file config.json from cache at C:\Users\Lab000/.cache\huggingface\hub\models--bert-base-cased\snapshots\5532cc56f74641d4bb33641f5c76a55d11f846e0\config.json
Model config BertConfig {
  "_name_or_path": "bert-base-cased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.26.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 28996
}

Configuration saved in ./token_model_config/config\config.json
loading weights file pytorch_model.bin from cache at C:\Users\Lab000/.cache\huggingface\hub\models-

In [18]:
param_optimizer = list(model.named_parameters())
#print(param_optimizer[0])
print("==============================")
print(param_optimizer[0][0])
print("==============================")
print(param_optimizer[0][1])

bert.embeddings.word_embeddings.weight
Parameter containing:
tensor([[-0.0005, -0.0416,  0.0131,  ..., -0.0039, -0.0335,  0.0150],
        [ 0.0169, -0.0311,  0.0042,  ..., -0.0147, -0.0356, -0.0036],
        [-0.0006, -0.0267,  0.0080,  ..., -0.0100, -0.0331, -0.0165],
        ...,
        [-0.0064,  0.0166, -0.0204,  ..., -0.0418, -0.0492,  0.0042],
        [-0.0048, -0.0027, -0.0290,  ..., -0.0512,  0.0045, -0.0118],
        [ 0.0313, -0.0297, -0.0230,  ..., -0.0145, -0.0525,  0.0284]],
       requires_grad=True)


In [19]:
len(param_optimizer)

205

In [20]:

N_FOLD=10
N_BERT_LABEL = 30
SEED = 42
BS = 8
# parameter
n_epoch = 3
learning_rate = 5e-5
max_grad_norm = 1.0

gkf = GroupKFold(n_splits=N_FOLD).split(X=train["question_body"], groups=train["question_body"])#??

spearman_scores = []
best_spearman_lst = []
losses_lst = []
epoch_spearman_lst = []
lr_lst_lst = []
each_speaman_dfs = []
#model=CustomBert(model)
#torch.save(model.config, 'config.pth')
for fold, (train_idx, valid_idx) in enumerate(gkf):#??
  # print(train_idx)
  # print("=======================")
  # print(valid_idx)
  # if fold in [0,1,2,3,4,5,6,7]:
  #   continue

  seed_everything(SEED)

  # Load Model
#   config = BertConfig.from_pretrained(pretrained_weights)
#   model = CustomBert.from_pretrained(pretrained_weights, config=config)

  model = model.to(device)
  model.bert.resize_token_embeddings(len(tokenizer))#??
  model = model.train()
  
  # optimizer setting
  param_optimizer = list(model.named_parameters())
  no_decay = ['bias', 'LayerNorm.weight', 'LayerNorm.bias']
  optimizer_grouped_parameters = []
  max_lrs = []
  for param in param_optimizer:
    if any(n in param[0] for n in no_decay):#weight_decay
      weight_decay = 0.0
    else:
      weight_decay = 0.1
    if param[0].find("bert.encoder.layer") != -1:
      
      n_diff_last = 11 - int(param[0].split(".")[3])
      lr = learning_rate*0.9**n_diff_last
    elif "embeddings" in param[0]:
      lr = learning_rate*0.9**11
    else:
      lr = learning_rate
    max_lrs.append(lr)
    d = {"params": param[1], "weight_decay": weight_decay}
    optimizer_grouped_parameters.append(d)
  optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, correct_bias=True)

  # print(train_idx)
  # print("===========================")
  # print(valid_idx)
  
  # train valid split
  train_x = X[train_idx]
  valid_x = X[valid_idx]
  train_y = y[train_idx]
  valid_y = y[valid_idx]
  
  # set loader  
  train_dataset = torch.utils.data.TensorDataset(torch.tensor(train_x, dtype=torch.long), torch.tensor(train_y, dtype=torch.float))
  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BS, shuffle=True)
  valid_dataset = torch.utils.data.TensorDataset(torch.tensor(valid_x, dtype=torch.long), torch.tensor(valid_y, dtype=torch.float))
  valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=BS, shuffle=False)

  # set schedueler
  num_training_steps = len(train_loader)*n_epoch
  scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lrs, total_steps=num_training_steps)

  model.zero_grad()
  optimizer.zero_grad()
   
  best_spearman = 0
  losses = []
  epoch_spearman = []
  lr_lst = []
  for epoch in range(n_epoch):
    lr = np.array([param_group["lr"] for param_group in optimizer.param_groups]).mean()
    tk0 = tqdm_notebook(enumerate(train_loader), total=len(train_loader), leave=False)
    for i, (x_batch, y_batch) in tk0:
      input_ids = x_batch[:, :512]
      token_ids = x_batch[:, 512:1024]
      input_ids2 = x_batch[:, 1024:1536]
      token_ids2 = x_batch[:, 1536:]
      #print((input_ids > 0))
      #print(token_ids.max())
      #mask=(input_ids > 0).type(torch.uint8)
      #print(mask) 
      y_pred = model(input_ids.to(device), attention_mask=(input_ids > 0).to(device), token_type_ids=token_ids.to(device),
                     input_ids2=input_ids2.to(device),attention_mask2=(input_ids2 > 0).to(device),token_type_ids2=token_ids2.to(device))
      #y_pred = model(input_ids.to(device), attention_mask=mask.to(device), token_type_ids=token_ids.to(device))
      loss = custom_loss(y_pred[0], y_batch.to(device))
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 
      optimizer.step()
      optimizer.zero_grad()
      scheduler.step()
      lr_lst.append(np.array([param_group["lr"] for param_group in optimizer.param_groups]).mean())
      losses.append(float(loss))

    # epoch validation
    for param in model.parameters():
      param.requires_grad=False
    model.eval()

    lst = []
    sum_loss = 0
    for i, (x_batch, y_batch)  in enumerate(valid_loader):
      input_ids = x_batch[:, :512]
      token_ids = x_batch[:, 512:1024]
      input_ids2 = x_batch[:, 1024:1536]
      token_ids2 = x_batch[:, 1536:]

      
      with torch.no_grad():
        #y_pred = model(input_ids.to(device), attention_mask=(input_ids > 0).to(device), token_type_ids=token_ids.to(device))
        y_pred = model(input_ids.to(device), attention_mask=(input_ids > 0).to(device), token_type_ids=token_ids.to(device),
                     input_ids2=input_ids2.to(device),attention_mask2=(input_ids2 > 0).to(device),token_type_ids2=token_ids2.to(device))
        loss = custom_loss(y_pred[0], y_batch.to(device))

      lst.append(y_pred[0].sigmoid().cpu().squeeze().numpy())
      sum_loss += loss.cpu().squeeze().numpy()
    valid_pred = np.vstack(lst)#(608,30)
    
    ave_loss = sum_loss/len(valid_loader)

    spearman_score = spearman_corr(valid_y[:,:N_BERT_LABEL], valid_pred)  
    epoch_spearman.append(spearman_score)
    
    for param in model.parameters():
      param.requires_grad=True
    model.train()
    model_name="double-bert-based-case"
    # print(f"{model}_f{fold}_best")
    # print("=======================================")
    if best_spearman <= spearman_score:
      #torch.save(model.state_dict(), f"{model_name}_f{fold}_best")
      torch.save(model.state_dict(), f"./DoubleBertBasedCase/bce_no_opt_binning/train/{model_name}_f{fold}_best")
      best_spearman = spearman_score
      each_speaman_df = calc_each_spearman(valid_y[:,:N_BERT_LABEL], valid_pred)
      display(each_speaman_df)

    print(f"fold-{fold} epoch {epoch}: {spearman_score} / loss avg: {ave_loss}")
    
  best_spearman_lst.append(best_spearman)
  losses_lst.append(losses)
  epoch_spearman_lst.append(epoch_spearman)
  lr_lst_lst.append(lr_lst)
  each_speaman_dfs.append(each_speaman_df)

  torch.cuda.empty_cache()

  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.322458,0.676385,0.429384,0.251573,0.373057,0.372886,0.237701,0.443761,0.540277,0.089258,...,0.451083,0.1955,0.424136,0.147473,0.180255,0.245593,0.772346,0.352799,0.667963,0.132365


fold-0 epoch 0: 0.3831563310266756 / loss avg: 101.01835441589355


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.320946,0.693226,0.473345,0.249375,0.389262,0.414767,0.296838,0.444004,0.625918,0.085325,...,0.459529,0.222439,0.438847,0.128741,0.145673,0.336104,0.781062,0.348938,0.682836,0.16655


fold-0 epoch 1: 0.4107282618495351 / loss avg: 96.75173046714382


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.339792,0.710979,0.472709,0.290799,0.396883,0.426899,0.298854,0.460809,0.633582,0.081204,...,0.468497,0.239686,0.474852,0.171034,0.186672,0.341407,0.788535,0.353267,0.696803,0.182691


fold-0 epoch 2: 0.420120804333323 / loss avg: 96.35554052654065


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.522171,0.754335,0.396839,0.457342,0.521601,0.531249,0.377484,0.499775,0.657903,0.124731,...,0.575238,0.309214,0.342442,0.264653,0.266966,0.399445,0.843009,0.394039,0.785899,0.240289


fold-1 epoch 0: 0.47692393453946785 / loss avg: 89.85160205238743


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.527194,0.758953,0.408697,0.464222,0.52399,0.548916,0.392563,0.545375,0.663215,0.114488,...,0.616701,0.34407,0.481261,0.285743,0.300912,0.479274,0.844764,0.405833,0.80273,0.337355


fold-1 epoch 1: 0.49846773816482554 / loss avg: 88.81205859937165


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.521871,0.764865,0.414189,0.478131,0.522742,0.562158,0.428582,0.53733,0.674798,0.118213,...,0.608145,0.316748,0.476593,0.263839,0.293461,0.454862,0.854783,0.415288,0.807537,0.347734


fold-1 epoch 2: 0.5009761464362502 / loss avg: 88.22977086117393


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.498675,0.750161,0.437135,0.683871,0.727936,0.639874,0.421408,0.544243,0.760823,0.145686,...,0.687872,0.485082,0.539465,0.382508,0.390958,0.570414,0.890147,0.615778,0.878783,0.377621


fold-2 epoch 0: 0.5603699780296563 / loss avg: 81.79672261288292


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.515091,0.753285,0.437227,0.681301,0.732559,0.632875,0.480699,0.597424,0.750033,0.146652,...,0.678602,0.479176,0.575449,0.356897,0.419952,0.577559,0.885068,0.612297,0.891297,0.362543


fold-2 epoch 1: 0.565403816370125 / loss avg: 80.53574767865632


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.531226,0.76559,0.441978,0.700365,0.750058,0.654956,0.494056,0.604145,0.768949,0.149814,...,0.701996,0.512258,0.594683,0.388845,0.436201,0.580064,0.893217,0.615321,0.900671,0.401198


fold-2 epoch 2: 0.577286566094623 / loss avg: 78.72932444120708


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.547828,0.77929,0.473996,0.818171,0.80717,0.77664,0.446182,0.661124,0.8114,0.165628,...,0.747098,0.587138,0.557744,0.479222,0.454697,0.593017,0.921939,0.673482,0.930514,0.405032


fold-3 epoch 0: 0.6161719556608595 / loss avg: 72.87208501916183


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.54917,0.827166,0.484232,0.839937,0.812326,0.77451,0.472869,0.673923,0.81941,0.165249,...,0.756868,0.584216,0.559888,0.465641,0.463002,0.617918,0.920456,0.679537,0.931559,0.451374


fold-3 epoch 1: 0.6241305365074591 / loss avg: 71.8088398983604


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.548891,0.833737,0.481758,0.848643,0.820589,0.777862,0.491567,0.686898,0.826339,0.167048,...,0.76096,0.612244,0.580325,0.481965,0.469769,0.615533,0.921829,0.685389,0.936728,0.479117


fold-3 epoch 2: 0.6314364568472732 / loss avg: 70.54375392512272


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.627458,0.820898,0.503906,0.84283,0.860955,0.790132,0.583494,0.710836,0.836494,0.195238,...,0.767177,0.589717,0.465196,0.506407,0.463601,0.630853,0.938399,0.733008,0.945091,0.530627


fold-4 epoch 0: 0.6505641209536304 / loss avg: 68.5566675286544


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.658525,0.818581,0.503112,0.863748,0.870029,0.793855,0.553033,0.73534,0.845173,0.187698,...,0.778627,0.601062,0.47749,0.504916,0.45766,0.654959,0.936635,0.731722,0.944619,0.571621


fold-4 epoch 1: 0.6554093181057126 / loss avg: 67.51166674965306


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.662056,0.832168,0.50994,0.870475,0.877588,0.80217,0.568861,0.730975,0.850841,0.194173,...,0.791069,0.600098,0.504238,0.512302,0.476305,0.657521,0.943163,0.737047,0.946786,0.572747


fold-4 epoch 2: 0.6616238452922529 / loss avg: 66.0219316482544


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.667096,0.893992,0.499725,0.907993,0.852256,0.795593,0.684168,0.748601,0.861444,0.13933,...,0.839804,0.652549,0.538442,0.581555,0.540112,0.667305,0.945648,0.791513,0.949986,0.624833


fold-5 epoch 0: 0.6782562279501617 / loss avg: 65.96560834583484


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.680928,0.88804,0.499711,0.903596,0.862848,0.793051,0.676651,0.743541,0.863796,0.138866,...,0.843004,0.682095,0.570312,0.591374,0.525364,0.731103,0.94626,0.796636,0.957611,0.650789


fold-5 epoch 1: 0.6844862704702019 / loss avg: 64.70622328708046


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.687225,0.897933,0.503375,0.919765,0.86483,0.795243,0.700812,0.776654,0.868058,0.139214,...,0.851102,0.692154,0.585022,0.607059,0.538219,0.737555,0.948296,0.799355,0.958806,0.661489


fold-5 epoch 2: 0.6912381243654105 / loss avg: 63.63266156849108


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.712075,0.874934,0.630329,0.911476,0.886945,0.851319,0.676226,0.795521,0.88644,0.15487,...,0.85904,0.767425,0.655833,0.658051,0.615763,0.802546,0.948206,0.790636,0.956736,0.704295


fold-6 epoch 0: 0.7102310139884283 / loss avg: 65.13162427199514


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.735597,0.896768,0.62991,0.915405,0.890507,0.849925,0.710239,0.823751,0.885142,0.154559,...,0.869403,0.756701,0.686513,0.661011,0.617823,0.798599,0.949463,0.794354,0.958156,0.68664


fold-6 epoch 1: 0.7148089116304832 / loss avg: 64.09487036654824


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.772146,0.900222,0.6316,0.921332,0.894785,0.853801,0.720825,0.836171,0.891061,0.154663,...,0.882228,0.767621,0.695001,0.668483,0.624709,0.816798,0.952092,0.795516,0.960649,0.710926


fold-6 epoch 2: 0.7217441686742843 / loss avg: 62.47668567456697


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.79782,0.909188,0.579854,0.926248,0.877442,0.853164,0.729295,0.829938,0.882484,0.208549,...,0.881118,0.785241,0.666853,0.682437,0.607232,0.785116,0.94986,0.789262,0.952955,0.722658


fold-7 epoch 0: 0.7281446669984867 / loss avg: 64.56481371427837


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.802708,0.911747,0.579085,0.927564,0.877221,0.852825,0.734371,0.856905,0.884805,0.208394,...,0.884999,0.780939,0.693153,0.671975,0.610572,0.796367,0.950698,0.78715,0.955794,0.726548


fold-7 epoch 1: 0.7311175591994442 / loss avg: 63.609852138318516


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.818457,0.923052,0.58225,0.932176,0.881361,0.854328,0.754483,0.866913,0.884389,0.207387,...,0.899693,0.798736,0.70241,0.675719,0.616875,0.806726,0.95185,0.791401,0.958406,0.736699


fold-7 epoch 2: 0.7368277298528465 / loss avg: 62.332378688611485


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.83182,0.910277,0.485161,0.924939,0.883162,0.814527,0.773697,0.86055,0.891669,0.19729,...,0.89442,0.772305,0.586858,0.643538,0.619896,0.820444,0.952313,0.80723,0.955432,0.768908


fold-8 epoch 0: 0.7231736338196642 / loss avg: 65.4066750375848


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.827871,0.926493,0.488498,0.927108,0.883871,0.812383,0.780701,0.883305,0.893006,0.197372,...,0.904585,0.788537,0.619769,0.648746,0.612761,0.829886,0.95506,0.805677,0.956071,0.79


fold-8 epoch 1: 0.7280378207630676 / loss avg: 63.73650917253996


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.841665,0.933246,0.488222,0.928601,0.88741,0.815574,0.793938,0.894846,0.893448,0.197372,...,0.91531,0.796798,0.626088,0.658212,0.624675,0.851751,0.956374,0.807863,0.959357,0.809319


fold-8 epoch 2: 0.7336646684985113 / loss avg: 62.39915707236842


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.851997,0.936702,0.50782,0.917183,0.872965,0.812346,0.8082,0.891667,0.847251,0.17069,...,0.913196,0.821071,0.675021,0.673102,0.623485,0.844364,0.946703,0.808509,0.953687,0.850565


fold-9 epoch 0: 0.731432947662144 / loss avg: 61.42920574389006


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.858123,0.947638,0.50881,0.917922,0.876345,0.813737,0.819497,0.910827,0.848179,0.170596,...,0.917437,0.829605,0.676507,0.672005,0.607343,0.840697,0.949325,0.809144,0.955765,0.826946


fold-9 epoch 1: 0.732647245761325 / loss avg: 60.30438844781173


  0%|          | 0/684 [00:00<?, ?it/s]

Unnamed: 0,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,question_not_really_a_question,...,question_well_written,answer_helpful,answer_level_of_information,answer_plausible,answer_relevance,answer_satisfaction,answer_type_instructions,answer_type_procedure,answer_type_reason_explanation,answer_well_written
0,0.865281,0.958894,0.509501,0.918124,0.876556,0.814844,0.829576,0.917616,0.848857,0.170785,...,0.928904,0.834651,0.695534,0.684959,0.619887,0.855583,0.950638,0.811295,0.957032,0.852147


fold-9 epoch 2: 0.7379565390130199 / loss avg: 59.05546404186048


In [21]:
len(tokenizer)

29001