In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os
import numpy as np
from tqdm import tqdm
import llist
from utils import *
from hydra import compose, initialize
from omegaconf import OmegaConf
import hydra
import torch
from tqdm import tqdm
import pickle
from utils import (
    init_wandb, 
    set_deterministic, 
    get_dataloaders, 
    get_dataset, 
    get_device, 
    init_tokenizers,
    init_model
)
from omegaconf import OmegaConf
from accelerate import Accelerator
import hashlib

In [None]:
# os.chdir('../notebooks')
os.chdir('../src') # assuming that `jupyter notebook` is running in `notebooks/` folder
!pwd

In [None]:
plt.rcParams.update({"text.usetex": True})

In [None]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path="../src/conf", job_name="test_app")

In [None]:
cfg = compose(config_name="main")

In [None]:
print(f"Hydra configuration:\n{OmegaConf.to_yaml(cfg)}")
set_deterministic(cfg.seed)
dataset = get_dataset(cfg)
tokenizer_l1, tokenzier_l2 = init_tokenizers(cfg, dataset)

### scheduler

In [None]:
model = init_model(cfg,tokenizer_l1, tokenizer_l2, 'cpu')

In [None]:
optimizer = hydra.utils.instantiate(cfg.optimizer,model.parameters(),lr=0.3)

In [None]:
scheduler = hydra.utils.instantiate(cfg.scheduler,optimizer)

In [None]:
lrs = []
for i in range(10000):
    lrs.extend(scheduler.get_lr())
    optimizer.step()
    scheduler.step()

In [None]:
scheduler.state_dict()['_step_count']

In [None]:
fig, ax = plt.subplots(figsize=(4,3),dpi=200)

ax.plot(np.array(lrs[:10000]))
ax.set_xlabel('step')
ax.set_ylabel('lr')
# fig.savefig('../images/learning_rate.svg', transparent=True, pad_inches=0, bbox_inches='tight')

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def f(x, d_model=512, warm_up_steps=4000):
    return np.power(d_model, -0.5) * np.minimum(np.power(x, -0.5), x * np.power(warm_up_steps, -1.5))
    
xs = np.linspace(1,8000, 8000)
ys = f(xs)

plt.plot(xs,ys)
plt.xlabel("step")
plt.ylabel("lr")

### merging lists of strings benchmark

In [None]:
import time
import itertools
import random
import string

n_iterations = 1000
list_size = 1000  # Each list will have 100 elements
n_lists = 10  # Number of lists to merge
string_size = 100

# Generate the lists
to_merge_1 = [[''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(string_size)) for i in range(list_size)] for _ in range(n_lists)]
to_merge_2 = list(to_merge_1)
to_merge_3 = list(to_merge_1)


# Test 1: Extending the big one as they come
start_1 = time.time()
big_list_1 = []
for lst in to_merge_1:
    big_list_1.extend(lst)
time_1 = time.time() - start_1

# Test 2: Save them to a list of lists and flatten at once with itertools.chain
big_list_2 = []
start_2 = time.time()
for lst in to_merge_2:
    big_list_2 += lst
time_2 = time.time() - start_2

start_3 = time.time()
big_list_3 = list(itertools.chain.from_iterable(to_merge_3))
time_3 = time.time() - start_3

time_1, time_2, time_3

### merging pairs of elements in a list

In [None]:
def f1(el1, el2, new_tok, l=None):
    new_l = []
    i = 0
    while i < len(l) - 1:
        if l[i] == el1 and l[i+1] == el2:
            new_l.append(new_tok)
            i += 2  
        else:
            new_l.append(l[i])
            i += 1
    if i < len(l):  
        new_l.append(l[-1])
    return new_l

def f2(el1, el2, new_tok,ll=None):
    for node in ll.iternodes():
        if node.next == None:
            break
        if node.value == el1 and node.next.value == el2:
            node.value = new_tok
            ll.remove(node.next)
    return ll

def f3(el1, el2, new_tok,l=None):
    new_l = []
    skip = False
    t2 = None
    for t1,t2 in zip(l,l[1:]):
        if t1==el1 and t2==el2:
            new_l.append(new_tok)
            skip = True
        elif skip:
            skip = False
        else:
            new_l.append(t1)
    if not skip and t2 is not None:
        new_l.append(t2)
    return new_l

In [None]:
for i in range(10000):
    l = random.choices(range(1,3),k=100)
    l1 = f1(1,2,3,list(l))
    l2 = f2(1,2,3,llist.dllist(list(l)))
    l3 = f3(1,2,3,list(l))
    # l1 = l
    for i in range(len(l1)):
        if len(l1) != len(l2) or len(l2) != len(l3):
            print('len mismatch')
            print(l)
            print(l1)
            print(l2)
            print(l3)
            break
        if l1[i] != l2[i] or l1[i] != l3[i]:
            print(f"el mismatch at {i}")
            print(l)
            print(l1)
            print(l2)
            print(l3)
            break
        # print(f1(n,1,2,3,l1) == f2(n,1,2,3,l2))

