In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import pandas as pd
import numpy as np
import torch
from torch import nn
from transformers import BertModel, AdamW, BertTokenizer, RobertaTokenizer, RobertaModel, AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader, TensorDataset, RandomSampler
from sklearn.model_selection import train_test_split
import random
import os
from tqdm import tqdm
from transformers import get_linear_schedule_with_warmup
import torch.nn.functional as F
from torch.nn import Parameter
import math
from torch.optim import Adam
from sklearn.model_selection import KFold
import urllib.request
from typing import List
from functools import partial
import torchmetrics
# import torch_xla
# import torch_xla.core.xla_model as xm

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
seed_num = 22
random.seed(seed_num)
np.random.seed(seed_num)
torch.manual_seed(seed_num)
torch.cuda.manual_seed_all(seed_num)
kf = KFold(n_splits=5, random_state=seed_num, shuffle=True)

In [None]:
if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print('No GPU available, using the CPU instead.')
# device = xm.xla_device()
# torch.set_default_tensor_type('torch.FloatTensor')

In [None]:
# train = pd.read_csv('/kaggle/input/kor-nli/dacon/open/train_data.csv')
train = pd.read_csv('/kaggle/input/add-train/update_train.csv')
test = pd.read_csv('/kaggle/input/kor-nli/dacon/open/test_data.csv')
submission = pd.read_csv('/kaggle/input/kor-nli/dacon/open/sample_submission.csv')
# trans = pd.read_csv('/kaggle/input/translate/trans_info.csv')
# softlabel = pd.read_csv('/kaggle/input/softlabel/softlabel.csv')

In [None]:
class SNLIDataset(Dataset):

    def __init__(self, data, is_train=True):
        super().__init__()
        self.max_length = 70
        self.label_dict = {"entailment" : 0, "contradiction" : 1, "neutral" : 2}
        self.data = data
        self.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
        self.is_train = is_train
        

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.is_train:
            sentence_1, sentence_2, label = self.data['premise'][idx], self.data['hypothesis'][idx], self.data['label'][idx]
            label = self.label_dict[label]
            label = torch.LongTensor([label])
        else:
            sentence_1, sentence_2 = self.data['premise'][idx], self.data['hypothesis'][idx]
        # remove .
        sentence_1_input_ids = self.tokenizer.encode(sentence_1, add_special_tokens=False)
        sentence_2_input_ids = self.tokenizer.encode(sentence_2, add_special_tokens=False)
        input_ids = sentence_1_input_ids + [2] + sentence_2_input_ids
        if len(input_ids) > self.max_length - 2:
            input_ids = input_ids[:self.max_length - 2]
        # convert list to tensor
        length = torch.LongTensor([len(input_ids) + 2])
        input_ids = torch.LongTensor([0] + input_ids + [2])
        if self.is_train:
            return input_ids, label, length
        else:
            return input_ids, length

In [None]:
batch_size = 32
# train_dataset = TensorDataset(input_ids1, att_mask1, input_ids2, att_mask2, label)
# train_dataset = TensorDataset(input_ids, att_mask, token_type, label)
# train_dataset = TensorDataset(input_ids, att_mask, token_type, label, soft)

In [None]:
train_dataset = SNLIDataset(train)

