# no_unk15 training
Conditionable decoder with target language, no UNK language

In [None]:
import os
import sys
import random
import warnings
from tokenizers import Tokenizer

from pathlib import Path

import torch as torch
from IPython.display import SVG, display

from torch.utils.data import DataLoader

### Enable Hot Reload

In [None]:
%load_ext autoreload
%autoreload 2

### Edit Python path
Add the `models` directory to Python's `path`.

In [None]:
b_paths = [os.path.abspath(os.path.join('..', '..', '..')), os.path.abspath(os.path.join('..', '..')), os.path.abspath(os.path.join('..', '..', 'scripts'))]
for b_path in b_paths:
    if b_path not in sys.path:
        sys.path.append(b_path)

BASE_DIR = Path(os.getcwd()).parent.parent.parent.resolve()
%cd $BASE_DIR

### Ignore Warnings

In [None]:
warnings.filterwarnings('ignore')

### Import Helpers

In [None]:
from models.scripts.transformer.MultiLang import ConditionableTransformer
from models.scripts.transformer.utils import strokes_to_svg, seed_all, preprocess_with_lang
from models.scripts.generate_dataset import WordDatasetGenerator, WordGenerator
from models.scripts.defaults import Languages

### Configuration Settings

In [None]:
VERSION = "no_unk15"
SEED = 2021
BATCH_SIZE = 256
EXPR_MODE = 'all'

In [None]:
seed_all(SEED) # Reproducibility

### Create Vocabulary and dataset

In [None]:
TOKENIZER_FILE = os.path.join("word_sources","tokenizer-big_multi-normalized.json")

In [None]:
use_cache = True
VOCAB = Tokenizer.from_file(TOKENIZER_FILE)
VOCAB.add_special_tokens([f'<bos_{lang.name.lower()}>' for lang in Languages])

### Visualize tokenizer

In [None]:
print(VOCAB)

In [None]:
print(sorted(VOCAB.get_vocab()))

In [None]:
print(VOCAB.token_to_id("<pad>"))
print(VOCAB.token_to_id("<bos_en>"))
print(VOCAB.token_to_id("<bos_de>"))
print(VOCAB.token_to_id("<bos_fr>"))
print(VOCAB.token_to_id("<bos_it>"))
print(VOCAB.token_to_id("<bos_unk>"))

In [None]:
N_TOKENS = VOCAB.get_vocab_size() # len(VOCAB)
print(f"Number of Tokens: {N_TOKENS}\n")

In [None]:
train_sets=[]
valid_sets=[]
test_sets=[]

if use_cache: # Generate from cache file
    for lang in Languages:
        if lang == Languages.UNK:
            continue
        d_gen = WordDatasetGenerator(vocab = VOCAB,
                                     expr_mode=EXPR_MODE,
                                     fname=f"words_stroke_{lang.name.lower()}_full")
        train, valid, test = d_gen.generate_from_cache()
        train_sets.append(train)
        valid_sets.append(valid)
        test_sets.append(test)

else: # Generate from scratch and cache (if regenerated, results could change)
    for lang in Languages:
        if lang == Languages.UNK:
            continue
        lower_name = lang.name.lower()
        news_commentary_path = os.path.join(BASE_DIR, "word_sources", f"news-commentary-v14.{lower_name}")
        words = WordGenerator().generate_from_file(news_commentary_path, words_only=False)

        BRUSH_SPLIT=0.15
        d_gen = WordDatasetGenerator(vocab = VOCAB,
                                     expr_mode=EXPR_MODE,
                                     words=words[:int(len(words)*(1-BRUSH_SPLIT))],
                                     extended_dataset=False,
                                     fname=f"words_stroke_{lower_name}_full")
        d_gen.generate()
        d_gen.add_training_words(words[int(len(words)*(1-BRUSH_SPLIT)):])
        train, valid, test = d_gen.generate_from_cache()
        train_sets.append(train)
        valid_sets.append(valid)
        test_sets.append(test)

