In [None]:
import sys
sys.path.insert(0, 'feedback_debertav2_fasttokenizer')
import torch as t
from glob import glob
import pandas as pd
import dill as pickle
import re
import numpy as np
from tqdm import tqdm
import spacy
from spacy import displacy
from transformers import DebertaV2TokenizerFast
import h5py

In [None]:
tokenizer = DebertaV2TokenizerFast.from_pretrained('microsoft/deberta-v2-xlarge')
tokenizer.model_max_length = 2048

In [None]:
data = pd.read_csv('train.csv')
data.loc[data.discourse_id==1623258656795.0, 'discourse_text'] =  data.loc[data.discourse_id==1623258656795.0, 
                                                                           'discourse_text'].map(lambda x: x.replace('florida', 'LOCATION_NAME')).values

In [None]:
label_names = ['None', 'Lead', 'Position', 'Evidence', 'Claim',
               'Concluding Statement', 'Counterclaim', 'Rebuttal']

In [None]:
colors = {
            'Lead': '#8000ff',
            'Position': '#2b7ff6',
            'Evidence': '#2adddd',
            'Claim': '#80ffb4',
            'Concluding Statement': 'd4dd80',
            'Counterclaim': '#ff8042',
            'Rebuttal': '#ff0000'
         }
options = {"ents": list(colors.keys()), "colors": colors}

In [None]:
token_maps = dict(zip(colors, range(1, 2 * len(colors), 2)))

In [None]:
token_maps

In [None]:
def show_sample(ID):
    with open(f'train/{ID}.txt') as f:
        org_text = f.read()
    dsc = data.loc[data.id==ID].sort_values('discourse_start')
    ents = []
    pointer = 0
    for _, row in dsc.iterrows():
        entity = row.discourse_text.strip()
        starts_at = org_text[pointer:].index(entity) + pointer
        end = 0
        start = 0
        if starts_at + len(entity) < len(org_text) and  punkt_re.match(org_text[starts_at + len(entity)]) is not None:
            end = 1
        if punkt_re.match(org_text[starts_at]) is not None:
            start = 1
        ents.append({'start': starts_at + start, 'end': starts_at + len(entity) + end, 'label': row.discourse_type})
        pointer = starts_at + len(entity)
    displacy.render({'text': org_text, 'ents': ents, 'title': ID}, style="ent",
                    options=options, manual=True, jupyter=True)

In [None]:
regexp = re.compile('[0-9a-zA-z]')

