# import libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import os
import gc
from typing import Tuple, Dict
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import datasets
import json
import shutil
from collections import Counter
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from tqdm import tqdm
import torchsummary
import Levenshtein
import numpy as np
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = 'cpu'
print(DEVICE)

# Initialize Tokenizer

In [None]:
import re
from collections import Counter

class UpgradeTokenizer:
    def __init__(self, max_vocab_size, punctuations=['.', ',', '!', '?', ':', ';']):
        self.vocab = {'[PAD]': 0, '[UNK]': 1, '[CLS]': 2, '[SEP]': 3, '[MASK]': 4}
        self.mask_token = '[MASK]'
        self.max_vocab_size = max_vocab_size
        self.punctuations = punctuations

    def custom_tokenize(self, text):
        # Escape punctuations for regular expression
        escaped_punctuations = [re.escape(p) for p in self.punctuations]
        # Pattern for words (including non-separated punctuations) or specified punctuations
        pattern = r"[^\s" + ''.join(escaped_punctuations) + r"]+|" + '|'.join(escaped_punctuations)
        
        tokens = re.findall(pattern, text)
        return tokens
    
    def build_vocab(self, corpus):
        # Tokenize and count word frequencies
        word_counts = Counter(word for sentence in corpus for word in self.custom_tokenize(sentence))
        
        # Select the most common words up to max_vocab_size
        for word, _ in word_counts.most_common(self.max_vocab_size - len(self.vocab)):
            self.vocab[word] = len(self.vocab)
    
    def tokenize(self, text):
        return [self.vocab.get(word, self.vocab['[UNK]']) for word in self.custom_tokenize(text)]

    def convert_tokens_to_string(self, tokens):
        words = [list(self.vocab.keys())[list(self.vocab.values()).index(token)] for token in tokens]
        # sentence = ''
        # for word in words:
        #     if word in self.punctuations:
        #         sentence += word  # Add punctuation without space
        #     else:
        #         if sentence and not sentence.endswith(' '):
        #             sentence += ' '  # Add space before word if it's not the start of the sentence
        #         sentence += word
        sentence = " ".join(words)
        return sentence

In [None]:
tokenizer = UpgradeTokenizer(max_vocab_size=40000)

In [None]:
vocab_file = '/home/luqiao/project/data/vocab40000-update.json'
with open(vocab_file, 'r') as f:
    VOCAB = json.load(f)

tokenizer.vocab = VOCAB

In [None]:
for key, value in enumerate(tokenizer.vocab):
    print(key, value)
    if key ==100:
        break


# Initialize Dataset

