In [508]:
import sys
sys.path.append('..')

from pathlib import Path
import re
import json
from prompter.utilities import flatten, get_chunks, pad_sequences, dump_object, load_object, dump_json
from sklearn.model_selection import train_test_split
import numpy as np

from tqdm import tqdm_notebook

from transformers import DistilBertTokenizer, DistilBertForTokenClassification, AdamW, BertTokenizer, BertForTokenClassification


import torch
from torch.utils.data import DataLoader, Dataset

from kaggle_google_qa_labeling.LengthSortSampler import LengthSortSampler
from kaggle_google_qa_labeling.dataset.cross_dataset_utilities import pad_sequences

from ipdb import set_trace


class TokenClassificationDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y

    def __getitem__(self, ind):
        if self.Y is not None:
            return self.X[ind], self.Y[ind]
        else:
            return (self.X[ind], )

    def __len__(self):
        return len(self.X)

    @staticmethod
    def get_collate_fn(max_len, pad_id):
        def collate_fn(data):
            d = list(zip(*data))
            X = d[0]
            Y = d[1] if len(d) == 2 else None

            seq_len = min(max_len, max([len(x) for x in X]))

            X = torch.LongTensor(pad_sequences(X, seq_len, 'post', pad_id))
            
            if Y is not None:
                Y = torch.LongTensor(pad_sequences(Y, seq_len, 'post', pad_id))
                res = (X, Y)
            else:
                res = (X,)

            return res

        return collate_fn

    def get_data_loader(self, bs, max_len, pad_id, drop_last, use_length_sampler=True):

        if use_length_sampler:
            sampler_ = LengthSortSampler(self.X, bs=bs)
        else:
            sampler_ = None

        dl = DataLoader(self, batch_size=bs, collate_fn=self.get_collate_fn(max_len, pad_id),
                        sampler=sampler_, drop_last=drop_last)
        return dl

In [507]:
FILE_PATH = '../data/codereview.stackexchange.com/0.jsonl'
BERT_NAME = 'bert-base-cased'
OUT_DIR = Path(f"../data/ner/code/{BERT_NAME.replace('-', '_')}")
X_CACHE_FILE = '../data/codereview.stackexchange.com/X.pkl'
Y_CACHE_FILE = '../data/codereview.stackexchange.com/Y.pkl'
TOKENIZER_CLS = BertTokenizer#DistilBertTokenizer
MODEL_CLS = BertForTokenClassification#DistilBertForTokenClassification
CODE_TOKEN = '[CODE]'
additional_special_tokens = ['\n', CODE_TOKEN]
NULL_LABEL = 'null'
MAX_LEN = 512
MIN_LEN = 5

In [238]:
tokenizer = TOKENIZER_CLS.from_pretrained(BERT_NAME)
tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens})

I0117 13:55:30.214243 140260791818048 tokenization_utils.py:398] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /home/a.karnachev/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
I0117 13:55:30.244993 140260791818048 tokenization_utils.py:552] Adding 
 to the vocabulary
I0117 13:55:30.246401 140260791818048 tokenization_utils.py:552] Adding [CODE] to the vocabulary
I0117 13:55:30.247126 140260791818048 tokenization_utils.py:629] Assigning ['\n', '[CODE]'] to the additional_special_tokens key of the tokenizer


2

In [26]:
texts = []
labels = []

with open(FILE_PATH) as f:
    for line in f:
        d = json.loads(line)
        texts.append(d['text'])
        labels.append(d.get('labels', []))
        
id2label = sorted(set(flatten([[x[2] for x in labels_] for labels_ in labels])))
id2label.insert(0, NULL_LABEL)
all_texts = texts

In [5]:
try:
    X = load_object(X_CACHE_FILE)
    Y = load_object(Y_CACHE_FILE)