In [None]:
def collate_to_max_length(batch: List[List[torch.Tensor]], max_len: int = None, fill_values: List[float] = None) -> \
    List[torch.Tensor]:
    """
    pad to maximum length of this batch
    Args:
        batch: a batch of samples, each contains a list of field data(Tensor), which shape is [seq_length]
        max_len: specify max length
        fill_values: specify filled values of each field
    Returns:
        output: list of field batched data, which shape is [batch, max_length]
    """
    # [batch, num_fields]
    lengths = np.array([[len(field_data) for field_data in sample] for sample in batch])
    batch_size, num_fields = lengths.shape
    fill_values = fill_values or [0.0] * num_fields
    # [num_fields]
    max_lengths = lengths.max(axis=0)
    if max_len:
        assert max_lengths.max() <= max_len
        max_lengths = np.ones_like(max_lengths) * max_len

    output = [torch.full([batch_size, max_lengths[field_idx]],
                         fill_value=fill_values[field_idx],
                         dtype=batch[0][field_idx].dtype)
              for field_idx in range(num_fields)]
    for sample_idx in range(batch_size):
        for field_idx in range(num_fields):
            # seq_length
            data = batch[sample_idx][field_idx]
            output[field_idx][sample_idx][: data.shape[0]] = data
    # generate span_index and span_mask
    max_sentence_length = max_lengths[0]
    start_indexs = []
    end_indexs = []
    for i in range(1, max_sentence_length - 1):
        for j in range(i, max_sentence_length - 1):
            # # span大小为10
            # if j - i > 10:
            #     continue
            start_indexs.append(i)
            end_indexs.append(j)
    # generate span mask
    span_masks = []
    for input_ids, label, length in batch:
        span_mask = []
        middle_index = input_ids.tolist().index(2)
        for start_index, end_index in zip(start_indexs, end_indexs):
            if 1 <= start_index <= length.item() - 2 and 1 <= end_index <= length.item() - 2 and (
                start_index > middle_index or end_index < middle_index):
                span_mask.append(0)
            else:
                span_mask.append(1e6)
        span_masks.append(span_mask)
    # add to output
    output.append(torch.LongTensor(start_indexs))
    output.append(torch.LongTensor(end_indexs))
    output.append(torch.LongTensor(span_masks))
    return output  # (input_ids, labels, length, start_indexs, end_indexs, span_masks)

In [None]:
class SICModel(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

        self.W_1 = nn.Linear(hidden_size, hidden_size)
        self.W_2 = nn.Linear(hidden_size, hidden_size)
        self.W_3 = nn.Linear(hidden_size, hidden_size)
        self.W_4 = nn.Linear(hidden_size, hidden_size)

    def forward(self, hidden_states, start_indexs, end_indexs):
        W1_h = self.W_1(hidden_states)  # (bs, length, hidden_size)
        W2_h = self.W_2(hidden_states)
        W3_h = self.W_3(hidden_states)
        W4_h = self.W_4(hidden_states)

        W1_hi_emb = torch.index_select(W1_h, 1, start_indexs)  # (bs, span_num, hidden_size)
        W2_hj_emb = torch.index_select(W2_h, 1, end_indexs)
        W3_hi_start_emb = torch.index_select(W3_h, 1, start_indexs)
        W3_hi_end_emb = torch.index_select(W3_h, 1, end_indexs)
        W4_hj_start_emb = torch.index_select(W4_h, 1, start_indexs)
        W4_hj_end_emb = torch.index_select(W4_h, 1, end_indexs)

        # [w1*hi, w2*hj, w3(hi-hj), w4(hi⊗hj)]
        span = W1_hi_emb + W2_hj_emb + (W3_hi_start_emb - W3_hi_end_emb) + torch.mul(W4_hj_start_emb, W4_hj_end_emb)
        h_ij = torch.tanh(span)
        return h_ij

In [None]:
class InterpretationModel(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.h_t = nn.Linear(hidden_size, 1)

    def forward(self, h_ij, span_masks):
        o_ij = self.h_t(h_ij).squeeze(-1)  # (ba, span_num)
        # mask illegal span
        o_ij = o_ij - span_masks
        # normalize all a_ij, a_ij sum = 1
        a_ij = nn.functional.softmax(o_ij, dim=1)
        # weight average span representation to get H
        H = (a_ij.unsqueeze(-1) * h_ij).sum(dim=1)  # (bs, hidden_size)
        return H, a_ij

In [None]:
class ExplainableModel(nn.Module):
    def __init__(self):
        super().__init__()
#         self.bert_config = RobertaConfig.from_pretrained(bert_dir, output_hidden_states=False)
        self.intermediate = AutoModel.from_pretrained("klue/roberta-large")
        self.span_info_collect = SICModel(1024)
        self.interpretation = InterpretationModel(1024)
        self.output = nn.Linear(1024, 3)

    def forward(self, input_ids, start_indexs, end_indexs, span_masks):
        # generate mask
        attention_mask = (input_ids != 1).long()
        # intermediate layer
        x= self.intermediate(input_ids, attention_mask=attention_mask)  # output.shape = (bs, length, hidden_size)
        # span info collecting layer(SIC)
        h_ij = self.span_info_collect(x.last_hidden_state, start_indexs, end_indexs)
        # interpretation layer
        H, a_ij = self.interpretation(h_ij, span_masks)
        # output layer
        out = self.output(H)
        return out, a_ij

In [None]:
epochs = 3
criterion = nn.CrossEntropyLoss()

In [None]:
# def cal_accuracy(preds, labels):
# #     pred_flat = preds>0.5
#     pred_flat = np.argmax(preds, axis=0).flatten()
#     labels_flat = labels
#     return np.sum(pred_flat == labels_flat) / len(labels_flat)
train_acc = torchmetrics.Accuracy()
def cal_accuracy(X,Y):
    predict_scores = F.softmax(X, dim=1)
    predict_labels = torch.argmax(predict_scores, dim=-1)
    acc = train_acc(predict_labels.to('cpu'), y.cpu())
    return train_acc
#     max_vals, max_indices = torch.max(X, 1)
#     train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
#     return train_acc

In [None]:
## new best 0.87

for fold,(train_idx,valid_idx) in enumerate(kf.split(train_dataset)):
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    valid_subsampler = torch.utils.data.SubsetRandomSampler(valid_idx)
    train_dataLoader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_subsampler, collate_fn=partial(collate_to_max_length, fill_values=[1, 0, 0]))
    valid_dataLoader = DataLoader(train_dataset, batch_size=batch_size, sampler=valid_subsampler, collate_fn=partial(collate_to_max_length, fill_values=[1, 0, 0]))
    best_acc = 0
#     model = MainClassifier(batch_size).to(device)
    model = ExplainableModel().to(device)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, betas=(0.9, 0.98), lr=2e-5, eps=1e-8)
    model.zero_grad()
    print(f'------------fold no---------{fold + 1}----------------------')
    for epoch_i in range(0, epochs):
        # model.train(False)
        model.train()
        total_loss = 0
        train_accuracy = 0
        nb_train_steps = 0
        for batch in tqdm(train_dataLoader):
            batch = tuple(t.to(device) for t in batch)
