# conv3 training

In [None]:
import os
import sys
import random
import warnings
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers import normalizers
from tokenizers import pre_tokenizers
from tokenizers.normalizers import NFD, StripAccents
from tokenizers.pre_tokenizers import ByteLevel, Digits
from tokenizers.trainers import BpeTrainer
import torch.nn as nn

from pathlib import Path

import torch as torch
from IPython.display import SVG, display

from torch.utils.data import DataLoader

### Enable Hot Reload

In [None]:
%load_ext autoreload
%autoreload 2

### Edit Python path
Add the `models` directory to Python's `path`.

In [None]:
b_paths = [os.path.abspath(os.path.join('..', '..', '..')), os.path.abspath(os.path.join('..', '..')), os.path.abspath(os.path.join('..', '..', 'scripts'))]
for b_path in b_paths:
    if b_path not in sys.path:
        sys.path.append(b_path)

BASE_DIR = Path(os.getcwd()).parent.parent.parent.resolve()
%cd $BASE_DIR

### Ignore Warnings

In [None]:
warnings.filterwarnings('ignore')

### Import Helpers

In [None]:
from models.scripts.transformer.PreEncoders import Conv2DTransformer
from models.scripts.transformer.utils import strokes_to_svg, preprocess_dataset, seed_all, pad_collate_fn, tensor_to_word
from models.scripts.generate_dataset import WordDatasetGenerator, WordGenerator

### Configuration Settings

In [None]:
VERSION = "conv3"
SEED = 2021
BATCH_SIZE = 256
EXPR_MODE = 'all'

In [None]:
seed_all(SEED) # Reproducibility

### Create Vocabulary and dataset

In [None]:
TOKENIZER_FILE = os.path.join("word_sources","tokenizer-news-normalized.json")

In [None]:
use_cache = True

if use_cache: # Generate from cache file
    VOCAB = Tokenizer.from_file(TOKENIZER_FILE)

    d_gen = WordDatasetGenerator(vocab = VOCAB,
                                 expr_mode=EXPR_MODE,
                                 fname="words_stroke_100_155805")
    train, valid, test = d_gen.generate_from_cache()

else: # Generate from scratch and cache (if regenerated, results could change)
    news_commentary_path = os.path.join(BASE_DIR, "word_sources", "news-commentary-v14.en")
    words = WordGenerator().generate_from_file(news_commentary_path, words_only=False)

    VOCAB = Tokenizer(BPE())
    VOCAB.normalizer = normalizers.Sequence([NFD(), StripAccents()])
    VOCAB.pre_tokenizer = pre_tokenizers.Sequence([Digits(), ByteLevel()])
    # Train it
    trainer = BpeTrainer(
        vocab_size=2000,
        min_frequency=25,
        show_progress=True,
        special_tokens=['<unk>', '<pad>', '<bos>', '<eos>']
    )
    VOCAB.train_from_iterator(words, trainer)
    # Save the tokenizer model
    VOCAB.save(f"{TOKENIZER_FILE}_new")
    BRUSH_SPLIT=0.15
    d_gen = WordDatasetGenerator(vocab = VOCAB,
                                 expr_mode=EXPR_MODE,
                                 words=words[:int(len(words)*(1-BRUSH_SPLIT))],
                                 extended_dataset=False)
    d_gen.generate()
    d_gen.add_training_words(words[int(len(words)*(1-BRUSH_SPLIT)):])
    train, valid, test = d_gen.generate_from_cache()

### Visualize tokenizer

In [None]:
print(VOCAB)

In [None]:
print(sorted(VOCAB.get_vocab()))

In [None]:
print(VOCAB.token_to_id("<bos>"))

In [None]:
N_TOKENS = VOCAB.get_vocab_size() # len(VOCAB)
print(f"Number of Tokens: {N_TOKENS}\n")

### Create Dataset for PyTorch

In [None]:
train_set = DataLoader(preprocess_dataset(train, VOCAB,  os.path.join(d_gen.fname+"-bpe", "train.pt"), total_len=d_gen.get_learning_set_length("train")), batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_collate_fn)
valid_set = DataLoader(preprocess_dataset(valid, VOCAB,  os.path.join(d_gen.fname+"-bpe", "valid.pt"), total_len=d_gen.get_learning_set_length("valid")), batch_size=BATCH_SIZE, shuffle=False, collate_fn=pad_collate_fn)

### Inspect Generated Data

In [None]:
# Get random index
x_dummy, y_dummy = next(iter(valid_set)) # Create dummy for visualization
ind = random.choice(range(y_dummy.shape[0]))
print("Index:", ind)

print()
print("X Shape:", x_dummy[ind].shape)
# Show actual expr for first tensor
print("Y Shape:", y_dummy[ind].shape)
print()
print("Label:", VOCAB.decode(y_dummy[ind].tolist(), False))

# Get length of subplot depending on granularity (exclude bos/eos for strokes)
svg_str = strokes_to_svg(x_dummy[ind], {'height':100, 'width':100}, d_gen.padding_value, VOCAB.token_to_id('<bos>'), VOCAB.token_to_id('<eos>'))
display(SVG(data = svg_str))


print()
print(f'X[{ind}]:', x_dummy[ind])
print()

eos_tensor = torch.zeros(x_dummy[ind].size(-1)) + d_gen.eos_idx


for i, row in enumerate(x_dummy[ind]):
    if torch.all(row.eq(eos_tensor)):
        print("EOS is in position:", i)
        break

### Model Hyper-parameters/Create Transformer Model

In [None]:
model= Conv2DTransformer(VERSION, VOCAB, n_tokens=N_TOKENS, encoder_name='bpe2', decoder_name='bpe2', n_conv_layers=4)
model.save_hyperparameters_to_json()
model.count_parameters()
print(f"Convolution trainable parameters: {sum(p.numel() for p in model.preencoder.parameters() if p.requires_grad):,}.")
print(f"Encoder trainable parameters: {sum(p.numel() for p in model.encoder.parameters() if p.requires_grad):,}.")
print(f"Decoder trainable parameters: {sum(p.numel() for p in model.decoder.parameters() if p.requires_grad):,}.")
print("\n\n\n", model)
model.to(model.device)

### Training conv layers and caching

In [None]:
LEARNING_RATE = 5e-4
criterion = nn.CrossEntropyLoss(ignore_index=VOCAB.token_to_id('<pad>'))
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.5)

In [None]:
model.train_loop(resume=False,
                 train_set=train_set,
                 valid_set=valid_set,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 n_epochs=120)

### Plot Training Logs

In [None]:
model.plot_training()

### Fine tuning the whole Transformer

In [None]:
model.load_best_version()
model.encoder.requires_grad_(True)
model.decoder.requires_grad_(True)
model.name += "_ft"
model.save_hyperparameters_to_json()
model.bm_path = model.bm_path[:-3] + "_ft" + ".pt"
model.log_path = model.log_path[:-4] + "_ft" + ".log"

LEARNING_RATE = 1.5e-4
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=300, gamma=0.5)

In [None]:
model.train_loop(resume=False,
                 train_set=train_set,
                 valid_set=valid_set,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 n_epochs=4000)

### Plot Training  Logs

In [None]:
model.plot_training()