except:
    X = []
    Y = []
    pb = tqdm_notebook(zip(texts, labels), total=len(texts))
    for text, labels_ in pb:
        all_labels = []

        for label in labels_:
            if len(all_labels) == 0 and label[0] > 0:
                all_labels.extend([[0, label[0], NULL_LABEL], label])
            elif len(all_labels) == 0 and label[0] == 0:
                all_labels.append(label)
            elif all_labels[-1][1] == label[0]:
                all_labels.append(label)
            elif all_labels[-1][1] < label[0]:
                all_labels.extend([[all_labels[-1][1], label[0], NULL_LABEL], label])
            else:
                raise ValueError('Unhandled case')

        if len(labels_) and labels_[-1][1] != len(text):
            all_labels.append([labels_[-1][1], len(text), NULL_LABEL])


        x = []
        y = []

        for label in all_labels:
            text_chunk = text[label[0]:label[1]]
            x_ = tokenizer.encode(text_chunk, add_special_tokens=False)
            y_ = [id2label.index(label[2])] * len(x_)

            x.extend(x_)
            y.extend(y_)

        for x_, y_ in zip(get_chunks(x, MAX_LEN - 2), get_chunks(y, MAX_LEN - 2)):
            x_ = tokenizer.build_inputs_with_special_tokens(x_)
            y_.insert(0, id2label.index(NULL_LABEL))
            y_.append(id2label.index(NULL_LABEL))
            assert len(x_) == len(y_)

            if MAX_LEN >= len(x_) >= MIN_LEN:
                X.append(x_)
                Y.append(y_)

    dump_object(X, X_CACHE_FILE)
    dump_object(Y, Y_CACHE_FILE)

In [6]:
X_train, X_valid, Y_train, Y_valid = train_test_split(X, Y, test_size=0.025)

In [7]:
DEVICE = 'cuda'
LR = 2.5e-5
BS = 8
N_EPOCHS = 1
EVAL_EACH = 2000

ds_train = TokenClassificationDataset(X_train, Y_train)
dl_train = ds_train.get_data_loader(BS, MAX_LEN, tokenizer.pad_token_id, drop_last=False, use_length_sampler=True)
ds_valid = TokenClassificationDataset(X_valid, Y_valid)
dl_valid = ds_valid.get_data_loader(BS, MAX_LEN, tokenizer.pad_token_id, drop_last=False, use_length_sampler=True)

model = MODEL_CLS.from_pretrained(BERT_NAME, num_labels=2).to(DEVICE)
model.resize_token_embeddings(len(additional_special_tokens) + model.config.vocab_size)
optimizer = AdamW(model.parameters(), lr=LR)

I0116 18:52:18.100585 140260791818048 configuration_utils.py:185] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /home/a.karnachev/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.d7a3af18ce3a2ab7c0f48f04dc8daff45ed9a3ed333b9e9a79d012a0dedf87a6
I0116 18:52:18.101765 140260791818048 configuration_utils.py:199] Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": tru

In [8]:
train_loss = np.inf
valid_loss = np.inf

pb_train = tqdm_notebook(range(N_EPOCHS))
global_step = 0

for i_epoch in pb_train:
    
    pb_epoch = tqdm_notebook(dl_train, total=len(dl_train), leave=False)

    for X, Y in pb_epoch:
        global_step += 1
        X, Y = X.to(DEVICE), Y.to(DEVICE)
        model.train()
        
        loss, *_ = model(X, labels=Y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss = loss.item()

        if global_step % EVAL_EACH == 0:
            with torch.no_grad():
                pb_valid = tqdm_notebook(dl_valid, total=len(dl_valid), leave=False)
                valid_loss = 0
                for X, Y in pb_valid:
                    X, Y = X.to(DEVICE), Y.to(DEVICE)
                    model.eval()

                    loss, *_ = model(X, labels=Y)
                    valid_loss += loss.item() / len(dl_valid)
                    
                print(valid_loss)
                
        postfix = {
            'Loss/Train': train_loss,
            'Loss/Valid': valid_loss
        }
        
        pb_train.set_postfix(**postfix)

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=48195), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.04420299648771983


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.03590635688558015


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.033086047384037025


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.036807890264530276


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.029497885725769495


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.02985235114036917


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.027491709922784963


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.02793760213078471


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.036582024700603934


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.026125104571878208


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.025008777470397998


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.024771829709598302


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.02446234573752673


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.024639467752386894


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.023726099196339764


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.022806220411981003


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.028000366281590208


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.021351472664975085


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.023555013762551938


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.021270747564921977


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.022244172803813668


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.020891827352462342


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.02023444999708757