assert len(train_sets) == len(valid_sets) == len(test_sets) == len(Languages)-1

### Create Dataset for PyTorch

In [None]:
preprocessed_trains = []
preprocessed_valids = []
for i, lang in enumerate([l for l in Languages if l != Languages.UNK]):
    lower_name = lang.name.lower()
    d_gen = WordDatasetGenerator(vocab = VOCAB,
                                     expr_mode=EXPR_MODE,
                                     fname=f"words_stroke_{lower_name}_full")
    preprocessed_trains += preprocess_with_lang(train_sets[i], VOCAB,  os.path.join(d_gen.fname+"_lang", "train.pt"), total_len=d_gen.get_learning_set_length("train"), bos=VOCAB.token_to_id(f'<bos_{lower_name}>'))
    preprocessed_valids += preprocess_with_lang(valid_sets[i], VOCAB,  os.path.join(d_gen.fname+"_lang", "valid.pt"), total_len=d_gen.get_learning_set_length("valid"), bos=VOCAB.token_to_id(f'<bos_{lower_name}>'))

train_set = DataLoader(preprocessed_trains, batch_size=BATCH_SIZE, shuffle=True)
valid_set = DataLoader(preprocessed_valids, batch_size=BATCH_SIZE, shuffle=False)

### Inspect Generated Data

In [None]:
# Get random index
x_dummy, y_dummy = next(iter(valid_set)) # Create dummy for visualization
print(x_dummy.shape)
ind = random.choice(range(y_dummy.shape[0]))
print("Index:", ind)

print()
print("X Shape:", x_dummy[ind].shape)
# Show actual expr for first tensor
print("Y Shape:", y_dummy[ind].shape)
print()
print("Label:", VOCAB.decode(y_dummy[ind].tolist(), False))
print("Readable Label:", VOCAB.decode(y_dummy[ind].tolist(), False).replace(" ","").replace("Ġ", " ").rstrip("<pad>"))

# Get length of subplot depending on granularity (exclude bos/eos for strokes)
svg_str = strokes_to_svg(x_dummy[ind], {'height':100, 'width':100}, d_gen.padding_value, VOCAB.token_to_id('<bos>'), VOCAB.token_to_id('<eos>'))
display(SVG(data = svg_str))


print()
print(f'X[{ind}]:', x_dummy[ind])
print()

eos_tensor = torch.zeros(x_dummy[ind].size(-1)) + d_gen.eos_idx


for i, row in enumerate(x_dummy[ind]):
    if torch.all(row.eq(eos_tensor)):
        print("EOS is in position:", i)
        break

### Model Hyper-parameters/Create Transformer Model

In [None]:
model = ConditionableTransformer(VERSION, VOCAB, conv_layer_name='en-en11', encoder_name='en-en11', n_tokens=N_TOKENS, hid_dim=256, dec_heads=8, dec_layers=6, dec_pf_dim=256*3)
model.save_hyperparameters_to_json()
model.count_parameters()
model.requires_grad_(False)
model.decoder.requires_grad_(True)
model.enc_to_dec_proj.requires_grad_(True)
print(f"Convolution trainable parameters: {sum(p.numel() for p in model.preencoder.parameters() if p.requires_grad):,}.")
print(f"Encoder trainable parameters: {sum(p.numel() for p in model.encoder.parameters() if p.requires_grad):,}.")
print(f"Decoder trainable parameters: {sum(p.numel() for p in model.decoder.parameters() if p.requires_grad):,}.")
print("\n\n\n", model)
model.to(model.device)

### Train with the 4 languages

In [None]:
LEARNING_RATE = 7e-4
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=300, gamma=0.5)

In [None]:
model.train_loop(resume=False,
                 train_set=train_set,
                 valid_set=valid_set,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 n_epochs=4000)

### Plot Training  Logs

In [None]:
model.plot_training()