In [1]:
import os
import gc
import sys
import json
import time
import torch
import joblib
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
import plotly.express as px
import matplotlib.pyplot as plt

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
pd.set_option('display.float_format', lambda x: '%.4f' % x)

# Params

In [2]:
data_path = Path(r"/database/kaggle/Shovel Ready/data")
os.listdir(data_path)

['persuade_corpus.csv']

In [3]:
df = pd.read_csv(data_path/'persuade_corpus.csv')
df.shape

  df = pd.read_csv(data_path/'persuade_corpus.csv')


(285391, 30)

In [4]:
df.sample(5)

Unnamed: 0,essay_id,essay_id_comp,competition_set,full_text,holistic_essay_score,discourse_id,discourse_start,discourse_end,discourse_text,discourse_type,discourse_type_num,discourse_effectiveness,hierarchical_id,hierarchical_text,hierarchical_label,provider,task,source_text,prompt_name,assignment,gender,grade_level,ell_status,race_ethnicity,economically_disadvantaged,student_disability_status,essay_word_count,in_feedback2.0,test_split_feedback_1,test_split_feedback_2
172067,2171003706,7AE9C2E0846E,train,We should only have to make a c in or classes ...,3,1616793752264.0,0,77,We should only have to make a c in or classes ...,Position,Position 1,Adequate,,,,NCES,Independent,,Grades for extracurricular activities,Your principal is considering changing school ...,M,8.0,No,White,Not economically disadvantaged,Not identified as having disability,255,0,,
104882,AAAVUP14319000151770,69CD60305927,test,sciencetists say its worth to dicover even mor...,2,1617385529242.0,242,249,becuase,Unannotated,Unannotated 2,,,,,Indiana,Text dependent,"""The Challenge of Exploring Venus""",Exploring Venus,"In ""The Challenge of Exploring Venus,"" the aut...",M,10.0,Yes,Hispanic/Latino,Economically disadvantaged,Identified as having disability,225,1,,
31014,5076976,7EC6A6C5BF49,test,Many people use cars to get around every day. ...,3,1623109489459.0,457,460,and,Unannotated,Unannotated 2,,,,,Florida,Text dependent,"Source 1: ""In German Suburb, Life Goes On With...",Car-free cities,Write an explanatory essay to inform fellow ci...,M,10.0,No,Hispanic/Latino,,,390,0,,
110712,AAAVUP14319000027265,4DCE9956E3DE,train,The author states that going to Venus would be...,5,1617212012423.0,2014,2158,"Furthermore, the rest of the article including...",Claim,Claim 4,Adequate,1617211931343.0,The author states that going to Venus would be...,Position,Indiana,Text dependent,"""The Challenge of Exploring Venus""",Exploring Venus,"In ""The Challenge of Exploring Venus,"" the aut...",M,10.0,No,White,Not economically disadvantaged,Not identified as having disability,611,0,,
186750,AAAOPP13416000016509,5560F74D34AC,train,Cars. A large leap for man into technology. Bu...,3,1621006292387.0,1972,2651,They aren't totally driverless which means som...,Concluding Statement,Concluding Statement 1,Effective,,,,Indiana,Text dependent,"""Driverless Cars are Coming""",Driverless cars,"In the article “Driverless Cars are Coming,” t...",M,10.0,No,White,Not economically disadvantaged,Not identified as having disability,491,0,,


In [5]:
df['discourse_effectiveness'] = df['discourse_effectiveness'].fillna('NoEffectiveness')

In [6]:
# pip install spacy

In [7]:
import spacy
from spacy import displacy
from pylab import cm, matplotlib
import os

colors = {
            'Lead': '#8000ff',
            'Position': '#2b7ff6',
            'Evidence': '#2adddd',
            'Claim': '#80ffb4',
            'Concluding Statement': 'd4dd80',
            'Counterclaim': '#ff8042',
            'Rebuttal': '#ff0000'
         }

colors_effectiveness = {
            'Adequate': '#8000ff',
            'Effective': '#2b7ff6',
            'Ineffective': '#2adddd',
         }

def visualize(idx,train):
    
#     print(train[train['essay_id_comp'] == idx].assignment.values[0])
    ents = []
    for i, row in train[train['essay_id_comp'] == idx].iterrows():
        ents.append({
                        'start': int(row['discourse_start']), 
                         'end': int(row['discourse_end']), 
                         'label': str(row['discourse_type']) #+ ' - ' + str(row['discourse_effectiveness'])
                    })

    data = train[train['essay_id_comp'] == idx].full_text.values[0]

    doc2 = {
        "text": data,
        "ents": ents,
        "title": idx
    }

    options = {"ents": train.discourse_type.unique().tolist(), "colors": colors}
    displacy.render(doc2, style="ent", options=options, manual=True, jupyter=True)
    
def visualize_effectiveness(idx,train):
    
#     print(train[train['essay_id_comp'] == idx].assignment.values[0])
    ents = []
    for i, row in train[train['essay_id_comp'] == idx].iterrows():
        ents.append({
                        'start': int(row['discourse_start']), 
                         'end': int(row['discourse_end']), 
                         'label': str(row['discourse_effectiveness']) #+ ' - ' + str(row['discourse_effectiveness'])
                    })

    data = train[train['essay_id_comp'] == idx].full_text.str.strip().values[0]

    doc2 = {
        "text": data,
        "ents": ents,
        "title": idx
    }

    options = {"ents": train.discourse_effectiveness.unique().tolist(), "colors": colors_effectiveness}
    displacy.render(doc2, style="ent", options=options, manual=True, jupyter=True)

In [8]:
idx = random.choice(df.essay_id_comp.unique())
idx

'CC3B51667B02'

In [165]:
idx = '000A58BC095E'

In [166]:
visualize(idx,df)

In [11]:
visualize_effectiveness(idx,df)

In [167]:
df[df['essay_id_comp'] == "000A58BC095E"]

Unnamed: 0,essay_id,essay_id_comp,competition_set,full_text,holistic_essay_score,discourse_id,discourse_start,discourse_end,discourse_text,discourse_type,discourse_type_num,discourse_effectiveness,hierarchical_id,hierarchical_text,hierarchical_label,provider,task,source_text,prompt_name,assignment,gender,grade_level,ell_status,race_ethnicity,economically_disadvantaged,student_disability_status,essay_word_count,in_feedback2.0,test_split_feedback_1,test_split_feedback_2
147709,2091005889,000A58BC095E,test,"To Whom It May Concern,\n\nCommunity Service, ...",2,1614891041401.0,0,24,"To Whom It May Concern,\n\n",Unannotated,Unannotated 1,,,,,NCES,Independent,,Community service,Some of your friends perform community service...,M,8.0,No,Black/African American,Not economically disadvantaged,Not identified as having disability,249,1,,
147710,2091005889,000A58BC095E,test,"To Whom It May Concern,\n\nCommunity Service, ...",2,1614891041400.0,25,255,"Community Service, It helps other with less th...",Lead,Lead 1,Adequate,,,,NCES,Independent,,Community service,Some of your friends perform community service...,M,8.0,No,Black/African American,Not economically disadvantaged,Not identified as having disability,249,1,Public,Public
147711,2091005889,000A58BC095E,test,"To Whom It May Concern,\n\nCommunity Service, ...",2,1614891032621.0,256,376,I believe that we should have community servic...,Position,Position 1,Adequate,,,,NCES,Independent,,Community service,Some of your friends perform community service...,M,8.0,No,Black/African American,Not economically disadvantaged,Not identified as having disability,249,1,Public,Public
147712,2091005889,000A58BC095E,test,"To Whom It May Concern,\n\nCommunity Service, ...",2,1614897263598.0,377,458,Personally i enjoys helping the community and ...,Evidence,Evidence 1,Adequate,1614897345160.0,But some people don't have the time to.,Claim,NCES,Independent,,Community service,Some of your friends perform community service...,M,8.0,No,Black/African American,Not economically disadvantaged,Not identified as having disability,249,1,Public,Public
147713,2091005889,000A58BC095E,test,"To Whom It May Concern,\n\nCommunity Service, ...",2,1614897345160.0,459,498,But some people don't have the time to.,Claim,Claim 1,Adequate,1614891032621.0,I believe that we should have community servic...,Position,NCES,Independent,,Community service,Some of your friends perform community service...,M,8.0,No,Black/African American,Not economically disadvantaged,Not identified as having disability,249,1,Public,Public
147714,2091005889,000A58BC095E,test,"To Whom It May Concern,\n\nCommunity Service, ...",2,1614897353386.0,499,533,If it isn't sports or school work.,Evidence,Evidence 2,Ineffective,1614897345160.0,But some people don't have the time to.,Claim,NCES,Independent,,Community service,Some of your friends perform community service...,M,8.0,No,Black/African American,Not economically disadvantaged,Not identified as having disability,249,1,Public,Public
147715,2091005889,000A58BC095E,test,"To Whom It May Concern,\n\nCommunity Service, ...",2,1614891077726.0,534,613,Community service is one of the things that lo...,Claim,Claim 2,Adequate,1614891032621.0,I believe that we should have community servic...,Position,NCES,Independent,,Community service,Some of your friends perform community service...,M,8.0,No,Black/African American,Not economically disadvantaged,Not identified as having disability,249,1,Public,Public
147716,2091005889,000A58BC095E,test,"To Whom It May Concern,\n\nCommunity Service, ...",2,1614891086207.0,614,724,And also it may lead to good things like a bet...,Claim,Claim 3,Adequate,1614891032621.0,I believe that we should have community servic...,Position,NCES,Independent,,Community service,Some of your friends perform community service...,M,8.0,No,Black/African American,Not economically disadvantaged,Not identified as having disability,249,1,Public,Public
147717,2091005889,000A58BC095E,test,"To Whom It May Concern,\n\nCommunity Service, ...",2,1614891162215.0,725,878,Community service doesn't only have to be in y...,Evidence,Evidence 3,Adequate,,,,NCES,Independent,,Community service,Some of your friends perform community service...,M,8.0,No,Black/African American,Not economically disadvantaged,Not identified as having disability,249,1,Public,Public
147718,2091005889,000A58BC095E,test,"To Whom It May Concern,\n\nCommunity Service, ...",2,1614891136008.0,879,1295,To help the community should be a choice but i...,Concluding Statement,Concluding Statement 1,Effective,,,,NCES,Independent,,Community service,Some of your friends perform community service...,M,8.0,No,Black/African American,Not economically disadvantaged,Not identified as having disability,249,1,Public,Public


In [122]:
for i, row in df[df['essay_id'] == idx].iterrows():
    print(row['discourse_text'])
    print('---------------------------\n')

Dear Principal,


---------------------------

I think you should use Policy 1 and allow students to bring their phones to school and use them during lunch periods and other free times, as long as the phones are turned off during class time. 
---------------------------

It would give students a way to communicate to each other, but also allow them to focus on their work during class. 
---------------------------

If any student is caught using their phones during class, 
---------------------------

---------------------------

It also gives students a way to communicate with their parents if anything goes wrong at school. 
---------------------------

For example if a student forgot something at home and they need their parent to drop it off at their school, they could use their cell phones to contact them. 
---------------------------

Or, if a student isn't feeling well and they need to go home, they can contact their parent to bring them home, with permission from the health teach

# Dataset

In [13]:
import torch
import random
import numpy as np
import pandas as pd
from pathlib import Path

from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer, AutoModel, AutoConfig

%env TOKENIZERS_PARALLELISM = true


import warnings
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 500)
warnings.simplefilter("ignore")

env: TOKENIZERS_PARALLELISM=true


In [14]:
import re
from difflib import SequenceMatcher

import codecs
import os
from collections import Counter
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from text_unidecode import unidecode
from tqdm.notebook import tqdm

def replace_encoding_with_utf8(error: UnicodeError) -> Tuple[bytes, int]:
    return error.object[error.start : error.end].encode("utf-8"), error.end


def replace_decoding_with_cp1252(error: UnicodeError) -> Tuple[str, int]:
    return error.object[error.start : error.end].decode("cp1252"), error.end


# Register the encoding and decoding error handlers for `utf-8` and `cp1252`.
codecs.register_error("replace_encoding_with_utf8", replace_encoding_with_utf8)
codecs.register_error("replace_decoding_with_cp1252", replace_decoding_with_cp1252)



def resolve_encodings_and_normalize(text: str) -> str:
    """Resolve the encoding problems and normalize the abnormal characters."""
    text = (
        text.encode("raw_unicode_escape")
        .decode("utf-8", errors="replace_decoding_with_cp1252")
        .encode("cp1252", errors="replace_encoding_with_utf8")
        .decode("utf-8", errors="replace_decoding_with_cp1252")
    )
    text = unidecode(text)
    return text


def clean_text(text):
    text = resolve_encodings_and_normalize(text)
    text = text.replace(u'\xa0', u' ')
    text = text.replace(u'\x85', u'\n')
    text = text.strip()
    return text

In [15]:
def get_text_start_end(txt,s,search_from=0):
    txt = txt[int(search_from):]
    try:
        idx = txt.find(s)
        if idx>=0:
            st=idx
            ed = st+len(s)
        else:
            raise ValueError('Error')
    except:                
        res = [(m.start(0), m.end(0)) for m in re.finditer(s, txt)]
        if len(res):
            st,ed = res[0][0],res[0][1]
        else:
            m = SequenceMatcher(None, s,txt).get_opcodes()
            for tag,i1,i2,j1,j2 in m:
                if tag=='replace':
                    s = s[:i1]+txt[j1:j2]+s[i2:]
                if tag=="delete":
                    s = s[:i1]+s[i2:]
            
            res = [(m.start(0), m.end(0)) for m in re.finditer(s,txt)]
            if len(res):
                st,ed = res[0][0],res[0][1]
            else:
                idx = txt.find(s)
                if idx>=0:
                    st=idx
                    ed = st+len(s)
                else:
                    st,ed = 0,0
    return st+search_from,ed+search_from

def get_start_end(col):
    def search_start_end(row):
        txt = row.full_text
        search_from = row.previous_discourse_end
        s = row[col]
        # print(search_from)
        return get_text_start_end(txt,s,search_from)
    return search_start_end

In [16]:
def add_text_to_df(test_df):
    test_df['discourse_text'] = test_df['discourse_text'].transform(clean_text)
    test_df['discourse_text'] = test_df['discourse_text'].str.strip()

    test_df['previous_discourse_end'] = 0
    test_df['st_ed'] = test_df.apply(get_start_end('discourse_text'),axis=1)
    test_df['discourse_start'] = test_df['st_ed'].transform(lambda x:x[0])
    test_df['discourse_end'] = test_df['st_ed'].transform(lambda x:x[1])
    test_df['previous_discourse_end'] = test_df.groupby("essay_id")['discourse_end'].transform(lambda x:x.shift(1).fillna(0)).astype(int)
    test_df['st_ed'] = test_df.apply(get_start_end('discourse_text'),axis=1)
    test_df['discourse_start'] = test_df['st_ed'].transform(lambda x:x[0]) #+ test_df['previous_discourse_end']
    test_df['discourse_end'] = test_df['st_ed'].transform(lambda x:x[1]) #+ test_df['previous_discourse_end']
    return test_df