HBox(children=(IntProgress(value=0, max=1236), HTML(value='')))

0.02141906311688885



In [509]:
CONFIG = {
    'bert_name': BERT_NAME,
    'tokenizer_cls': TOKENIZER_CLS.__name__,
    'model_cls': MODEL_CLS.__name__,
    'token': CODE_TOKEN,
    'max_len': MAX_LEN,
    'min_len': MIN_LEN
}

OUT_DIR.mkdir(exist_ok=True, parents=True)
torch.save(model, OUT_DIR / 'model.pth')
dump_json(CONFIG, OUT_DIR / 'description.json')
tokenizer.save_pretrained(OUT_DIR)

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


('../data/ner/code/bert_base_cased/vocab.txt',
 '../data/ner/code/bert_base_cased/special_tokens_map.json',
 '../data/ner/code/bert_base_cased/added_tokens.json')

In [510]:
# torch.save(model, './model_bert_base_cased.pth')
model = torch.load(OUT_DIR / 'model.pth')

In [511]:
texts = all_texts[:1000]
slices = []
X = []
for i, text in enumerate(texts):
    X_ = tokenizer.encode(text, add_special_tokens=False)
    X_ = [tokenizer.build_inputs_with_special_tokens(x) for x in get_chunks(X_, MAX_LEN - 2)]
    
    if len(slices) == 0:
        slices.append(slice(0, len(X_)))
    else:
        slices.append(slice(slices[-1].stop, slices[-1].stop + len(X_)))
        
    X.extend(X_)
    
ds = TokenClassificationDataset(X, None)
dl = ds.get_data_loader(64, MAX_LEN, tokenizer.pad_token_id, drop_last=False, use_length_sampler=True)

Y = []
X = []
pb = tqdm_notebook(dl, total=len(dl))
backsort_inds = np.argsort(dl.sampler.inds)
with torch.no_grad():
    for X_ in pb:
        X_ = X_[0].to(DEVICE)
        Y_ = model(X_)[0].detach().cpu().numpy()
        X_ = X_.detach().cpu().numpy()
        Y.extend(Y_)
        X.extend(X_)
    
Y = [Y[i] for i in backsort_inds]
X = [X[i] for i in backsort_inds]

Y_merged = []
X_merged = []
for slice_ in slices:
    Y_ = np.vstack(Y[slice_])
    X_ = np.hstack(X[slice_])
    Y_merged.append(Y_)
    X_merged.append(X_)