In [None]:
import random
n=100000
%timeit f1(1,2,3,random.choices(range(1,3),k=n))
%timeit f2(1,2,3,llist.dllist(random.choices(range(1,3),k=n)))
%timeit f3(1,2,3,random.choices(range(1,3),k=n))

# Data

In [None]:
lang = 'cs-en'
l1 = lang[:2]
l2 = lang[3:]
l1,l2

In [None]:
from datasets import load_dataset

dataset = load_dataset("wmt14", lang,cache_dir='../src/data')

In [None]:
dataset

In [None]:
lengths = []
for e in dataset['train']['translation']:
    if len(e[l1]) < 10000:
        lengths.append(len(tokenizer_l1.encode(e[l1])))

In [None]:
fig, ax = plt.subplots(figsize=(4,3),dpi=200)

ax.hist(lengths, bins=100)
ax.set_yscale('log')
ax.set_ylabel('counts')
ax.set_xlabel('length (token)')
ax.set_title(f'sentence lengths (lang={l1}, vocab_size={32000})')
fig.savefig('../images/sentence_lengths_hist.svg', transparent=True, pad_inches=0, bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(10,1))
ax.boxplot(lengths, vert=False,widths=1)
ax.spines[:].set_visible(False)
ax.set_yticks([])

In [None]:
arr = np.array(lengths)
# qs = 1 - np.logspace(-4,0,10)
qs = [0.25,0.50,0.75,0.99,0.999,0.9999]
for q in qs:
    print(f"quantile {q:.4f}: {np.quantile(arr,q)}")
arr.mean()

In [None]:
np.searchsorted(np.sort(arr), 64) / len(arr)

## hashing

In [None]:
import pickle
import hashlib


_hash(dataset,1000,0.01,42)

In [None]:
lengths = []
for partition in dataset.values():
    for example in tqdm(partition):
        for l_example in example['translation'].values():
            s = l_example.encode('utf-8')
            lengths.append(len(s))
arr = np.array(lengths)
arr.mean(), np.quantile(arr,0.50), np.quantile(arr,0.90), np.quantile(arr,0.99)

## Tokenization

### character

In [None]:
chars = set()
for i in dataset:
    for r in dataset[i]:
        for t in r['translation']:
            chars.update(r['translation'][t])
chars = sorted(chars)

In [None]:
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
encode = lambda x: [stoi.get(c, stoi["a"]) for c in x]
decode = lambda x: ''.join([itos[c] for c in x])

In [None]:
import pandas as pd

In [None]:
for i in range(5):
    display(pd.DataFrame(itos.values(),itos.keys()).iloc[i*20:(i+1)*20].T)
print("...")

### bpe

In [None]:
example = dataset['train']['translation'][0]
example = example[cfg.data.l1] + "\n" + example[cfg.data.l2]
print(f"{tokenizer_l1.lang} tok:")
print(colorize_tokens(tokenizer_l1.encode(example), tokenizer_l1))
print(f"{tokenizer_l2.lang} tok:")
print(colorize_tokens(tokenizer_l2.encode(example), tokenizer_l2))

In [None]:
from utils import PersistentRandom

In [None]:
pr = PersistentRandom(5)
pr.rand()

In [None]:
def init(dataset, fraction=0.01):
    l = list()
    t = set()
    space = ' '.encode('utf-8')
    for key, partition in dataset.items():
        with tqdm(total=len(partition), desc=f"bpe: consolidating {key} partition") as pbar:
            for example in partition:
                if pr.rand() < fraction:
                    for l_example in example['translation'].values():
                        s = l_example.encode('utf-8')
                        t.update(s)
                        l.extend(list(s)+list(space))
                pbar.update(1)
        return l,t    
l,t = init(dataset)

In [None]:
len(t), len(l)

In [None]:
from collections import Counter
def get_pair_counts(token_list):
    pair_counts = Counter()
    pair_counts.update(zip(token_list, token_list[1:]))
    return pair_counts     

In [None]:
def merge(lst, pair, new_elem):
    elem1, elem2 = pair
    i = 0
    new_list = []
    while i < len(lst) - 1:
        if lst[i] == elem1 and lst[i+1] == elem2:
            new_list.append(new_elem)
            i += 2 
        else:
            new_list.append(lst[i])
            i += 1
    if i < len(lst):
        new_list.append(lst[i])  
    return new_list

In [None]:
max_tokens = 400
new_tok = 256
token_dict = {key: bytes([key]) for key in t}
merges = dict()
for i in tqdm(range(max_tokens - len(t))):
    counts = get_pair_counts(l)
    to_merge = counts.most_common(1)[0][0]
    l = merge(l, to_merge, new_tok)
    token_dict[new_tok] = token_dict[to_merge[0]] + token_dict[to_merge[1]]
    merges[to_merge] = new_tok
    new_tok += 1