In [None]:
class OpenWebTextDataset(Dataset):
    def __init__(self, tokenizer, dataset, partition, num_entries):
        self.dataset = dataset

        self.tokenizer = tokenizer

        self.data = self.dataset[partition][:num_entries]["text"]

        self.max_seq_len = 128


    def __len__(self):
        return len(self.data)

    def mask_tokens(self, tokenized_lines, mask_probability=0.15):
        inputs, labels = [], []
        for token in tokenized_lines:
            if random.random() < mask_probability and token != 1:
                inputs.append(4)
                labels.append(token) 
            else:
                inputs.append(token)
                labels.append(0) 
        return inputs, labels

    def __getitem__(self, idx):
        text = self.data[idx]
        # print(idx, "idx")
        # print(text, "text check")
        encoded_text = self.tokenizer.tokenize(text)
        # print(encoded_text, "encode check")
        encoded_text.insert(0,2)
        encoded_text.append(3)

        inputs, labels = self.mask_tokens(encoded_text)

        if len(inputs) > self.max_seq_len:
            inputs = inputs[:self.max_seq_len]
            labels = labels[:self.max_seq_len]

        return torch.tensor(inputs, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
    
    def collate_fn(self,batch, max_seq_len = 128):

        batch_x, batch_y, lengths_x, lengths_y = [], [], [], []

        for x, y in batch:
            batch_x.append(x)
            batch_y.append(y)
            lengths_x.append(len(x))
            lengths_y.append(len(y))

        batch_x_pad = pad_sequence(batch_x, batch_first=True)# TODO
        batch_y_pad = pad_sequence(batch_y, batch_first=True)# TODO


        return batch_x_pad, batch_y_pad, torch.tensor(lengths_x), torch.tensor(lengths_y)

In [None]:
OpenwebDataset = datasets.load_from_disk("/home/luqiao/project/data/filtered_dataset")

In [None]:
OpenwebDataset

In [None]:
# Loding the training dataset.
dataset     = np.load('/home/luqiao/project/data/tokenized_dataset_train.npy', allow_pickle=True)

print(dataset[:3])

In [None]:
val_dataset     = np.load('/home/luqiao/project/data/tokenized_dataset_val.npy', allow_pickle=True)
print(val_dataset[:3])


In [None]:
config = dict (
    batch_size          = 96,
    epochs              = 30,
    learning_rate       = 3e-4,
    weight_decay        = 5e-3,
    tf_ratio                = 1.0,
    patience                = 1,
)

with open('./config.json', 'w') as file:
    json.dump(config, file, indent=4) 

# Initialize Dataloader

In [None]:
class DataLoaderForLanguageModeling(torch.utils.data.DataLoader): # Inherit from torch.utils.data.DataLoader

    def __init__(self, dataset, batch_size, num_workers, seq_len = 128, shuffle= True, drop_last= False): 
        super(DataLoaderForLanguageModeling, self).__init__(
            dataset,
            batch_size=batch_size,
            # shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last

        )
        self.shuffle    = shuffle
        # self.drop_last  = drop_last
        self.seq_len = seq_len
        self.l = len(np.concatenate(dataset))
        self.num_batches = self.__len__()
        # self.num_workers = num_workers

    def __len__(self):
        if self.drop_last:
            return self.l//(self.batch_size*self.seq_len)
        else:
            return self.l//(self.batch_size*self.seq_len)+1

    def __iter__(self):
        if self.shuffle:
            # TODO
            np.random.shuffle(self.dataset)
        all = np.concatenate(self.dataset)
        # total_seq = (len(all)-1)// self.seq_len
        padding_size = -len(all) % self.batch_size
        padded_data = np.pad(all, (0, padding_size), mode='constant')

        reshaped = padded_data.reshape(self.batch_size, -1)
        targets = np.roll(reshaped, -1, axis=1)

        leftover = len(all) % self.seq_len

        batch_idx = 0
        while batch_idx < self.num_batches:
            start_idx = batch_idx * self.seq_len
            end_idx = start_idx + self.seq_len
            if batch_idx == self.num_batches - 1 and not self.drop_last:
                end_idx = start_idx + leftover

            batch_idx +=1

            input = torch.tensor(reshaped[:, start_idx:end_idx], dtype=torch.long)
            target = torch.tensor(targets[:, start_idx:end_idx], dtype= torch.long)

            yield input, target

In [None]:
# Some sanity checks

dl = DataLoaderForLanguageModeling(
    dataset     = dataset, 
    batch_size  = config["batch_size"], 
    shuffle     = True, 
    drop_last   = True,
    num_workers = 16,
    # Input Extra parameters here if needed
)

# inputs, targets = next(iter(dl))

# print(inputs.shape, targets.shape)


# for x, y in dl:
#     # print(x)
#     print("x: ", tokenizer.convert_tokens_to_string(x[0, :]))
#     print("y: ", tokenizer.convert_tokens_to_string(y[0, :]))
#     break

In [None]:
dl_val = DataLoaderForLanguageModeling(
    dataset     = val_dataset, 
    batch_size  = config["batch_size"], 
    shuffle     = True, 
    drop_last   = True,
    num_workers = 16,
    # Input Extra parameters here if needed
)

inputs, targets = next(iter(dl))

print(inputs.shape, targets.shape)

# for i,j in iter(dl):
#     print(i.shape, j.shape)

for x, y in dl:
    print(x)
    print("x: ", tokenizer.convert_tokens_to_string(x[0, :]))
    print("y: ", tokenizer.convert_tokens_to_string(y[0, :]))
    break

# Encoder-Decoder Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def scaled_dot_product_attention(query, key, value, mask=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    attn = F.softmax(scores, dim=-1)
    output = torch.matmul(attn, value)
    return output, attn
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.2):
        super().__init__()

        self.d_model = d_model
        self.d_k = d_model // num_heads
        self.h = num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, attn_mask=None):
        bs = q.size(0)

        # Perform linear operation and split into h heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        # Transpose to get dimensions bs * h * sl * d_model
        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)

        # Calculate attention
        scores, attn = scaled_dot_product_attention(q, k, v, attn_mask)

        # Concatenate heads and put through final linear layer
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        
        output = self.out(concat)

        return output

In [None]:
import math

