This notebook is modified version of [the original inference code.](https://www.kaggle.com/code/kurokurob/12th-place-solution-original/notebook)  
I fixed the code for removing the inappropriate word on the end of the 'predictionstring'.

In [None]:
import os
import pickle
import gc
gc.enable()

import sys
sys.path.append("../input/pythonbox")

from box import Box
from tqdm import tqdm
import copy

import random
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import (
    Dataset, DataLoader,
)
from joblib import Parallel, delayed
from transformers import AutoConfig, AutoModel, AutoTokenizer

import nltk
from nltk.corpus import stopwords
from string import punctuation

In [None]:
# seeD
def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(2022)

In [None]:
target_id_map = {
    "B-Lead": 0,
    "I-Lead": 1,
    "B-Position": 2,
    "I-Position": 3,
    "B-Evidence": 4,
    "I-Evidence": 5,
    "B-Claim": 6,
    "I-Claim": 7,
    "B-Concluding Statement": 8,
    "I-Concluding Statement": 9,
    "B-Counterclaim": 10,
    "I-Counterclaim": 11,
    "B-Rebuttal": 12,
    "I-Rebuttal": 13,
    "O": 14,
    "PAD": -100,
}


id_target_map = {v: k for k, v in target_id_map.items()}

class args1:
    input_path = "../input/feedback-prize-2021/"
    model = "../input/longformerlarge"
    model_weight= '../input/exp001'
    output = "."
    batch_size = 8
    max_len = 4096
    use_folds=[1,2,3]

class args2:
    input_path = "../input/feedback-prize-2021/"
    model = '../input/funneltransformers/large'
    model_weight= '../input/exp051'
    output = "."
    batch_size = 8
    max_len = 4096
    use_folds=[1,2,3]

class args3:
    input_path = "../input/feedback-prize-2021/"
    model = "../input/deberta/large"
    model_weight= '../input/exp056'
    output = "."
    batch_size = 2
    max_len = 4096
    use_folds= [1,2,3]

class args4:
    input_path = "../input/feedback-prize-2021/"
    model = "../input/deberta/large"
    model_weight= '../input/exp063'
    output = "."
    batch_size = 2
    max_len = 4096
    use_folds= [1,2]
    
class args5:
    input_path = "../input/feedback-prize-2021/"
    model = "../input/deberta-xlarge"
    model_weight= '../input/exp066'
    output = "."
    batch_size = 2
    max_len = 4096
    use_folds= [1,2,3,4,5]

configs={
    args1: 0.175,
    args2: 0.175,
    args3: 0.15,
    args4: 0.15,
    args5: 0.35
}

NUM_LABELS=len(target_id_map) - 1

In [None]:
class FeedbackDataset(Dataset):
    def __init__(self, samples, max_len, tokenizer):
        self.samples = samples
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.length = len(samples)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        input_ids = self.samples[idx]["input_ids"]
        # print(input_ids)
        # print(input_labels)

        # add start token id to the input_ids
        input_ids = [self.tokenizer.cls_token_id] + input_ids

        if len(input_ids) > self.max_len - 1:
            input_ids = input_ids[: self.max_len - 1]

        # add end token id to the input_ids
        input_ids = input_ids + [self.tokenizer.sep_token_id]
        attention_mask = [1] * len(input_ids)

        return {
            "ids": input_ids,
            "mask": attention_mask,
        }

In [None]:
class Collate:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        output = dict()
        output["ids"] = [sample["ids"] for sample in batch]
        output["mask"] = [sample["mask"] for sample in batch]

        # calculate max token length of this batch
        batch_max = max([len(ids) for ids in output["ids"]])

        # add padding
        if self.tokenizer.padding_side == "right":
            output["ids"] = [s + (batch_max - len(s)) * [self.tokenizer.pad_token_id] for s in output["ids"]]
            output["mask"] = [s + (batch_max - len(s)) * [0] for s in output["mask"]]
        else:
            output["ids"] = [(batch_max - len(s)) * [self.tokenizer.pad_token_id] + s for s in output["ids"]]
            output["mask"] = [(batch_max - len(s)) * [0] + s for s in output["mask"]]

        # convert to tensors
        output["ids"] = torch.tensor(output["ids"], dtype=torch.long)
        output["mask"] = torch.tensor(output["mask"], dtype=torch.long)

        return output

In [None]:
class FeedbackModel(nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.model_name = model_name
        self.num_labels = num_labels
        config = AutoConfig.from_pretrained(model_name)

        hidden_dropout_prob: float = 0.1
        layer_norm_eps: float = 1e-7
        config.update(
            {
                "output_hidden_states": True,
                "hidden_dropout_prob": hidden_dropout_prob,
                "layer_norm_eps": layer_norm_eps,
                "add_pooling_layer": False,
            }
        )
        self.transformer = AutoModel.from_config(config)
        self.output = nn.Linear(config.hidden_size, self.num_labels)

    def forward(self, ids, mask):
        transformer_out = self.transformer(ids, mask)
        sequence_output = transformer_out.last_hidden_state
        logits = self.output(sequence_output)
        logits = torch.softmax(logits, dim=-1)
        return logits

In [None]:
def split_sample(sample, max_len, stride = 128):

    split_features=[]

    length=len(sample['input_ids'])
    #special token分を除いた最大span
    max_span = max_len-2
    
    if length <= max_span:
        split_features.append(sample)
    
    else:
        #ループは最低2回以上なので、ceil + 1。
        loop_num = math.ceil((length - max_span) / stride) + 1

        for i in range(loop_num):
            split_feature={}
            start=i*stride

            split_feature['id'] = sample['id']
            split_feature['input_ids']  = sample['input_ids'][start:  start + max_span]
            split_feature['text']  = sample['text']
            split_feature['offset_mapping']  = sample['offset_mapping'][start:  start + max_span]

            split_features.append(split_feature)

    return split_features


    

In [None]:
def collate_samples(all_test_samples, raw_preds):

    preds_list=[j for i in raw_preds for j in i]

    for j in range(len(all_test_samples)):
        offset__len=len(all_test_samples[j]['offset_mapping'])
        
        if len(preds_list[j])<offset__len:
            #cls,sepを除く
            pred_len=len(preds_list[j])-2
            diff = offset__len - pred_len
            
            raw_pred = preds_list[j][1: 1+pred_len]
            #もしも、offset__lenの方が長い場合、Oが大きい値を取るようにする。
            pad=np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 100.])
            pad = np.repeat(pad[None, :], diff, axis=0)
            raw_pred = np.vstack([raw_pred, pad])
            all_test_samples[j]["raw_pred"] = raw_pred

        else:
            #1はずらした分。
            raw_pred = preds_list[j][1:1+offset__len]
            all_test_samples[j]["raw_pred"] = raw_pred

    #collate
    sample_dict={}
    for i in all_test_samples:
        if i['id'] not in sample_dict.keys():
            sample_dict[i['id']]=i
        else:
            sample_dict[i['id']]['input_ids'] += i['input_ids']
            sample_dict[i['id']]['offset_mapping'] += i['offset_mapping']
            sample_dict[i['id']]['raw_pred'] = np.vstack([sample_dict[i['id']]['raw_pred'],  i['raw_pred']])

    all_test_samples=list(sample_dict.values())
    all_test_samples=sorted(all_test_samples, key=lambda x: x['id'])

    return all_test_samples