In [None]:
def make_more_targets(targets):
    linkage = np.zeros((len(targets), 2), 'f4')
    class_index = np.zeros((len(targets),), 'f4')
    linkage_mask = np.ones((len(targets),), 'f4')
    current_target = -2
    for ix in range(1, len(targets) -1):
        if ((current_target % 2 == 0 and current_target == targets[ix]) 
              or (targets[ix] == current_target + 1 and current_target %2 == 1)):
            linkage[ix - 1, 1] = 1
            linkage[ix, 0] = 1
        current_target = targets[ix]
        class_index[:] = [x // 2 for x in targets + 1]
    link_sums = (linkage * np.array([2, 1])).sum(-1).astype('i4')
    bi =  np.zeros((len(targets), 2), 'f4')
    bi[link_sums < 2, 0] = 1
    bi[link_sums >= 2, 1] = 1
    bio = np.array(bi)
    bio[targets==0] = 0
    bies =  np.zeros((len(targets), 4), 'f4')
    bies[:, :2] = bi
    bies[link_sums == 0] = (0, 0, 0, 1)
    bies[link_sums == 2] = (0, 0, 1, 0)
    bieso = np.array(bies)
    bieso[targets==0] = 0
    return class_index, bi, bio, bies, bieso

In [None]:
def combine_labels(class_index, bi, bio, bies, bieso):
    combined_bi = class_index * 2 + bi[:, 0]
    combined_bies = class_index * 4 + bies @ np.array([0, 1, 2, 3])
    non_o_index = np.where(class_index != 0)[0]
    combined_bieso = np.array(class_index)
    combined_bieso[non_o_index] = (class_index[non_o_index] - 1) * 4 + bieso[non_o_index] @ np.array([1, 2, 3, 4])
    combined_bio = np.array(class_index)
    combined_bio[non_o_index] = (class_index[non_o_index] - 1) * 2 + bio[non_o_index] @ np.array([1, 2])
    return combined_bi, combined_bio, combined_bies, combined_bieso

In [None]:
import h5py

In [None]:
data_file = h5py.File('debertav2_data.h5py', 'w')

In [None]:
num_texts = len(glob('train/*.txt')) + 1

In [None]:
tokens_dataset = data_file.create_dataset('tokens', (num_texts, 2048), 'i8')
attention_masks_dataset = data_file.create_dataset('attention_masks', (num_texts, 2048), 'f4')
token_offsets_dataset = data_file.create_dataset('token_offsets', (num_texts, 2048, 2), 'i4')
class_labels_dataset = data_file.create_dataset('class_labels', (num_texts, 2048, 8), 'f4')
num_tokens_dataset = data_file.create_dataset('num_tokens', (num_texts, 2048), 'i4')

bi_labels_dataset = data_file.create_dataset('bi_labels', (num_texts, 2048, 2), 'f4')
bio_labels_dataset = data_file.create_dataset('bio_labels', (num_texts, 2048, 2), 'f4')
bies_labels_dataset = data_file.create_dataset('bies_labels', (num_texts, 2048, 4), 'f4')
bieso_labels_dataset = data_file.create_dataset('bieso_labels', (num_texts, 2048, 4), 'f4')

cbi_labels_dataset = data_file.create_dataset('cbi_labels', (num_texts, 2048, 16), 'f4')
cbio_labels_dataset = data_file.create_dataset('cbio_labels', (num_texts, 2048, 15), 'f4')
cbies_labels_dataset = data_file.create_dataset('cbies_labels', (num_texts, 2048, 32), 'f4')
cbieso_labels_dataset = data_file.create_dataset('cbieso_labels', (num_texts, 2048, 29), 'f4')


In [None]:
for dataset in tqdm((tokens_dataset, 
               attention_masks_dataset,
               token_offsets_dataset,
               class_labels_dataset,
               num_tokens_dataset,
               bi_labels_dataset,
               bio_labels_dataset,
               bies_labels_dataset,
               bieso_labels_dataset,
               cbi_labels_dataset,
               cbio_labels_dataset,
               cbies_labels_dataset,
               cbieso_labels_dataset,)):
    dataset[-1] = 0
    


In [None]:
def make_one_hot(indices, num_labels):
    array = np.zeros((len(indices), num_labels))
    array[np.arange(len(indices)), indices.astype('i4')] = 1
    return array

In [None]:
fix_text = lambda x: x.replace('\n', '‽')

In [None]:
from collections import Counter

In [None]:
broken_indices = []

In [None]:
id_to_ix_map = {}
for filename_ix, filename in tqdm(enumerate(glob('train/*.txt')), total = num_texts-1):
    ID = filename.split('/')[-1].split('.')[0]
    with open(filename) as f:
        text = fix_text(f.read().strip())
    tokenizer_outs = tokenizer(text, return_offsets_mapping=True)
    tokenizer_outs['input_ids'] = [x if x != 126599 else 128000 for x in tokenizer_outs['input_ids']]
    text_data = data.loc[data.id==ID].sort_values('discourse_start')
    ent_boundaries = []
    pointer = 0
    for row_id, row in text_data.iterrows():
        entity_text = fix_text(row.discourse_text.strip())
        entity_text = entity_text[next(regexp.finditer(entity_text)).start():]
        if len(entity_text.split()[0]) == 1 and pointer != 0:
            entity_start_ix = text[pointer:].index(entity_text)
            prev_text = text[:pointer + entity_start_ix]
            if pointer + entity_start_ix > 0 and prev_text[-1].isalpha():
                broken_indices.append((filename_ix, ID))
                print('cut entity ', filename_ix, ID)
                cut_word_chunk_size = len(prev_text.split()[-1])
                if cut_word_chunk_size > 1:
                    entity_text = entity_text[next(regexp.finditer(entity_text[1:])).start() + 1 :]
        if row.discourse_id in (1620147556527.0, 1622983056026.0):
            pointer += 10
        offset = text[pointer:].index(entity_text)
        starts_at = offset + pointer
        ent_boundaries.append((starts_at, starts_at + len(entity_text), row.discourse_type))
        pointer = starts_at + len(entity_text)
    all_boundaries = [(z, x[-1], t) for x in ent_boundaries for z, t in zip(x[:2], ('start', 'end'))]
    current_target = 0
    targets = np.zeros(len(tokenizer_outs['input_ids']), 'i8')
    token_positions = np.vstack(tokenizer_outs['offset_mapping']).astype('i4')
    for token_ix in range(len(tokenizer_outs['input_ids'])):
        token_start_ix, token_end_ix = tokenizer_outs['offset_mapping'][token_ix]
        if token_end_ix != 0 and (all_boundaries[0][2] == 'end' and token_end_ix >= all_boundaries[0][0])\
                            or (all_boundaries[0][2] == 'start' and token_end_ix > all_boundaries[0][0]):
            if all_boundaries[0][2] == 'start':
                current_target = token_maps[all_boundaries[0][1]]
                targets[token_ix] = current_target
                if token_end_ix == all_boundaries[1][0]:
                    current_target = 0
                    all_boundaries.pop(0)
                else:
                    current_target += 1
            else:
                if len(all_boundaries) > 1 and token_end_ix > all_boundaries[1][0]:
                    if token_start_ix >= all_boundaries[1][0]:
                        assert text[all_boundaries[0][0] - 1] == '¨'
                    all_boundaries.pop(0)
                    current_target = token_maps[all_boundaries[0][1]]
                    targets[token_ix] = current_target
                    current_target += 1
                else:
                    if token_start_ix >= all_boundaries[0][0]:
                        current_target = 0
                    targets[token_ix] = current_target
                    current_target = 0
            all_boundaries.pop(0)
            if not all_boundaries:
                break
        else:
            targets[token_ix] = current_target
    class_index, bi, bio, bies, bieso = make_more_targets(targets)
    combined_bi, combined_bio, combined_bies, combined_bieso = combine_labels(class_index, bi, bio, bies, bieso)
    assert (combined_bio[1:-1] == targets[1:-1]).all()
    num_tokens = len(targets)
    
    tokens_dataset[filename_ix, :num_tokens] = tokenizer_outs['input_ids']
    tokens_dataset[filename_ix, num_tokens:] = tokenizer.pad_token_id
    attention_masks_dataset[filename_ix, :num_tokens] = tokenizer_outs['attention_mask']
    attention_masks_dataset[filename_ix, num_tokens:] = 0
    token_offsets_dataset[filename_ix, :num_tokens] = token_positions
    token_offsets_dataset[filename_ix, num_tokens:] = 0
    class_labels_dataset[filename_ix, :num_tokens] = make_one_hot(class_index, 8)
    class_labels_dataset[filename_ix, num_tokens:] = 0
    num_tokens_dataset[filename_ix] = num_tokens
    bi_labels_dataset[filename_ix, :num_tokens] = bi
    bi_labels_dataset[filename_ix, num_tokens:] = 0
    bio_labels_dataset[filename_ix, :num_tokens] = bio
    bio_labels_dataset[filename_ix, num_tokens:] = 0
    bies_labels_dataset[filename_ix, :num_tokens] = bies
    bies_labels_dataset[filename_ix, num_tokens:] = 0
    bieso_labels_dataset[filename_ix, :num_tokens] = bieso
    bieso_labels_dataset[filename_ix, num_tokens:] = 0
    cbi_labels_dataset[filename_ix, :num_tokens] = make_one_hot(combined_bi, 16)
    cbi_labels_dataset[filename_ix, num_tokens:] = 0
    cbio_labels_dataset[filename_ix, :num_tokens] = make_one_hot(combined_bio, 15)
    cbio_labels_dataset[filename_ix, num_tokens:] = 0
    cbies_labels_dataset[filename_ix, :num_tokens] = make_one_hot(combined_bies, 32)
    cbies_labels_dataset[filename_ix, num_tokens:] = 0
    cbieso_labels_dataset[filename_ix, :num_tokens] = make_one_hot(combined_bieso, 29)
    cbieso_labels_dataset[filename_ix, num_tokens:] = 0
                                                                

In [None]:
data_file.close()

In [None]:
id_to_ix_map = {filename: ix for ix, filename in enumerate(glob('train/*.txt'))}
with open('id_to_ix_map.pickle', 'wb') as f:
    import dill as pickle
    pickle.dump(id_to_ix_map, f)

In [None]:
data_file = h5py.File('debertav2_data.h5py')
labels = data_file['cbio_labels']
token_nums = data_file['num_tokens']
token_counts = np.zeros(15)
for ix in tqdm(range(len(labels) - 1)):
    a, b = np.unique(labels[ix, :token_nums[ix, 0]].argmax(-1), return_counts=True)
    token_counts[a] += b

In [None]:
with open('token_counts.pickle', 'wb') as f:
    pickle.dump((token_counts, token_counts), f)

In [None]:
data_file.close()

In [None]:
#bert_normalizer.normalize_str, ftfy