# (input_ids, labels, length, start_indexs, end_indexs, span_masks)
            sen, label, length, start, end, span = batch
#             s1, m1, s2, m2, label = batch
            outputs, a_ij = model(sen, start, end, span)
            y = label.view(-1)
#             outputs = model(s1, m1, s2, m2)
#             outputs = model(s1, s2, m1, m2)
            # outputs = Arcface(outputs, label)
            # outputs = sigmoid(outputs)
            # loss = cal_mse(outputs, label)
#             loss = criterion(outputs.to(torch.float32), label.unsqueeze(-1).to(torch.float32))
            ce_loss = criterion(outputs, y)
#             reg_loss = 1.0 * a_ij.pow(2).sum(dim=1).mean()
            loss = ce_loss
            total_loss += loss.item()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
            logits = outputs
#             logits = logits.detach().cpu().numpy()
#             label = label.unsqueeze(-1).to('cpu').numpy()
            tmp_train_accuracy = cal_accuracy(logits, label.to('cpu').numpy())
            train_accuracy += tmp_train_accuracy
            nb_train_steps += 1
        avg_train_loss = total_loss / len(train_dataLoader)
        print("")
        print(epoch_i + 1, "  Average training loss: {0:.4f}".format(avg_train_loss))
        print("  Accuracy: {0:.4f}".format(train_accuracy/(nb_train_steps)))
        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        valid_loss = 0
        for batch in tqdm(valid_dataLoader):
            batch = tuple(t.to(device) for t in batch)
#             s1, m1, s2, m2, label = batch
            sen, label, length, start, end, span = batch
            with torch.no_grad():     
#                 outputs = model(s1, m1, s2, m2)
                outputs, a_ij = model(sen, start, end, span)
            # loss = cal_mse(outputs, label)
            # outputs = Arcface(outputs, label)
            # outputs = sigmoid(outputs)
            # print(outputs)
#             loss = criterion(outputs.to(torch.float32), label.to(torch.float32))
#             loss = criterion(outputs.to(torch.float32), label.unsqueeze(-1).to(torch.float32))
            y = label.view(-1)
            ce_loss = criterion(outputs, y)
#             reg_loss = 0.7 * a_ij.pow(2).sum(dim=1).mean()
            loss = ce_loss
            valid_loss += ce_loss.item()
            logits = outputs
#             logits = logits.detach().cpu().numpy()
#             label = label.unsqueeze(-1).to('cpu').numpy()
            tmp_eval_accuracy = cal_accuracy(logits, label.to('cpu').numpy())
            eval_accuracy += tmp_eval_accuracy
            nb_eval_steps += 1
        avg_valid_loss = valid_loss / len(valid_dataLoader)
        valid_accuracy = eval_accuracy/(nb_eval_steps)