In [None]:
len(token_dict), len(l)

In [None]:
token_dict[252] = b'<unk>'
token_dict[253] = b'<s/>'
token_dict[254] = b'<pad>'
token_dict[255] = b'<s>'

In [None]:
token_dict

In [None]:
def encode(x, token_dict,truncation=True, add_special_tokens=True, padding="max_length", max_length=20):
    encoded = list(x.encode('utf-8'))
    while len(encoded) > 1:
        pairs = get_pair_counts(encoded)
        to_merge = min(pairs, key=lambda k: merges.get(k, float('inf')))
        if to_merge not in merges:
            break
        replace = merges[to_merge]
        encoded = merge(encoded, to_merge, replace)
    if truncation and max_length is not None:
        encoded = encoded[: max_length - (2 if add_special_tokens else 0)]
    if add_special_tokens:
        encoded = [255] + encoded + [253]
    if padding == "max_length" and max_length is not None:
        encoded += [254] * (max_length - len(encoded))
    return encoded

def decode(x, vocab):
    special_token_ids = [252, 253, 254, 255]
    if isinstance(x, torch.Tensor):
        x = x.tolist()
    if isinstance(x, list) and (not x or isinstance(x[0], int)):
        x = [x]
    decoded_text = [
        b"".join(vocab.get(t,b'\xef\xbf\xbd') for t in seq if t not in special_token_ids).decode('utf-8', errors='replace')
        for seq in x
    ]
    return decoded_text

print(encode("Hej kámo, čo ti práši?",token_dict))
decode(encode("Hej kámo, čo ti práši?",token_dict)+[127],token_dict)

In [None]:
encode(' ',token_dict,add_special_tokens=False)

In [None]:
'�'.encode('utf-8')

In [None]:
#list of invalid utf-8 bytes
invalid_bytes = []
for i in range(256):
    errors = 0
    try:
        bytes([i]).decode('utf-8')
    except:
        errors += 1
    try:
        bytes([i,128]).decode('utf-8')
    except:
        errors += 1
    try:
        bytes([i,128,128]).decode('utf-8')
    except:
        errors += 1
    try:
        bytes([i,128,128,128]).decode('utf-8')
    except:
        errors += 1
    try:
        bytes([11000000,i]).decode('utf-8')
    except:
        errors += 1
    if errors == 5:
        invalid_bytes.append(i)
np.array(invalid_bytes)

# Model

In [None]:
model = hydra.utils.instantiate(cfg.model, device=device, 
                                pad_token_id=tokenizer.pad_token_id,
                                bos_token_id=tokenizer.bos_token_id,
                                eos_token_id=tokenizer.eos_token_id,)
model.to(device)
print(f"Model:\n{model}")
optimizer = hydra.utils.instantiate(cfg.optimizer, model.parameters())
print(f"Optimizer:\n{optimizer}")
criterion = hydra.utils.instantiate(cfg.criterion)
print(f"Criterion:\n{criterion}")

In [None]:
optimizer.zero_grad()
inputs, targets = next(iter(train_loader))
inputs = inputs.transpose(0,1).to(device)
targets = targets.transpose(0,1).to(device)

outputs = model(inputs, targets[:-1, :])
print(outputs.isnan().sum().item())
print(outputs.shape)
print(targets.shape)
loss = criterion(outputs.reshape(-1, outputs.shape[-1]), targets[1:,:].reshape(-1))
loss.backward()
optimizer.step()
loss.item()

## pretrained

In [None]:
import torch
from transformers import XLMRobertaTokenizer, XLMRobertaModel

In [None]:
tokenizer = XLMRobertaTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
model = XLMRobertaModel.from_pretrained("FacebookAI/xlm-roberta-base")
input_ids = torch.tensor(tokenizer.encode("<mask>")).unsqueeze(0)  # Batch size 1
outputs = model(input_ids)

In [None]:
tokenizer.decode([0,1,1,1,1,200,200,2],skip_special_tokens=True)

In [None]:
tokenizer.encode("<pad>")

In [None]:
train_loader, val_loader, test_loader = get_dataloaders(cfg,tokenizer)

In [None]:
for i in train_loader:
    print(i)
    break

In [None]:
epoch_loss = 0
optimizer = hydra.utils.instantiate(cfg.optimizer,params=model.parameters())
criterion = 
model.train()
for batch in train_loader:
    optimizer.zero_grad()
    input, output = batch
    predictions = model(input)

    loss = criterion(predictions, batch.label)
    # if regularizer is not None:
    #     loss += regularizer(model)
    # loss.backward()
    # if grad_clip_threshold is not None:
    #     torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_threshold)
    optimizer.step()
    epoch_loss += loss.item()