In [None]:
def _prepare_test_data_helper(args, tokenizer, ids):
    test_samples = []
    for idx in ids:
        filename = os.path.join(args.input_path, "test", idx + ".txt")
        with open(filename, "r") as f:
            text = f.read()

        encoded_text = tokenizer.encode_plus(
            text,
            add_special_tokens=False,
            return_offsets_mapping=True,
        )
        input_ids = encoded_text["input_ids"]
        offset_mapping = encoded_text["offset_mapping"]

        sample = {
            "id": idx,
            "input_ids": input_ids,
            "text": text,
            "offset_mapping": offset_mapping,
        }

        test_samples.append(sample)
    return test_samples


def prepare_test_data(df, tokenizer, args):
    test_samples = []
    ids = df["id"].unique()
    ids_splits = np.array_split(ids, 4)

    results = Parallel(n_jobs=4, backend="multiprocessing")(
        delayed(_prepare_test_data_helper)(args, tokenizer, idx) for idx in ids_splits
    )
    for result in results:
        test_samples.extend(result)

    return test_samples

In [None]:
def make_pred(df, args, perfome_split=False):

    tokenizer = AutoTokenizer.from_pretrained(args.model)
    collate = Collate(tokenizer=tokenizer)
    
    raw_preds=[]
    all_test_samples=[]
    test_samples=prepare_test_data(df, tokenizer, args)

    #split_sampleを行い、max_lenを超えていた場合にも対処
    if perfome_split==False:
        all_test_samples += test_samples
        test_dataset = FeedbackDataset(test_samples, args.max_len, tokenizer)
        del test_samples
        gc.collect()
        
    else:
        split_test_samples = []
        for i in test_samples:
            split_test_samples += split_sample(i, max_len = args.max_len, stride = args.max_len-2) 
        
        del test_samples
        gc.collect()

        all_test_samples += split_test_samples
        test_dataset = FeedbackDataset(split_test_samples, args.max_len, tokenizer)

        del split_test_samples
        gc.collect()


    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size = args.batch_size,
        shuffle = False,
        num_workers = 4,
        pin_memory = True,
        drop_last = False,
        collate_fn = collate
    )
    del test_dataset
    gc.collect()

    for idx in range(len(args.use_folds)):
        print(f'{args.use_folds[idx]}fold')

        model = FeedbackModel(model_name=args.model, num_labels=len(target_id_map) - 1)
        model.to('cuda')
        
        model_weight=f'{args.model_weight}/{args.use_folds[idx]}fold_best_metrics.ckpt'
        print(model_weight)
        model.load_state_dict(torch.load(model_weight))
        
        model.eval()
        
        #半精度

        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=True):
                for i, output in enumerate(test_dataloader):
                    input_ids=output['ids'].to('cuda')
                    attention_mask=output['mask'].to('cuda')
                    out=model(input_ids, attention_mask)
                    out=out.to('cpu').detach().numpy()
                    
                    out = out / len(args.use_folds)
                    if idx == 0:
                        raw_preds.append(out)
                    else:
                        raw_preds[i] += out

        del model
        gc.collect()

    all_test_samples=collate_samples(all_test_samples, raw_preds)
    del test_dataloader
    gc.collect()
    
    return all_test_samples, raw_preds