W0118 16:28:45.702174 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1000 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:45.737090 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (2325 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:45.805295 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (7181 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:45.819558 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1407 > 512). Running this sequence through the model will result in indexing errors
W011

W0118 16:28:46.405602 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (705 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:46.428546 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (542 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:46.438917 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1175 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:46.445019 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (621 > 512). Running this sequence through the model will result in indexing errors
W0118 1

W0118 16:28:46.962106 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1173 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:46.980705 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (907 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:47.006146 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1345 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:47.011683 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (639 > 512). Running this sequence through the model will result in indexing errors
W0118 

W0118 16:28:47.516906 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (960 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:47.522979 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (536 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:47.537938 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1211 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:47.549564 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1119 > 512). Running this sequence through the model will result in indexing errors
W0118 

W0118 16:28:48.024177 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (993 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:48.031269 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (635 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:48.051512 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (761 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:48.060549 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (888 > 512). Running this sequence through the model will result in indexing errors
W0118 16

W0118 16:28:48.509440 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (944 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:48.533666 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1537 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:48.548558 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1487 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:48.557250 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (687 > 512). Running this sequence through the model will result in indexing errors
W0118 

W0118 16:28:49.025918 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1240 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:49.031396 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (541 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:49.053860 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (2306 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:49.066177 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (935 > 512). Running this sequence through the model will result in indexing errors
W0118 

W0118 16:28:49.585990 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (769 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:49.593049 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (673 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:49.611292 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (673 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:49.628307 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1540 > 512). Running this sequence through the model will result in indexing errors
W0118 1

W0118 16:28:50.096844 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (617 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:50.108105 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (949 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:50.142491 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (3098 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:50.151402 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (817 > 512). Running this sequence through the model will result in indexing errors
W0118 1

W0118 16:28:50.760030 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (652 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:50.769803 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (990 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:50.796392 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (3144 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:50.805233 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (833 > 512). Running this sequence through the model will result in indexing errors
W0118 1

W0118 16:28:51.286988 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (704 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:51.296529 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (862 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:51.307572 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (607 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:51.330126 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (970 > 512). Running this sequence through the model will result in indexing errors
W0118 16

W0118 16:28:51.877121 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (786 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:51.885451 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (835 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:51.894610 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (1041 > 512). Running this sequence through the model will result in indexing errors
W0118 16:28:51.929615 140260791818048 tokenization_utils.py:1084] Token indices sequence length is longer than the specified maximum sequence length for this model (3041 > 512). Running this sequence through the model will result in indexing errors
W0118 

HBox(children=(IntProgress(value=0, max=28), HTML(value='')))




In [523]:
all_decoded = []

for x, y in zip(X_merged, Y_merged):
    decoded_ = []
    slices = []
    cur_span = None
    diffs = np.diff(y.argmax(1))
    y = y.argmax(1)
    y[0] = 0
    y[-1] = 0
    
    diffs = np.diff(y)
    s = list(np.where(diffs == 1)[0] + 1)
    e = list(np.where(diffs == -1)[0] + 1)
    if len(s) == len(e):
        pass
    elif (len(s) - len(e)) == 1:
        e.append(len(diffs))
    else:
        raise ValueError('Unhandled case')

    slices = [slice(s_, e_) for s_, e_ in zip(s, e)]


    for slice_ in slices:
        decoded_.append(tokenizer.decode(x[slice_], clean_up_tokenization_spaces=False))
        
    all_decoded.append(decoded_)

final_texts = []
    
for decoded, text in zip(all_decoded, texts):
    if len(decoded) == 0:
        final_texts.append(text)
        continue
        
    spaceless_text = text.replace(' ', '')
    space_inds = []
    rep_mask = np.zeros(len(text))

    for i, m in enumerate(re.finditer('[^ ]+', text)):
        space_inds.extend(list(range(*m.span())))

    for d in decoded:
        d = d.replace(' ', '')
        for m in re.finditer(re.escape(d), spaceless_text):
            orig_slice = slice(space_inds[m.span()[0]], space_inds[m.span()[1] - 1] + 1)
            rep_mask[orig_slice] = 1

    rep_inds = np.where(rep_mask)[0]
    
    if sum(rep_inds) == 0:
        final_texts.append(text)
        continue
    
    slice_bounds = [0] + list(np.where(np.diff(rep_inds) != 1)[0] + 1) + [len(rep_inds)]
    final_slices = [slice(rep_inds[slice_bounds[i]], rep_inds[slice_bounds[i+1]-1] + 1) for i in range(len(slice_bounds)-1)]
    final_slices = sorted(final_slices, key=lambda x: x.start)

    final_text = text
    token = CODE_TOKEN
    shift = 0
    for i, s in enumerate(final_slices):
        final_text = final_text[:s.start + shift] + token + final_text[s.stop + shift:]
        shift += len(token) - (s.stop - s.start)

    pat = f'( *{token} *)+'
    pat = re.sub('\[', '\[', pat)
    pat = re.sub('\]', '\]', pat)
    re.sub(pat, f' {token} ', final_text).strip()
    final_texts.append(final_text)