In [None]:
import torch
import torch.optim as optim
from model.transformer import Transformer
from utils.checkpoint import Checkpoint
from utils.clock import Clock
from utils.evaluator import Evaluator
from utils.search import Beam, Greedy
from utils.train import retrain
from utils.quantize import quantize_model
from utils.functional import load_module, save_model, parameter_count, model_size, graph, parse_config

In [None]:
config = parse_config("experiment/config.txt", verbose=True)
maxlen = config["maxlen"]
dm, dk, dv, nhead, layers, dff = config["dm"], config["dk"], config["dv"], config["nhead"], config["layers"], config["dff"]
bias, dropout, eps, scale  = config["bias"], config["dropout"], config["eps"], config["scale"]
beam_width, alpha, search_eps, fast = config["beam_width"], config["alpha"], config["search_eps"], config["fast"]
goal_bleu, corpus_level = config["goal_bleu"], config["corpus_level"]
warmups, epochs, clip = config["warmups"], config["epochs"], config["clip"]
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}")

In [None]:
tokenizer = load_module(path="experiment/tokenizer.pt", verbose=True)
dataloader = load_module(path="experiment/dataloader.pt", verbose=True)
testloader = load_module(path="experiment/testloader.pt", verbose=True)
sos, eos, pad = tokenizer.getitem("<sos>", module="source"), tokenizer.getitem("<eos>", module="source"), \
                tokenizer.getitem("<pad>", module="source")
source_vocab, target_vocab = tokenizer.vocab_size()
print(f"Number of input tokens: {source_vocab}\nNumber of output tokens: {target_vocab}")

In [None]:
model = Transformer(source_vocab, target_vocab, maxlen, pad_id=pad, dm=dm, dk=dk, dv=dv, nhead=nhead, layers=layers, 
                    dff=dff,bias=bias, dropout=dropout, eps=eps, scale=scale)
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
checkpoint = Checkpoint(model, optimizer, scheduler)
checkpoint.load_checkpoint("experiment/checkpoint.pt", verbose=True, device=device)
beam = Beam(sos, eos, maxlen, width=beam_width, alpha=alpha, eps=search_eps, fast=fast)
greedy = Greedy(sos, eos, maxlen, alpha=alpha, eps=search_eps)
evaluator = Evaluator(testloader, tokenizer, beam, goal_bleu=goal_bleu, corpus_level=corpus_level)
clock = Clock(checkpoint["duration"])
print(f"Number of Trainable Paramaters: {parameter_count(model):.1f}M\nSize of Model: {model_size(model):.1f}MB")

In [None]:
losses, test_losses, bleus = retrain(dataloader, checkpoint, evaluator, clock, epochs=3, warmups=warmups, 
                                     clip=clip, verbose=True, log="experiment/log.txt", device=device)

In [None]:
graph(losses, test_losses, bleus, path="experiment/metrics.jpg")

In [None]:
model_int8 = quantize_model(model, dtype=torch.qint8, inplace=False)
model_float16 = quantize_model(model, dtype=torch.float16, inplace=False)

In [None]:
save_model(model, path="experiment/model.pt", verbose=True)
save_model(model_int8, path="experiment/model_int8.pt", verbose=True)
save_model(model_float16, path="experiment/model_float16.pt", verbose=True)