In [None]:
df = pd.read_csv(os.path.join("../input/feedback-prize-2021/", "sample_submission.csv"))

for i, (args, weight) in enumerate(configs.items()):
    #from IPython.core.debugger import Pdb; Pdb().set_trace()
    if i==0:
        all_test_samples, raw_preds = make_pred(df, args, perfome_split=False)
        #アンサンブル用サンプルを作成。1個目がベース。
        ens_samples=copy.deepcopy(all_test_samples)
        #アンサンブルサンプルのraw_predを初期化
        for t in range(len(ens_samples)):
            ens_samples[t]['raw_pred']=np.zeros((len(ens_samples[t]['input_ids']), NUM_LABELS))
    elif i>=1:
        all_test_samples, raw_preds = make_pred(df, args, perfome_split=False)


    for j in tqdm(range(len(all_test_samples))):
        #諸々、１個目のモデルをbaseにする。
        #最初にtextをキャッチキャッチし、prob配列を作成したりしている。ここで文字レベルアンサンブルを実施
        text = ens_samples[j]['text']
        base_input_ids = ens_samples[j]['input_ids']
        base_offset_mapping= ens_samples[j]['offset_mapping']

        token_to_text_probability = np.full((len(text),NUM_LABELS),0, np.float32)
        text_to_token_probability = np.full((len(base_input_ids),NUM_LABELS),0, np.float32)

        p = all_test_samples[j]['raw_pred']
        #token to text
        for t,(start,end) in enumerate(all_test_samples[j]['offset_mapping']):
            token_to_text_probability[start:end]+=p[t] #**0.5
        token_to_text_probability = token_to_text_probability 

        #text to token
        for t,(start,end) in enumerate(base_offset_mapping):
            text_to_token_probability[t]=token_to_text_probability[start:end].mean(0)

        ens_samples[j]['raw_pred'] += text_to_token_probability * weight

    del all_test_samples, raw_preds
    gc.collect()


In [None]:
#あとは下を使って、アンサンブル結果を格納
for i in range(len(ens_samples)):
    pred_class = np.argmax(ens_samples[i]['raw_pred'], axis=1)
    pred_scrs = np.max(ens_samples[i]['raw_pred'], axis=1)
    ens_samples[i]["preds"] = [id_target_map[p] for p in pred_class]
    ens_samples[i]["pred_scores"] = pred_scrs

In [None]:
tokenizer = AutoTokenizer.from_pretrained("../input/longformerlarge")

