# Online Tokenizer

> Literally, the title is all

In [255]:
import os
import os.path as osp
import sys

import re
import argparse

import numpy as np
import pandas as pd
from collections import deque

import torch

from module.utils import get_data_files
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForTokenClassification

sys.path.insert(0, './codes/new_transformers_branch/transformers/src')

In [2]:
def get_config():
    parser = argparse.ArgumentParser(description="use huggingface models")
    parser.add_argument("--dataset_path", default='../../feedback-prize-2021', type=str)
    parser.add_argument("--save_path", default='result', type=str)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--min_len", default=0, type=int)
    parser.add_argument("--use_groupped_weights", default=False, type=bool)
    parser.add_argument("--global_attn", default=False, type=int)
    parser.add_argument("--epochs", default=9, type=int)
    parser.add_argument("--batch_size", default=4, type=int)
    parser.add_argument("--grad_acc_steps", default=2, type=int)
    parser.add_argument("--grad_checkpt", default=True, type=bool)
    parser.add_argument("--data_prefix", default='', type=str)
    parser.add_argument("--max_grad_norm", default=10.0, type=float)
    parser.add_argument("--start_eval_at", default=0, type=int)
    parser.add_argument("--weight_decay", default=1e-2, type=float)
    parser.add_argument("--weights_pow", default=0.1, type=float)
    parser.add_argument("--dataset_version", default=2, type=int)
    parser.add_argument("--decay_bias", default=False, type=bool)
    parser.add_argument("--val_fold", default=0, type=int)
    parser.add_argument("--num_worker", default=8, type=int)
    parser.add_argument("--local_rank", type=int, default=-1, help="do not modify!")
    parser.add_argument("--device", type=int, default=0, help="select the gpu device to train")

    # logging
    parser.add_argument("--wandb_user", default='ducky', type=str)
    parser.add_argument("--wandb_project", default='feedback_deberta_large', type=str)
    parser.add_argument("--wandb_comment", default="", type=str, help="comment will be added at the back of wandb project name")
    parser.add_argument("--print_acc", default=500, type=int, help="print accuracy of each class every `print_acc` steps")

    # optimizer
    parser.add_argument("--label_smoothing", default=0.1, type=float)
    parser.add_argument("--rce_weight", default=0.1, type=float)
    parser.add_argument("--ce_weight", default=0.9, type=float)
    parser.add_argument("--nesterov", default=True, type=bool, help="use nesterov for SGD")
    parser.add_argument("--momentum", default=0.9, type=float, help="momentum for SGD")

    # scheduler
    parser.add_argument("--lr", default=3e-5, type=float)
    parser.add_argument("--min_lr", default=1e-6, type=float)
    parser.add_argument("--warmup_steps", default=500, type=int)
    parser.add_argument("--gamma", default=0.8, type=float, help="gamma for cosine annealing warmup restart scheduler")
    parser.add_argument("--cycle_mult", default=1.0, type=float, help="cycle length adjustment for cosine annealing warmup restart scheduler")

    # model related arguments
    parser.add_argument("--model", default="microsoft/deberta-v3-large", type=str)
    parser.add_argument("--cnn1d", default=False, type=bool)
    parser.add_argument("--extra_dense", default= False, type=bool)
    parser.add_argument("--dropout_ratio", default=0.0, type=float)

    # swa
    parser.add_argument("--swa", action="store_true", help="use stochastic weight averaging")
    parser.add_argument("--swa_update_per_epoch", default=3, type=int)

    args = parser.parse_args([])

    if args.local_rank !=-1:
        print('[ DDP ] local rank', args.local_rank)
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend='nccl')
        args.device = torch.device("cuda", args.local_rank)
        args.rank = torch.distributed.get_rank()
        args.world_size = torch.distributed.get_world_size()  

        # checking settings for distributed training
        assert args.batch_size % args.world_size == 0, f'--batch_size {args.batch_size} must be multiple of world size'
        assert torch.cuda.device_count() > args.local_rank, 'insufficient CUDA devices for DDP command'

        args.ddp = True
    else:
        args.device = torch.device("cuda", args.device)
        args.rank = -1
        args.ddp = False

    return args

In [3]:
args = get_config()
all_texts, token_weights, data, csv, train_ids, val_ids, train_text_ids, val_text_ids = get_data_files(args)

In [4]:
text_id = train_text_ids[0]
text_id

'B72D0B4875B4'

In [5]:
text = all_texts[text_id]

## DebertaV3 Tokenizer