#         if avg_valid_loss <= best_loss:
        if best_acc < valid_accuracy:
            best_acc = valid_accuracy
#             best_loss = avg_valid_loss
            torch.save(model, f'/kaggle/working/model{fold + 1}')
            print(f'model{fold + 1} saved')
        print(epoch_i + 1, "  Average valid loss: {0:.4f}".format(avg_valid_loss))
        print("  Accuracy: {0:.4f}".format(valid_accuracy))

In [None]:
test_dataset = SNLIDataset(test, False)

In [None]:
def collate_test(batch: List[List[torch.Tensor]], max_len: int = None, fill_values: List[float] = None) -> \
    List[torch.Tensor]:
    """
    pad to maximum length of this batch
    Args:
        batch: a batch of samples, each contains a list of field data(Tensor), which shape is [seq_length]
        max_len: specify max length
        fill_values: specify filled values of each field
    Returns:
        output: list of field batched data, which shape is [batch, max_length]
    """
    # [batch, num_fields]
    lengths = np.array([[len(field_data) for field_data in sample] for sample in batch])
    batch_size, num_fields = lengths.shape
    fill_values = fill_values or [0.0] * num_fields
    # [num_fields]
    max_lengths = lengths.max(axis=0)
    if max_len:
        assert max_lengths.max() <= max_len
        max_lengths = np.ones_like(max_lengths) * max_len

    output = [torch.full([batch_size, max_lengths[field_idx]],
                         fill_value=fill_values[field_idx],
                         dtype=batch[0][field_idx].dtype)
              for field_idx in range(num_fields)]
    for sample_idx in range(batch_size):
        for field_idx in range(num_fields):
            # seq_length
            data = batch[sample_idx][field_idx]
            output[field_idx][sample_idx][: data.shape[0]] = data
    # generate span_index and span_mask
    max_sentence_length = max_lengths[0]
    start_indexs = []
    end_indexs = []
    for i in range(1, max_sentence_length - 1):
        for j in range(i, max_sentence_length - 1):
            # # span大小为10
            # if j - i > 10:
            #     continue
            start_indexs.append(i)
            end_indexs.append(j)
    # generate span mask
    span_masks = []
    for input_ids, length in batch:
        span_mask = []
        middle_index = input_ids.tolist().index(2)
        for start_index, end_index in zip(start_indexs, end_indexs):
            if 1 <= start_index <= length.item() - 2 and 1 <= end_index <= length.item() - 2 and (
                start_index > middle_index or end_index < middle_index):
                span_mask.append(0)
            else:
                span_mask.append(1e6)
        span_masks.append(span_mask)
    # add to output
    output.append(torch.LongTensor(start_indexs))
    output.append(torch.LongTensor(end_indexs))
    output.append(torch.LongTensor(span_masks))
    return output  # (input_ids, labels, length, start_indexs, end_indexs, span_masks)

In [None]:
test_dataLoader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=partial(collate_test, fill_values=[1, 0]))

In [None]:
folds = 5
pred = []
for i in range(folds) : 
    model = torch.load(f'/kaggle/working/model{i + 1}')
    model.eval()
    result = []
    for batch in tqdm(test_dataLoader):
        batch = tuple(t.to(device) for t in batch)
#         s1, m1, s2, m2 = batch
#         sen, mask, tok = batch
        sen, length, start, end, span = batch
        with torch.no_grad():     
            outputs, a_ij = model(sen, start, end, span)
        result.extend(outputs)    
    pred.append(result)

In [None]:
output = []
for pred1, pred2, pred3, pred4, pred5 in zip(pred[0], pred[1], pred[2], pred[3], pred[4]):
    output.append(int(torch.argmax(pred1 + pred2 + pred3 + pred4 + pred5)))

In [None]:
import datetime

dt_now = datetime.datetime.now()
print(dt_now)
# 2020-09-02 15:13:29.383069

# 날짜만 취득
fname = str(dt_now.date())

In [None]:
label_dict = {"entailment" : 0, "contradiction" : 1, "neutral" : 2}
out = [list(label_dict.keys())[_] for _ in output]

submission["label"] = out

In [None]:
submission

In [None]:
submission.to_csv(f'/kaggle/working/'+ fname + "_1" + ".csv", index = False)