remove_words = [
    "I",
    "You",
    "She",
    "she",
    "He",
    "he",
    "the",
    "The",
    "This",
    "That",
    "however",
    "However",
    "Although",
    "although",
    "Because",
    "Some",
    "Not",
    "There",
    "To",
    "(",
    "[",
    "A",
    "In",
    "On",
    "At",
    "With",
    "Electoral",
    "If",
    "It",
    "They",
    "Venus",
    "When",
    "You",
    "College",
    "Students",
    "So",
    "A",
    "We",
    "Many",
    "As",
    "With",
    "People",
    "These",
    "That",
    "Another",
    "And",
    "Also",
    "Action",
    "Facial",
    "President",
    "After",
    "Online",
    "Teachers",
    "Why",
    "Being",
    "Cars",
    "Asking",
    "Just",
    "According",
    "Its",
    "System",
    "Since",
    "Sometimes",
    "Cowboys",
    "Summer",
    "Do",
    "No",
    "How",
    "Do",
    "No",
    "How",
    "Paris",
    "Have",
    "Well",
    "Which",
    "Like",
    "Without",
    "American",
    "Mars.",
    "Earth.",
    "America",
    "School",
    "She",
    "Everyone",
    "Maybe",
    "Therefore,",
    "Distance",
    "Mona",
    "An",
    "That's",
    "So",
    "Those",
    "Imagine",
    "Or",
    "Venus,",
    "College.",
    "2000",
    "Seeking",
    "Using",
    "Is",
    "Your",
    "Getting",
    "Google",
    "Every",
    "Viking",
    "Cell",
    "Cowboy",
    "First",
    "Schools",
    "Our",
    "From",
    "Making",
    "Lastly",
    "During",
    "Each",
    "Exploring",
    "Finally",
    "Other",
    "Source",
    "Instead",
    "Their",
    "Technology",
    "Policy",
    "Kids",
    "Driving",
    "Are",
    "But,",
    "Texting",
    "Furthermore,",
    "Community",
    "Once",
    "Would",
    "Car",
    "More",
    "Overall,",
    "Challenge",
    "Thats",
    "College,",
    "Taking",
    "Who",
    "New",
    "Of",
    "Yes",
    "Though",
    "Extracurricular",
    "Limiting",
    "War",
    "Learning",
    "Global",
    "Think",
    "First",
    "Drivers",
    "Doing",
    "House",
    "Whether",
    "Due",
    "-",
    "Over",
    "Well",
    "Therefore",
    "Student",
    "Secondly",
    "Although",
    "Congress",
    "Say",
    "Advice",
    "Phones",
    "Others",
    "Going",
    "Yes",
    "Indefensible",
    "Life",
    "Surveyor",
    "Scientists",
    "Should",
    "Bush",
    "Democratic",
    "Voters",
    "Giving",
    "National",
    "Allowing",
    "Paragraph",
    "Talking",
    "Middle",
    "Knowing",
    "Sports",
    'Venus"',
    "Last",
    "Snake",
    "Attending",
    "Additionally,",
    "River",
    "Despite",
    "Someone",
    "Studies",
    "Humans",
    "Such",
    "Plain",
    "Vice",
    "Safety",
    "Plus",
    "English",
    "Second",
    "Overall",
    "Can",
    "Working",
    "Only",
    "Monday",
    "Did",
    "Under",
    "Here",
    "Multiple",
    "Before",
    "Five",
    "Popular",
    "Through",
    "Defense",
    "Almost",
    "Parents",
    "Both",
    "Especially",
    "Will",
    "Driveless",
    "Different",
    "Goes",
    "Helping",
    "Along",
    "Let",
    "Today",
    "Everybody",
    "Next",
    "Throughout",
    "Second",
    "His",
    "Today",
    "High",
    "Red",
    "Sometimes",
    "Take",
    "Office",
    "Lastly",
    "Scientist",
    "Something",
    "Teacher",
    "Nobody",
    "Moreover",
    "Representatives",
    "Traffic",
    "Please",
    "Education",
    "Day",
    "Teens",
    "Either",
    "Electors",
    "Less",
    "Hearing",
    "Island",
    "Any",
    "Public",
    "Things",
    "Often",
    "District",
    "Thus",
    "Keeping",
    "Teenagers",
    "Projects",
    "Citizens",
    "Yet",
    "Based",
    "Changing",
    "Perhaps",
    "Studying",
    "Everyday",
    "Besides",
    "Her",
    "Back",
    "State",
    "Children",
    "Make",
    "Where",
    "Science",
    "Human",
    "Seeing",
    "Unfortunately,",
    "End",
    "Camera",
    "Firstly,",
    "Whenever",
    "See",
    "Looking",
    "Swing",
    "Instead,",
    "Distracted",
    "Does",
    "Colleges",
    "Thank",
    "Transportation",
    "Representatives,",
    "Letting",
    "Smog",
    "Home",
    "Personally",
    "Wyoming",
    "Video",
    "Depending",
    "Meaning",
    "Presidential",
    "Finding",
    "Theres",
    "Us",
    "System",
    "Social",
    "Smart",
    "Thousands",
    'Coming"',
    "System.",
    "Dear",
    "Personally",
    "Lots",
    "Space",
    "Two",
    "Later",
    "Nations",
    'Smile"',
    "Relief",
    "Pacific",
    "Candidates",
    "Always",
    "Reasons",
    "Computers",
    "Everything",
    "Participating",
    "Joining",
    "Forcing",
    "About",
    "Ever",
    "Duffer",
    "Whatever",
    "Could",
    "Time",
    "Third",
    "Adding",
    "Bill",
    "Usually",
    "Anyone",
    "Me",
    "Nothing",
    "Stress",
    "Remember",
    "Be",
    "Voting",
    "Unlike",
    "Majority",
    "Good",
    "Ultimately",
    "Ask",
    "Agency",
    "Consider",
    "Putting",
    "Opinions",
    "Protection",
    "Hopefully",
    "Bullying",
    "Classes",
    "Out",
    "Pollution",
    "Listening",
    "Activities",
    "Face,",
    "May",
    "Anything",
    "Reading",
    "Recently,",
    "Companies",
    "Considering",
    "Receiving",
    "Rather",
    "Walking",
    "Given",
    "Certain",
    "Playing",
    "Clearly",
    "Help",
    "Unless",
    "Sure,",
    "Congestion",
    "Staying",
    "Statistics",
    "Trying",
    "Wyoming,",
    "Until",
    "Service",
    "Keep",
    "Within",
    "Free",
    "Evening",
    "Three",
    "Author",
    "Theory",
    "Nearly",
    "Avoid",
    "Trust",
    "Mostly",
    "Coming",
    "Advanced",
    "Never",
    "Honestly",
    "Millions",
    "Constitution.",
    "Conclusion",
    "Last",
    "Reason",
    "Emotions",
    "Accidents",
    "Feedback",
    "Exploration",
    "Creating",
    "Soon",
    "Too",
    "Use",
    "Maine",
    "Look",
    "Program",
    "Congressional",
    "West",
    "Constitution",
    "Elections",
    "Also,",
    "People",
    "Well",
    "Driverless",
    "All",
    "Even",
    "Program",
    "One",
    "Having",
    "Congress",
    "Which",
    "My",
    "By",
    "Most",
    "On",
    "Why",
    "Planet",
    "Outcome",
    "Driving",
    "Like",
    "Online",
    "Do",
    "Phones",
    "Distance",
    "Asking",
    "Cell",
    "Just",
    "Service",
    "Voting",
    "Ocean",
    "Since",
    "Learning",
    "Etc",
    "Everyone",
    "Limiting",
    "Every",
    "Bush",
    "Getting",
    "House",
    "Without",
    "Times",
    "Due",
    "Are",
    "Multiple",
    "Looking",
    "Seagoing",
    "Giving",
    "Taking",
    'Coming"',
    "Problems",
    "Imagine",
    "Exploring",
    "Bullying",
    "Texting",
    "Avoid",
    "Studying",
    "Knowing",
    "Perhaps",
    "Idea",
    "Road",
    "Finding",
    "Constisution",
    "Viking",
    "Studies",
    "Sounds",
    "Nation",
    "Making",
    "Process",
    "Let's",
    "Respect",
    "Statistics",
    "Day",
    "Kids",
    "Would",
    "Education",
    "Help",
    "Ravines",
    "Tragic",
    "Fun",
    "Works",
    "Band",
    "Bad",
    "Money",
    "Surveyor",
    'Illions."',
    "Quizizz",
    "Minimizing",
    "Humanity",
    "Nearly",
    "Hardly",
    "--",
    "Market",
    "Orbiters",
    "Trouble",
    "Healthy",
    "Unfair,",
    "Fast",
    "Threes",
    "Passing",
    "Baseball",
    "Unfortunatley,",
    "Smile",
    "Life",
    "Polish",
    "Clothes,",
    "Mean",
]