class PositionalEncoding(torch.nn.Module):

    def __init__(self, projection_size, max_seq_len= 128):
        super().__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        denominator = torch.exp(torch.arange(0, projection_size, 2) * -(math.log(10000.0) / projection_size))
        pe = torch.zeros(max_seq_len, projection_size, device=DEVICE)

        pe[:, 0::2] = torch.sin(position * denominator)
        pe[:, 1::2] = torch.cos(position * denominator)

        self.pe = pe.unsqueeze(0)
        self.register_buffer('pos_encode',self.pe)

    def forward(self, x):
        # print('input to pe', x[0])
        # print('pe', self.pe.shape, self.pe[:, :x.size(1)].shape)
        # print(self.pos_encode[0, :x.size(1)])
        x = x + self.pos_encode[:, :x.size(1)]
        return x
    


class TransformerEncoder(torch.nn.Module):
    def __init__(self, projection_size, hidden_size, num_heads, dropout= 0.2):
        super().__init__()

        # create the key, query and value weights
        self.KW         = torch.nn.Linear(projection_size, projection_size)
        self.VW         = torch.nn.Linear(projection_size, projection_size)
        self.QW         = torch.nn.Linear(projection_size, projection_size)
        # print(projection_size, num_heads)
        self.attention = MultiHeadAttention(projection_size, num_heads)

        self.bn1        = torch.nn.LayerNorm(projection_size)

        self.bn2        = torch.nn.LayerNorm(projection_size)

        # Feed forward neural network
        self.MLP        = torch.nn.Sequential(
            torch.nn.Linear(projection_size, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(hidden_size, projection_size)
        )# TODO

    def forward(self, x):
        key     = self.KW(x)
        value   = self.VW(x)
        query   = self.QW(x)
        
        encoder_output, self_attention_weights   = self.attention(query,key,value)# TODO
        out1    = encoder_output + x 
        # Apply batch norm to out1
        out1    = self.bn1(out1)
        
        # Apply the output of the feed forward network
        out2    = self.MLP(out1)
        # Apply a residual connection between the input and output of the  FFN
        out2    = out2 + out1
        # Apply batch norm to the output
        out2    = self.bn2(out2)

        return out2, self_attention_weights

In [None]:
class Encoder(torch.nn.Module):

    def __init__(self,
                input_size,
                encoder_embedding_size,
                encoder_hidden_size,
                n_heads,
                tf_blocks,):
        super().__init__()
        
        self.embedding_size = encoder_embedding_size
        self.embedding = nn.Embedding(input_size, encoder_embedding_size)

        # compute the postion encoding
        self.positional_encoding    = PositionalEncoding(encoder_embedding_size)# TODO

        # create a sequence of transformer blocks
        self.transformer_encoder    = torch.nn.ModuleList([TransformerEncoder(encoder_embedding_size, encoder_hidden_size, n_heads) for _ in range(tf_blocks)])

        self.layer_norm = nn.LayerNorm(encoder_embedding_size)

        self.droupout = nn.Dropout()

    def forward(self, x):

        # Pass the output through the embedding
        output                  = self.embedding(x)# TODO
        # calculate the new output length
        # output_lengths          = (output_lengths + 1 ) // 4 # TODO

        # print(output.shape)
        # calculate the position encoding
        output  = self.positional_encoding(output)# TODO
        output = self.droupout(output)


        output = self.layer_norm(output)
        # print(output.shape, output)
        # Pass the output of the positional encoding through the transformer encoder
        for encoder in self.transformer_encoder:
            output, _ = encoder(output)# TODO

        return output

In [None]:
class DecoderWithAttention(nn.Module):
    def __init__(self, input_size, decoder_embedding_size, hidden_size, num_layers, num_heads, output_size, dropout):
        super(DecoderWithAttention, self).__init__()
        self.embedding = nn.Embedding(input_size, decoder_embedding_size)
        self.dropout = nn.Dropout(dropout)
        self.positional_encoding    = PositionalEncoding(decoder_embedding_size)
        self.layer_norm = nn.LayerNorm(decoder_embedding_size, eps=1e-6)
        self.slf_attn = MultiHeadAttention(embedding_size=decoder_embedding_size, num_heads=num_heads)
        self.enc_attn = MultiHeadAttention(embedding_size=decoder_embedding_size, num_heads=num_heads)

        self.MLP        = torch.nn.Sequential(
            torch.nn.Linear(decoder_embedding_size, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, decoder_embedding_size),
        )# TODO
        self.linear = nn.Linear(decoder_embedding_size, output_size)
        self.layer_norm2 = nn.LayerNorm(decoder_embedding_size)

    def forward(self, encoder_outputs, y=None, MAX_LENGTH=21):
        batch_size = encoder_outputs.shape[0]
        if y == None:
            y = torch.full((batch_size, 1), tokenizer.vocab['[CLS]'], device=DEVICE)
            predictions = []
            # attention_weights_all = []
            # print(y.shape)
            for t in range(MAX_LENGTH): 
                embedding  = self.embedding(y)# TODO
                embedding = self.layer_norm(embedding)
                # print(embedding.shape)

                # embedding = torch.unsqueeze(self.dropout(self.embedding(y)), 1)
                # context_vector, attention_weights = self.attention(embedding, encoder_outputs[:,t,:], encoder_outputs[:,t,:])
                # output = torch.cat((context_vector, embedding), dim=2)
                # prediction = torch.squeeze(self.fc_out(output), 1)

                dec_output, dec_slf_attn = self.slf_attn(embedding, embedding, embedding)
                # print(dec_output.shape, encoder_outputs.shape)
                dec_output, dec_enc_attn = self.enc_attn(dec_output, encoder_outputs, encoder_outputs)
                # dec_output = self.MLP(dec_output) + dec_output  # Apply residual connection
                # dec_output = self.layer_norm2(dec_output)
                # print(dec_output.shape)
                prediction = self.linear(dec_output)

                prediction = prediction[:, -1, :]
                predictions.append(prediction)
                # attention_weights_all.append((dec_slf_attn, dec_enc_attn))
                # print(prediction.shape)
                # if t == 0:
                #     next_token = torch.argmax(prediction, dim=-1, keepdim=True)
                # else:
                next_token = torch.argmax(prediction, dim=-1, keepdim=True)

                # print(next_token.shape)
                # print(y.shape)
                y = torch.cat([y, next_token], dim=1)

            predictions = torch.stack(predictions, dim=1)  # Shape: (batch_size, MAX_LENGTH, output_size)
            return predictions 

        # y shape: (batch_size, seq length)
        y  = self.embedding(y)
        y  = self.positional_encoding(y)

        embedding = self.dropout(y)  # shape: (batch_size, 1, embedding_size)

        embedding = self.layer_norm(embedding)

        # dec_slf_attn_list, dec_enc_attn_list = [], []
        dec_output, dec_slf_attn = self.slf_attn(embedding, embedding, embedding)
        dec_output, dec_enc_attn = self.enc_attn(dec_output, encoder_outputs, encoder_outputs)
        
        # Apply the output of the feed forward network
        out1    = self.MLP(dec_output) 
        # Apply a residual connection between the input and output of the  FFN
        out2    = dec_output + out1 
        # Apply batch norm to the output
        out2    = self.layer_norm2(out2) 

        logits = self.linear(out2)
        return logits


In [None]:
class Seq2SeqWithAttention(nn.Module):
    def __init__(self, input_size, encoder_embedding_size, encoder_hidden_size, encoder_n_heads, tf_blocks,
                decoder_embedding_size, decoder_hidden_size, num_layers, num_heads, output_size, dropout = 0.1):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = Encoder(input_size,encoder_embedding_size, encoder_hidden_size, n_heads = encoder_n_heads, tf_blocks = tf_blocks)
        self.decoder = DecoderWithAttention(input_size, decoder_embedding_size, decoder_hidden_size, num_layers, num_heads, output_size, dropout)

    def forward(self, source, target=None):

        encoder_outputs= self.encoder(source)

        dec_output = self.decoder(encoder_outputs, target)

        return dec_output


# Set Model Config

In [None]:
model_config = dict (
    batch_size          = 128,
    epochs              = 30,
    learning_rate       = 3e-4,
    weight_decay        = 5e-3,
    encoder_embedding_size  = 256,
    encoder_hidden_size     = 256,
    encoder_n_heads         = 8,
    tf_blocks               = 6,
    vocab_size              = 40000,
    decoder_embedding_size   = 256, 
    decoder_hidden_size     = 256,
    num_layers              = 3,
    num_heads               = 4,
    tf_ratio                = 1.0,
    patience                = 1,
)

with open('./model_config.json', 'w') as file:
    json.dump(model_config, file, indent=4) 

In [None]:
model = Seq2SeqWithAttention(model_config["vocab_size"], model_config['encoder_embedding_size'], model_config['encoder_hidden_size'], model_config['encoder_n_heads'], model_config['tf_blocks'],
                model_config['decoder_embedding_size'], model_config['decoder_hidden_size'], model_config['num_layers'], model_config['num_heads'], model_config["vocab_size"])
model = model.to(DEVICE)
print(model)

In [None]:
# import torchsummary
# x_sample    = torch.rand(64, 128).long()
# x_sample = x_sample.to(DEVICE)
# y_sample = torch.rand(64, 128).long()
# y_sample = y_sample.to(DEVICE)
# # print(x_sample, y_sample)
# torchsummary.summary(model, x_sample, y_sample)
# del x_sample, y_sample

In [None]:
# x_sample    = torch.rand(64, 128).long()
# torchsummary.summary(model, x_sample.to(DEVICE))
# del x_sample

# Define functions

In [None]:
def calc_edit_distance(predictions, y,tokenizer, vocab= VOCAB, print_example= True):

    dist                = 0
    batch_size, seq_len = predictions.shape

    for batch_idx in range(batch_size):

        y_sliced    = tokenizer.convert_tokens_to_string(y[batch_idx])
        pred_sliced = tokenizer.convert_tokens_to_string(predictions[batch_idx])
        print(y_sliced)
        print(pred_sliced)

        # # Strings - When you are using characters from the AudioDataset
        # y_string    = ''.join(y_sliced)
        # pred_string = ''.join(pred_sliced)

        dist        += Levenshtein.distance(pred_sliced, y_sliced)
        # Comment the above abd uncomment below for toy dataset
        # dist      += Levenshtein.distance(y_sliced, pred_sliced)

    # if print_example:
    #     # Print y_sliced and pred_sliced if you are using the toy dataset
    #     print("\nGround Truth : ", y_string)
    #     print("Prediction   : ", pred_string)

    dist    /= batch_size
    return dist

In [None]:
print("Before optimizer:", next(model.parameters()).device)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding for loss calculation
print("Before optimizer:", next(model.parameters()).device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print("After optimizer:", next(model.parameters()).device)
# optimizer = optim.Adam(model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, threshold=0.001)

In [None]:
def calculate_loss(criterion, out, target):

    out     = out.view(-1, out.size(2))
    targets = torch.flatten(target)
    loss    = criterion(out, targets)

    return loss

In [None]:

def train(model, dataloader, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train')

    for i, (src, trg) in enumerate(dataloader):

        src = src.to(DEVICE)
        trg = trg.to(DEVICE)


        # with torch.cuda.amp.autocast():

        optimizer.zero_grad()

        output = model(src, trg)

        # trg shape: (batch_size, trg_len)
        # output shape: (batch_size, trg_len, output_dim)

        # output_dim = output.shape[-1]

        # output = output[:, 1:].reshape(-1, output_dim)
        # trg = trg[:, 1:].reshape(-1)

        loss = calculate_loss(criterion, output, trg)
        # print(loss)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.item()

        batch_bar.set_postfix(
            loss="{:.04f}".format(epoch_loss/(i+1)),
            # perplexity="{:.04f}".format(running_perplexity/(i+1)),
            lr="{:.04f}".format(float(optimizer.param_groups[0]['lr'])))
        batch_bar.update()

        del src, trg
        torch.cuda.empty_cache()
    
    batch_bar.close()

    return epoch_loss / len(dataloader)



In [None]:
def validate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0

    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc="Val")
    running_dist = 0
    with torch.inference_mode():
        for i, (src, trg) in enumerate(dataloader):

            src = src.to(DEVICE)
            trg = trg.to(DEVICE)
            
            # print(tokenizer.convert_tokens_to_string(src[0]))
            # print(tokenizer.convert_tokens_to_string(trg[0]))
            # print(trg[0])
            print(trg.shape, "trg shape")
            
            
            output = model(src) # turn off teacher forcing
            # trg = [trg len, batch size]
            # output = [trg len, batch size, output dim]

            # output_dim = output.shape[-1]
            # output = output[1:].view(-1, output_dim)
            # trg = trg[1:].view(-1)
            print(output.shape)

            # output = output.transpose(1,2)
            
            # loss    = criterion(output, targets)
            # print(loss)
            
            prob_dist = torch.nn.functional.log_softmax(output, dim=-1)
            pred_string = torch.argmax(prob_dist, dim=-1)
            print("here")
            dist = calc_edit_distance(predictions=pred_string, y=trg, tokenizer=tokenizer)
            running_dist += dist
            # epoch_loss += loss.item()
            batch_bar.set_postfix(
                loss="{:.04f}".format(epoch_loss/(i+1)))
            batch_bar.update()
            del src, trg
            torch.cuda.empty_cache()
            print(running_dist)
    
    batch_bar.close()
    # print(running_dist)

    return epoch_loss / len(dataloader), running_dist/ len(dataloader)


In [None]:
torch.cuda.empty_cache()
gc.collect()

# Experiments

In [None]:
N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    train_loss = train(model, dl, optimizer, criterion, CLIP)
    valid_loss, dist = validate(model, dl_val, criterion)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), './best-model.pth')

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.5f}')
    print(f'\t Val. Loss: {valid_loss:.5f}')
    print(f'\t Distance: {dist:.5f}')
