In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
TRAIN_PATH = '/kaggle/input/feedback-prize-2021/train'
TEST_PATH = '/kaggle/input/feedback-prize-2021/test'
TRAIN_LABEL = '/kaggle/input/feedback-prize-2021/train.csv'
SUBMISSION = '/kaggle/working/submission.csv'
LABEL_2_ID = {'PAD':0, 'Claim': 1, 'Evidence': 2, 'Position': 3,
              'Concluding Statement': 4, 'Lead': 5, 'Counterclaim': 6, 'Rebuttal': 7, 'non': 8}
LABEL_BIO = {'PAD':0, 'B1': 1, 'I1': 2, 'B2': 3, 'I2': 4, 'B3': 5, 'I3': 6, 'B4': 7, 'I4': 8, 'B5': 9, 'I5': 10,
             'B6': 11, 'I6': 12, 'B7': 13, 'I7': 14, 'O': 15}
BOUNDARY_LABEL = {'PAD':0, 'B': 1, 'E': 2, 'O': 3}
BOUNDARY_LABEL_UNIDIRECTION = {'PAD':0, 'B': 1, 'O': 3}
TEST_SIZE = 0
DEV_SIZE = 0.1
MAX_LEN = 512
NUM_LABELS = 16
BATCH_SIZE = 6
LEARNING_RATE = 5e-5
NUM_EPOCH = 1
LSTM_HIDDEN = 100
BIAFFINE_DROPOUT = 0.5
BASELINE = False
LONGBERT = False 


In [None]:
from unicodedata import bidirectional
import transformers
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import profiler
from transformers import AutoModel
from typing import Optional

class TModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        if LONGBERT:
            self.transformer = AutoModel.from_pretrained(
                pretrained_model_name_or_path="allenai/longformer-base-4096", cache_dir=MODEL_CACHE_DIR, config=config)
        else:
            self.transformer = AutoModel.from_pretrained(
                pretrained_model_name_or_path="/kaggle/input/robertabase/", config=config)
        self.dropout = nn.Dropout()
        self.relu = nn.ReLU(True)
        if BASELINE:
            self.ner = nn.Linear(config.hidden_size, len(LABEL_BIO))
        else:
            self.boundary_encoder = nn.LSTM(bidirectional=True, input_size=config.hidden_size, hidden_size=LSTM_HIDDEN, batch_first=True)
            self.boundary_decoder = nn.LSTM(bidirectional=False, input_size=LSTM_HIDDEN*2, hidden_size=LSTM_HIDDEN, batch_first=True)
            self.boundary_biaffine = BoundaryBiaffine(LSTM_HIDDEN, LSTM_HIDDEN*2, len(BOUNDARY_LABEL_UNIDIRECTION))
            #self.boundary_seg = BoundarySeg()
            self.boundary_final0 = nn.Linear(config.hidden_size, LSTM_HIDDEN*2)
            self.boundary_final1 = nn.Linear(LSTM_HIDDEN*2, LSTM_HIDDEN*2)
            self.boundary_fc = nn.Linear(LSTM_HIDDEN*2, len(BOUNDARY_LABEL))

            # No *2 since the boundary decoder can only be unidirectioal
            self.seg_final0 = nn.Linear(config.hidden_size, LSTM_HIDDEN)
            self.seg_final1 = nn.Linear(LSTM_HIDDEN, LSTM_HIDDEN)
            self.boundary = nn.ModuleList([self.boundary_encoder, self.boundary_decoder, self.boundary_biaffine, self.boundary_final0, self.boundary_final1, self.boundary_fc])
            self.type_lstm = nn.LSTM(bidirectional=True, input_size=config.hidden_size, hidden_size=LSTM_HIDDEN, batch_first=True)
            self.type_final0 = nn.Linear(config.hidden_size, LSTM_HIDDEN*2)
            self.type_final1 = nn.Linear(LSTM_HIDDEN*2, LSTM_HIDDEN*2)
            self.type_fc = nn.Linear(LSTM_HIDDEN*2, len(LABEL_2_ID))
            self.type_predict = nn.ModuleList([self.type_lstm, self.type_final0, self.type_final1, self.type_fc])
            self.ner_final = nn.Linear(LSTM_HIDDEN*5+config.hidden_size, len(LABEL_BIO))
            self.ner = nn.ModuleList([self.seg_final0, self.seg_final1, self.ner_final])
        self.get_trigram = nn.Conv1d(LSTM_HIDDEN*2, LSTM_HIDDEN*2, 3, padding=1, bias=False)
        self.get_trigram.weight = torch.nn.Parameter(torch.ones([LSTM_HIDDEN*2, LSTM_HIDDEN*2, 3]), requires_grad=False)
        self.get_trigram.requires_grad_ = False
        
    
    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                ):
        outputs = self.transformer(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        if BASELINE:
            ner_result = self.ner(sequence_output)
            return ner_result
        else:
            boundary_hidden = self.boundary_encoder(sequence_output)[0]
            seg_result = self.get_trigram(boundary_hidden.transpose(1,2)).transpose(1,2)
            seg_result = self.boundary_decoder(seg_result)[0]
            seg_matrix = self.boundary_biaffine(seg_result, boundary_hidden)
            #seg_result = F.softmax(self.boundary_biaffine(seg_result, boundary_hidden), dim=2)
            #seg_result = self.boundary_seg(seg_result, boundary_hidden)
            boundary_result = F.logsigmoid(self.boundary_final0(sequence_output)+self.boundary_final1(boundary_hidden)).mul(boundary_hidden)
            type_hidden = self.type_lstm(sequence_output)[0]
            type_result = F.logsigmoid(self.type_final0(sequence_output)+self.type_final1(type_hidden)).mul(type_hidden)
            ner_result = F.logsigmoid(self.seg_final0(sequence_output)+self.seg_final1(seg_result)).mul(seg_result)
            ner_result = self.ner_final(torch.cat([sequence_output, boundary_result, type_result, seg_result], dim=-1))
            #del seg_result, boundary_result, type_result
            #torch.cuda.empty_cache()
            return ner_result, self.boundary_fc(boundary_hidden), self.type_fc(type_hidden), seg_matrix


class BoundarySeg(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, span_adjacency, bound_hidden):
        with profiler.profile(with_stack=True) as p:
            temp = []
            for j in range(MAX_LEN):
                j_sum = []
                for i in range(j, MAX_LEN):
                    result = torch.cat([bound_hidden[:, i], bound_hidden[:, j]], 1)
                    result = result * span_adjacency[:, j, i]
                    j_sum.append(result)
                temp.append(torch.sum(torch.stack(j_sum, dim=0), dim=0))
            final = torch.stack(temp, 1)
        print(p.key_averages(group_by_stack_n=5).table(sort_by="self_cuda_time_total"))
        return final


class PairwiseBilinear(nn.Module):
    """ A bilinear module that deals with broadcasting for efficient memory usage.
    Input: tensors of sizes (N x L1 x D1) and (N x L2 x D2)
    Output: tensor of size (N x L1 x L2 x O)"""
    def __init__(self, input1_size, input2_size, output_size, bias=True):
        super().__init__()

        self.input1_size = input1_size
        self.input2_size = input2_size
        self.output_size = output_size

        self.weight = nn.Parameter(torch.zeros(input1_size, input2_size, output_size), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(output_size), requires_grad=True) if bias else 0

    def forward(self, input1, input2):
        input1_size = list(input1.size())
        input2_size = list(input2.size())
        output_size = [input1_size[0], input1_size[1], input2_size[1], self.output_size]

        # ((N x L1) x D1) * (D1 x (D2 x O)) -> (N x L1) x (D2 x O)
        intermediate = torch.mm(input1.contiguous().view(-1, input1_size[-1]), self.weight.view(-1, self.input2_size * self.output_size))
        # (N x L2 x D2) -> (N x D2 x L2)
        input2 = input2.transpose(1, 2)
        # (N x (L1 x O) x D2) * (N x D2 x L2) -> (N x (L1 x O) x L2)
        output = intermediate.view(input1_size[0], input1_size[1] * self.output_size, input2_size[2]).bmm(input2)
        # (N x (L1 x O) x L2) -> (N x L1 x L2 x O)
        output = output.view(input1_size[0], input1_size[1], self.output_size, input2_size[1]).transpose(2, 3).contiguous()
        # (N x L1 x L2 x O) + (O) -> (N x L1 x L2 x O)
        output = output + self.bias

        return output

class BoundaryBiaffine(nn.Module):
    def __init__(self, input1_size, input2_size, output_size):
        super().__init__()
        self.W_bilin = PairwiseBilinear(input1_size, input2_size, output_size)
        self.U = nn.Linear(input1_size, output_size)
        self.V = nn.Linear(input2_size, output_size)

    def forward(self, input1, input2):
        # Changed from original pairwise biaffine, U is only on input1 (d_j) V is only on input2 (h_i^Bdy)
        return self.W_bilin(input1, input2).add(self.U(input1).unsqueeze(2)).add(self.V(input2).unsqueeze(1))

class FocalLoss(torch.nn.Module):
    """ Focal Loss, as described in https://arxiv.org/abs/1708.02002.
    It is essentially an enhancement to cross entropy loss and is
    useful for classification tasks when there is a large class imbalance.
    x is expected to contain raw, unnormalized scores for each class.
    y is expected to contain class labels.
    Shape:
        - x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
        - y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
    """

    def __init__(self,
                 alpha: Optional[torch.Tensor] = None,
                 gamma: float = 0.,
                 reduction: str = 'mean',
                 ignore_index: int = -100):
        """Constructor.
        Args:
            alpha (Tensor, optional): Weights for each class. Defaults to None.
            gamma (float, optional): A constant, as described in the paper.
                Defaults to 0.
            reduction (str, optional): 'mean', 'sum' or 'none'.
                Defaults to 'mean'.
            ignore_index (int, optional): class label to ignore.
                Defaults to -100.
        """
        if reduction not in ('mean', 'sum', 'none'):
            raise ValueError(
                'Reduction must be one of: "mean", "sum", "none".')

        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.reduction = reduction

        self.log_softmax = torch.nn.LogSoftmax(-1)
        self.nll_loss = torch.nn.NLLLoss(
            weight=alpha, reduction='none', ignore_index=ignore_index)

    def __repr__(self):
        arg_keys = ['alpha', 'gamma', 'ignore_index', 'reduction']
        arg_vals = [self.__dict__[k] for k in arg_keys]
        arg_strs = [f'{k}={v}' for k, v in zip(arg_keys, arg_vals)]
        arg_str = ', '.join(arg_strs)
        return f'{type(self).__name__}({arg_str})'

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        if x.ndim > 2:
            # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
            c = x.shape[1]
            x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
            # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
            y = y.view(-1)

        unignored_mask = y != self.ignore_index
        y = y[unignored_mask]
        if len(y) == 0:
            return 0.
        x = x[unignored_mask]

        # compute weighted cross entropy term: -alpha * log(pt)
        # (alpha is already part of self.nll_loss)
        log_p = self.log_softmax(x)
        ce = self.nll_loss(log_p, y)

        # get true class column from each row
        all_rows = torch.arange(len(x))
        log_pt = log_p[all_rows, y]

        # compute focal term: (1 - pt)^gamma
        pt = log_pt.exp()
        focal_term = (1 - pt)**self.gamma

        # the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
        loss = focal_term * ce

        if self.reduction == 'mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.sum()

        return loss

In [None]:
from posixpath import split
import pandas as pd
import nltk
import numpy as np
import torch
from torch.utils.data import TensorDataset, Dataset

import re
import itertools
import six

# sent_tok = nltk.data.load(f"tokenizers/punkt/English.pickle")
re_bos = re.compile(r'^\s?\W?(?:(?:[A-Z]{1}[a-z]+)|(?:I))\s?[a-z]*')
re_eos = re.compile(r'[?\.!]\'?\"?\s*$')


def preprocessing_train(labels: pd.DataFrame, raw_text: str, tokenizer) -> "tuple[list]":
    """
    Tokenization for training. Insert [NP] tokens at new paragraph

    Args:
        labels: the DataFrame containing label information.
        raw_text: the raw text input as string

    Returns:
        new_segements: list of encoded tokenized inputs, organized in segments
        discourse_type: list of segments' type
        subword_mask: list of subword masks (for post-processing)
    """
    err = False
    new_segements = []
    prev_end = -1
    prev_shift = 0
    prev_label = -1
    subword_mask = []
    seg_labels = []
    raw_text = raw_text.replace('\xa0', ' ')
    raw_text = raw_text.replace('Â', '')
    prev_eos = True
    splitted = re.sub('\n+', ' [NP]', raw_text).split(' ')
    for _, segment in labels.iterrows():
        seg_ids = []
        positions = segment['predictionstring'].split(' ')
        positions = [int(e) for e in positions]
        start = positions[0]
        end = positions[-1]
        # Find is there any text before current discourse start and previous discourse end
        # Or any text before the first discourse start

        # Find if there is still any span before current discourse start and prev discourse end
        if prev_end < start or (prev_end == -1 and start != 0):
            if prev_end == -1:
                hold_seg = splitted[: start]
            else:
                hold_seg = splitted[prev_end: start]
            hold_seg = ' '.join(hold_seg)
            hold_seg = re.sub('\n+', ' [NP] ', hold_seg)
            temp_sents = nltk.tokenize.sent_tokenize(hold_seg)
            temp_ids = []
            temp_label = []
            for i, sent in enumerate(temp_sents):
                tokenized_hold = tokenizer(sent)
                hold_seg_ids = tokenized_hold['input_ids']
                # Remove [CLS] or [SEP] token if segment start is not start of new sentence
                # or segment end not end of sentence
                # If previous segment ends with EOS, assign bos to current segment
                if (not re_bos.search(sent)) and (not prev_eos):
                    hold_seg_ids = hold_seg_ids[1:]
                if not re_eos.search(sent):
                    hold_seg_ids = hold_seg_ids[:-1]
                    prev_eos = False
                else:
                    prev_eos = True
                if i == 0 and prev_shift == len(sent.split(' ')):
                    temp_label.extend([prev_label]*len(hold_seg_ids))
                else:
                    temp_label.extend([8]*len(hold_seg_ids))
                temp_ids.extend(hold_seg_ids)

            if len(temp_ids) != 0 and len(temp_label) != 0:
                new_segements.append(temp_ids)
                seg_labels.append(temp_label)

        seg = splitted[start:end+1]
        seg = ' '.join(seg)
        # Insert special token for New Paragraph (strong indicator for boundary)
        seg = re.sub('\n+', ' [NP] ', seg)

        temp_sents = nltk.tokenize.sent_tokenize(seg)
        temp_ids = []
        temp_label = []
        for sent in temp_sents:
            tokenized_sent = tokenizer(sent)
            a='For example, the text states, ¨A thick atmosphere o'
            seg_ids = tokenized_sent['input_ids']
            # Remove [CLS] or [SEP] token if segment start is not start of new sentence
            # or segment end not end of sentence
            if (not re_bos.search(sent) and not prev_eos) and sent != '[NP]':
                seg_ids = seg_ids[1:]
            if (not re_eos.search(sent)) and sent != '[NP]':
                seg_ids = seg_ids[:-1]
                prev_eos = False
            else:
                prev_eos = True
            current_seg_label = [
                LABEL_2_ID[segment['discourse_type']]]*len(seg_ids)
            temp_label.extend(current_seg_label)
            temp_ids.extend(seg_ids)
        if len(temp_ids) != 0 and len(temp_label) != 0:
            seg_labels.append(temp_label)
            new_segements.append(temp_ids)
        if len(positions) < len(segment['discourse_text'].split(' ')) and segment['discourse_text'].split(' ') != '':
            prev_shift = len(segment['discourse_text'].split(' ')) - len(positions)
            prev_label = current_seg_label[0]
        else:
            prev_shift = 0
        prev_end = end+1

    # Find is there any text after the last discourse end
    if end+1 < len(splitted):
        hold_seg_ids = []
        hold_seg = splitted[end+1:]
        hold_seg = [e for e in hold_seg if e != '']
        if len(hold_seg) > 0:
            hold_seg = ' '.join(hold_seg)
            hold_seg = re.sub('\n+', ' [NP] ', hold_seg)
            temp_sents = nltk.tokenize.sent_tokenize(hold_seg)
            temp_ids = []
            temp_label = []
            for i, sent in enumerate(temp_sents):
                tokenized_hold = tokenizer(sent)
                hold_seg_ids = tokenized_hold['input_ids']
                # Remove [CLS] or [SEP] token if segment start is not start of new sentence
                # or segment end not end of sentence
                if not re_bos.search(sent) and not prev_eos:
                    hold_seg_ids = hold_seg_ids[1:]
                if not re_eos.search(sent):
                    hold_seg_ids = hold_seg_ids[:-1]
                    prev_eos = False
                else:
                    prev_eos = True
                if i == 0 and prev_shift == len(sent.split(' ')):
                    temp_label.extend([prev_label]*len(hold_seg_ids))
                else:
                    temp_label.extend([8]*len(hold_seg_ids))
                temp_ids.extend(hold_seg_ids)
            if len(temp_ids) != 0 and len(temp_label) != 0:
                new_segements.append(temp_ids)
                seg_labels.append(temp_label)


    tokenized = []
    for e in new_segements:
        tokenized.extend(tokenizer.convert_ids_to_tokens(e))
    
    tok_counter = 0
    hold = ''
    for i, tok in enumerate(tokenized):
        # Assign special token subword mask -1
        # '[NP]' needs to be treated differently, as part of word
        if tok in ['<s>', '</s>']:
            subword_mask.append(-1)
            continue
        # RoBERTa and Longformer tokenizer use this char to denote start of new word
        if tok.startswith('Ġ'):
            tok = tok[1:]
        #if tok in ['Â']:
        #    print()
        #    continue
        # If BERT token matches simple split token, append position as subword mask
        if splitted[tok_counter] == tok:
            subword_mask.append(tok_counter)
            tok_counter+=1
            hold = ''
        # Else, combine the next BERT token until there is a match (e.g.: original: "Abcdefgh", BERT: "Abc", "def", "gh")
        # each subword of full word are assigned same full word position
        else:
            hold+=tok
            subword_mask.append(tok_counter)
            if splitted[tok_counter] == hold:
                hold = ''
                tok_counter+=1
        # if combined token length larger than 50, most likely something wrong happened
        if len(hold)>50:
            err = True
    assert len(subword_mask) == len(list(itertools.chain.from_iterable(new_segements))) == len(list(itertools.chain.from_iterable(seg_labels))), "Length of ids/labels/subword_mask mismatch"
    return new_segements, seg_labels, subword_mask, err


def preprocessing_test(raw_text: str, tokenizer) -> "tuple[list]":
    """
    Tokenization or testing (without ground truth), simply tokenize and output subword mask
    Need to take care of [NP] tokens when decoding
    """
    ids = []
    subword_mask = []
    err = False
    raw_text = raw_text.replace('\xa0', ' ')
    raw_text = re.sub('\n+', ' [NP] ', raw_text)
    temp_sents = nltk.tokenize.sent_tokenize(raw_text)
    for sent in temp_sents:
        tokenized_sent = tokenizer(sent)
        ids.extend(tokenized_sent['input_ids'])

    tokenized = tokenizer.convert_ids_to_tokens(ids)
    splitted = re.sub('\n+', ' [NP] ', raw_text).split(' ')
    tok_counter = 0
    hold = ''
    for i, tok in enumerate(tokenized):
        # Assign special token subword mask -1
        # '[NP]' needs to be treated differently, as part of word
        if tok in ['<s>', '</s>']:
            subword_mask.append(-1)
            continue
        # RoBERTa and Longformer tokenizer use this char to denote start of new word
        if tok.startswith('Ġ'):
            tok = tok[1:]
        # If BERT token matches simple split token, append position as subword mask
        if splitted[tok_counter] == tok:
            subword_mask.append(tok_counter)
            tok_counter+=1
            hold = ''
        # Else, combine the next BERT token until there is a match (e.g.: original: "Abcdefgh", BERT: "Abc", "def", "gh")
        # each subword of full word are assigned same full word position
        else:
            hold+=tok
            subword_mask.append(tok_counter)
            if splitted[tok_counter] == hold:
                hold = ''
                tok_counter+=1
        # if combined token length larger than 50, most likely something wrong happened
        if len(hold)>50:
            err = True
    return ids, subword_mask, err


class SlidingWindowFeature():
    def __init__(self, doc_id, input_ids, labels_type, labels_bio, labels_boundary, subword_masks, cls_pos, sliding_window, tokenizer=None) -> None:
        self.doc_id = doc_id
        self.tokenizer = tokenizer
        self.cls_pos = cls_pos
        self.sliding_window = sliding_window
        if sliding_window is not None:
            self.input_ids = [input_ids[start:end]
                              for start, end in sliding_window]
            self.subword_masks = [subword_masks[start:end]
                                  for start, end in sliding_window]
            self.labels_type = [labels_type[start:end]
                                for start, end in sliding_window]
            self.labels_bio = [labels_bio[start:end]
                               for start, end in sliding_window]
            self.labels_boundary = [labels_boundary[start:end]
                                    for start, end in sliding_window]
            self.num_windows = len(sliding_window)
        else:
            self.input_ids = [input_ids]
            self.subword_masks = [subword_masks]
            self.labels_type = [labels_type]
            self.labels_bio = [labels_bio]
            self.labels_boundary = [labels_boundary]
            self.sliding_window = [[0, len(input_ids)]]
            self.num_windows = 1


class SlidingWindowFeatureTest():
    def __init__(self, doc_id, input_ids, subword_masks, cls_pos, sliding_window, tokenizer=None) -> None:
        self.doc_id = doc_id
        self.tokenizer = tokenizer
        self.cls_pos = cls_pos
        self.sliding_window = sliding_window
        if sliding_window is not None:
            self.input_ids = [input_ids[start:end]
                              for start, end in sliding_window]
            self.subword_masks = [subword_masks[start:end]
                                  for start, end in sliding_window]
            self.num_windows = len(sliding_window)
        else:
            self.input_ids = [input_ids]
            self.subword_masks = [subword_masks]
            self.sliding_window = [[0, len(input_ids)]]
            self.num_windows = 1


class DocFeature():
    def __init__(self, doc_id: str, raw_text: str, train_or_test: str, seg_labels=None, tokenizer=None) -> None:
        self.doc_id = doc_id
        self.tokenizer = tokenizer
        if train_or_test == 'train':
            self.input_ids, self.seg_labels, self.subword_masks, self.err = preprocessing_train(
                labels=seg_labels, raw_text=raw_text, tokenizer=tokenizer)
            #self.labels = [[label]*len(seg) for seg, label in zip(self.input_ids, label_ids)]
            self.labels_bio = [self.convert_label_to_bio(label, len(
                seg)) for seg, label in zip(self.input_ids, self.seg_labels)]
            self.labels_bio = list(
                itertools.chain.from_iterable(self.labels_bio))
            self.labels = list(itertools.chain.from_iterable(self.seg_labels))
            self.input_ids = list(
                itertools.chain.from_iterable(self.input_ids))
            self.boundary_pos = self.get_boundary_pos()
            self.cls_pos = [index for index, element in enumerate(
                self.input_ids) if element == tokenizer.cls_token_id]
            self.count = self.get_sent_level_label()
            self.boundary_label = self.convert_label_to_bound()
            self.sliding_window = self.create_sliding_window_train()
        elif train_or_test == 'test':
            self.input_ids, self.subword_masks, self.err = preprocessing_test(
                raw_text, tokenizer=tokenizer)
            self.cls_pos = [index for index, element in enumerate(
                self.input_ids) if element == tokenizer.cls_token_id]
            self.sliding_window = self.create_sliding_window_test()
        else:
            raise NameError('Should be either train/test')

    def convert_label_to_bio(self, label, seq_len):
        if label[0] != 8:
            temp = [LABEL_BIO[f'I{label[0]}']]*seq_len
            temp[0] = LABEL_BIO[f'B{label[0]}']
        else:
            temp = [LABEL_BIO['O']]*seq_len
        return temp

    def convert_label_to_bound(self):
        bound = []
        for i, e in enumerate(self.labels_bio):
            if e in [1, 3, 5, 7, 9, 11, 13]:
                bound.append(1)
                if i == 0:
                    pass
                else:
                    bound[-2] = 2
            elif e == 0:
                bound.append(0)
            else:
                bound.append(3)
        return bound

    def get_sent_level_label(self):
        prev_cls = 0
        labels = list(itertools.chain.from_iterable(self.seg_labels))
        count = 0
        for pos in self.cls_pos:
            distinct = set(labels[prev_cls:pos])
            if (8 in distinct and len(distinct) > 2) or (8 not in distinct and len(distinct) > 1):
                count += 1
            prev_cls = pos
        return count

    def get_boundary_pos(self):
        boundary = []
        prev = 0
        for seg in self.seg_labels:
            boundary.append(len(seg)+prev)
            prev = len(seg) + prev
        return boundary

    def create_sliding_window_train(self):
        if len(self.input_ids) <= MAX_LEN:
            return SlidingWindowFeature(doc_id=self.doc_id, input_ids=self.input_ids, labels_type=self.labels, labels_bio=self.labels_bio,
                                        labels_boundary=self.boundary_label, subword_masks=self.subword_masks, cls_pos=self.cls_pos, sliding_window=None)
        else:
            # Create intersection of boundary pos list and cls token pos list, as we can only create slice on cls token, not any boundary
            bound_cls_pos = list(
                set(self.boundary_pos).intersection(set(self.cls_pos)))
            bound_cls_pos.append(len(self.input_ids))
            bound_cls_pos.sort()
            slice_pos_list = []
            slice_start = 0
            slice_end = -1
            # For the case that last candidate boundary is less than MAX_LEN: slice there.
            if max(bound_cls_pos) < MAX_LEN:
                slice_pos_list.append([0, bound_cls_pos[-1]])
                if len(bound_cls_pos) > 1:
                    slice_pos_list.append(
                        [bound_cls_pos[-2]], len(self.input_ids))
                else:
                    print()
            for pos in bound_cls_pos:
                if (pos - slice_start) > MAX_LEN:
                    # When the two adjacent boundary pos having distance larger than MAX_LEN, or the first boundary is already more than MAX_LEN
                    if (slice_end == -1 or slice_end == slice_start) or (bound_cls_pos.index(slice_end) == 0):
                        prev = 0
                        for i, idx in enumerate(self.cls_pos[self.cls_pos.index(slice_start):]):
                            if idx > MAX_LEN:
                                break
                            prev = idx
                        slice_end = prev
                        slice_pos_list.append([slice_start, slice_end])
                        try:
                            slice_start = self.cls_pos[i-2]
                        except IndexError:
                            slice_start = self.cls_pos[i-1]
                        if slice_end in bound_cls_pos:
                            if bound_cls_pos.index(slice_end) == 0:
                                print()
                    # Normal case, finding the n'th boundary having distance > MAX_LEN with slice_start
                    # Make the n-1'th boundary become slice_end
                    else:
                        # When the n-1'th boundary is having distance with current slice start > MAX_LEN
                        # Just find the first sentence boundary within it that has distacne < MAX_LEN
                        if slice_end-slice_start > MAX_LEN:
                            for idx in self.cls_pos[self.cls_pos.index(slice_start):]:
                                if slice_end-idx <= MAX_LEN:
                                    break
                            slice_start = idx
                        # If the n-1'th boundary is too short, pick a sentence boundary after it and before the next slice start
                        # But skip when reaching the end of document, as there would no more boundary after it
                        if slice_end-slice_start < 150 and pos != bound_cls_pos[-1]:
                            next_start = bound_cls_pos[bound_cls_pos.index(
                                slice_end)+1]
                            if next_start - slice_start <= MAX_LEN:
                                candidate_end = self.cls_pos[self.cls_pos.index(
                                    slice_end): self.cls_pos.index(next_start)]
                                slice_end = candidate_end[len(
                                    candidate_end)-len(candidate_end) // 3]
                            else:
                                candidate_end = self.cls_pos[self.cls_pos.index(
                                    slice_end): self.cls_pos.index(next_start)]
                                for idx in candidate_end:
                                    if idx-slice_start > MAX_LEN:
                                        break
                                    slice_end = idx
                            slice_pos_list.append([slice_start, slice_end])
                            slice_start = candidate_end[len(
                                candidate_end) // 3]
                        else:
                            se_index = bound_cls_pos.index(slice_end)
                            slice_pos_list.append([slice_start, slice_end])
                            slice_start = bound_cls_pos[:se_index][-1]
                slice_end = pos
            if slice_pos_list[-1][1] != len(self.input_ids):
                if slice_start != slice_pos_list[-1][0]:
                    if len(self.input_ids) - slice_start:
                        for idx in self.cls_pos[self.cls_pos.index(slice_start):]:
                            if slice_end-idx <= MAX_LEN:
                                break
                        slice_start = idx
                    slice_pos_list.append([slice_start, len(self.input_ids)])
                else:
                    for idx in self.cls_pos[self.cls_pos.index(slice_start):]:
                        if slice_end-idx <= MAX_LEN:
                            break
                    slice_start = idx
                    slice_pos_list.append([slice_start, len(self.input_ids)])
            return SlidingWindowFeature(doc_id=self.doc_id, input_ids=self.input_ids, labels_type=self.labels, labels_bio=self.labels_bio,
                                        labels_boundary=self.boundary_label, subword_masks=self.subword_masks, cls_pos=self.cls_pos, sliding_window=slice_pos_list)

    def create_sliding_window_test(self):
        if len(self.input_ids) <= MAX_LEN:
            return SlidingWindowFeatureTest(doc_id=self.doc_id, input_ids=self.input_ids, subword_masks=self.subword_masks, cls_pos=self.cls_pos, sliding_window=None)
        else:
            slice_pos_list = []
            slice_start = 0
            slice_end = -1
            if len(self.cls_pos) == 1:
                return SlidingWindowFeatureTest(doc_id=self.doc_id, input_ids=self.input_ids, subword_masks=self.subword_masks, cls_pos=self.cls_pos, sliding_window=[[0, MAX_LEN]])
            if max(self.cls_pos) <= MAX_LEN:
                slice_pos_list.append([0, self.cls_pos[-1]])
                try:
                    slice_pos_list.append(
                        [self.cls_pos[-4], len(self.input_ids)])
                except IndexError:
                    slice_pos_list.append(
                        [self.cls_pos[-2], len(self.input_ids)])
            else:
                for i, pos in enumerate(self.cls_pos):
                    if (pos-slice_start) > MAX_LEN:
                        slice_pos_list.append([slice_start, slice_end])
                        se_index = self.cls_pos.index(slice_end)
                        ss_index = self.cls_pos.index(slice_start)
                        temp = self.cls_pos[ss_index:se_index]
                        if len(temp) > 2:
                            slice_start = temp[len(temp) - (len(temp)//3)]
                        else:
                            slice_start = temp[-1]
                    slice_end = pos
                    if i == len(self.cls_pos)-1:
                        slice_pos_list.append([slice_start, slice_end])
                        se_index = self.cls_pos.index(slice_end)
                        ss_index = self.cls_pos.index(slice_start)
                        temp = self.cls_pos[ss_index:se_index]
                        if len(temp) > 2:
                            slice_start = temp[len(temp) - (len(temp)//3)]
                        else:
                            slice_start = temp[-1]
                if slice_pos_list[-1][1] != len(self.input_ids):
                    if slice_start != slice_pos_list[-1][0]:
                        if len(self.input_ids) - slice_start:
                            for idx in self.cls_pos[self.cls_pos.index(slice_start):]:
                                if slice_end-idx <= MAX_LEN:
                                    break
                            slice_start = idx
                        slice_pos_list.append(
                            [slice_start, len(self.input_ids)])
            return SlidingWindowFeatureTest(doc_id=self.doc_id, input_ids=self.input_ids, subword_masks=self.subword_masks, cls_pos=self.cls_pos, sliding_window=slice_pos_list)


def pad_sequences(sequences, maxlen=None, dtype='int32',
                  padding='pre', truncating='pre', value=0.):
    """
    """
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    num_samples = len(sequences)

    lengths = []
    sample_shape = ()
    flag = True

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.

    for x in sequences:
        try:
            lengths.append(len(x))
            if flag and len(x):
                sample_shape = np.asarray(x).shape[1:]
                flag = False
        except TypeError:
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))

    if maxlen is None:
        maxlen = np.max(lengths)

    is_dtype_str = np.issubdtype(
        dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
    if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
        raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
                         "You should set `dtype=object` for variable length strings."
                         .format(dtype, type(value)))

    x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
    for idx, s in enumerate(sequences):
        if not len(s):
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" '
                             'not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError('Shape of sample %s of sequence at position %s '
                             'is different from expected shape %s' %
                             (trunc.shape[1:], idx, sample_shape))

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x


def create_tensor_ds(features: "list[DocFeature]") -> TensorDataset:
    input_ids = []
    labels_bio = []
    labels_boundary = []
    labels_type = []
    attention_masks = []
    subword_masks = []
    cls_pos = []
    for feat in features:
        input_ids.append(feat.input_ids)
        labels_bio.append(feat.labels_bio)
        labels_boundary.append(feat.boundary_label)
        labels_type.append(feat.labels)
        attention_masks.append([1]*len(feat.input_ids))
        subword_masks.append(feat.subword_masks)
        cls_pos.append(feat.cls_pos)
    input_ids = pad_sequences(input_ids,
                              maxlen=MAX_LEN, value=0, padding="post",
                              dtype="long", truncating="post").tolist()
    input_ids = torch.LongTensor(input_ids)
    labels_bio = pad_sequences(labels_bio,
                           maxlen=MAX_LEN, value=0, padding="post",
                           dtype="long", truncating="post").tolist()
    labels_bio = torch.LongTensor(labels_bio)
    labels_boundary = pad_sequences(labels_boundary,
                           maxlen=MAX_LEN, value=0, padding="post",
                           dtype="long", truncating="post").tolist()
    labels_boundary = torch.LongTensor(labels_boundary)
    labels_type = pad_sequences(labels_type,
                           maxlen=MAX_LEN, value=0, padding="post",
                           dtype="long", truncating="post").tolist()
    labels_type = torch.LongTensor(labels_type)
    attention_masks = pad_sequences(attention_masks,
                                    maxlen=MAX_LEN, value=0, padding="post",
                                    dtype="long", truncating="post").tolist()
    attention_masks = torch.LongTensor(attention_masks)
    subword_masks = pad_sequences(subword_masks,
                                  maxlen=MAX_LEN, value=0, padding="post",
                                  dtype="long", truncating="post").tolist()
    subword_masks = torch.LongTensor(subword_masks)
    return TensorDataset(input_ids, labels_type, labels_bio, labels_boundary, attention_masks, subword_masks)


class SlidingWindowDataset(Dataset):
    def __init__(self, input_ids: torch.Tensor,  labels_type: torch.Tensor, labels_bio: torch.Tensor, labels_boundary: torch.Tensor, attention_masks: torch.Tensor,
                 subword_masks: torch.Tensor, cls_pos: list, sliding_window_pos: "list[list]") -> None:
        self.input_ids = input_ids
        self.labels_type = labels_type
        self.labels_bio = labels_bio
        self.labels_boundary = labels_boundary
        self.attention_masks = attention_masks
        self.subword_masks = subword_masks
        self.cls_pos = cls_pos
        self.sliding_window_pos = sliding_window_pos

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.labels_type[idx], self.labels_bio[idx], self.labels_boundary[idx], self.attention_masks[idx], self.subword_masks[idx], self.cls_pos[idx], self.sliding_window_pos[idx]


class SlidingWindowDatasetTest(Dataset):
    def __init__(self, input_ids: torch.Tensor, attention_masks: torch.Tensor, subword_masks: torch.Tensor,
                 cls_pos: list, sliding_window_pos: "list[list]") -> None:
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.subword_masks = subword_masks
        self.cls_pos = cls_pos
        self.sliding_window_pos = sliding_window_pos

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.attention_masks[idx], self.subword_masks[idx], self.cls_pos[idx], self.sliding_window_pos[idx]


def create_tensor_ds_sliding_window(features: "list[DocFeature]") -> TensorDataset:
    c = 0
    input_ids = []
    labels_bio = []
    labels_type = []
    labels_boundary = []
    attention_masks = []
    subword_masks = []
    cls_pos = []
    sliding_window_pos = []
    for feat in features:
        for i in range(feat.sliding_window.num_windows):
            # If the document contains no puncutation at all... No way but just delete it
            if len(feat.cls_pos) == 1:
                continue
            input_ids.append(feat.sliding_window.input_ids[i])
            labels_bio.append(feat.sliding_window.labels_bio[i])
            labels_boundary.append(feat.sliding_window.labels_boundary[i])
            labels_type.append(feat.sliding_window.labels_type[i])
            attention_masks.append([1]*len(feat.sliding_window.input_ids[i]))
            subword_masks.append(feat.sliding_window.subword_masks[i])
            cls_pos.append(feat.sliding_window.cls_pos)
            sliding_window_pos.append(
                [feat.sliding_window.sliding_window, feat.doc_id])
            if len(feat.sliding_window.input_ids[i]) > MAX_LEN:
                c += 1
            if feat.sliding_window.sliding_window[i][0] == feat.sliding_window.sliding_window[i][1]:
                print()
            if i > 0 and feat.sliding_window.sliding_window[i][0] == feat.sliding_window.sliding_window[i-1][1]:
                c += 1
    input_ids = pad_sequences(input_ids,
                              maxlen=MAX_LEN, value=0, padding="post",
                              dtype="long", truncating="post").tolist()
    input_ids = torch.LongTensor(input_ids)
    labels_type = pad_sequences(labels_type,
                                maxlen=MAX_LEN, value=0, padding="post",
                                dtype="long", truncating="post").tolist()
    labels_type = torch.LongTensor(labels_type)
    labels_bio = pad_sequences(labels_bio,
                               maxlen=MAX_LEN, value=0, padding="post",
                               dtype="long", truncating="post").tolist()
    labels_bio = torch.LongTensor(labels_bio)
    labels_boundary = pad_sequences(labels_boundary,
                                    maxlen=MAX_LEN, value=0, padding="post",
                                    dtype="long", truncating="post").tolist()
    labels_boundary = torch.LongTensor(labels_boundary)
    attention_masks = pad_sequences(attention_masks,
                                    maxlen=MAX_LEN, value=0, padding="post",
                                    dtype="long", truncating="post").tolist()
    attention_masks = torch.LongTensor(attention_masks)
    subword_masks = pad_sequences(subword_masks,
                                  maxlen=MAX_LEN, value=0, padding="post",
                                  dtype="long", truncating="post").tolist()
    subword_masks = torch.LongTensor(subword_masks)
    return SlidingWindowDataset(input_ids, labels_type, labels_bio, labels_boundary, attention_masks, subword_masks, cls_pos, sliding_window_pos)


def create_tensor_ds_sliding_window_test(features: "list[DocFeature]") -> TensorDataset:
    c = 0
    input_ids = []
    attention_masks = []
    subword_masks = []
    cls_pos = []
    sliding_window_pos = []
    for feat in features:
        for i in range(feat.sliding_window.num_windows):
            # If the document contains no puncutation at all... No way but just delete it
            if len(feat.cls_pos) == 1:
                continue
            input_ids.append(feat.sliding_window.input_ids[i])
            attention_masks.append([1]*len(feat.sliding_window.input_ids[i]))
            subword_masks.append(feat.sliding_window.subword_masks[i])
            cls_pos.append(feat.sliding_window.cls_pos)
            sliding_window_pos.append(
                [feat.sliding_window.sliding_window[i], feat.doc_id])
            if len(feat.sliding_window.input_ids[i]) > MAX_LEN:
                c += 1
            if feat.sliding_window.sliding_window[i][0] == feat.sliding_window.sliding_window[i][1]:
                print()
            if i > 0 and feat.sliding_window.sliding_window[i][0] == feat.sliding_window.sliding_window[i-1][1]:
                c += 1
    input_ids = pad_sequences(input_ids,
                              maxlen=MAX_LEN, value=0, padding="post",
                              dtype="long", truncating="post").tolist()
    input_ids = torch.LongTensor(input_ids)
    attention_masks = pad_sequences(attention_masks,
                                    maxlen=MAX_LEN, value=0, padding="post",
                                    dtype="long", truncating="post").tolist()
    attention_masks = torch.LongTensor(attention_masks)
    subword_masks = pad_sequences(subword_masks,
                                  maxlen=MAX_LEN, value=0, padding="post",
                                  dtype="long", truncating="post").tolist()
    subword_masks = torch.LongTensor(subword_masks)
    return SlidingWindowDatasetTest(input_ids, attention_masks, subword_masks, cls_pos, sliding_window_pos)


In [None]:
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoConfig, AdamW
import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler
from torch.profiler import profile
from torch.cuda.amp import autocast
from tqdm import tqdm
from sklearn.metrics import classification_report

import itertools
import random
import re
import os
import gc
import math

In [None]:
if LONGBERT:
    TOKENIZER = AutoTokenizer.from_pretrained("/kaggle/input/XXXXX/") # waitfor uploading
else:
    TOKENIZER = AutoTokenizer.from_pretrained("/kaggle/input/robertabase/") 
TOKENIZER.add_special_tokens({'additional_special_tokens': ['[NP]']})

In [None]:
def bound_to_matrix(bound: torch.Tensor) -> torch.LongTensor:
    """
    To convert the boundary list to a adjacency matrix (like) that represents
    the segment span.
    1: start->end
    2: end->start (or simply become 0 to ignore backward link)

    input:
        bound: [batch_size, seq_length]
    """
    bs = bound.size(0)
    mat = torch.zeros([bs, MAX_LEN, MAX_LEN], dtype=torch.long)
    for b, seq in enumerate(bound):
        for i, e in enumerate(seq):
            if e == 1:
                for j in range(i, MAX_LEN):
                    if seq[j] == 2:
                        mat[b][i][j] = 1
                        break
    return mat

all_doc_ids = []
all_doc_texts = []
for f in tqdm(list(os.listdir(TRAIN_PATH))):
    all_doc_ids.append(f.replace('.txt', ''))
    all_doc_texts.append(open(os.path.join(TRAIN_PATH, f),
                         'r', encoding='utf-8').read())

test_doc_ids = []
test_doc_texts = []
for f in tqdm(list(os.listdir(TEST_PATH))):
    test_doc_ids.append(f.replace('.txt', ''))
    test_doc_texts.append(
        open(os.path.join(TEST_PATH, f), 'r', encoding='utf-8').read())

all_labels = pd.read_csv(TRAIN_LABEL)

def del_list_idx(l, id_to_del):
    arr = np.array(l, dtype='int32')
    return list(np.delete(arr, id_to_del))


scope_len = len(all_doc_ids)
train_len = math.floor((1 - TEST_SIZE - DEV_SIZE) * scope_len)
dev_len = scope_len - train_len
scope_index = list(range(scope_len))
train_index = random.sample(scope_index, k=train_len)

train_doc_ids = [all_doc_ids[i] for i in train_index]
train_doc_texts = [all_doc_texts[i] for i in train_index]

In [None]:
print('Dataset Loading...')
if LONGBERT:
    train_ds = torch.load('/kaggle/input/forwritingcompetition/longformer_train_ds.pt')
else:
    train_ds = torch.load('/kaggle/input/forwritingcompetition/train_ds.pt')

In [None]:
print('Create Dataset')
dev_features = [DocFeature(doc_id=ids, raw_text=test_doc_texts[test_doc_ids.index(
    ids)], train_or_test='test', tokenizer=TOKENIZER) for ids in test_doc_ids]

if LONGBERT:
    dev_ds = create_tensor_ds(dev_features)

else:
    dev_ds = create_tensor_ds_sliding_window_test(dev_features)


In [None]:
train_sp = RandomSampler(train_ds)
dev_sp = RandomSampler(dev_ds)
def custom_batch_collation(x):
    num_elements = len(x[0])
    return_tup = [[] for _ in range(num_elements)]
    for row in x:
        for i, e in enumerate(row):
            return_tup[i].append(e)
    return return_tup

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=train_sp, collate_fn=custom_batch_collation)
dev_dl = DataLoader(dev_ds, batch_size=2, sampler=dev_sp, collate_fn=custom_batch_collation)

if LONGBERT:
    config = AutoConfig.from_pretrained("/kaggle/input/XXXXX/")
else:
    config = AutoConfig.from_pretrained("/kaggle/input/robertabase/")
config.num_labels = NUM_LABELS

In [None]:
model = TModel(config=config)
model = model.to('cuda')
model.transformer.resize_token_embeddings(len(TOKENIZER))

bio_cls_weights = torch.Tensor([0, 100, 10, 100, 10, 100, 10, 100, 10, 100, 10, 100, 10, 100, 10, 5]).cuda()
bio_loss = FocalLoss(ignore_index=0, gamma=2, alpha=bio_cls_weights)
boundary_loss = FocalLoss(ignore_index=0, gamma=2, alpha=torch.Tensor([0,10,10,1]).cuda())
type_loss = FocalLoss(ignore_index=0, gamma=2)
seg_loss = FocalLoss(ignore_index=0, gamma=2)

bert_param_optimizer = list(model.transformer.named_parameters())
ner_fc_param_optimizer = list(model.ner.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

if not BASELINE:
    boundary_param_optimizer = list(model.boundary.named_parameters())
    type_param_optimizer = list(model.type_predict.named_parameters())

    optimizer_grouped_parameters = [
        {'params': [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01,
            'lr': LEARNING_RATE},
        {'params': [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
            'lr': LEARNING_RATE},
        {'params': [p for n, p in ner_fc_param_optimizer if not any(nd in n for nd in no_decay)],
            'weight_decay': 0.01,
            'lr': LEARNING_RATE},
        {'params': [p for n, p in ner_fc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
            'lr': LEARNING_RATE},
        {'params': [p for n, p in boundary_param_optimizer if not any(nd in n for nd in no_decay)],
            'weight_decay': 0.01,
            'lr': LEARNING_RATE},
        {'params': [p for n, p in boundary_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
            'lr': LEARNING_RATE},
        {'params': [p for n, p in type_param_optimizer if not any(nd in n for nd in no_decay)],
            'weight_decay': 0.01,
            'lr': LEARNING_RATE},
        {'params': [p for n, p in type_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
            'lr': LEARNING_RATE},
    ]
else:
    optimizer_grouped_parameters = [
        {'params': [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01,
            'lr': LEARNING_RATE},
        {'params': [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
            'lr': LEARNING_RATE},
        {'params': [p for n, p in ner_fc_param_optimizer if not any(nd in n for nd in no_decay)],
            'weight_decay': 0.01,
            'lr': LEARNING_RATE},
        {'params': [p for n, p in ner_fc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
            'lr': LEARNING_RATE},
    ]
t_total = int(len(train_dl) / 1 * NUM_EPOCH)
optimizer = AdamW(params=optimizer_grouped_parameters, lr=LEARNING_RATE)


class AverageMeter(object):
    '''
    computes and stores the average and current value
    Example:
        >>> loss = AverageMeter()
        >>> for step,batch in enumerate(train_data):
        >>>     pred = self.model(batch)
        >>>     raw_loss = self.metrics(pred,target)
        >>>     loss.update(raw_loss.item(),n = 1)
        >>> cur_loss = loss.avg
    '''

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
NUM_EPOCH

In [None]:
for i in range(NUM_EPOCH):
    torch.cuda.empty_cache()
    model.train()
    pbar = tqdm(total=len(train_dl), desc='Train')
    train_loss = AverageMeter()
    for step, batch in enumerate(train_dl):
        bs = len(batch[0])
        optimizer.zero_grad()
        if LONGBERT:
            input_ids, labels_type, labels_bio, labels_boundary, attention_masks, subword_masks = batch 
        else:
            input_ids, labels_type, labels_bio, labels_boundary, attention_masks, subword_masks, cls_pos, sliding_window_pos = batch 

        input_ids = torch.stack(input_ids).cuda()
        labels_type = torch.stack(labels_type).cuda()
        labels_bio = torch.stack(labels_bio).cuda()
        labels_boundary = torch.stack(labels_boundary).cuda()
        attention_masks = torch.stack(attention_masks).cuda()
        subword_masks = torch.stack(subword_masks).cuda()
        active_padding_mask = attention_masks.view(-1) == 1

        boundary_matrix = bound_to_matrix(labels_boundary).cuda()
        pad_matrix = []
        for i in range(bs):
            tmp = attention_masks[i].clone()
            tmp = tmp.view(MAX_LEN, 1)
            tmp_t = tmp.transpose(0, 1)
            mat = tmp * tmp_t
            pad_matrix.append(mat)
        pad_matrix = torch.stack(pad_matrix, 0)
        matrix_padding_mask = pad_matrix.view(-1) == 1
        with autocast():
            #with torch.autograd.profiler.profile(use_cuda=True) as prof:
            if BASELINE:
                ner_logits = model(input_ids=input_ids, attention_mask=attention_masks)
                ner_loss_ = bio_loss(
                    ner_logits.view(-1, len(LABEL_BIO))[active_padding_mask], labels_bio.view(-1)[active_padding_mask])
                loss = ner_loss_
            else:
                ner_logits, boundary_logits, type_logits, seg_logits = model(input_ids=input_ids, attention_mask=attention_masks)
                ner_loss_ = bio_loss(
                    ner_logits.view(-1, len(LABEL_BIO))[active_padding_mask], labels_bio.view(-1)[active_padding_mask])
                boundary_loss_ = boundary_loss(boundary_logits.view(-1, len(BOUNDARY_LABEL))[active_padding_mask], labels_boundary.view(-1)[active_padding_mask])
                type_loss_ = type_loss(type_logits.view(-1, len(LABEL_2_ID))[active_padding_mask], labels_type.view(-1)[active_padding_mask])
                seg_loss_ = seg_loss(seg_logits.view(-1, len(BOUNDARY_LABEL_UNIDIRECTION))[matrix_padding_mask], boundary_matrix.view(-1)[matrix_padding_mask])
                loss = ner_loss_+boundary_loss_+type_loss_+seg_loss_
                    
            #print(prof.key_averages().table())
            loss.backward()
            optimizer.step()
        #torch.cuda.empty_cache()
        #gc.collect()
        train_loss.update(loss.item(), n=input_ids.size(0))
        pbar.update()
        pbar.set_postfix({'loss': train_loss.avg})
    print(train_loss.avg)

In [None]:
BIO_LABEL={ 1:'Claim', 3:'Evidence', 5: 'Position', 7:'Concluding Statement',  9:'Lead', 11:'Counterclaim', 13: 'Rebuttal'}

def submit_formatting(ner_logits_i, subword_masks_i, text_id):
    dataframe=pd.DataFrame()
    positions=[]
    labels=[]
    end_prediction=subword_masks_i.max()
    prev_e=None
    prev_position=None
    for i in range(len(ner_logits_i)):

        e=ner_logits_i[i]
        startposition=subword_masks_i[i]

        if e % 2==1 and e!=15 and startposition!=-1 and e!=prev_e and startposition!=prev_position:
            label= BIO_LABEL[e]
            positions.append(startposition)
            labels.append(label)
            prev_e=e
            prev_position=startposition
    
    positions.append(end_prediction)
    positions=list(positions)

    length=len(positions)
    # print(positions)
    if length==1:
        list_=[]
        dict1={}
        dict1['id'] = text_id
        dict1['class']= 'Evidence'
        list_=[str(e) for e in range(0, positions[-1])]
        dict1['predictionstring']=" ".join(list_)
        
    else:
        
        pointer = 0
        while pointer < length-1:
            list_=[]
            dict1={}
            dict1['id'] = text_id
            dict1['class']= labels[pointer]
            list_=[str(e) for e in range(positions[pointer], positions[pointer+1])]
            dict1['predictionstring']=" ".join(list_)
            # print(dict1)
            dataframe=dataframe.append(dict1, ignore_index=True)
            pointer+=1

    return dataframe

In [None]:
model.eval()
valid_loss = AverageMeter()
pbar = tqdm(total=len(dev_dl), desc='Eval')
df_all=pd.DataFrame()
for batch in dev_dl:
    bs = len(batch[0])
    input_ids, attention_masks, subword_masks, cls_pos, sliding_window_pos = batch 
    input_ids = torch.stack(input_ids).cuda()
    attention_masks = torch.stack(attention_masks).cuda()
    subword_masks = torch.stack(subword_masks).cuda()
    active_padding_mask = attention_masks.view(-1) == 1
    
    pad_matrix = []
    for i in range(bs):
        tmp = attention_masks[i].clone()
        tmp = tmp.view(MAX_LEN, 1)
        tmp_t = tmp.transpose(0, 1)
        mat = tmp * tmp_t
        pad_matrix.append(mat)
    pad_matrix = torch.stack(pad_matrix, 0)
    matrix_padding_mask = pad_matrix.view(-1) == 1

    with torch.no_grad():
        with autocast():
            ner_logits, boundary_logits, type_logits, seg_logits = model(input_ids=input_ids, attention_mask=attention_masks)
    
    text_id=sliding_window_pos[0][1]
    print(text_id)
    attention_masks_i=attention_masks[0]
    attention_masks_i=attention_masks_i==1
    ner_logits_i=ner_logits.argmax(-1)[0]
    ner_logits_i=ner_logits_i[attention_masks_i].cpu().numpy()
    subword_masks_i=subword_masks[0]
    subword_masks_i=subword_masks_i[attention_masks_i].cpu().numpy()
    df=submit_formatting(ner_logits_i, subword_masks_i, text_id)
    df_all=df_all.append(df,ignore_index=True)
    try:
        text_id=sliding_window_pos[1][1]
        attention_masks_i=attention_masks[1]
        attention_masks_i=attention_masks_i==1
        ner_logits_i=ner_logits.argmax(-1)[1]
        ner_logits_i=ner_logits_i[attention_masks_i].cpu().numpy()
        subword_masks_i=subword_masks[1]
        subword_masks_i=subword_masks_i[attention_masks_i].cpu().numpy()
        df=submit_formatting(ner_logits_i, subword_masks_i, text_id)
        df_all=df_all.append(df,ignore_index=True)
    except:
        pass

In [None]:
df_all

In [None]:
df_all.to_csv('/kaggle/working/submission.csv', index=False)