# capitalize()で大文字化させたものを除きたい。

remove_words_ids = [
    tokenizer(i, add_special_tokens=False)["input_ids"][0] for i in remove_words
]

In [None]:
for i in tqdm(range(len(ens_samples))):
    preds=ens_samples[i]['preds']
    input_ids=ens_samples[i]['input_ids']
    idx=0
    while idx<len(preds):
        if preds[idx] != "O":
            label = preds[idx][2:]
            matching_label = f"I-{label}"
        else:
            idx+=1
            continue
        idx+=1
        while idx<len(preds):
            if preds[idx] == matching_label:
                idx+=1
                #print(idx)
                continue
            else:
                #from IPython.core.debugger import Pdb; Pdb().set_trace()
                if input_ids[idx-1] in remove_words_ids:
                    ens_samples[i]['preds'][idx-1]='O'
                    ens_samples[i]['pred_scores'][idx-1]=1.
                    break
                else:
                    break

In [None]:
def link_evidence(oof):
    if not len(oof):
        return oof
  
    def jn(pst, start, end):
        return " ".join([str(x) for x in pst[start:end] if x !=-1])
  
    thresh = 1
    idu = oof['id'].unique()
    eoof = oof[oof['class'] == "Evidence"]
    neoof = oof[oof['class'] != "Evidence"]
    eoof.index = eoof[['id', 'class']]
    for thresh2 in range(30, 31, 1):
        retval = []
        for idv in tqdm(idu, desc='link_evidence', leave=False):
            for c in ['Evidence']:
                q = eoof[(eoof['id'] == idv)]

                if len(q) == 0:
                    continue
                pst = []

                for r in q.itertuples():
                    pst = [*pst, -1,  *[int(x) for x in r.predictionstring.split()]]
                start = 1
                end = 1
                for i in range(2, len(pst)):
                    cur = pst[i]
                    end = i
                    if  ((cur == -1) and ((pst[i + 1] > pst[end - 1] + thresh) or (pst[i + 1] - pst[start] > thresh2))):
                        retval.append((idv, c, jn(pst, start, end)))
                        start = i + 1
                v = (idv, c, jn(pst, start, end + 1))
                retval.append(v)
    roof = pd.DataFrame(retval, columns=['id', 'class', 'predictionstring'])
    roof = roof.merge(neoof, how='outer')
    return roof