In [151]:
test_df = df[df['essay_id_comp'] == '0000D23A521A']
test_df.shape

(8, 30)

In [152]:
test_df.head(2)

Unnamed: 0,essay_id,essay_id_comp,competition_set,full_text,holistic_essay_score,discourse_id,discourse_start,discourse_end,discourse_text,discourse_type,discourse_type_num,discourse_effectiveness,hierarchical_id,hierarchical_text,hierarchical_label,provider,task,source_text,prompt_name,assignment,gender,grade_level,ell_status,race_ethnicity,economically_disadvantaged,student_disability_status,essay_word_count,in_feedback2.0,test_split_feedback_1,test_split_feedback_2
140907,AAAOPP13416000075452,0000D23A521A,train,"Some people belive that the so called ""face"" o...",3,1617734767734.0,0,169,"Some people belive that the so called ""face"" o...",Position,Position 1,Adequate,,,,Indiana,Text dependent,"""Unmasking the Face on Mars""",The Face on Mars,You have read the article 'Unmasking the Face ...,M,8.0,No,White,Not economically disadvantaged,Not identified as having disability,248,0,,
140908,AAAOPP13416000075452,0000D23A521A,train,"Some people belive that the so called ""face"" o...",3,1617734782429.0,170,356,"It was not created by aliens, and there is no ...",Evidence,Evidence 1,Adequate,1617734767734.0,"Some people belive that the so called ""face"" o...",Position,Indiana,Text dependent,"""Unmasking the Face on Mars""",The Face on Mars,You have read the article 'Unmasking the Face ...,M,8.0,No,White,Not economically disadvantaged,Not identified as having disability,248,0,,


In [153]:
test_df['full_text'] = test_df['full_text'].transform(clean_text)
test_df['discourse_text'] = test_df['discourse_text'].transform(clean_text)

test_df['previous_discourse_end'] = 0
test_df['st_ed'] = test_df.apply(get_start_end('discourse_text'),axis=1)
test_df['discourse_start'] = test_df['st_ed'].transform(lambda x:x[0])
test_df['discourse_end'] = test_df['st_ed'].transform(lambda x:x[1])
test_df['previous_discourse_end'] = test_df.groupby("essay_id")['discourse_end'].transform(lambda x:x.shift(1).fillna(0)).astype(int)
test_df['st_ed'] = test_df.apply(get_start_end('discourse_text'),axis=1)
test_df['discourse_start'] = test_df['st_ed'].transform(lambda x:x[0]) #+ test_df['previous_discourse_end']
test_df['discourse_end'] = test_df['st_ed'].transform(lambda x:x[1]) #+ test_df['previous_discourse_end']

In [154]:
test_df.head(2)

Unnamed: 0,essay_id,essay_id_comp,competition_set,full_text,holistic_essay_score,discourse_id,discourse_start,discourse_end,discourse_text,discourse_type,discourse_type_num,discourse_effectiveness,hierarchical_id,hierarchical_text,hierarchical_label,provider,task,source_text,prompt_name,assignment,gender,grade_level,ell_status,race_ethnicity,economically_disadvantaged,student_disability_status,essay_word_count,in_feedback2.0,test_split_feedback_1,test_split_feedback_2,previous_discourse_end,st_ed
140907,AAAOPP13416000075452,0000D23A521A,train,"Some people belive that the so called ""face"" o...",3,1617734767734.0,0,169,"Some people belive that the so called ""face"" o...",Position,Position 1,Adequate,,,,Indiana,Text dependent,"""Unmasking the Face on Mars""",The Face on Mars,You have read the article 'Unmasking the Face ...,M,8.0,No,White,Not economically disadvantaged,Not identified as having disability,248,0,,,0,"(0, 169)"
140908,AAAOPP13416000075452,0000D23A521A,train,"Some people belive that the so called ""face"" o...",3,1617734782429.0,170,356,"It was not created by aliens, and there is no ...",Evidence,Evidence 1,Adequate,1617734767734.0,"Some people belive that the so called ""face"" o...",Position,Indiana,Text dependent,"""Unmasking the Face on Mars""",The Face on Mars,You have read the article 'Unmasking the Face ...,M,8.0,No,White,Not economically disadvantaged,Not identified as having disability,248,0,,,169,"(170, 356)"


In [155]:
visualize(idx,test_df)

In [156]:
visualize(idx,df)

# Dataset

In [22]:
import re
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset

from tqdm.auto import tqdm


LABEL2TYPE = ('Lead', 'Position', 'Claim', 'Counterclaim', 'Rebuttal',
              'Evidence', 'Concluding Statement')
TYPE2LABEL = {t: l for l, t in enumerate(LABEL2TYPE)}


LABEL2EFFEC = ('Adequate', 'Effective', 'Ineffective')
EFFEC2LABEL = {t: l for l, t in enumerate(LABEL2EFFEC)}

## =============================================================================== ##
class FeedbackDataset(Dataset):
    def __init__(self,
                 df,
                 tokenizer,
                 mask_prob=0.0,
                 mask_ratio=0.0,
                 ):
        self.df = self.prepare_df(df)
        self.samples = list(self.df.groupby('essay_id_comp'))
        self.tokenizer = tokenizer
        
        print(f'Loaded {len(self)} samples.')

        assert 0 <= mask_prob <= 1
        assert 0 <= mask_ratio <= 1
        self.mask_prob = mask_prob
        self.mask_ratio = mask_ratio

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        text_id, df = self.samples[index]
        text = df['full_text'].values[0]

        tokens = self.tokenizer(text, return_offsets_mapping=True)
        input_ids = torch.LongTensor(tokens['input_ids'])
        attention_mask = torch.LongTensor(tokens['attention_mask'])
        offset_mapping = np.array(tokens['offset_mapping'])
        offset_mapping = self.strip_offset_mapping(text, offset_mapping)
        num_tokens = len(input_ids)

        # token slices of words
        woff = self.get_word_offsets(text)
        toff = offset_mapping
        wx1, wx2 = woff.T
        tx1, tx2 = toff.T
        ix1 = np.maximum(wx1[..., None], tx1[None, ...])
        ix2 = np.minimum(wx2[..., None], tx2[None, ...])
        ux1 = np.minimum(wx1[..., None], tx1[None, ...])
        ux2 = np.maximum(wx2[..., None], tx2[None, ...])
        ious = (ix2 - ix1).clip(min=0) / (ux2 - ux1 + 1e-12)
        assert (ious > 0).any(-1).all()

        word_boxes = []
        for row in ious:
            inds = row.nonzero()[0]
            word_boxes.append([inds[0], 0, inds[-1] + 1, 1])
        word_boxes = torch.FloatTensor(word_boxes)

        # word slices of ground truth spans
        gt_spans = []
        for _, row in df.iterrows():
            
            word_start = len(row['full_text'][:row['discourse_start']].split())
            word_end = word_start + len(row['full_text'][row['discourse_start']:row['discourse_end']].split())
            word_end = min(word_end, len(row['full_text'].split()))
            
            span_label = TYPE2LABEL[row['discourse_type']]
            span_effectiveness = EFFEC2LABEL[row['discourse_effectiveness']]
            gt_spans.append([word_start,word_end,span_label,span_effectiveness])
            
        gt_spans = torch.LongTensor(gt_spans)

        # random mask augmentation
        if np.random.random() < self.mask_prob:
            all_inds = np.arange(1, len(input_ids) - 1)
            n_mask = max(int(len(all_inds) * self.mask_ratio), 1)
            np.random.shuffle(all_inds)
            mask_inds = all_inds[:n_mask]
            input_ids[mask_inds] = self.tokenizer.mask_token_id

        return dict(text=text,
                    text_id=text_id,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    word_boxes=word_boxes,
                    gt_spans=gt_spans)
    
    def prepare_df(self,test_df):
        test_df['full_text'] = test_df['full_text'].transform(clean_text)        
        test_df['discourse_text'] = test_df['discourse_text'].transform(clean_text)
        test_df['previous_discourse_end'] = 0
        test_df['st_ed'] = test_df.apply(get_start_end('discourse_text'),axis=1)
        test_df['discourse_start'] = test_df['st_ed'].transform(lambda x:x[0])
        test_df['discourse_end'] = test_df['st_ed'].transform(lambda x:x[1])
        
        test_df['previous_discourse_end'] = test_df.groupby("essay_id_comp")['discourse_end'].transform(lambda x:x.shift(1).fillna(0)).astype(int)
        test_df['st_ed'] = test_df.apply(get_start_end('discourse_text'),axis=1)
        
        test_df['discourse_start'] = test_df['st_ed'].transform(lambda x:x[0]) #+ test_df['previous_discourse_end']
        test_df['discourse_end'] = test_df['st_ed'].transform(lambda x:x[1]) #+ test_df['previous_discourse_end']
        return test_df
    
    def strip_offset_mapping(self, text, offset_mapping):
        ret = []
        for start, end in offset_mapping:
            match = list(re.finditer('\S+', text[start:end]))
            if len(match) == 0:
                ret.append((start, end))
            else:
                span_start, span_end = match[0].span()
                ret.append((start + span_start, start + span_end))
        return np.array(ret)

    def get_word_offsets(self, text):
        matches = re.finditer("\S+", text)
        spans = []
        words = []
        for match in matches:
            span = match.span()
            word = match.group()
            spans.append(span)
            words.append(word)
        assert tuple(words) == tuple(text.split())
        return np.array(spans)

In [57]:
class CustomCollator(object):
    def __init__(self, tokenizer, model):
        self.pad_token_id = tokenizer.pad_token_id
        if hasattr(model.config, 'attention_window'):
            # For longformer
            # https://github.com/huggingface/transformers/blob/v4.17.0/src/transformers/models/longformer/modeling_longformer.py#L1548
            self.attention_window = (model.config.attention_window
                                     if isinstance(
                                         model.config.attention_window, int)
                                     else max(model.config.attention_window))
        else:
            self.attention_window = None

    def __call__(self, samples):
        batch_size = len(samples)
        assert batch_size == 1, f'Only batch_size=1 supported, got batch_size={batch_size}.'

        sample = samples[0]

        max_seq_length = len(sample['input_ids'])
        if self.attention_window is not None:
            attention_window = self.attention_window
            padded_length = (attention_window -
                             max_seq_length % attention_window
                             ) % attention_window + max_seq_length
        else:
            padded_length = max_seq_length

        input_shape = (1, padded_length)
        input_ids = torch.full(input_shape,
                               self.pad_token_id,
                               dtype=torch.long)
        attention_mask = torch.zeros(input_shape, dtype=torch.long)

        seq_length = len(sample['input_ids'])
        input_ids[0, :seq_length] = sample['input_ids']
        attention_mask[0, :seq_length] = sample['attention_mask']

        text_id = sample['text_id']
        text = sample['text']
        word_boxes = sample['word_boxes']
        gt_spans = sample['gt_spans']

        return dict(text_id=text_id,
                    text=text,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    word_boxes=word_boxes,
                    gt_spans=gt_spans)

In [24]:
len(LABEL2TYPE)

7

In [218]:
df.head(2)

Unnamed: 0,essay_id,essay_id_comp,competition_set,full_text,holistic_essay_score,discourse_id,discourse_start,discourse_end,discourse_text,discourse_type,discourse_type_num,discourse_effectiveness,hierarchical_id,hierarchical_text,hierarchical_label,provider,task,source_text,prompt_name,assignment,gender,grade_level,ell_status,race_ethnicity,economically_disadvantaged,student_disability_status,essay_word_count,in_feedback2.0,test_split_feedback_1,test_split_feedback_2
0,5408891152126,423A1CA112E2,train,Phones\n\nModern humans today are always on th...,3,1622627660525.0,0,7,Phones\n\n,Unannotated,Unannotated 1,NoEffectiveness,,,,Georgia Virtual,Independent,,Phones and driving,Today the majority of humans own and operate c...,M,,,Black/African American,,,378,1,,
1,5408891152126,423A1CA112E2,train,Phones\n\nModern humans today are always on th...,3,1622627660524.0,8,229,Modern humans today are always on their phone....,Lead,Lead 1,Adequate,,,,Georgia Virtual,Independent,,Phones and driving,Today the majority of humans own and operate c...,M,,,Black/African American,,,378,1,,


In [219]:
df_train = df[df.competition_set=='train']
df_train = df_train[df_train.discourse_effectiveness.isin(LABEL2EFFEC)]
df_train.shape

(144289, 30)

In [220]:
df.test_split_feedback_2.value_counts()

test_split_feedback_2
Private    13868
Public      9122
Name: count, dtype: int64

In [227]:
df_test = df[df.test_split_feedback_2=='Public']
df_test = df_test[df_test.discourse_effectiveness.isin(LABEL2EFFEC)]
df_test.shape

(9122, 30)

In [228]:
df_test.discourse_effectiveness.unique()

array(['Effective', 'Adequate', 'Ineffective'], dtype=object)

In [229]:
df_test