In [42]:
from new_transformers import DebertaV2TokenizerFast
from transformers import AutoTokenizer

In [13]:
tokenizer = DebertaV2TokenizerFast.from_pretrained('microsoft/deberta-v3-large')
tokenizer.model_max_length = 2048

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [44]:
auto_tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-large')
auto_tokenizer.model_max_length = 2048

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [22]:
tokenizer(['a', 'b'])

{'input_ids': [[1, 266, 2], [1, 2165, 2]], 'token_type_ids': [[0, 0, 0], [0, 0, 0]], 'attention_mask': [[1, 1, 1], [1, 1, 1]]}

In [24]:
tokenizer('a\n')

{'input_ids': [1, 266, 507, 2], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}

In [45]:
auto_tokenizer('a\n')

{'input_ids': [1, 266, 2], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]}

In [39]:
tokenizer('a\n')

{'input_ids': [1, 266, 507, 2], 'token_type_ids': [0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1]}

In [25]:
tokenizer('\n')

{'input_ids': [1, 507, 2], 'token_type_ids': [0, 0, 0], 'attention_mask': [1, 1, 1]}

In [46]:
decoded_text = tokenizer.decode([12, 507, 16])
decoded_text

'<0x08> <0x0C>'

In [49]:
decoded_text

'<0x08> <0x0C>'

### newline (\n) is removed by DebertaV3 Tokenizer

In [None]:
fix_text = lambda x: x.replace('\n', '‽')

text = fix_text(f.read().strip())

In [None]:
tokenizer_outs = tokenizer(text, return_offsets_mapping=True)

# token replacement ‽ -> [MASK]
tokenizer_outs['input_ids'] = [x if x != 126861 else 128000 for x in tokenizer_outs['input_ids']]

### what about prediction string that matters `' '`?

In [None]:
char_start = discourse_start
char_end = discourse_end
word_start = len(full_text[:char_start].split())
word_end = word_start + len(full_text[char_start:char_end].split())
word_end = min( word_end, len(full_text.split()) )
predictionstring = " ".join( [str(x) for x in range(word_start,word_end)] )

In [25]:
'  a'.split(' ')

['', '', 'a']

### Does Tokenizer `model_max_length` is a problem?

> this is weird 

In [52]:
tokenizer = DebertaV2TokenizerFast.from_pretrained('microsoft/deberta-v3-large')
tokenizer.model_max_length = 5

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [54]:
tokenizer.tokenize('do really tokenizer cut in length 5?')

Token indices sequence length is longer than the specified maximum sequence length for this model (9 > 5). Running this sequence through the model will result in indexing errors


['▁do', '▁really', '▁token', 'izer', '▁cut', '▁in', '▁length', '▁5', '?']

In [None]:
regexp = re.compile('[0-9a-zA-z]')

## Dataset
> everything is controlled by `text_id`

1. text is preprocessed (text -> list)
    - strip
    - newline replacing
    - **change text to list**