In [None]:
def link_lead(oof):
    if not len(oof):
        return oof
  
    def jn(pst, start, end):
        return " ".join([str(x) for x in pst[start:end] if x !=-1])
  
    thresh = 200
    idu = oof['id'].unique()
    eoof = oof[oof['class'] == "Lead"]
    neoof = oof[oof['class'] != "Lead"]
    eoof.index = eoof[['id', 'class']]
    for thresh2 in range(1000, 1001, 1):
        retval = []
        for idv in tqdm(idu, desc='link_lead', leave=False):
            for c in ['Lead']:
                q = eoof[(eoof['id'] == idv)]
                if len(q) == 0:
                    continue
                pst = []
                for r in q.itertuples():
                    pst = [*pst, -1,  *[int(x) for x in r.predictionstring.split()]]
                start = 1
                end = 1
                for i in range(2, len(pst)):
                    cur = pst[i]
                    end = i
                    if  ((cur == -1) and ((pst[i + 1] > pst[end - 1] + thresh) or (pst[i + 1] - pst[start] > thresh2))):
                        retval.append((idv, c, jn(pst, start, end)))
                        start = i + 1
                v = (idv, c, jn(pst, start, end + 1))
                retval.append(v)
    roof = pd.DataFrame(retval, columns=['id', 'class', 'predictionstring'])
    roof = roof.merge(neoof, how='outer')
    return roof

In [None]:
def link_conclude(oof):
    if not len(oof):
        return oof
  
    def jn(pst, start, end):
        return " ".join([str(x) for x in pst[start:end] if x !=-1])
  
    thresh = 100
    idu = oof['id'].unique()
    eoof = oof[oof['class'] == "Concluding Statement"]
    neoof = oof[oof['class'] != "Concluding Statement"]
    eoof.index = eoof[['id', 'class']]
    for thresh2 in range(500, 501, 1):
        retval = []
        for idv in tqdm(idu, desc='link_conclude', leave=False):
            for c in ["Concluding Statement"]:
                q = eoof[(eoof['id'] == idv)]
                if len(q) == 0:
                    continue
                pst = []
                for r in q.itertuples():
                    pst = [*pst, -1,  *[int(x) for x in r.predictionstring.split()]]
                start = 1
                end = 1
                for i in range(2, len(pst)):
                    cur = pst[i]
                    end = i
                    if  ((cur == -1) and ((pst[i + 1] > pst[end - 1] + thresh) or (pst[i + 1] - pst[start] > thresh2))):
                        retval.append((idv, c, jn(pst, start, end)))
                        start = i + 1
                v = (idv, c, jn(pst, start, end + 1))
                retval.append(v)
    roof = pd.DataFrame(retval, columns=['id', 'class', 'predictionstring'])
    roof = roof.merge(neoof, how='outer')
    return roof

In [None]:
def link_position(oof):
    if not len(oof):
        return oof
  
    def jn(pst, start, end):
        return " ".join([str(x) for x in pst[start:end] if x !=-1])
  
    thresh = 15
    idu = oof['id'].unique()
    eoof = oof[oof['class'] == "Position"]
    neoof = oof[oof['class'] != "Position"]
    eoof.index = eoof[['id', 'class']]
    for thresh2 in range(200, 201, 1):
        retval = []
        for idv in tqdm(idu, desc='link_position', leave=False):
            for c in ['Position']:
                q = eoof[(eoof['id'] == idv)]
                if len(q) == 0:
                    continue
                pst = []
                for r in q.itertuples():
                    pst = [*pst, -1,  *[int(x) for x in r.predictionstring.split()]]
                start = 1
                end = 1
                for i in range(2, len(pst)):
                    cur = pst[i]
                    end = i
                    if  ((cur == -1) and ((pst[i + 1] > pst[end - 1] + thresh) or (pst[i + 1] - pst[start] > thresh2))):
                        retval.append((idv, c, jn(pst, start, end)))
                        start = i + 1
                v = (idv, c, jn(pst, start, end + 1))
                retval.append(v)
    roof = pd.DataFrame(retval, columns=['id', 'class', 'predictionstring'])
    roof = roof.merge(neoof, how='outer')
    return roof