Unnamed: 0,essay_id,essay_id_comp,competition_set,full_text,holistic_essay_score,discourse_id,discourse_start,discourse_end,discourse_text,discourse_type,discourse_type_num,discourse_effectiveness,hierarchical_id,hierarchical_text,hierarchical_label,provider,task,source_text,prompt_name,assignment,gender,grade_level,ell_status,race_ethnicity,economically_disadvantaged,student_disability_status,essay_word_count,in_feedback2.0,test_split_feedback_1,test_split_feedback_2
439,5947221152126,E4FA060FF3E3,test,Don't Touch the Screen\n\nAs we drove to the f...,6,1622402403413.0000,24,741,As we drove to the final basketball game of th...,Lead,Lead 1,Effective,,,,Georgia Virtual,Independent,,Phones and driving,Today the majority of humans own and operate c...,F,,,White,,,614,1,Public,Public
440,5947221152126,E4FA060FF3E3,test,Don't Touch the Screen\n\nAs we drove to the f...,6,1622402291297.0000,742,865,"As a result of the risk to the public, drivers...",Position,Position 1,Effective,,,,Georgia Virtual,Independent,,Phones and driving,Today the majority of humans own and operate c...,F,,,White,,,614,1,Public,Public
441,5947221152126,E4FA060FF3E3,test,Don't Touch the Screen\n\nAs we drove to the f...,6,1622402314102.0000,866,1199,Operating a vehicle requires attention and foc...,Evidence,Evidence 1,Effective,1622402309853.0000,Texting while driving is particularly harmful...,Claim,Georgia Virtual,Independent,,Phones and driving,Today the majority of humans own and operate c...,F,,,White,,,614,1,Public,Public
442,5947221152126,E4FA060FF3E3,test,Don't Touch the Screen\n\nAs we drove to the f...,6,1622402309853.0000,1200,1305,Texting while driving is particularly harmful ...,Claim,Claim 1,Effective,1622402291297.0000,"As a result of the risk to the public, driver...",Position,Georgia Virtual,Independent,,Phones and driving,Today the majority of humans own and operate c...,F,,,White,,,614,1,Public,Public
443,5947221152126,E4FA060FF3E3,test,Don't Touch the Screen\n\nAs we drove to the f...,6,1622402321238.0000,1306,1574,According to the National Highway Traffic Safe...,Evidence,Evidence 2,Effective,1622402309853.0000,Texting while driving is particularly harmful...,Claim,Georgia Virtual,Independent,,Phones and driving,Today the majority of humans own and operate c...,F,,,White,,,614,1,Public,Public
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
278118,AAAXMP138200001471962810_OR,9921880270E3,test,Have you every though about seeking advice for...,3,1617637084641.0000,931,1118,Having multiple opinions form different people...,Claim,Claim 5,Effective,1617636580975.0000,Seeking advice is a great thing to do and is v...,Position,Virginia,Independent,,Seeking multiple opinions,"When people ask for advice, they sometimes tal...",F,8.0000,Yes,Black/African American,Not economically disadvantaged,Not identified as having disability,521,1,Public,Public
278119,AAAXMP138200001471962810_OR,9921880270E3,test,Have you every though about seeking advice for...,3,1617637150089.0000,1119,2005,Seeing the pro and cons of one persons advice ...,Evidence,Evidence 2,Adequate,1617637084641.0000,Having multiple opinions form different people...,Claim,Virginia,Independent,,Seeking multiple opinions,"When people ask for advice, they sometimes tal...",F,8.0000,Yes,Black/African American,Not economically disadvantaged,Not identified as having disability,521,1,Public,Public
278121,AAAXMP138200001471962810_OR,9921880270E3,test,Have you every though about seeking advice for...,3,1617637255255.0000,2015,2080,having multiple opinion on a single topic make...,Claim,Claim 6,Adequate,1617636580975.0000,Seeking advice is a great thing to do and is v...,Position,Virginia,Independent,,Seeking multiple opinions,"When people ask for advice, they sometimes tal...",F,8.0000,Yes,Black/African American,Not economically disadvantaged,Not identified as having disability,521,1,Public,Public
278122,AAAXMP138200001471962810_OR,9921880270E3,test,Have you every though about seeking advice for...,3,1617637333546.0000,2081,2442,Being confident on what your doing helps peopl...,Evidence,Evidence 3,Ineffective,1617637255255.0000,having multiple opinion on a single topic make...,Claim,Virginia,Independent,,Seeking multiple opinions,"When people ask for advice, they sometimes tal...",F,8.0000,Yes,Black/African American,Not economically disadvantaged,Not identified as having disability,521,1,Public,Public


In [26]:
model_name = 'microsoft/deberta-v3-large'
tokenizer = AutoTokenizer.from_pretrained(model_name)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [230]:
ds = FeedbackDataset(df_test,
                 tokenizer)

Loaded 1115 samples.


In [232]:
from torch.utils.data import DataLoader
collator = CustomCollator(tokenizer,net)
dl = DataLoader(ds,batch_size=1,collate_fn=collator)

In [233]:
for b in dl:
    print(b["gt_spans"][:,3])

