# Training and testing Scaleformer

In [None]:
import random
import warnings
import torch

try:
    from scaleformer import BytePairEncoder
    from scaleformer import Transformer
    from scaleformer import strings_to_tensor
    from scaleformer import train_loop
    from scaleformer import plot_loss
except ImportError:
    import sys
    sys.path.insert(0, '..')
    from scaleformer import BytePairEncoder
    from scaleformer import Transformer
    from scaleformer import strings_to_tensor
    from scaleformer import train_loop
    from scaleformer import plot_loss

random.seed(42)
torch.manual_seed(42)
%matplotlib inline                                                                              

In [None]:
if torch.cuda.is_available():
    !nvidia-smi

## Loading data

In [None]:
with open("../data/sentence_pairs.txt", encoding="utf-8") as fp:
    data = fp.read().split("\n")
    en, fr = zip(*[d.lower().split("\t") for d in data if len(d) > 0])

## Training input's tokenizer

In [None]:
try:
    tokenizer_in = BytePairEncoder.load("tokenizer/tokenizer_in.json")
except:
    tokenizer_in = BytePairEncoder()
    subwords_en = tokenizer_in.train(en, min_frequency=1.0e-07,
                                     max_tokens=5000, prune=True)
    tokenizer_in.save("tokenizer/tokenizer_in.json", overwrite=True)

## Training target's tokenizer

In [None]:
try:
    tokenizer_out = BytePairEncoder.load("tokenizer/tokenizer_out.json")
except:
    tokenizer_out = BytePairEncoder()
    subwords_fr = tokenizer_out.train(fr, min_frequency=1.0e-07,
                                      max_tokens=5000, prune=True)
    tokenizer_out.save("tokenizer/tokenizer_out.json", overwrite=True)

## Converting dataset to tensors

In [None]:
try:
    x_train = torch.load("models/x_train.pty")
    y_train = torch.load("models/y_train.pty")
    x_val = torch.load("models/x_val.pty")
    y_val = torch.load("models/y_val.pty")
except:
    x = strings_to_tensor(en, tokenizer_in)
    y = strings_to_tensor(fr, tokenizer_out)

    indexes = list(range(len(x)))
    random.shuffle(indexes)

    lim = int(round(0.8 * len(x)))
    i_train, i_val = indexes[:lim], indexes[lim:]
    x_train, y_train = x[i_train], y[i_train]
    x_val, y_val = x[i_val], y[i_val]

    torch.save(x_train, "models/x_train.pty")
    torch.save(y_train, "models/y_train.pty")
    torch.save(x_val, "models/x_val.pty")
    torch.save(y_val, "models/y_val.pty")

## Training the model

In [None]:
data_train = (x_train, y_train)
data_valid = (x_val, y_val)

torch.cuda.empty_cache()
model = Transformer(tokenizer_in, tokenizer_out, n_stages=6,
                    projection_dim=64, n_heads=4, dropout=0.0,
                    scalable=True)
model.to("cuda:0")

optimizer = torch.optim.Adam(model.parameters(), lr=1.0e-03)


with warnings.catch_warnings(record=True) as w:
    warnings.simplefilter("always")
    rets = train_loop(model, optimizer, data_train, data_valid,
                    n_epochs=1000, patience=100, batch_size=10)
    train_losses, val_losses, best_epoch = rets

torch.save(model, "models/model.pty")
torch.save(optimizer, "models/optimizer.pty")

## Display results

In [None]:
plot_loss(train_losses, val_losses, best_epoch)

## Use in production

In [None]:
new_model = torch.load("models/model.pty").to("cpu")
new_model.predict("Tom is gone")