In [None]:
def link_counterclaim(oof):
    if not len(oof):
        return oof
  
    def jn(pst, start, end):
        return " ".join([str(x) for x in pst[start:end] if x !=-1])
  
    thresh = 8
    idu = oof['id'].unique()
    eoof = oof[oof['class'] == "Counterclaim"]
    neoof = oof[oof['class'] != "Counterclaim"]
    eoof.index = eoof[['id', 'class']]
    for thresh2 in range(50, 51, 1):
        retval = []
        for idv in tqdm(idu, desc='link_counterclaim', leave=False):
            for c in ["Counterclaim"]:
                q = eoof[(eoof['id'] == idv)]
                if len(q) == 0:
                    continue
                pst = []
                for r in q.itertuples():
                    pst = [*pst, -1,  *[int(x) for x in r.predictionstring.split()]]
                start = 1
                end = 1
                for i in range(2, len(pst)):
                    cur = pst[i]
                    end = i
                    if  ((cur == -1) and ((pst[i + 1] > pst[end - 1] + thresh) or (pst[i + 1] - pst[start] > thresh2))):
                        retval.append((idv, c, jn(pst, start, end)))
                        start = i + 1
                v = (idv, c, jn(pst, start, end + 1))
                retval.append(v)
    roof = pd.DataFrame(retval, columns=['id', 'class', 'predictionstring'])
    roof = roof.merge(neoof, how='outer')
    return roof

In [None]:
def link_rebuttal(oof):
    if not len(oof):
        return oof
    
    def jn(pst, start, end):
        return " ".join([str(x) for x in pst[start:end] if x !=-1])
  
    thresh = 1
    idu = oof['id'].unique()
    eoof = oof[oof['class'] == "Rebuttal"]
    neoof = oof[oof['class'] != "Rebuttal"]
    eoof.index = eoof[['id', 'class']]
    for thresh2 in range(40, 41, 1):
        retval = []
        for idv in tqdm(idu, desc='link_rebuttal', leave=False):
            for c in ["Rebuttal"]:
                q = eoof[(eoof['id'] == idv)]
                if len(q) == 0:
                    continue
                pst = []
                for r in q.itertuples():
                    pst = [*pst, -1,  *[int(x) for x in r.predictionstring.split()]]
                start = 1
                end = 1
                for i in range(2, len(pst)):
                    cur = pst[i]
                    end = i

                    if  ((cur == -1) and ((pst[i + 1] > pst[end - 1] + thresh) or (pst[i + 1] - pst[start] > thresh2))):
                        retval.append((idv, c, jn(pst, start, end)))
                        start = i + 1
                v = (idv, c, jn(pst, start, end + 1))
                retval.append(v)
    roof = pd.DataFrame(retval, columns=['id', 'class', 'predictionstring'])
    roof = roof.merge(neoof, how='outer')
    return roof

In [None]:
def add_token_v4(row, discourse_type, thresh_list: list):

    #昇順sort
    thresh_list=sorted(thresh_list, key=lambda x: x['start'])

    #thresh_listのspanに重なりがないか確認。
    count_array=np.zeros(10000)
    for d in thresh_list:
        count_array[d['start']: d['end']+1]+=1
    if (count_array>1).sum()>=1:
        raise Exception("spans are overlapping")
    
    pred_len=len(row['predictionstring'].split())
    for d in thresh_list:
        if (row['class']==discourse_type) and (pred_len>=d['start']) and (pred_len<=d['end']):
            max_idx=int(row['predictionstring'].split()[-1])
            min_idx=int(row['predictionstring'].split()[0])
            fix_pred=' '.join([str(i) for i in range(min_idx, max_idx+d['add_num']+1)])
            return fix_pred
    #もし、どのspanにも含まれなければ、そのまま返す。
    return row['predictionstring']

In [None]:
claim_thresh_list=[
                {
    'start':1,
    'end':5,
    'add_num':1
},
             {
    'start':6,
    'end':10,
    'add_num':2
},

{
    'start':11,
    'end': 20,
    'add_num':4
    
},
]

lead_thresh_list=[
                {
    'start':7,
    'end':13,
    'add_num':6
},
             {
    'start':14,
    'end':19,
    'add_num':12
},

             {
    'start':20,
    'end':30,
    'add_num':14
},

]

position_thresh_list=[
             {
    'start':5,
    'end':15,
    'add_num':3
},
{
    'start':16,
    'end': 20,
    'add_num':2
},

]

rebuttal_thresh_list=[

             {
    'start':2,
    'end':4,
    'add_num':1
},
                      
             {
    'start':5,
    'end':13,
    'add_num':5
},

             {
    'start':14,
    'end':21,
    'add_num':7
},

             {
    'start':22,
    'end':27,
    'add_num':8
},

]