tensor([0, 0, 0, 0, 2, 0, 0, 0, 1])
tensor([2, 0, 0, 2, 0])
tensor([2, 0, 2, 0, 0, 0])
tensor([0, 0, 2, 0, 2, 0, 0])
tensor([2, 2, 0, 0, 0, 0, 0, 2])
tensor([0, 0, 0, 0, 0, 2, 2, 0, 2, 2])
tensor([0, 0, 0, 2, 0, 2, 0, 0, 0])
tensor([0, 0, 0, 0, 2, 0, 0, 0, 0])
tensor([2, 2, 2, 2, 2, 2])
tensor([0, 0, 2, 0])
tensor([2, 2, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 2, 2, 0, 0])
tensor([1, 1, 1, 0, 1, 0, 1, 1, 0])
tensor([1, 1, 0, 1, 0, 1, 1, 0, 1])
tensor([0, 0, 0, 0, 0, 0, 2, 1])
tensor([0, 0, 0, 0, 2, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0])
tensor([0, 0, 1, 1, 1, 1, 1, 1, 1])
tensor([2, 2, 2])
tensor([0, 0, 2, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 2])
tensor([1, 0, 1, 1, 1, 1, 0])
tensor([1, 0, 1, 1, 1, 1, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 0])
tensor([1, 1, 1, 0, 1, 1, 1, 1, 0, 1])
tensor([1, 1, 1, 0, 1, 1, 1, 0])
tensor([1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1])
tensor([0, 0, 2, 2, 1, 2, 0])
tensor([0, 0, 0, 0, 0, 0,

tensor([1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1])
tensor([0, 0, 0, 2, 0, 2, 0, 2, 0])
tensor([0, 2, 0, 0])
tensor([0, 1, 0, 0, 0, 0, 2, 0, 2, 2, 0, 0])
tensor([1, 1, 1, 0])
tensor([1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 2, 0, 2])
tensor([1, 0, 0, 1, 1, 0, 2, 0, 0, 0])
tensor([2, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 2])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0])
tensor([2, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 2])
tensor([0, 2, 2, 2])
tensor([1, 0, 0, 2, 0, 2, 0, 1, 0])
tensor([0, 0, 0, 0, 0, 0, 2])
tensor([0, 2, 0, 0, 0])
tensor([0, 2, 1, 0])
tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 1, 0, 0])
tensor([2, 2, 2])
tensor([2, 0, 1, 0, 1, 1, 0, 0, 0, 0])
tensor([2, 2, 2, 2])
tensor([0, 0, 2, 0, 2, 2, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 2])
tensor([2, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 2])
tensor([0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0])
tensor([2, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 2, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 0, 0,

tensor([1, 1, 1, 0, 0, 1, 0, 1])
tensor([1, 0, 0, 1, 1, 1, 1, 1])
tensor([0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1])
tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 0, 0, 0, 2, 2, 0])
tensor([0, 0, 0, 0, 2, 2, 2])
tensor([0, 0, 2, 0, 0, 0])
tensor([0, 0, 2, 0, 0, 0, 0])
tensor([0, 0, 2, 0, 0, 2, 0])
tensor([0, 2, 0, 2, 0, 2, 2, 0, 0])
tensor([0, 1, 1, 0, 1, 0, 1, 1, 1])
tensor([0, 0, 0, 2, 2, 2, 2])
tensor([0, 2, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 2, 0, 0, 0, 0, 0, 0])
tensor([1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 1, 0, 1, 1, 2, 2, 0])
tensor([0, 2, 0, 0, 0, 0, 2, 0])
tensor([0, 0, 0, 2, 2, 0])
tensor([0, 2, 0, 0, 0, 2, 0, 0, 0])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 0, 2, 0, 0, 2, 2, 0])
tensor([0, 0, 1, 1, 0, 2, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 1])
tensor([2, 2, 2, 0, 2])
tensor([0, 0, 0, 2, 0])
tensor([0, 0, 0, 2, 0, 0, 0, 2, 0, 2, 2])
tensor([2, 2, 0, 0])
tensor([0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0,

tensor([0, 0, 2, 2, 0])
tensor([1, 1, 1, 1, 1, 0])
tensor([0, 1, 1, 0, 1, 1, 1, 1, 1])
tensor([2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([2, 0, 0, 0])
tensor([1, 1, 1, 1])
tensor([0, 0, 0, 0, 2, 0, 0, 0, 0, 0])
tensor([1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([0, 0, 2, 2])
tensor([0, 2, 0, 0, 0, 2, 0])
tensor([0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0])
tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([0, 0, 2, 0, 0, 0, 2])
tensor([0, 0, 2, 0, 0])
tensor([1, 1, 0, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 2, 2, 0, 0, 0])
tensor([0, 1, 1, 1, 1, 1, 1, 1, 1])
tensor([0, 0, 0, 0, 1, 2])
tensor([0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0])
tensor([2, 2, 2])
tensor([0, 0, 0, 0, 0, 0, 0, 2, 0])
tensor([0, 2, 2])
tensor([2])
tensor([2])
tensor([2])
tensor([0, 2, 0, 0, 0, 0])
tensor([0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 2, 2, 0, 0])
tensor([2, 0, 0, 2, 0, 0, 0, 0, 0, 0])
tensor([2, 0, 0, 1, 0, 0, 0, 0, 0, 0])
tensor([2])
tensor([0, 2, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0])
tens

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0])
tensor([0, 0, 1, 1, 1, 1, 1, 1, 0])
tensor([1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1])
tensor([0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1])
tensor([0, 0, 0, 0, 0, 2, 0, 2, 0, 2, 0, 0, 0])
tensor([0, 2, 0, 0, 0])
tensor([0, 2, 0, 0, 0, 1, 1, 1, 0, 1])
tensor([0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2])
tensor([0, 0, 0, 0, 0, 2])
tensor([0, 0, 2, 2, 0, 2, 0, 0, 2, 2, 0, 0, 0, 2, 0, 0, 2])
tensor([2])
tensor([2, 0, 0, 1, 0, 1, 0])
tensor([0, 0, 0, 2, 0, 0, 0, 2])
tensor([2, 0, 2, 2, 0, 2, 2, 0])
tensor([0, 1, 0, 0, 2, 0, 1, 1, 0, 1, 0, 0, 0])
tensor([0, 0, 0, 0, 2, 2, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0])
tensor([0, 0, 0, 0, 0, 2, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 2])
tensor([0, 0, 0, 0, 0, 2, 0])
tensor([0, 0, 0, 2, 0, 0, 0])
tensor([0, 0, 0, 0, 2])
tensor([0, 0, 0, 0, 0, 0, 0, 2, 0])
tensor([0, 0, 2])
tensor([0, 0, 2, 2, 0, 0, 0, 0, 0])
tensor([2, 1, 0, 0, 0, 0, 1, 0])
tensor([0, 2, 0])
tensor([0, 0, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0])
tensor([0,

In [60]:
b['input_ids'].shape

torch.Size([1, 296])

In [28]:
visualize(ds[0]['text_id'],ds.df)

In [29]:
i = 0
ds[0]['text'].split()[ds[0]["gt_spans"][i][:2][0]:ds[0]["gt_spans"][i][:2][1]]

['Some',
 'people',
 'belive',
 'that',
 'the',
 'so',
 'called',
 '"face"',
 'on',
 'mars',
 'was',
 'created',
 'by',
 'life',
 'on',
 'mars.',
 'This',
 'is',
 'not',
 'the',
 'case.',
 'The',
 'face',
 'on',
 'Mars',
 'is',
 'a',
 'naturally',
 'occuring',
 'land',
 'form',
 'called',
 'a',
 'mesa.']

In [51]:
def to_gpu(data):
    if isinstance(data, dict):
        return {k: to_gpu(v) for k, v in data.items()}
    elif isinstance(data, list):
        return [to_gpu(v) for v in data]
    elif isinstance(data, torch.Tensor):
        return data.cuda()
    else:
        return data


def to_np(t):
    if isinstance(t, torch.Tensor):
        return t.data.cpu().numpy()
    else:
        return t


In [148]:
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
import torch.utils.checkpoint


from torchvision.ops import roi_align, nms

def aggregate_tokens_to_words(feat, word_boxes):
    feat = feat.permute(0, 2, 1).unsqueeze(2)
    output = roi_align(feat, [word_boxes], 1, aligned=True)
    return output.squeeze(-1).squeeze(-1)


def span_nms(start, end, score, nms_thr=0.5):
    boxes = torch.stack(
        [
            start,
            torch.zeros_like(start),
            end,
            torch.ones_like(start),
        ],
        dim=1,
    ).float()
    keep = nms(boxes, score, nms_thr)
    return keep


@torch.no_grad()
def build_target(gt_spans, obj_pred, reg_pred, cls_pred,eff_pred,dynamic_positive=False):

    obj_target = torch.zeros_like(obj_pred)
    reg_target = torch.zeros_like(reg_pred)
    cls_target = torch.zeros_like(cls_pred)
    eff_target = torch.zeros_like(eff_pred)

    # first token as positive
    pos_loc = gt_spans[:, 0]
    obj_target[pos_loc] = 1
    reg_target[pos_loc, 0] = gt_spans[:, 0].float()
    reg_target[pos_loc, 1] = gt_spans[:, 1].float()
    cls_target[pos_loc, gt_spans[:, 2]] = 1
    eff_target[pos_loc, gt_spans[:, 3]] = 1

    # dynamically assign one more positive
    if dynamic_positive:
        cls_prob = (obj_pred.sigmoid().unsqueeze(1) *
                    cls_pred.sigmoid()).sqrt()

        eff_prob = (obj_pred.sigmoid().unsqueeze(1) *
                    eff_pred.sigmoid()).sqrt()

        for start, end, label,eff_label in gt_spans:
            _cls_prob = cls_prob[start:end]
            _cls_gt = _cls_prob.new_full((_cls_prob.size(0), ),
                                         label,
                                         dtype=torch.long)
            _cls_gt = F.one_hot(
                _cls_gt, num_classes=_cls_prob.size(1)).type_as(_cls_prob)
            cls_cost = F.binary_cross_entropy(_cls_prob,
                                              _cls_gt,
                                              reduction='none').sum(-1)

            _eff_prob = eff_prob[start:end]
            _eff_gt = _eff_prob.new_full((_eff_prob.size(0), ),
                                         eff_label,
                                         dtype=torch.long)
            _eff_gt = F.one_hot(
                _eff_gt, num_classes=_eff_prob.size(1)).type_as(_eff_prob)
            eff_cost = F.binary_cross_entropy(_eff_prob,
                                              _eff_gt,
                                              reduction='none').sum(-1)


            _reg_pred = reg_pred[start:end].exp()
            _reg_loc = torch.arange(_reg_pred.size(0),
                                    device=_reg_pred.device)

            px1 = _reg_loc - _reg_pred[:, 0]
            px2 = _reg_loc + _reg_pred[:, 1]
            ix1 = torch.max(px1, _reg_loc[0])
            ix2 = torch.min(px2, _reg_loc[-1])
            ux1 = torch.min(px1, _reg_loc[0])
            ux2 = torch.max(px2, _reg_loc[-1])
            inter = (ix2 - ix1).clamp(min=0)
            union = (ux2 - ux1).clamp(min=0) + 1e-12
            iou = inter / union
            iou_cost = -torch.log(iou + 1e-12)

            cost = cls_cost + eff_cost + iou_cost

            pos_ind = start + cost.argmin()

            obj_target[pos_ind] = 1
            reg_target[pos_ind, 0] = start
            reg_target[pos_ind, 1] = end
            cls_target[pos_ind, label] = 1
            eff_target[pos_ind, eff_label] = 1

        pos_loc = (obj_target == 1).nonzero().flatten()
        
    return obj_target, reg_target,cls_target,eff_target,pos_loc

class FeedbackModel(nn.Module):
    def __init__(self,
                 model_name,
                 num_labels,
                 config_path=None,
                 pretrained_path = None,
                 use_dropout=False,
                 use_gradient_checkpointing = False
                 ):
        super().__init__()
        self.pretrained_path = pretrained_path
        self.config = AutoConfig.from_pretrained(model_name, output_hidden_states=True) if not config_path else torch.load(config_path)

        self.use_dropout = use_dropout
        if not self.use_dropout:
            self.config.update(
                                {
                                    "hidden_dropout_prob": 0.0,
                                    "attention_probs_dropout_prob": 0.0,
                                }
                                    )

        self.backbone = AutoModel.from_pretrained(model_name,config=self.config) if not config_path else AutoModel.from_config(self.config)
        
        if use_gradient_checkpointing:
            self.backbone.gradient_checkpointing_enable()

        
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.fc = nn.Linear(self.config.hidden_size, num_labels)
        
        self.fc.bias.data[0].fill_(bias_init_with_prob(0.02))
        self.fc.bias.data[3:-3].fill_(bias_init_with_prob(1 / 7))
        self.fc.bias.data[-3:].fill_(bias_init_with_prob(1 / 3))


    def forward(self,b):
        x = self.backbone(b["input_ids"].reshape(1,-1),b["attention_mask"].reshape(1,-1)).last_hidden_state
        x = self.dropout(x)
        x = self.fc(x)
#         print(x.shape)
        x = aggregate_tokens_to_words(x, b['word_boxes'])
#         print(x.shape)
        obj_pred = x[..., 0]
        reg_pred = x[..., 1:3]
        tpe_pred = x[..., 3:-3]
        eff_pred = x[..., -3:]
        
        return obj_pred, reg_pred, tpe_pred,eff_pred
    
def predict(model,data, test_score_thr=0.5):
    
    obj_pred, reg_pred, cls_pred,eff_pred = model(data)
    obj_pred = obj_pred.sigmoid()
    reg_pred = reg_pred.exp()
    cls_pred = cls_pred.sigmoid()
    eff_pred = eff_pred.sigmoid()

    obj_scores = obj_pred
    cls_scores, cls_labels = cls_pred.max(-1)
    eff_scores, eff_labels = eff_pred.max(-1)
    
    pr_scores = (obj_scores * cls_scores)**0.5
    pr_eff_score = (obj_scores * eff_scores)**0.5
    eff_pred_score = (obj_scores.reshape(-1,1) * eff_pred)**0.5
    pos_inds = pr_scores > test_score_thr

    if pos_inds.sum() == 0:
        return None

    pr_score,pr_label,pr_eff = pr_scores[pos_inds], cls_labels[pos_inds], eff_labels[pos_inds]
    pos_loc = pos_inds.nonzero().flatten()
    start = pos_loc - reg_pred[pos_inds, 0]
    end = pos_loc + reg_pred[pos_inds, 1]

    min_idx, max_idx = 0, obj_pred.numel() - 1
    start = start.clamp(min=min_idx, max=max_idx).round().long()
    end = end.clamp(min=min_idx, max=max_idx).round().long()

    # nms
    keep = span_nms(start, end, pr_score)
    start = start[keep]
    end = end[keep]
    pr_score = pr_score[keep]
    pr_label = pr_label[keep]
    pr_eff = pr_eff[keep]
    pr_eff_score = pr_eff_score[keep]
    eff_pred_score = eff_pred_score[keep]
    
    res = dict(id=data['text_id'],
                start=to_np(start),
                end=to_np(end),
                score_discourse_type=to_np(pr_score),
                discourse_type=to_np(pr_label),
                discourse_effectiveness = to_np(pr_eff),
                score_discourse_effectiveness = to_np(pr_eff_score),
                score_discourse_effectiveness_0 = to_np(eff_pred_score[:,0]),
                score_discourse_effectiveness_1 = to_np(eff_pred_score[:,1]),
                score_discourse_effectiveness_2 = to_np(eff_pred_score[:,2]),
               )
    res = pd.DataFrame(res).sort_values('start').reset_index(drop=True)
    res['predictionstring'] = res.apply(get_pred(' '),axis=1)
    return res

In [138]:
1+1

2

In [35]:
net = FeedbackModel(model_name,1+2+7+3)

Some weights of the model checkpoint at microsoft/deberta-v3-large were not used when initializing DebertaV2Model: ['lm_predictions.lm_head.bias', 'lm_predictions.lm_head.LayerNorm.weight', 'lm_predictions.lm_head.dense.weight', 'mask_predictions.LayerNorm.weight', 'lm_predictions.lm_head.LayerNorm.bias', 'mask_predictions.LayerNorm.bias', 'mask_predictions.classifier.weight', 'lm_predictions.lm_head.dense.bias', 'mask_predictions.dense.weight', 'mask_predictions.classifier.bias', 'mask_predictions.dense.bias']
- This IS expected if you are initializing DebertaV2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaV2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [36]:
rd = random.choice(range(len(ds)))
obj_pred, reg_pred, tpe_pred,eff_pred = net(ds[rd])

torch.Size([1, 469, 13])
torch.Size([400, 13])


In [37]:
obj_pred.shape , reg_pred.shape , tpe_pred.shape ,eff_pred.shape 

(torch.Size([400]),
 torch.Size([400, 2]),
 torch.Size([400, 7]),
 torch.Size([400, 3]))

In [40]:
obj_target, reg_target, cls_target,eff_target, pos_loc = build_target(
            ds[rd]['gt_spans'], obj_pred, reg_pred, tpe_pred,eff_pred)

In [43]:
obj_target.shape, reg_target.shape, cls_target.shape,eff_target.shape,pos_loc.shape

(torch.Size([400]),
 torch.Size([400, 2]),
 torch.Size([400, 7]),
 torch.Size([400, 3]),
 torch.Size([9]))

In [47]:
import torch.nn.functional as F

def get_losses(obj_pred, reg_pred, cls_pred,eff_pred, obj_target, reg_target,
               cls_target,eff_target, pos_loc):
    
    num_total_samples = pos_loc.numel()
    assert num_total_samples > 0
    reg_pred = reg_pred[pos_loc].exp()
    reg_target = reg_target[pos_loc]
    px1 = pos_loc - reg_pred[:, 0]
    px2 = pos_loc + reg_pred[:, 1]
    gx1 = reg_target[:, 0]
    gx2 = reg_target[:, 1]
    ix1 = torch.max(px1, gx1)
    ix2 = torch.min(px2, gx2)
    ux1 = torch.min(px1, gx1)
    ux2 = torch.max(px2, gx2)
    inter = (ix2 - ix1).clamp(min=0)
    union = (ux2 - ux1).clamp(min=0) + 1e-12
    iou = inter / union

    reg_loss = -iou.log().sum() / num_total_samples
    
    cls_loss = F.binary_cross_entropy_with_logits(
        cls_pred[pos_loc],
        cls_target[pos_loc] * iou.detach().reshape(-1, 1),
        reduction='sum') / num_total_samples
    
    eff_loss = F.binary_cross_entropy_with_logits(
        eff_pred[pos_loc],
        eff_target[pos_loc] * iou.detach().reshape(-1, 1),
        reduction='sum') / num_total_samples
    
    obj_loss = F.binary_cross_entropy_with_logits(
        obj_pred, obj_target, reduction='sum') / num_total_samples
    
    return obj_loss, reg_loss, cls_loss,eff_loss

In [48]:
obj_loss, reg_loss, cls_loss,eff_loss = get_losses(obj_pred, reg_pred,
                                                       tpe_pred,eff_pred, obj_target,
                                                       reg_target, cls_target,eff_target,
                                                       pos_loc)

In [50]:
obj_loss, reg_loss, cls_loss,eff_loss

(tensor(34.3420, grad_fn=<DivBackward0>),
 tensor(2.9669, grad_fn=<DivBackward0>),
 tensor(5.7624, grad_fn=<DivBackward0>),
 tensor(2.1812, grad_fn=<DivBackward0>))

In [77]:
losses = []

for i in range(3):
    loss = obj_loss + reg_loss + cls_loss + eff_loss
    losses.append([[loss.item(),obj_loss.item(),reg_loss.item(),cls_loss.item(),eff_loss.item()]])

In [78]:
losses

[[[45.25254821777344,
   34.3420295715332,
   2.966911554336548,
   5.762429714202881,
   2.181180000305176]],
 [[45.25254821777344,
   34.3420295715332,
   2.966911554336548,
   5.762429714202881,
   2.181180000305176]],
 [[45.25254821777344,
   34.3420295715332,
   2.966911554336548,
   5.762429714202881,
   2.181180000305176]]]

In [79]:
np.concatenate(losses,axis=0).shape

(3, 5)

In [80]:
np.concatenate(losses,axis=0).mean(0).shape

(5,)

In [149]:
res = predict(net,ds[rd], test_score_thr=0.5)

torch.Size([1, 469, 13])
torch.Size([400, 13])


In [151]:
res

Unnamed: 0,id,start,end,score_discourse_type,discourse_type,discourse_effectiveness,score_discourse_effectiveness,score_discourse_effectiveness_0,score_discourse_effectiveness_1,score_discourse_effectiveness_2,predictionstring
0,963859B3739B,0,2,0.5462,5,1,0.4449,0.2879,0.4449,0.4072,0 1
1,963859B3739B,0,1,0.5177,1,2,0.4795,0.3176,0.4785,0.4795,0
2,963859B3739B,1,3,0.6212,2,1,0.666,0.5041,0.666,0.525,1 2
3,963859B3739B,2,4,0.5943,5,1,0.5568,0.3504,0.5568,0.473,2 3
4,963859B3739B,3,5,0.6293,1,1,0.5495,0.4351,0.5495,0.4993,3 4
5,963859B3739B,5,7,0.6291,5,1,0.44,0.2866,0.44,0.4341,5 6
6,963859B3739B,7,9,0.5378,1,1,0.5421,0.4168,0.5421,0.5143,7 8
7,963859B3739B,8,10,0.6147,1,2,0.545,0.3812,0.545,0.496,8 9
8,963859B3739B,9,12,0.6767,5,1,0.5551,0.2783,0.5551,0.4507,9 10 11
9,963859B3739B,11,13,0.602,4,1,0.6605,0.3783,0.6605,0.4939,11 12


In [88]:
def get_pred(col):
    def row_wise(row):
        return " ".join([str(x) for x in range(row.start,row.end)])
    return row_wise

In [111]:
gt = pd.DataFrame({
              "id":ds[rd]['text_id'],
              "start":ds[rd]["gt_spans"][:,0],
              "end":ds[rd]["gt_spans"][:,1],
              "discourse_type":ds[rd]["gt_spans"][:,2],
              "discourse_effectiveness":ds[rd]["gt_spans"][:,3],
             })
gt['predictionstring'] = gt.apply(get_pred('za'),axis=1)

In [129]:
gt

Unnamed: 0,id,start,end,discourse_type,discourse_effectiveness,predictionstring
0,963859B3739B,0,9,1,2,0 1 2 3 4 5 6 7 8
1,963859B3739B,16,31,2,2,16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
2,963859B3739B,32,42,2,0,32 33 34 35 36 37 38 39 40 41
3,963859B3739B,46,52,2,0,46 47 48 49 50 51
4,963859B3739B,63,76,2,0,63 64 65 66 67 68 69 70 71 72 73 74 75
5,963859B3739B,76,93,5,0,76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 9...
6,963859B3739B,93,102,2,0,93 94 95 96 97 98 99 100 101
7,963859B3739B,158,391,5,0,158 159 160 161 162 163 164 165 166 167 168 16...
8,963859B3739B,391,400,6,0,391 392 393 394 395 396 397 398 399


In [113]:
res_df.columns = ['id', 'start', 'end', 'score', 'discourse_type', 'discourse_effectiveness', 'predictionstring']

In [114]:
res_df

Unnamed: 0,id,start,end,score,discourse_type,discourse_effectiveness,predictionstring
0,963859B3739B,257,259,0.7544,1,1,257 258
1,963859B3739B,111,113,0.743,1,1,111 112
2,963859B3739B,262,265,0.7394,1,1,262 263 264
3,963859B3739B,93,96,0.7365,1,1,93 94 95
4,963859B3739B,151,153,0.7351,1,1,151 152
5,963859B3739B,123,125,0.7343,1,1,123 124
6,963859B3739B,181,184,0.731,1,1,181 182 183
7,963859B3739B,274,277,0.725,1,1,274 275 276
8,963859B3739B,36,39,0.7208,0,1,36 37 38
9,963859B3739B,21,23,0.7198,4,1,21 22


In [197]:
import pandas as pd
import numpy as np

def _tp_score_effectiveness(col=""):
    def row_wise(row):
        gt_eff = row.discourse_effectiveness # ('Adequate' = 0, 'Effective' = 1, 'Ineffective' = 2)
        return row[f'score_discourse_effectiveness_{gt_eff}']
    return row_wise

def calc_overlap(set_pred, set_gt,threshold=0.5):
    """
    Calculates if the overlap between prediction and
    ground truth is enough fora potential True positive
    """
    # Length of each and intersection
    try:
        len_gt = len(set_gt)
        len_pred = len(set_pred)
        inter = len(set_gt & set_pred)
        overlap_1 = inter / len_gt
        overlap_2 = inter/ len_pred
        return overlap_1 >= threshold and overlap_2 >= threshold,max(overlap_1,overlap_2)
    except:  # at least one of the input is NaN
        return False,0

def score_feedback_comp_micro(pred_df, gt_df, discourse_type,threshold=0.5,weight_tp_segment=0.5):
    """
    A function that scores for the kaggle
        Student Writing Competition
        
    Uses the steps in the evaluation page here:
        https://www.kaggle.com/c/feedback-prize-2021/overview/evaluation
    """
    gt_df = gt_df.loc[gt_df['discourse_type'] == discourse_type, 
                      ['id', 'predictionstring','discourse_effectiveness']].reset_index(drop=True)
    pred_df = pred_df.loc[pred_df['discourse_type'] == discourse_type,
                      ['id', 'predictionstring','score_discourse_effectiveness_0',
                                               'score_discourse_effectiveness_1',
                                               'score_discourse_effectiveness_2']].reset_index(drop=True)
    pred_df['pred_id'] = pred_df.index
    gt_df['gt_id'] = gt_df.index
    pred_df['predictionstring'] = [set(pred.split(' ')) for pred in pred_df['predictionstring']]
    gt_df['predictionstring'] = [set(pred.split(' ')) for pred in gt_df['predictionstring']]
    
    # Step 1. all ground truths and predictions for a given class are compared.
    joined = pred_df.merge(gt_df,
                           left_on='id',
                           right_on='id',
                           how='outer',
                           suffixes=('_pred','_gt')
                          )
    
    overlaps = [calc_overlap(*args,threshold=threshold) for args in zip(joined.predictionstring_pred, 
                                                     joined.predictionstring_gt)]
    
    joined['overlaps'] = np.asarray([x[0] for x in overlaps])*1
    joined['overlaps_scores'] = np.asarray([x[1] for x in overlaps])*1
    joined['effectiveness_TP_score'] = joined.apply(_tp_score_effectiveness(' '),axis=1)
    joined['1_effectiveness_TP_score'] = 1-joined['effectiveness_TP_score']
    # 2. If the overlap between the ground truth and prediction is >= 0.5, 
    # and the overlap between the prediction and the ground truth >= 0.5,
    # the prediction is a match and considered a true positive.
    # If multiple matches exist, the match with the highest pair of overlaps is taken.
    # we don't need to compute the match to compute the score
    
    joined = joined.sort_values(["overlaps",'overlaps_scores'],ascending=False).reset_index(drop=True).groupby('gt_id').head(1)
    
    TP = joined[joined.overlaps==1]['gt_id'].nunique()
    
    TP_weighted = weight_tp_segment*TP + (1-weight_tp_segment)*(joined[joined.overlaps==1]['effectiveness_TP_score'].sum())
    
    # 3. Any unmatched ground truths are false negatives
    # and any unmatched predictions are false positives.
    TPandFP = len(pred_df)
    TPandFN = len(gt_df)
    
    TPandFP_weighted = TP_weighted + (joined[joined.overlaps==0]['1_effectiveness_TP_score'].sum())
    
    #calc microf1
    f1_score_fb1 = 2*TP / (TPandFP + TPandFN)
    new_f1_score = 2*TP_weighted / (TPandFP_weighted + TPandFN)
    return f1_score_fb1,new_f1_score

def score_feedback_comp(pred_df, gt_df,threshold=0.5, weight_tp_segment=0.5,return_class_scores=False):
    class_scores_fb1 = {}
    new_class_scores = {}
    for discourse_type in gt_df.discourse_type.unique():
        s_fb1,s = score_feedback_comp_micro(pred_df, gt_df, discourse_type,threshold,weight_tp_segment)
        class_scores_fb1[discourse_type] = s_fb1
        new_class_scores[discourse_type] = s
    
    f1_fb1 = np.mean([v for v in class_scores_fb1.values()])
    new_f1 = np.mean([v for v in new_class_scores.values()])
    if return_class_scores:
        return f1_fb1,class_scores_fb1,new_f1, new_class_scores
    return f1_fb1,new_f1

In [202]:
gt = pd.DataFrame({'id':'963859B3739B',
                   "discourse_type":[1, 2, 2, 2, 2, 5, 2, 5, 6],
                   "discourse_effectiveness":[2,2,0,0,0,0,0,0,0],
                   "predictionstring":['0 1 2 3 4 5 6 7 8',
                                       '16 17 18 19 20 21 22 23 24 25 26 27 28 29 30',
                                       '32 33 34 35 36 37 38 39 40 41', '46 47 48 49 50 51',
                                       '63 64 65 66 67 68 69 70 71 72 73 74 75',
                                       '76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92',
                                       '93 94 95 96 97 98 99 100 101',
                                       '158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390',
                                       '391 392 393 394 395 396 397 398 399']
                  })
gt

Unnamed: 0,id,discourse_type,discourse_effectiveness,predictionstring
0,963859B3739B,1,2,0 1 2 3 4 5 6 7 8
1,963859B3739B,2,2,16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
2,963859B3739B,2,0,32 33 34 35 36 37 38 39 40 41
3,963859B3739B,2,0,46 47 48 49 50 51
4,963859B3739B,2,0,63 64 65 66 67 68 69 70 71 72 73 74 75
5,963859B3739B,5,0,76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 9...
6,963859B3739B,2,0,93 94 95 96 97 98 99 100 101
7,963859B3739B,5,0,158 159 160 161 162 163 164 165 166 167 168 16...
8,963859B3739B,6,0,391 392 393 394 395 396 397 398 399


In [207]:
pred_df = gt.copy()
pred_df[['score_discourse_effectiveness_0',
         'score_discourse_effectiveness_1',
         'score_discourse_effectiveness_2']] = 1

In [208]:
pred_df

Unnamed: 0,id,discourse_type,discourse_effectiveness,predictionstring,score_discourse_effectiveness_0,score_discourse_effectiveness_1,score_discourse_effectiveness_2
0,963859B3739B,1,2,0 1 2 3 4 5 6 7 8,1,1,1
1,963859B3739B,2,2,16 17 18 19 20 21 22 23 24 25 26 27 28 29 30,1,1,1
2,963859B3739B,2,0,32 33 34 35 36 37 38 39 40 41,1,1,1
3,963859B3739B,2,0,46 47 48 49 50 51,1,1,1
4,963859B3739B,2,0,63 64 65 66 67 68 69 70 71 72 73 74 75,1,1,1
5,963859B3739B,5,0,76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 9...,1,1,1
6,963859B3739B,2,0,93 94 95 96 97 98 99 100 101,1,1,1
7,963859B3739B,5,0,158 159 160 161 162 163 164 165 166 167 168 16...,1,1,1
8,963859B3739B,6,0,391 392 393 394 395 396 397 398 399,1,1,1


In [210]:
f1_fb1 , new_f1 = score_feedback_comp(pred_df, gt,threshold=0.5, weight_tp_segment=0.5, return_class_scores=False)

In [212]:
f1_fb1 , new_f1

(1.0, 1.0)

In [192]:
res

Unnamed: 0,id,start,end,score_discourse_type,discourse_type,discourse_effectiveness,score_discourse_effectiveness,score_discourse_effectiveness_0,score_discourse_effectiveness_1,score_discourse_effectiveness_2,predictionstring
0,963859B3739B,0,2,0.5462,5,1,0.4449,1,1,1,0 1
1,963859B3739B,0,1,0.5177,1,2,0.4795,1,1,1,0
2,963859B3739B,1,3,0.6212,2,1,0.666,1,1,1,1 2
3,963859B3739B,2,4,0.5943,5,1,0.5568,1,1,1,2 3
4,963859B3739B,3,5,0.6293,1,1,0.5495,1,1,1,3 4
5,963859B3739B,5,7,0.6291,5,1,0.44,1,1,1,5 6
6,963859B3739B,7,9,0.5378,1,1,0.5421,1,1,1,7 8
7,963859B3739B,8,10,0.6147,1,2,0.545,1,1,1,8 9
8,963859B3739B,9,12,0.6767,5,1,0.5551,1,1,1,9 10 11
9,963859B3739B,11,13,0.602,4,1,0.6605,1,1,1,11 12


In [186]:
res[['score_discourse_effectiveness_0','score_discourse_effectiveness_1','score_discourse_effectiveness_2']] = 1

In [94]:
visualize(ds[rd]['text_id'],ds.df)

In [None]:
class TextSpanDetector(BaseModule):
    def __init__(self,
                 arch,
                 num_classes=7,
                 dynamic_positive=False,
                 with_cp=False,
                 local_files_only=True,
                 init_cfg=None):
        super().__init__(init_cfg)

        self.num_classes = num_classes
        self.dynamic_positive = dynamic_positive
        self.model = AutoModelForTokenClassification.from_pretrained(
            arch,
            num_labels=1 + 2 + num_classes,
            local_files_only=local_files_only)
        if with_cp:
            self.model.gradient_checkpointing_enable()

        self.tokenizer = AutoTokenizer.from_pretrained(
            arch, local_files_only=local_files_only)

        # init bias
        self.model.classifier.bias.data[0].fill_(bias_init_with_prob(0.02))
        self.model.classifier.bias.data[3:].fill_(
            bias_init_with_prob(1 / self.num_classes))

    def forward_logits(self, data):
        batch_size = data['input_ids'].size(0)
        assert batch_size == 1, f'Only batch_size=1 supported, got batch_size={batch_size}.'
        outputs = self.model(input_ids=data['input_ids'],
                             attention_mask=data['attention_mask'])
        logits = outputs['logits']
        logits = aggregate_tokens_to_words(logits, data['word_boxes'])
        assert logits.size(0) == data['text'].split().__len__()

        obj_pred = logits[..., 0]
        reg_pred = logits[..., 1:3]
        cls_pred = logits[..., 3:]
        return obj_pred, reg_pred, cls_pred

    def predict(self, data, test_score_thr):
        data = to_gpu(data)
        obj_pred, reg_pred, cls_pred = self.forward_logits(data)
        obj_pred = obj_pred.sigmoid()
        reg_pred = reg_pred.exp()
        cls_pred = cls_pred.sigmoid()

        obj_scores = obj_pred
        cls_scores, cls_labels = cls_pred.max(-1)
        pr_scores = (obj_scores * cls_scores)**0.5
        pos_inds = pr_scores > test_score_thr

        if pos_inds.sum() == 0:
            return None

        pr_score, pr_label = pr_scores[pos_inds], cls_labels[pos_inds]
        pos_loc = pos_inds.nonzero().flatten()
        start = pos_loc - reg_pred[pos_inds, 0]
        end = pos_loc + reg_pred[pos_inds, 1]

        min_idx, max_idx = 0, obj_pred.numel() - 1
        start = start.clamp(min=min_idx, max=max_idx).round().long()
        end = end.clamp(min=min_idx, max=max_idx).round().long()

        # nms
        keep = span_nms(start, end, pr_score)
        start = start[keep]
        end = end[keep]
        pr_score = pr_score[keep]
        pr_label = pr_label[keep]

        return dict(text_id=data['text_id'],
                    start=to_np(start),
                    end=to_np(end),
                    score=to_np(pr_score),
                    label=to_np(pr_label))

    def train_step(self, data, optimizer, **kwargs):
        data = to_gpu(data)
        obj_pred, reg_pred, cls_pred = self.forward_logits(data)
        obj_target, reg_target, cls_target, pos_loc = self.build_target(
            data['gt_spans'], obj_pred, reg_pred, cls_pred)

        obj_loss, reg_loss, cls_loss = self.get_losses(obj_pred, reg_pred,
                                                       cls_pred, obj_target,
                                                       reg_target, cls_target,
                                                       pos_loc)
        loss = obj_loss + reg_loss + cls_loss
        log_vars = dict(
            obj_loss=obj_loss.item(),
            reg_loss=reg_loss.item(),
            cls_loss=cls_loss.item(),
            loss=loss.item(),
        )
        outputs = dict(loss=loss, log_vars=log_vars, num_samples=1)

        return outputs

    def get_losses(self, obj_pred, reg_pred, cls_pred, obj_target, reg_target,
                   cls_target, pos_loc):
        num_total_samples = pos_loc.numel()
        assert num_total_samples > 0
        reg_pred = reg_pred[pos_loc].exp()
        reg_target = reg_target[pos_loc]
        px1 = pos_loc - reg_pred[:, 0]
        px2 = pos_loc + reg_pred[:, 1]
        gx1 = reg_target[:, 0]
        gx2 = reg_target[:, 1]
        ix1 = torch.max(px1, gx1)
        ix2 = torch.min(px2, gx2)
        ux1 = torch.min(px1, gx1)
        ux2 = torch.max(px2, gx2)
        inter = (ix2 - ix1).clamp(min=0)
        union = (ux2 - ux1).clamp(min=0) + 1e-12
        iou = inter / union

        reg_loss = -iou.log().sum() / num_total_samples
        cls_loss = F.binary_cross_entropy_with_logits(
            cls_pred[pos_loc],
            cls_target[pos_loc] * iou.detach().reshape(-1, 1),
            reduction='sum') / num_total_samples
        obj_loss = F.binary_cross_entropy_with_logits(
            obj_pred, obj_target, reduction='sum') / num_total_samples
        return obj_loss, reg_loss, cls_loss

    @torch.no_grad()
    def build_target(self, gt_spans, obj_pred, reg_pred, cls_pred,eff_pred):
        
        obj_target = torch.zeros_like(obj_pred)
        reg_target = torch.zeros_like(reg_pred)
        cls_target = torch.zeros_like(cls_pred)
        eff_target = torch.zeros_like(eff_pred)
        
        # first token as positive
        pos_loc = gt_spans[:, 0]
        obj_target[pos_loc] = 1
        reg_target[pos_loc, 0] = gt_spans[:, 0].float()
        reg_target[pos_loc, 1] = gt_spans[:, 1].float()
        cls_target[pos_loc, gt_spans[:, 2]] = 1
        eff_target[pos_loc, gt_spans[:, 3]] = 1
        
        # dynamically assign one more positive
        if self.dynamic_positive:
            cls_prob = (obj_pred.sigmoid().unsqueeze(1) *
                        cls_pred.sigmoid()).sqrt()
            
            eff_prob = (obj_pred.sigmoid().unsqueeze(1) *
                        eff_pred.sigmoid()).sqrt()
            
            for start, end, label,eff_label in gt_spans:
                _cls_prob = cls_prob[start:end]
                _cls_gt = _cls_prob.new_full((_cls_prob.size(0), ),
                                             label,
                                             dtype=torch.long)
                _cls_gt = F.one_hot(
                    _cls_gt, num_classes=_cls_prob.size(1)).type_as(_cls_prob)
                cls_cost = F.binary_cross_entropy(_cls_prob,
                                                  _cls_gt,
                                                  reduction='none').sum(-1)
                
                _eff_prob = eff_prob[start:end]
                _eff_gt = _eff_prob.new_full((_eff_prob.size(0), ),
                                             eff_label,
                                             dtype=torch.long)
                _eff_gt = F.one_hot(
                    _eff_gt, num_classes=_eff_prob.size(1)).type_as(_eff_prob)
                eff_cost = F.binary_cross_entropy(_eff_prob,
                                                  _eff_gt,
                                                  reduction='none').sum(-1)
                
                
                _reg_pred = reg_pred[start:end].exp()
                _reg_loc = torch.arange(_reg_pred.size(0),
                                        device=_reg_pred.device)
                
                px1 = _reg_loc - _reg_pred[:, 0]
                px2 = _reg_loc + _reg_pred[:, 1]
                ix1 = torch.max(px1, _reg_loc[0])
                ix2 = torch.min(px2, _reg_loc[-1])
                ux1 = torch.min(px1, _reg_loc[0])
                ux2 = torch.max(px2, _reg_loc[-1])
                inter = (ix2 - ix1).clamp(min=0)
                union = (ux2 - ux1).clamp(min=0) + 1e-12
                iou = inter / union
                iou_cost = -torch.log(iou + 1e-12)
                
                cost = cls_cost + eff_cost + iou_cost

                pos_ind = start + cost.argmin()
                
                obj_target[pos_ind] = 1
                reg_target[pos_ind, 0] = start
                reg_target[pos_ind, 1] = end
                cls_target[pos_ind, label] = 1
                eff_target[pos_ind, eff_label] = 1
                
            pos_loc = (obj_target == 1).nonzero().flatten()
        return obj_target, reg_target, cls_target,eff_target,pos_loc


In [None]:
# ====================================================
# Metric functions
# ====================================================

def calc_overlap(row):
    """
    Calculates the overlap between prediction and
    ground truth and overlap percentages used for determining
    true positives.
    """
    set_pred = set(row.predictionstring_pred.split(" "))
    set_gt = set(row.predictionstring_gt.split(" "))
    # Length of each and intersection
    len_gt = len(set_gt)
    len_pred = len(set_pred)
    inter = len(set_gt.intersection(set_pred))
    overlap_1 = inter / len_gt
    overlap_2 = inter / len_pred
    return [overlap_1, overlap_2]


def score_feedback_comp_micro(pred_df, gt_df):
    """
    A function that scores for the kaggle
        Student Writing Competition

    Uses the steps in the evaluation page here:
        https://www.kaggle.com/c/feedback-prize-2021/overview/evaluation
    """
    gt_df = gt_df[["id", "discourse_type", "predictionstring", 'Ineffective', 'Adequate', 'Effective']].reset_index(drop=True).copy()
    pred_df = pred_df[["id", "class", "predictionstring", 'Ineffective', 'Adequate', 'Effective']].reset_index(drop=True).copy()
    pred_df["pred_id"] = pred_df.index
    gt_df["gt_id"] = gt_df.index
    # Step 1. all ground truths and predictions for a given class are compared.
    joined = pred_df.merge(
        gt_df,
        left_on=["id", "class"],
        right_on=["id", "discourse_type"],
        how="outer",
        suffixes=("_pred", "_gt"),
    )
    
    joined["predictionstring_gt"] = joined["predictionstring_gt"].fillna(" ")
    joined["predictionstring_pred"] = joined["predictionstring_pred"].fillna(" ")

    joined["overlaps"] = joined.apply(calc_overlap, axis=1)

    joined.to_csv("new_joined.csv", index=False)

    
    # 2. If the overlap between the ground truth and prediction is >= 0.5,
    # and the overlap between the prediction and the ground truth >= 0.5,
    # the prediction is a match and considered a true positive.
    # If multiple matches exist, the match with the highest pair of overlaps is taken.
    joined["overlap1"] = joined["overlaps"].apply(lambda x: eval(str(x))[0])
    joined["overlap2"] = joined["overlaps"].apply(lambda x: eval(str(x))[1])

    joined["potential_TP"] = (joined["overlap1"] >= 0.5) & (joined["overlap2"] >= 0.5)
    # joined["potential_TP"] = (joined["overlap1"] >= 0.75) & (joined["overlap2"] >= 0.75)
    joined["max_overlap"] = joined[["overlap1", "overlap2"]].max(axis=1)
        
    tp_pred_ids = (
        joined.query("potential_TP")
        .sort_values("max_overlap", ascending=False)
        .groupby(["id", "predictionstring_gt"])
        .first()["pred_id"]
        .values
    )

    # 3. Any unmatched ground truths are false negatives
    # and any unmatched predictions are false positives.
    fp_pred_ids = [p for p in joined["pred_id"].unique() if p not in tp_pred_ids]
    matched_gt_ids = joined.query("potential_TP")["gt_id"].unique()
    unmatched_gt_ids = [c for c in joined["gt_id"].unique() if c not in matched_gt_ids]
    
    # Get numbers of each type
    TP = len(tp_pred_ids)
    FP = len(fp_pred_ids)
    FN = len(unmatched_gt_ids)
    # calc microf1
    my_f1_score = TP / (TP + 0.5 * (FP + FN))
    new_joined = joined.loc[(joined.pred_id.isin(tp_pred_ids) & joined.gt_id.isin(matched_gt_ids))]    
    
    fb2_score = 1.0
    if len(new_joined) > 0:
        fb2_score = np.mean([
            metrics.log_loss(new_joined['Ineffective_gt'].values, new_joined['Ineffective_pred'].values),
            metrics.log_loss(new_joined['Adequate_gt'].values, new_joined['Adequate_pred'].values),
            metrics.log_loss(new_joined['Effective_gt'].values, new_joined['Effective_pred'].values)
        ])
    
        return (my_f1_score, fb2_score)
    return (my_f1_score, fb2_score)


def score_feedback_comp(pred_df, gt_df, return_class_scores=False):
    class_scores = {}
    pred_df = pred_df[["id", "class", "predictionstring", 'Ineffective', 'Adequate', 'Effective']].reset_index(drop=True).copy()
    for discourse_type, gt_subset in gt_df.groupby("discourse_type"):
        pred_subset = pred_df.loc[pred_df["class"] == discourse_type].reset_index(drop=True).copy()
        class_score = score_feedback_comp_micro(pred_subset, gt_subset)
        class_scores[discourse_type] = class_score
        
    fb1 = np.mean([v[0] for v in class_scores.values()])
    fb2 = np.mean([v[1] for v in class_scores.values()])
    
    overall_f1 = 0.5*fb1 + 0.5*(1-fb2)
    
    if return_class_scores:
        return overall_f1, fb1, fb2, class_scores
    return overall_f1, fb1, fb2

In [8]:
DATA_PATH = Path(r"/database/kaggle/Shovel Ready/data")
CHECKPOINT_PATH = Path(r"/database/kaggle/Shovel Ready/checkpoint")

# Dataset

In [9]:
import torch
import random
import numpy as np
import pandas as pd
from pathlib import Path

from transformers import DataCollatorWithPadding
from transformers import AutoTokenizer, AutoModel, AutoConfig

%env TOKENIZERS_PARALLELISM = true


import warnings
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 500)
warnings.simplefilter("ignore")

env: TOKENIZERS_PARALLELISM=true


In [10]:
import re
from difflib import SequenceMatcher

import torch
import codecs
import os
from collections import Counter
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from text_unidecode import unidecode
from tqdm.notebook import tqdm

def replace_encoding_with_utf8(error: UnicodeError) -> Tuple[bytes, int]:
    return error.object[error.start : error.end].encode("utf-8"), error.end


def replace_decoding_with_cp1252(error: UnicodeError) -> Tuple[str, int]:
    return error.object[error.start : error.end].decode("cp1252"), error.end


# Register the encoding and decoding error handlers for `utf-8` and `cp1252`.
codecs.register_error("replace_encoding_with_utf8", replace_encoding_with_utf8)
codecs.register_error("replace_decoding_with_cp1252", replace_decoding_with_cp1252)



def resolve_encodings_and_normalize(text: str) -> str:
    """Resolve the encoding problems and normalize the abnormal characters."""
    text = (
        text.encode("raw_unicode_escape")
        .decode("utf-8", errors="replace_decoding_with_cp1252")
        .encode("cp1252", errors="replace_encoding_with_utf8")
        .decode("utf-8", errors="replace_decoding_with_cp1252")
    )
    text = unidecode(text)
    return text


def clean_text(text):
    text = resolve_encodings_and_normalize(text)
    text = text.replace(u'\xa0', u' ')
    text = text.replace(u'\x85', u'\n')
    text = text.strip()
    return text

def to_gpu(data,device):
    if isinstance(data, dict):
        return {k: to_gpu(v,device) for k, v in data.items()}
    elif isinstance(data, list):
        return [to_gpu(v,device) for v in data]
    elif isinstance(data, torch.Tensor):
        return data.to(device)
    else:
        return data
    
def to_np(data):
    if isinstance(data, dict):
        return {k: to_np(v) for k, v in data.items()}
    elif isinstance(data, list):
        return [to_np(v) for v in data]
    elif isinstance(data, torch.Tensor):
        return data.cpu().numpy()
    else:
        return data
# def to_np(t):
#     if isinstance(t, torch.Tensor):
#         return t.data.cpu().numpy()
#     else:
#         return t

def get_text_start_end(txt,s,search_from=0):
    txt = txt[int(search_from):]
    try:
        idx = txt.find(s)
        if idx>=0:
            st=idx
            ed = st+len(s)
        else:
            raise ValueError('Error')
    except:                
        res = [(m.start(0), m.end(0)) for m in re.finditer(s, txt)]
        if len(res):
            st,ed = res[0][0],res[0][1]
        else:
            m = SequenceMatcher(None, s,txt).get_opcodes()
            for tag,i1,i2,j1,j2 in m:
                if tag=='replace':
                    s = s[:i1]+txt[j1:j2]+s[i2:]
                if tag=="delete":
                    s = s[:i1]+s[i2:]
            
            res = [(m.start(0), m.end(0)) for m in re.finditer(s,txt)]
            if len(res):
                st,ed = res[0][0],res[0][1]
            else:
                idx = txt.find(s)
                if idx>=0:
                    st=idx
                    ed = st+len(s)
                else:
                    st,ed = 0,0
    return st+search_from,ed+search_from

def get_start_end(col):
    def search_start_end(row):
        txt = row.full_text
        search_from = row.previous_discourse_end
        s = row[col]
        # print(search_from)
        return get_text_start_end(txt,s,search_from)
    return search_start_end

In [11]:
import re
import torch
import random
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
# from data.data_utils import clean_text,get_start_end

from tqdm.auto import tqdm


LABEL2TYPE = ('Lead', 'Position', 'Claim', 'Counterclaim', 'Rebuttal',
              'Evidence', 'Concluding Statement')
TYPE2LABEL = {t: l for l, t in enumerate(LABEL2TYPE)}


LABEL2EFFEC = ('Adequate', 'Effective', 'Ineffective')
EFFEC2LABEL = {t: l for l, t in enumerate(LABEL2EFFEC)}

## =============================================================================== ##
class FeedbackDataset(Dataset):
    def __init__(self,
                 df,
                 tokenizer,
                 mask_prob=0.0,
                 mask_ratio=0.0,
                 ):
        self.df = self.prepare_df(df)
        self.samples = list(self.df.groupby('essay_id_comp'))
        self.tokenizer = tokenizer
        
        print(f'Loaded {len(self)} samples.')

        assert 0 <= mask_prob <= 1
        assert 0 <= mask_ratio <= 1
        self.mask_prob = mask_prob
        self.mask_ratio = mask_ratio

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        text_id, df = self.samples[index]
        text = df['full_text'].values[0]

        tokens = self.tokenizer(text, return_offsets_mapping=True)
        input_ids = torch.LongTensor(tokens['input_ids'])
        attention_mask = torch.LongTensor(tokens['attention_mask'])
        offset_mapping = np.array(tokens['offset_mapping'])
        offset_mapping = self.strip_offset_mapping(text, offset_mapping)
        num_tokens = len(input_ids)

        # token slices of words
        woff = self.get_word_offsets(text)
        toff = offset_mapping
        wx1, wx2 = woff.T
        tx1, tx2 = toff.T
        ix1 = np.maximum(wx1[..., None], tx1[None, ...])
        ix2 = np.minimum(wx2[..., None], tx2[None, ...])
        ux1 = np.minimum(wx1[..., None], tx1[None, ...])
        ux2 = np.maximum(wx2[..., None], tx2[None, ...])
        ious = (ix2 - ix1).clip(min=0) / (ux2 - ux1 + 1e-12)
        assert (ious > 0).any(-1).all()

        word_boxes = []
        for row in ious:
            inds = row.nonzero()[0]
            word_boxes.append([inds[0], 0, inds[-1] + 1, 1])
        word_boxes = torch.FloatTensor(word_boxes)

        # word slices of ground truth spans
        gt_spans = []
        for _, row in df.iterrows():
            
            word_start = len(row['full_text'][:row['discourse_start']].split())
            word_end = word_start + len(row['full_text'][row['discourse_start']:row['discourse_end']].split())
            word_end = min(word_end, len(row['full_text'].split()))
            
            span_label = TYPE2LABEL[row['discourse_type']]
            span_effectiveness = EFFEC2LABEL[row['discourse_effectiveness']]
            gt_spans.append([word_start,word_end,span_label,span_effectiveness])
            
        gt_spans = torch.LongTensor(gt_spans)

        # random mask augmentation
        if np.random.random() < self.mask_prob:
            all_inds = np.arange(1, len(input_ids) - 1)
            n_mask = max(int(len(all_inds) * self.mask_ratio), 1)
            np.random.shuffle(all_inds)
            mask_inds = all_inds[:n_mask]
            input_ids[mask_inds] = self.tokenizer.mask_token_id

        return dict(text=text,
                    text_id=text_id,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    word_boxes=word_boxes,
                    gt_spans=gt_spans)
    
    def prepare_df(self,test_df):
        test_df['full_text'] = test_df['full_text'].transform(clean_text)        
        test_df['discourse_text'] = test_df['discourse_text'].transform(clean_text)
        test_df['previous_discourse_end'] = 0
        test_df['st_ed'] = test_df.apply(get_start_end('discourse_text'),axis=1)
        test_df['discourse_start'] = test_df['st_ed'].transform(lambda x:x[0])
        test_df['discourse_end'] = test_df['st_ed'].transform(lambda x:x[1])
        
        test_df['previous_discourse_end'] = test_df.groupby("essay_id_comp")['discourse_end'].transform(lambda x:x.shift(1).fillna(0)).astype(int)
        test_df['st_ed'] = test_df.apply(get_start_end('discourse_text'),axis=1)
        
        test_df['discourse_start'] = test_df['st_ed'].transform(lambda x:x[0]) #+ test_df['previous_discourse_end']
        test_df['discourse_end'] = test_df['st_ed'].transform(lambda x:x[1]) #+ test_df['previous_discourse_end']
        return test_df
    
    def strip_offset_mapping(self, text, offset_mapping):
        ret = []
        for start, end in offset_mapping:
            match = list(re.finditer('\S+', text[start:end]))
            if len(match) == 0:
                ret.append((start, end))
            else:
                span_start, span_end = match[0].span()
                ret.append((start + span_start, start + span_end))
        return np.array(ret)

    def get_word_offsets(self, text):
        matches = re.finditer("\S+", text)
        spans = []
        words = []
        for match in matches:
            span = match.span()
            word = match.group()
            spans.append(span)
            words.append(word)
        assert tuple(words) == tuple(text.split())
        return np.array(spans)
    
## =============================================================================== ##
class CustomCollator(object):
    def __init__(self, tokenizer, model):
        self.pad_token_id = tokenizer.pad_token_id
        if hasattr(model.config, 'attention_window'):
            # For longformer
            # https://github.com/huggingface/transformers/blob/v4.17.0/src/transformers/models/longformer/modeling_longformer.py#L1548
            self.attention_window = (model.config.attention_window
                                     if isinstance(
                                         model.config.attention_window, int)
                                     else max(model.config.attention_window))
        else:
            self.attention_window = None

    def __call__(self, samples):
        batch_size = len(samples)
        assert batch_size == 1, f'Only batch_size=1 supported, got batch_size={batch_size}.'

        sample = samples[0]

        max_seq_length = len(sample['input_ids'])
        if self.attention_window is not None:
            attention_window = self.attention_window
            padded_length = (attention_window -
                             max_seq_length % attention_window
                             ) % attention_window + max_seq_length
        else:
            padded_length = max_seq_length

        input_shape = (1, padded_length)
        input_ids = torch.full(input_shape,
                               self.pad_token_id,
                               dtype=torch.long)
        attention_mask = torch.zeros(input_shape, dtype=torch.long)

        seq_length = len(sample['input_ids'])
        input_ids[0, :seq_length] = sample['input_ids']
        attention_mask[0, :seq_length] = sample['attention_mask']

        text_id = sample['text_id']
        text = sample['text']
        word_boxes = sample['word_boxes']
        gt_spans = sample['gt_spans']

        return dict(text_id=text_id,
                    text=text,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    word_boxes=word_boxes,
                    gt_spans=gt_spans)

In [12]:
model_name = 'microsoft/deberta-large'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [13]:
df = pd.read_csv(data_path/'persuade_corpus.csv')
LABEL2EFFEC = ('Adequate', 'Effective', 'Ineffective')
EFFEC2LABEL = {t: l for l, t in enumerate(LABEL2EFFEC)}

df_train = df[df.competition_set=='train']
df_train = df_train[df_train.discourse_effectiveness.isin(LABEL2EFFEC)]
df_train['fold'] = 1

df_valid = df[df.test_split_feedback_2=='Public']
df_valid = df_valid[df_valid.discourse_effectiveness.isin(LABEL2EFFEC)]
df_valid['fold'] = 0

In [14]:
ds = FeedbackDataset(df_train,
                 tokenizer)

Loaded 15594 samples.


In [15]:
from torch.utils.data import DataLoader
from transformers import AutoModelForTokenClassification, AutoTokenizer

In [16]:
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel
import torch.utils.checkpoint
import torch.nn.functional as F
import gc

# from mmcv.cnn import bias_init_with_prob
from torchvision.ops import roi_align, nms

def aggregate_tokens_to_words(feat, word_boxes):
    feat = feat.permute(0, 2, 1).unsqueeze(2)
    output = roi_align(feat, [word_boxes], 1, aligned=True)
    return output.squeeze(-1).squeeze(-1)


def span_nms(start, end, score, nms_thr=0.5):
    boxes = torch.stack(
        [
            start,
            torch.zeros_like(start),
            end,
            torch.ones_like(start),
        ],
        dim=1,
    ).float()
    keep = nms(boxes, score, nms_thr)
    return keep

class FeedbackModel(nn.Module):
    def __init__(self,
                 model_name,
                 config_path = None
                 ):
        super().__init__()
#         self.pretrained_path = pretrained_path
        self.config = AutoConfig.from_pretrained(model_name, output_hidden_states=True) if not config_path else torch.load(config_path)

#         self.use_dropout = use_dropout
#         if not self.use_dropout:
#             self.config.update(
#                                 {
#                                     "hidden_dropout_prob": 0.0,
#                                     "attention_probs_dropout_prob": 0.0,
#                                 }
#                                     )

        self.model =AutoModelForTokenClassification.from_pretrained(
            model_name,
            num_labels=1 + 2 + 7,
            local_files_only=False)
        
#         if use_gradient_checkpointing:
#             self.backbone.gradient_checkpointing_enable()

        
#         self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
#         self.fc = nn.Linear(self.config.hidden_size, 1+2+num_label_discourse_type+num_label_effectiveness)
        
        # self.fc.bias.data[0].fill_(bias_init_with_prob(0.02))
        # self.fc.bias.data[3:-3].fill_(bias_init_with_prob(1 / num_label_discourse_type))
        # self.fc.bias.data[-3:].fill_(bias_init_with_prob(1 / num_label_effectiveness))


    def forward(self,b):
        x = self.model(b["input_ids"],b["attention_mask"])['logits']
#         x = self.dropout(x)
#         x = self.fc(x)
        x = aggregate_tokens_to_words(x, b['word_boxes'])
        obj_pred = x[..., 0]
        reg_pred = x[..., 1:3]
        type_pred = x[..., 3:]
        eff_pred = x[..., -3:]
        return obj_pred, reg_pred, type_pred,eff_pred

In [17]:
net = FeedbackModel(model_name)

Some weights of the model checkpoint at microsoft/deberta-large were not used when initializing DebertaForTokenClassification: ['deberta.embeddings.position_embeddings.weight', 'lm_predictions.lm_head.LayerNorm.weight', 'lm_predictions.lm_head.LayerNorm.bias', 'lm_predictions.lm_head.dense.bias', 'lm_predictions.lm_head.dense.weight', 'lm_predictions.lm_head.bias']
- This IS expected if you are initializing DebertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DebertaForTokenClassification were not initialized from the model checkpoint at microsoft/deberta-large and are newly initi

In [18]:
os.listdir(CHECKPOINT_PATH)

['deberta_large_fold0.pth',
 'fold',
 'fprize_microsoft_deberta-large_fold7_epoch_02_iov_v2_val_0.7041_20220311101354.pth',
 'deberta_large_fold0.pth.zip']

In [19]:
pretrained_path = CHECKPOINT_PATH/"deberta_large_fold0.pth"

In [20]:
net.load_state_dict(torch.load(pretrained_path, map_location=lambda storage, loc: storage)['state_dict'])

<All keys matched successfully>

In [21]:
collator = CustomCollator(tokenizer,net)
dl = DataLoader(ds,batch_size=1,collate_fn=collator,  drop_last =  False,
  num_workers =  16,
  pin_memory =  False,
  shuffle =  False)

In [22]:
# # ----------------- One Step --------------------- #
@torch.no_grad()
def build_target(gt_spans, obj_pred, reg_pred, cls_pred,eff_pred,dynamic_positive=False):

    obj_target = torch.zeros_like(obj_pred)
    reg_target = torch.zeros_like(reg_pred).float()
    cls_target = torch.zeros_like(cls_pred)
    eff_target = torch.zeros_like(eff_pred)

    # first token as positive
    pos_loc = gt_spans[:, 0]
    obj_target[pos_loc] = 1
    reg_target[pos_loc, 0] = gt_spans[:, 0].float()
    reg_target[pos_loc, 1] = gt_spans[:, 1].float()
    cls_target[pos_loc, gt_spans[:, 2]] = 1
    eff_target[pos_loc, gt_spans[:, 3]] = 1

    # dynamically assign one more positive
    if dynamic_positive:
        cls_prob = (obj_pred.sigmoid().unsqueeze(1) *
                    cls_pred.sigmoid()).sqrt()

        eff_prob = (obj_pred.sigmoid().unsqueeze(1) *
                    eff_pred.sigmoid()).sqrt()

        for start, end, label,eff_label in gt_spans:
            _cls_prob = cls_prob[start:end]
            _cls_gt = _cls_prob.new_full((_cls_prob.size(0), ),
                                         label,
                                         dtype=torch.long)
            _cls_gt = F.one_hot(
                _cls_gt, num_classes=_cls_prob.size(1)).type_as(_cls_prob)
            cls_cost = F.binary_cross_entropy(_cls_prob,
                                              _cls_gt,
                                              reduction='none').sum(-1)

            _eff_prob = eff_prob[start:end]
            _eff_gt = _eff_prob.new_full((_eff_prob.size(0), ),
                                         eff_label,
                                         dtype=torch.long)
            _eff_gt = F.one_hot(
                _eff_gt, num_classes=_eff_prob.size(1)).type_as(_eff_prob)
            eff_cost = F.binary_cross_entropy(_eff_prob,
                                              _eff_gt,
                                              reduction='none').sum(-1)


            _reg_pred = reg_pred[start:end].exp()
            _reg_loc = torch.arange(_reg_pred.size(0),
                                    device=_reg_pred.device)

            px1 = _reg_loc - _reg_pred[:, 0]
            px2 = _reg_loc + _reg_pred[:, 1]
            ix1 = torch.max(px1, _reg_loc[0])
            ix2 = torch.min(px2, _reg_loc[-1])
            ux1 = torch.min(px1, _reg_loc[0])
            ux2 = torch.max(px2, _reg_loc[-1])
            inter = (ix2 - ix1).clamp(min=0)
            union = (ux2 - ux1).clamp(min=0) + 1e-12
            iou = inter / union
            iou_cost = -torch.log(iou + 1e-12)

            cost = cls_cost + eff_cost + iou_cost

            pos_ind = start + cost.argmin()

            obj_target[pos_ind] = 1
            reg_target[pos_ind, 0] = start
            reg_target[pos_ind, 1] = end
            cls_target[pos_ind, label] = 1
            eff_target[pos_ind, eff_label] = 1

        pos_loc = (obj_target == 1).nonzero().flatten()
        
    return obj_target, reg_target,cls_target,eff_target,pos_loc

In [23]:
import pandas as pd
import numpy as np

def _tp_score_effectiveness(col=""):
    def row_wise(row):
        gt_eff = row.discourse_effectiveness # ('Adequate' = 0, 'Effective' = 1, 'Ineffective' = 2)
        return row[f'score_discourse_effectiveness_{gt_eff}']
    return row_wise

def calc_overlap(set_pred, set_gt,threshold=0.5):
    """
    Calculates if the overlap between prediction and
    ground truth is enough fora potential True positive
    """
    # Length of each and intersection
    try:
        len_gt = len(set_gt)
        len_pred = len(set_pred)
        inter = len(set_gt & set_pred)
        overlap_1 = inter / len_gt
        overlap_2 = inter/ len_pred
        return overlap_1 >= threshold and overlap_2 >= threshold,max(overlap_1,overlap_2)
    except:  # at least one of the input is NaN
        return False,0

def score_feedback_comp_micro(pred_df, gt_df, discourse_type,threshold=0.5,weight_tp_segment=0.5):
    """
    A function that scores for the kaggle
        Student Writing Competition
        
    Uses the steps in the evaluation page here:
        https://www.kaggle.com/c/feedback-prize-2021/overview/evaluation
    """
    gt_df = gt_df.loc[gt_df['discourse_type'] == discourse_type, 
                      ['id', 'predictionstring','discourse_effectiveness']].reset_index(drop=True)
    pred_df = pred_df.loc[pred_df['discourse_type'] == discourse_type,
                      ['id', 'predictionstring','score_discourse_effectiveness_0',
                                               'score_discourse_effectiveness_1',
                                               'score_discourse_effectiveness_2']].reset_index(drop=True)
    pred_df['pred_id'] = pred_df.index
    gt_df['gt_id'] = gt_df.index
    pred_df['predictionstring'] = [set(pred.split(' ')) for pred in pred_df['predictionstring']]
    gt_df['predictionstring'] = [set(pred.split(' ')) for pred in gt_df['predictionstring']]
    
    # Step 1. all ground truths and predictions for a given class are compared.
    joined = pred_df.merge(gt_df,
                           left_on='id',
                           right_on='id',
                           how='outer',
                           suffixes=('_pred','_gt')
                          )
    
    overlaps = [calc_overlap(*args,threshold=threshold) for args in zip(joined.predictionstring_pred, 
                                                     joined.predictionstring_gt)]
    
    joined['overlaps'] = np.asarray([x[0] for x in overlaps])*1
    joined['overlaps_scores'] = np.asarray([x[1] for x in overlaps])*1
    joined["discourse_effectiveness"] = joined["discourse_effectiveness"].fillna(0).astype(int) 
    joined['effectiveness_TP_score'] = joined.apply(_tp_score_effectiveness(' '),axis=1)
    joined['1_effectiveness_TP_score'] = 1-joined['effectiveness_TP_score']
    # 2. If the overlap between the ground truth and prediction is >= 0.5, 
    # and the overlap between the prediction and the ground truth >= 0.5,
    # the prediction is a match and considered a true positive.
    # If multiple matches exist, the match with the highest pair of overlaps is taken.
    # we don't need to compute the match to compute the score
    
    joined = joined.sort_values(["overlaps",'overlaps_scores'],ascending=False).reset_index(drop=True).groupby('gt_id').head(1)
    
    TP = joined[joined.overlaps==1]['gt_id'].nunique()
    
    TP_weighted = weight_tp_segment*TP + (1-weight_tp_segment)*(joined[joined.overlaps==1]['effectiveness_TP_score'].sum())
    
    # 3. Any unmatched ground truths are false negatives
    # and any unmatched predictions are false positives.
    TPandFP = len(pred_df)
    TPandFN = len(gt_df)
    
    TPandFP_weighted = TP_weighted + (joined[joined.overlaps==0]['1_effectiveness_TP_score'].sum())
    
    #calc microf1
    f1_score_fb1 = 2*TP / (TPandFP + TPandFN)
    new_f1_score = 2*TP_weighted / (TPandFP_weighted + TPandFN)
    return f1_score_fb1,new_f1_score

def score_feedback_comp(pred_df, gt_df,threshold=0.5, weight_tp_segment=0.5,return_class_scores=False):
    class_scores_fb1 = {}
    new_class_scores = {}
    for discourse_type in gt_df.discourse_type.unique():
        s_fb1,s = score_feedback_comp_micro(pred_df, gt_df, discourse_type,threshold,weight_tp_segment)
        class_scores_fb1[discourse_type] = s_fb1
        new_class_scores[discourse_type] = s
    
    f1_fb1 = np.mean([v for v in class_scores_fb1.values()])
    new_f1 = np.mean([v for v in new_class_scores.values()])
    if return_class_scores:
        return f1_fb1,class_scores_fb1,new_f1, new_class_scores
    return f1_fb1,new_f1

In [24]:
def get_pred(col):
    def row_wise(row):
        return " ".join([str(x) for x in range(row.start,row.end)])
    return row_wise

def predict(model,data, test_score_thr=0.5):
    
    obj_pred, reg_pred, cls_pred,eff_pred = model(data)
    obj_pred = obj_pred.sigmoid()
    reg_pred = reg_pred.exp()
    cls_pred = cls_pred.sigmoid()
    eff_pred = eff_pred.sigmoid()

    obj_scores = obj_pred
    cls_scores, cls_labels = cls_pred.max(-1)
    eff_scores, eff_labels = eff_pred.max(-1)
    
    pr_scores = (obj_scores * cls_scores)**0.5
    pr_eff_score = (obj_scores * eff_scores)**0.5
    eff_pred_score = (obj_scores.reshape(-1,1) * eff_pred)**0.5
    pos_inds = pr_scores > test_score_thr

    if pos_inds.sum() == 0:
        return None

    pr_score,pr_label,pr_eff = pr_scores[pos_inds], cls_labels[pos_inds], eff_labels[pos_inds]
    pos_loc = pos_inds.nonzero().flatten()
    start = pos_loc - reg_pred[pos_inds, 0]
    end = pos_loc + reg_pred[pos_inds, 1]

    min_idx, max_idx = 0, obj_pred.numel() - 1
    start = start.clamp(min=min_idx, max=max_idx).round().long()
    end = end.clamp(min=min_idx, max=max_idx).round().long()

    # nms
    keep = span_nms(start, end, pr_score)
    start = start[keep]
    end = end[keep]
    pr_score = pr_score[keep]
    pr_label = pr_label[keep]
    pr_eff = pr_eff[keep]
    pr_eff_score = pr_eff_score[keep]
    eff_pred_score = eff_pred_score[keep]
    
    res = dict(id=data['text_id'],
                start=to_np(start),
                end=to_np(end),
                score_discourse_type=to_np(pr_score),
                discourse_type=to_np(pr_label),
                discourse_effectiveness = to_np(pr_eff),
                score_discourse_effectiveness = to_np(pr_eff_score),
                score_discourse_effectiveness_0 = to_np(eff_pred_score[:,0]),
                score_discourse_effectiveness_1 = to_np(eff_pred_score[:,1]),
                score_discourse_effectiveness_2 = to_np(eff_pred_score[:,2]),
               )
    res = pd.DataFrame(res).sort_values('start').reset_index(drop=True)
    res['predictionstring'] = res.apply(get_pred(' '),axis=1)
    return res

# # ----------------- One Step --------------------- #
def get_losses(obj_pred, reg_pred, cls_pred,eff_pred, obj_target, reg_target,
               cls_target,eff_target, pos_loc):
    
    num_total_samples = pos_loc.numel()
    assert num_total_samples > 0
    reg_pred = reg_pred[pos_loc].exp()
    reg_target = reg_target[pos_loc]
    px1 = pos_loc - reg_pred[:, 0]
    px2 = pos_loc + reg_pred[:, 1]
    gx1 = reg_target[:, 0]
    gx2 = reg_target[:, 1]
    ix1 = torch.max(px1, gx1)
    ix2 = torch.min(px2, gx2)
    ux1 = torch.min(px1, gx1)
    ux2 = torch.max(px2, gx2)
    inter = (ix2 - ix1).clamp(min=0)
    union = (ux2 - ux1).clamp(min=0) + 1e-12
    iou = inter / union

    reg_loss = -iou.log().sum() / num_total_samples
    
    cls_loss = F.binary_cross_entropy_with_logits(
        cls_pred[pos_loc],
        cls_target[pos_loc] * iou.detach().reshape(-1, 1),
        reduction='sum') / num_total_samples
    
    eff_loss = F.binary_cross_entropy_with_logits(
        eff_pred[pos_loc],
        eff_target[pos_loc] * iou.detach().reshape(-1, 1),
        reduction='sum') / num_total_samples
    
    obj_loss = F.binary_cross_entropy_with_logits(
        obj_pred, obj_target, reduction='sum') / num_total_samples
    
    return obj_loss, reg_loss, cls_loss,eff_loss

# # ----------------- One Step --------------------- #
def training_step(args,model,data):
    model.train()
    device = model.backbone.device
    data = to_gpu(data, device)

    if args.trainer['use_amp']:
        with amp.autocast(args.trainer['use_amp']):
            obj_pred, reg_pred, tpe_pred,eff_pred = model(data)

            obj_target, reg_target, cls_target,eff_target, pos_loc = build_target(
            data['gt_spans'], obj_pred, reg_pred, tpe_pred,eff_pred)
            obj_loss, reg_loss, cls_loss,eff_loss = get_losses(obj_pred, reg_pred,
                                                       tpe_pred,eff_pred, obj_target,
                                                       reg_target, cls_target,eff_target,
                                                       pos_loc)

            loss = obj_loss + reg_loss + cls_loss + eff_loss
            log_vars = dict(
                train_obj_loss=obj_loss.item(),
                train_reg_loss=reg_loss.item(),
                train_cls_loss=cls_loss.item(),
                train_eff_loss=eff_loss.item(),
                train_loss=loss.item(),
            )

    else:
        obj_pred, reg_pred, tpe_pred,eff_pred = model(data)

        obj_target, reg_target, cls_target,eff_target, pos_loc = build_target(
        data['gt_spans'], obj_pred, reg_pred, tpe_pred,eff_pred)
        obj_loss, reg_loss, cls_loss,eff_loss = get_losses(obj_pred, reg_pred,
                                                    tpe_pred,eff_pred, obj_target,
                                                    reg_target, cls_target,eff_target,
                                                    pos_loc)

        loss = obj_loss + reg_loss + cls_loss + eff_loss
        log_vars = dict(
            train_obj_loss=obj_loss.item(),
            train_reg_loss=reg_loss.item(),
            train_cls_loss=cls_loss.item(),
            train_eff_loss=eff_loss.item(),
            train_loss=loss.item(),
        )
    
    return loss,log_vars


def evaluation_step(model,val_loader):

    device = model.model.device
    model.eval()

    losses = []
    gt_df = []
    pred_df = []
    with torch.no_grad():
        for data in tqdm(val_loader):
            data = to_gpu(data, device)
            pred_df.append(predict(model,data, test_score_thr=0.5))
            obj_pred, reg_pred, tpe_pred,eff_pred = model(data)
            

            obj_target, reg_target, cls_target,eff_target, pos_loc = build_target(
            data['gt_spans'], obj_pred, reg_pred, tpe_pred,eff_pred)
            obj_loss, reg_loss, cls_loss,eff_loss = get_losses(obj_pred, reg_pred,
                                                       tpe_pred,eff_pred, obj_target,
                                                       reg_target, cls_target,eff_target,
                                                       pos_loc)

            loss = obj_loss + reg_loss + cls_loss + eff_loss
            losses.append([[loss.item(),obj_loss.item(),reg_loss.item(),cls_loss.item(),eff_loss.item()]])
            
#             pred_df.append(predict(model,data, test_score_thr=0.33))

            data = to_np(data)
            # print(data)
            gt = pd.DataFrame({
              "id":data['text_id'],
              "start":data["gt_spans"][:,0],
              "end":data["gt_spans"][:,1],
              "discourse_type":data["gt_spans"][:,2],
              "discourse_effectiveness":data["gt_spans"][:,3],
                            })
            gt['predictionstring'] = gt.apply(get_pred(' '),axis=1)
            gt_df.append(gt)
            

    pred_df = pd.concat(pred_df,axis=0).reset_index(drop=True)
    gt_df = pd.concat(gt_df,axis=0).reset_index(drop=True)

    macro_f1,new_macro_f1 = score_feedback_comp(pred_df, gt_df,threshold=0.5, weight_tp_segment=0.5,return_class_scores=False)

    losses = np.concatenate(losses,axis=0).mean(0)

    log_vars = dict(
        valid_obj_loss=losses[1],
        valid_reg_loss=losses[2],
        valid_cls_loss=losses[3],
        valid_eff_loss=losses[4],
        valid_loss=losses[0],
        f1_score_fb1 = macro_f1,
        f1_score_new = new_macro_f1
    )
    
    return log_vars,pred_df,gt_df

In [25]:
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [26]:
net = net.to(device)

In [27]:
log_vars,pred_df,gt_df = evaluation_step(net,dl)

Token indices sequence length is longer than the specified maximum sequence length for this model (621 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (895 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1510 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (725 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (559 > 512). Running this sequence through the model will result in indexing errors


  0%|          | 0/15594 [00:01<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (576 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (600 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1154 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (579 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (549 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for th

In [29]:
log_vars

{'valid_obj_loss': 1.0515049243699657,
 'valid_reg_loss': 0.20314175415860805,
 'valid_cls_loss': 0.5793302156335959,
 'valid_eff_loss': 7.261822584027952,
 'valid_loss': 9.0957994823106,
 'f1_score_fb1': 0.7995783940930086,
 'f1_score_new': 0.5430791422738409}

In [135]:
log_vars

{'valid_obj_loss': 1.449174536249262,
 'valid_reg_loss': 0.37150918686670575,
 'valid_cls_loss': 0.9040437645709033,
 'valid_eff_loss': 5.820261013935499,
 'valid_loss': 8.544988494710537,
 'f1_score_fb1': 0.43916997421873677,
 'f1_score_new': 0.47009926344471326}

In [140]:
log_vars

{'valid_obj_loss': 1.449174536249262,
 'valid_reg_loss': 0.37150918686670575,
 'valid_cls_loss': 0.9040437645709033,
 'valid_eff_loss': 5.820261013935499,
 'valid_loss': 8.544988494710537,
 'f1_score_fb1': 0.6636342906166302,
 'f1_score_new': 0.44404925987166305}

In [168]:
gt_df[gt_df.id=="000A58BC095E"]

Unnamed: 0,id,start,end,discourse_type,discourse_effectiveness,predictionstring
0,000A58BC095E,5,48,0,0,5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 ...
1,000A58BC095E,48,73,1,0,48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 6...
2,000A58BC095E,73,89,5,0,73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
3,000A58BC095E,89,97,2,0,89 90 91 92 93 94 95 96
4,000A58BC095E,97,104,5,2,97 98 99 100 101 102 103
5,000A58BC095E,104,118,2,0,104 105 106 107 108 109 110 111 112 113 114 11...
6,000A58BC095E,118,139,2,0,118 119 120 121 122 123 124 125 126 127 128 12...
7,000A58BC095E,139,165,5,0,139 140 141 142 143 144 145 146 147 148 149 15...
8,000A58BC095E,165,247,6,1,165 166 167 168 169 170 171 172 173 174 175 17...


In [169]:
pred_df[pred_df.id=="000A58BC095E"]

Unnamed: 0,id,start,end,score_discourse_type,discourse_type,discourse_effectiveness,score_discourse_effectiveness,score_discourse_effectiveness_0,score_discourse_effectiveness_1,score_discourse_effectiveness_2,predictionstring
0,000A58BC095E,5,52,0.9631,0,1,0.0044,0.0015,0.0044,0.002,5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 ...
1,000A58BC095E,48,73,0.963,1,2,0.0007,0.0004,0.0007,0.0006,48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 6...
2,000A58BC095E,73,90,0.7617,2,1,0.001,0.0004,0.001,0.0008,73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 8...
3,000A58BC095E,89,98,0.5033,3,1,0.0007,0.0004,0.0007,0.0006,89 90 91 92 93 94 95 96 97
4,000A58BC095E,97,107,0.5646,5,1,0.0247,0.0041,0.0247,0.018,97 98 99 100 101 102 103 104 105 106
5,000A58BC095E,104,118,0.867,2,1,0.0037,0.0023,0.0037,0.0016,104 105 106 107 108 109 110 111 112 113 114 11...
6,000A58BC095E,118,145,0.6162,5,1,0.0102,0.0014,0.0102,0.0021,118 119 120 121 122 123 124 125 126 127 128 12...
7,000A58BC095E,139,167,0.7116,2,1,0.0006,0.0002,0.0006,0.0001,139 140 141 142 143 144 145 146 147 148 149 15...
8,000A58BC095E,165,209,0.6191,6,2,0.0004,0.0002,0.0004,0.0001,165 166 167 168 169 170 171 172 173 174 175 17...
9,000A58BC095E,165,187,0.5022,2,2,0.0003,0.0002,0.0003,0.0001,165 166 167 168 169 170 171 172 173 174 175 17...


In [None]:
if not os.path.exists("../input/train_cache.csv"):
        df = pd.read_csv(f"{cfg.dataset.base_dir}/persuade_corpus.csv")
        df['id'] = df['essay_id']
        df_train = df.loc[df['competition_set'] == "train"].reset_index(drop=True)
        df_train = df_train.loc[df_train['discourse_type'] != "Unannotated"].reset_index(drop=True)
        
        
        df_test = df.loc[df['competition_set'] == "test"].reset_index(drop=True)
        df_test = df_test.loc[df_test['discourse_type'] != "Unannotated"]
        df_test = df_test[~df_test.discourse_effectiveness.isna()].reset_index(drop=True)
        df_test['id'] = df_test['essay_id'] 
        df_test['Ineffective'] = df_test.discourse_effectiveness.apply(lambda x: 1 if x == "Ineffective" else 0)
        df_test['Adequate'] = df_test.discourse_effectiveness.apply(lambda x: 1 if x == "Adequate" else 0)
        df_test['Effective'] = df_test.discourse_effectiveness.apply(lambda x: 1 if x == "Effective" else 0)

        
        df_test_pri = df_test.loc[df_test['test_split_feedback_1'] == "Private"].reset_index(drop=True)
        df_test = df_test.loc[df_test['test_split_feedback_1'] == "Public"].reset_index(drop=True)


        fb1 = df_train.loc[df_train['in_feedback2.0'] == 0].reset_index(drop=True)
        fb1_effective = fb1.loc[fb1['discourse_effectiveness'] == "Effective"].reset_index(drop=True)
        fb1_adequate = fb1.loc[fb1['discourse_effectiveness'] == "Adequate"].reset_index(drop=True)
        fb2 = df_train.loc[df_train['in_feedback2.0'] == 1].reset_index(drop=True)

        df = pd.concat([fb1_adequate, fb1_effective, fb2]).reset_index(drop=True)

        data = []
        for essay_id in tqdm(df.essay_id.unique(), total=len(df.essay_id.unique())):
            temp_df = df[df.essay_id == essay_id].reset_index(drop=True)
            res = utils.add_discourse_start_end(temp_df, datatype="train")
            data.append(res)

        df_train = pd.concat(data).reset_index(drop=True)

        data_test = []
        for essay_id in tqdm(df_test.essay_id.unique(), total=len(df_test.essay_id.unique())):
            temp_df = df_test[df_test.essay_id == essay_id].reset_index(drop=True)
            res = utils.add_discourse_start_end(temp_df, datatype="train")
            data_test.append(res)

        df_test = pd.concat(data_test).reset_index(drop=True)  
        df_train = df_train[df_train.discourse_start != -1].reset_index(drop=True)
        df_test = df_test[df_test.discourse_start != -1].reset_index(drop=True)
        
        df_train.to_csv("../input/train_cache.csv", index=False)
        df_test.to_csv("../input/test_cache.csv", index=False)
    else:
        df_train = pd.read_csv("../input/train_cache.csv")
        df_test = pd.read_csv("../input/test_cache.csv")