2. augmentation & noise injection (list -> text)
    - one char changing
    - one char removing
    - [text augmentation](https://www.kaggle.com/c/feedback-prize-2021/discussion/295277)
    - **change list to text**
3. calculate entity boundaries
    - use boundary supported by train.csv
    - alternatively could calculate entity boundary with noise elimination
4. tokenize the text
    - token id
    - mask
    - offset
5. calculate the label by using `entity boundary` and `offset`

In [240]:
colors = {
            'Lead': '#8000ff',
            'Position': '#2b7ff6',
            'Evidence': '#2adddd',
            'Claim': '#80ffb4',
            'Concluding Statement': 'd4dd80',
            'Counterclaim': '#ff8042',
            'Rebuttal': '#ff0000'
         }
cat2id = dict(zip(colors, range(1, 2 * len(colors), 2)))
cat2id

{'Lead': 1,
 'Position': 3,
 'Evidence': 5,
 'Claim': 7,
 'Concluding Statement': 9,
 'Counterclaim': 11,
 'Rebuttal': 13}

In [56]:
text_id

'B72D0B4875B4'

In [195]:
text = all_texts[text_id]

In [214]:
text_df = csv.query('id == @text_id')

### Text -> List

> clean the text_df to match the text file's content also

In [197]:
text_id

'B72D0B4875B4'

In [213]:
def text2list(text, text_df, clean_text_df=True):
    """Convert the text to list
    This is mainly to work on data augmentation and noise injection
    
    I'm working now quark! -> [[Lead, I'm working"],
                               [Nonez, " "],
                               [Claim, "now quark!"]]
    
    Args:
        text (str): literally the text of each text_id returns
        text_df (pandas.DataFrame): the dataframe file for each text
        clean_text_df (bool): text files and discourse_text in train.csv file doesn't match
                              fix the text to which is stored in the "{text_id}.txt" files
        
    Returns:
        text_list (list): list that stores the divided text and category of each text
        text_df (pandas.DataFrame): the dataframe file for each text

    """
    text_df = text_df.copy()
    
    text_list = []
    first_sentence = True
    last_end_idx = 0
    for row in text_df.itertuples():
        start_idx = int(row.discourse_start)
        end_idx = int(row.discourse_end)
        cat = row.discourse_type

        # the first sentence that will stored in the list
        if first_sentence:
            # when the first sentence is not the entity
            # 1. store the first sentence with none entity
            # 2. store the entity sentence
            if start_idx != 0:
                text_list.append(["None", text[:start_idx]])

            # save the entity
            text_list.append([cat, text[start_idx:end_idx]])
            first_sentence = False
            last_end_idx = end_idx
        else:
            # when there is a middle sentence save it also
            if last_end_idx != start_idx:
                middle_text = text[last_end_idx:start_idx]
                text_list.append(["None", middle_text])

            # save the entity
            text_list.append([cat, text[start_idx:end_idx]])
            last_end_idx = end_idx

    # when there is sentence left store it
    text_len = len(text)
    if last_end_idx != text_len:
        last_text = text[last_end_idx:text_len]
        text_list.append(["None", last_text])
        
    if clean_text_df:
        discourse_texts = []
        for discourse_type, discourse_text in text_list:
            if discourse_type != 'None':
                discourse_texts.append(discourse_text)
                
        text_df.loc[text_df.index, 'discourse_text'] = discourse_texts
        
    return text_list, text_df

In [221]:
text_list, clean_text_df = text2list(text, text_df, clean_text_df=True)

In [216]:
text_df.iloc[2].discourse_text

"It is unfair because the people's votes might be overuled "

In [222]:
clean_text_df.iloc[2].discourse_text

"It is unfair because the people's votes might be overuled,"

### List -> Text

In [223]:
def list2text(text_list, text_df, return_df=False):
    """Convert the text to list
    Convert the list to text after data augmentation and noise injection
    
    [[Lead, I'm working"],
     [None, " "],
     [Claim, "now quark!"]]
    -> I'm working now quark!
    Args:
        text_list (list): list that stores the divided text and category of each text
        text_df (pandas.DataFrame): the dataframe file for each text
        return_df (bool): If the augmentation is hard enough
                          to change the word or sentence different from original
                          recalculate the text_df totally with prediction string also
    
    Returns:
        text (str): Merged text from text_list
        text_df (optional[pandas.DataFrame]): None, or the dataframe file for each text
                                              if return_df is True
    """
    text_df = text_df.copy()
    
    # convert to text
    text = ''.join(np.array(text_list)[:, 1])
    
    if not return_df:
        return text, text_df
    
    # convert to text_df
    text_id = text_df.id.iloc[0]

    last_position = 0
    discourses = []
    for discourse_type, discourse_text in text_list:
        text_len = len(discourse_text)
        if discourse_type != "None":
            discourse_start = last_position
            discourse_end = last_position + text_len
            discourse_rows = {'id': text_id,
                              'discourse_start': discourse_start,
                              'discourse_end': discourse_end,
                              'discourse_text': discourse_text,
                              'discourse_type': discourse_type}
            discourses.append(discourse_rows)

        last_position += text_len
    
    text_df = pd.DataFrame(discourses)
    
    # recalculate prediction string
    text_df['predictionstring'] = text_df[['discourse_start', 'discourse_end']].apply(calculate_predictionstring, axis=1)
    
    return text, text_df

In [224]:
def calculate_predictionstring(row):
    """recalculate prediction string for the augmented text data
    
    reference - https://www.kaggle.com/c/feedback-prize-2021/discussion/297591
    """
    word_start = len(text[:row.discourse_start].split())
    word_end = word_start + len(text[row.discourse_start:row.discourse_end].split())
    word_end = min(word_end, len(text.split()))

    predictionstring = " ".join([str(x) for x in range(word_start, word_end)])
    
    return predictionstring

In [226]:
new_text, new_text_df = list2text(text_list, clean_text_df, return_df=False)

### Tokenizer

In [248]:
tokenizer = DebertaV2TokenizerFast.from_pretrained('microsoft/deberta-v3-large')
tokenizer.model_max_length = 2048

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [251]:
tokenizer_outs = tokenizer(new_text, return_offsets_mapping=True)

### Entity Boundary Calculation
- with `LRU cache`: text is preserved except some minor changed that doens't need to consider
- with out `LRU cache`: text is changed a lot so need recalculate the boundary again

In [227]:
alphanumeric_re = re.compile('[0-9a-zA-z]')

In [None]:
from functools import lru_cache

j
    """
    """
    ...

In [238]:
def get_entity_boundary(text, text_df):
    """built by sergei chudov"""
    ent_boundaries = []
    pointer = 0
    
    for row in text_df.itertuples():
        entity_text = row.discourse_text

        # regex to find text start with alphanumeric (a-zA-Z0-9)
        entity_text = entity_text[next(alphanumeric_re.finditer(entity_text)).start():]
        
        # if the first character length is 1, then check the previous text chunk
        if len(entity_text.split()[0]) == 1 and pointer != 0:
            entity_start_ix = text[pointer:].index(entity_text)
            prev_text = text[:pointer + entity_start_ix]
            
            # current text is not the beginning and the previous text last char is alphanumeric
            if pointer + entity_start_ix > 0 and prev_text[-1].isalpha():
                cut_word_chunk_size = len(prev_text.split()[-1])
                
                # if the previous text last word length is not 1
                if cut_word_chunk_size > 1:
                    entity_text = entity_text[next(alphanumeric_re.finditer(entity_text[1:])).start() + 1:]

        offset = text[pointer:].index(entity_text)
        starts_at = offset + pointer
        ent_boundaries.append((starts_at, starts_at + len(entity_text), row.discourse_type))
        pointer = starts_at + len(entity_text)
            
    return ent_boundaries

In [263]:
@lru_cache(maxsize=5)
def test(text_id):
    print(text_id)

NameError: name 'lru_cache' is not defined

In [259]:
entity_boundary = get_entity_boundary(new_text, new_text_df)
entity_boundary

[(0, 92, 'Lead'),
 (93, 130, 'Position'),
 (131, 189, 'Claim'),
 (190, 222, 'Claim'),
 (227, 284, 'Claim'),
 (285, 356, 'Claim'),
 (357, 690, 'Evidence'),
 (691, 783, 'Claim'),
 (783, 1287, 'Evidence'),
 (1288, 1445, 'Claim'),
 (1446, 1540, 'Counterclaim'),
 (1541, 1572, 'Rebuttal'),
 (1573, 1902, 'Evidence'),
 (1903, 2117, 'Concluding Statement')]

In [260]:
original_entity_boundary = new_text_df[['discourse_start', 'discourse_end', 'discourse_type']].values
original_entity_boundary

array([[0.0, 92.0, 'Lead'],
       [93.0, 130.0, 'Position'],
       [131.0, 189.0, 'Claim'],
       [190.0, 222.0, 'Claim'],
       [226.0, 284.0, 'Claim'],
       [285.0, 356.0, 'Claim'],
       [357.0, 690.0, 'Evidence'],
       [691.0, 783.0, 'Claim'],
       [783.0, 1287.0, 'Evidence'],
       [1288.0, 1445.0, 'Claim'],
       [1446.0, 1540.0, 'Counterclaim'],
       [1541.0, 1572.0, 'Rebuttal'],
       [1573.0, 1902.0, 'Evidence'],
       [1903.0, 2117.0, 'Concluding Statement']], dtype=object)

### Token Labeling

In [257]:
def token_labeling(tokenizer_outs, ent_boundaries, cat2id):
    """label the tokens"""

    all_boundaries = deque([])
    for ent_boundary in ent_boundaries:
        for position, boundary_type in zip(ent_boundary[:2], ('start', 'end')):
            discourse_type = ent_boundary[-1]
            all_boundaries.append((position, discourse_type, boundary_type))
            
    current_target = 0
    targets = np.zeros(len(tokenizer_outs['input_ids']), dtype='i8')
    token_positions = np.vstack(tokenizer_outs['offset_mapping']).astype('i4')
    
    for token_ix in range(len(tokenizer_outs['input_ids'])):
        token_start_ix, token_end_ix = tokenizer_outs['offset_mapping'][token_ix]
        
        cur_pos, cur_cat_type, cur_bound_type = all_boundaries[0]

        if token_end_ix != 0 \
           and (cur_bound_type == 'end' and token_end_ix >= cur_pos) \
           or (cur_bound_type == 'start' and token_end_ix > cur_pos):
            
            if len(all_boundaries) > 1:
                next_pos, next_dis_type, next_bound_type = all_boundaries[1]
            if cur_bound_type == 'start':
                # token map {'Lead': 1, 'Position': 3, ..., 'Rebuttal': 13}
                current_target = cat2id[cur_cat_type]
                targets[token_ix] = current_target
                
                if token_end_ix == next_pos:
                    current_target = 0
                    all_boundaries.popleft()
                else:
                    current_target += 1
            else:
                # If there is more entity left to consider and current is already on the next pos
                if len(all_boundaries) > 1 and token_end_ix > next_pos:
                    
                    # can this actually happen?
                    if token_start_ix >= next_pos:
                        assert text[cur_pos - 1] == '¨'

                    all_boundaries.popleft()
                    current_target = cat2id[cur_cat_type]
                    targets[token_ix] = current_target
                    current_target += 1
                else:
                    if token_start_ix >= cur_pos:
                        current_target = 0

                    targets[token_ix] = current_target
                    current_target = 0

            all_boundaries.popleft()
            if not all_boundaries:
                break
        else:
            targets[token_ix] = current_target
            
    return targets

In [268]:
tokenizer_outs

{'input_ids': [1, 1258, 355, 428, 272, 262, 11992, 1575, 269, 298, 3092, 260, 273, 418, 280, 297, 428, 306, 280, 368, 1299, 260, 279, 11992, 1575, 269, 379, 10425, 260, 325, 269, 10425, 401, 262, 355, 280, 268, 5125, 520, 282, 360, 6013, 569, 261, 262, 11992, 1575, 269, 16348, 261, 263, 306, 372, 298, 794, 356, 1251, 264, 262, 355, 280, 268, 4713, 260, 1244, 265, 305, 261, 267, 347, 1281, 262, 1123, 1647, 520, 298, 282, 265, 356, 772, 260, 5216, 261, 262, 1123, 1647, 702, 280, 297, 912, 267, 262, 1129, 270, 1574, 278, 74392, 742, 270, 262, 11992, 1575, 260, 471, 337, 273, 849, 264, 1647, 270, 266, 35018, 1574, 304, 262, 11992, 1575, 3110, 322, 5147, 270, 262, 59483, 273, 338, 286, 10487, 312, 326, 446, 264, 1647, 260, 11727, 261, 262, 11992, 1575, 280, 268, 1647, 85399, 268, 262, 1123, 1647, 260, 344, 908, 264, 360, 92364, 262, 1123, 1647, 261, 262, 355, 277, 262, 11992, 1575, 281, 16348, 260, 369, 262, 355, 328, 281, 277, 262, 11992, 1575, 281, 16348, 393, 261, 306, 520, 298, 413, 355

In [258]:
token_labeling(tokenizer_outs, entity_boundary, cat2id)

array([ 0,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
        2,  2,  2,  2,  2,  3,  4,  4,  4,  4,  4,  4,  7,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  7,  8,  8,  8,  8,  8,  0,
        7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  7,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  5,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  5,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  7,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8

In [283]:
token = token_labeling(tokenizer_outs, original_entity_boundary, cat2id)
token

array([ 0,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
        2,  2,  2,  2,  2,  3,  4,  4,  4,  4,  4,  4,  7,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  7,  8,  8,  8,  8,  8,  0,
        7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  7,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  5,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  5,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  7,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        8,  8,  8,  8,  8

### One hot encoding

In [274]:
def make_one_hot(indices, num_labels):
    array = np.zeros((len(indices), num_labels))
    array[np.arange(len(indices)), indices.astype('i4')] = 1
    return array

In [284]:
token = make_one_hot(token, 15)
token.shape

(467, 15)

In [285]:
len(token)

467

In [290]:
label = np.zeros((2048, 15), dtype='f4')
label[:len(token)] = token

In [291]:
label.shape

(2048, 15)

### Building Dataset

In [5]:
class FeedbackDataset(torch.utils.data.Dataset):
    def __init__(
        self, text_ids, csv, all_texts, token_weights
    ):
        self.csv = csv
        self.all_texts = all_texts
        self.text_ids = text_ids
        self.class_names = class_names
        self.token_weights = token_weights
        
        # store original data as desired dictionary format
        initialize_data_dict()
        
        self.space_regex = re.compile("[\s\n]")
        
    def initialize_data_dict(self, preprocess=False):
        """save original data by dictionary with text_id key
        
        preprocess includes
        - strip
        - newline exchange 
        
        {
         ...
         text_id: {'text_list': text_list, 'text_df': text_df},
         text_id: {'text_list': text_list, 'text_df': text_df}
         text_id: {'text_list': text_list, 'text_df': text_df}
         ...
         }

        1. text_list
        2. test_df
        """
        self.original_data = {}
        for text_id in text_ids:
            # load original data
            text = self.all_texts[text_id]
            text_df = self.csv.query('id == @text_id').reset_index(drop=True).copy()
            
            # convert to the dictionary format
            self.original_data[text_id] = {}
            self.original_data[text_id]['text_list'] = text2list(text, text_df)
            self.original_data[text_id]['text_df'] = text_df
        
    def noise_injection(self, text):
        ...

    def preprocess_text(self, text):
        text = text.strip()
        
        # newline is removed from debertav3 tokenizer
        text = text.replace('\n', '‽')
        
        return text
    
    def forward(self, idx):
        text_id = self.text_ids[idx]
        
        

    def __len__(self):
        return len(self.text_ids)
    

In [None]:
    def __getitem__(self, idx):
        i = self.ids[idx]

        # load text data & text dataframe
        text_id = self.val_text_ids[idx]
        text = self.all_texts[text_id]
        sample_df = self.csv.query("id == @text_id")

        # load ground truth prediction string for f1macro metric
        gt_dict = {}
        for class_i in range(1, 8):
            class_name = self.class_names[class_i]
            class_df = sample_df.query("discourse_type == @class_name")
            if len(class_df):
                gt_dict[class_i] = [
                    (x[0], x[1])
                    for x in class_df.predictionstring.map(split_predstring)
                ]

        # load valid data
        tokens = self.data["tokens"][i]
        attention_mask = self.data["attention_masks"][i]
        num_tokens = self.data["num_tokens"][i, 0]
        token_bounds = self.data["token_offsets"][i]
        cbio_labels = self.data["cbio_labels"][i]

        # class weight per token
        class_weight = np.zeros_like(attention_mask)
        argmax_labels = cbio_labels.argmax(-1)

        for class_i in range(1, 15):
            class_weight[argmax_labels == class_i] = self.token_weights[class_i]

        class_none_index = argmax_labels == 0
        class_none_index[num_tokens - 1 :] = False
        class_weight[class_none_index] = self.token_weights[0]
        class_weight[0] = 0

        # ???
        index_map = []
        current_word = 0
        blank = False
        for char_ix in range(text.index(text.strip()[0]), len(text)):
            if self.space_regex.match(text[char_ix]) is not None:
                blank = True
            elif blank:
                current_word += 1
                blank = False
            index_map.append(current_word)

        return (
            tokens,
            attention_mask,
            cbio_labels,
            class_weight,
            token_bounds,
            gt_dict,
            index_map,
            num_tokens,
        )


first_batch = True


def train_collate_fn(ins):
    global first_batch
    if first_batch:
        max_len = 2048
        first_batch = False
    else:
        max_len = (max(x[-1] for x in ins) + 7) // 8 * 8

    return tuple(
        torch.from_numpy(
            np.concatenate([ins[z][x][None, :max_len] for z in range(len(ins))])
        )
        for x in range(len(ins[0]) - 1)
    )


def val_collate_fn(ins):
    max_len = (max(x[-1] for x in ins) + 7) // 8 * 8
    return tuple(
        torch.from_numpy(
            np.concatenate([ins[z][x][None, :max_len] for z in range(len(ins))])
        )
        for x in range(len(ins[0]) - 3)
    ) + (
        [x[-3] for x in ins],
        [x[-2] for x in ins],
        np.array([x[-1] for x in ins]),
    )


def get_dataloader(
    args,
    train_ids,
    val_ids,
    data,
    csv,
    all_texts,
    val_text_ids,
    class_names,
    token_weights,
):
    train_dataset = TrainDataset(
        train_ids, data, args.label_smoothing, token_weights, args.data_prefix
    )
    val_dataset = ValDataset(
        val_ids, data, csv, all_texts, val_text_ids, class_names, token_weights
    )

    train_dataloader = DataLoader(
        train_dataset,
        collate_fn=train_collate_fn,
        batch_size=args.batch_size,
        num_workers=args.num_worker,
        shuffle=True,
    )
    val_dataloader = DataLoader(
        val_dataset,
        collate_fn=val_collate_fn,
        batch_size=args.batch_size,
        num_workers=8,
        persistent_workers=True,
    )

    return train_dataloader, val_dataloader