counterclaim_thresh_list=[
             {
    'start':5,
    'end':24,
    'add_num':4
},
{
    'start':25,
    'end': 37,
    'add_num':5
},

]

# conclude_thresh_list=[
#              {
#     'start':9,
#     'end':15,
#     'add_num':2
# },
# ]



evidence_thresh_list=[
             {
    'start':17,
    'end':20,
    'add_num':11
},

             {
    'start':21,
    'end':23,
    'add_num':14
},

             {
    'start':24,
    'end':29,
    'add_num':17
},

             {
    'start':30,
    'end':36,
    'add_num':20
},
]

In [None]:
proba_thresh = {
    "Lead": 0.604,
    "Position": 0.55,
    "Evidence": 0.6,
    "Claim": 0.525,
    "Concluding Statement": 0.62,
    "Counterclaim": 0.52,
    "Rebuttal": 0.535,
}

min_thresh = {
    "Lead": 7,
    "Position": 5,
    "Evidence": 17,
    "Claim": 1,
    "Concluding Statement": 9,
    "Counterclaim": 5,
    "Rebuttal": 4,
}

max_thresh = {
    "Lead": 300,
    "Position": 100,
    "Evidence": 1000,
    "Claim": 200,
    "Concluding Statement": 500,
    "Counterclaim": 1000,
    "Rebuttal": 1000,
}

submission = []
for sample_idx, sample in enumerate(ens_samples):
    preds = sample["preds"]
    offset_mapping = sample["offset_mapping"]
    sample_id = sample["id"]
    sample_text = sample["text"]
    sample_input_ids = sample["input_ids"]
    sample_pred_scores = sample["pred_scores"]
    sample_preds = []

    idx = 0
    phrase_preds = []
    while idx < len(offset_mapping):
        #はじめ
        start, _ = offset_mapping[idx]
        # if preds[idx] != "O":
        #     label = preds[idx][2:]
        if preds[idx] != "O":
            label = preds[idx][2:]
        else:
            label = "O"
        phrase_scores = []
        phrase_scores.append(sample_pred_scores[idx])

        idx += 1
        while idx < len(offset_mapping):
            if label == "O":
                matching_label = "O"
            else:
                matching_label = f"I-{label}"
            if preds[idx] == matching_label:
                _, end = offset_mapping[idx]
                phrase_scores.append(sample_pred_scores[idx])
                idx += 1
            else:
                break
        if "end" in locals():
            phrase = sample_text[start:end]
            phrase_preds.append((phrase, start, end, label, phrase_scores))

    temp_df = []
    for phrase_idx, (phrase, start, end, label, phrase_scores) in enumerate(phrase_preds):
        word_start = len(sample_text[:start].split())
        word_end = word_start + len(sample_text[start:end].split())
        word_end = min(word_end, len(sample_text.split()))
        ps = " ".join([str(x) for x in range(word_start, word_end)])
        if label != "O":
            # if sum(phrase_scores) / len(phrase_scores) >= proba_thresh[label]:
            #     if len(ps.split()) >= min_thresh[label]:
            #         temp_df.append((sample_id, label, ps))
            if sum(phrase_scores) / len(phrase_scores) >= proba_thresh[label]:
                if (len(ps.split()) >= min_thresh[label]) and (len(ps.split()) <= max_thresh[label]):
                    temp_df.append((sample_id, label, ps))
    
    temp_df = pd.DataFrame(temp_df, columns=["id", "class", "predictionstring"])
    submission.append(temp_df)

submission = pd.concat(submission).reset_index(drop=True)
submission = link_evidence(submission)
submission = link_lead(submission)
submission = link_conclude(submission)
submission = link_position(submission)
submission = link_counterclaim(submission)
submission = link_rebuttal(submission)
#submission = link_claim(submission)

# submission['predictionstring']=submission.apply(lambda x: add_token_v4(x, 'Concluding Statement', conclude_thresh_list), axis=1)
submission['predictionstring']=submission.apply(lambda x: add_token_v4(x, 'Evidence', evidence_thresh_list), axis=1)
submission['predictionstring']=submission.apply(lambda x: add_token_v4(x, 'Claim', claim_thresh_list), axis=1)#best params
submission['predictionstring']=submission.apply(lambda x: add_token_v4(x, 'Lead', lead_thresh_list), axis=1)#best params
submission['predictionstring']=submission.apply(lambda x: add_token_v4(x, 'Position', position_thresh_list), axis=1) #best params
submission['predictionstring']=submission.apply(lambda x: add_token_v4(x, 'Counterclaim', counterclaim_thresh_list), axis=1) #best params
submission['predictionstring']=submission.apply(lambda x: add_token_v4(x, 'Rebuttal', rebuttal_thresh_list), axis=1) #best params

submission.to_csv("submission.csv", index=False)

In [None]:
submission.head()