In [1]:
import torch 
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from collections import Counter
torch.set_float32_matmul_precision('high')

In [2]:
from BPE_Encoder import  token2id, train_dataset, val_dataset
from Transformer import MLLM

In [None]:
# Train tokens
train_ids = torch.load("E:/Projects/Seq2Seq/data/train_texttokens.pt", map_location="cpu")
train_ids_stream       = train_ids.tolist()
print(len(train_ids_stream))


# validation set tokens 
val_ids = torch.load("E:/Projects/Seq2Seq/data/val_texttokens.pt", map_location="cpu")
val_ids_stream       = val_ids.tolist()
print(len(val_ids_stream))

def overlap_memorize(train_ids, val_ids, n=10):

    train_ngrams = Counter(
        tuple(train_ids[i : i + n]) 
        for i in range(len(train_ids) - n + 1)
    )
    val_ngrams = Counter(
        tuple(val_ids[i : i + n]) 
        for i in range(len(val_ids) - n + 1)
    )

    overlap = sum((train_ngrams & val_ngrams).values())
    total_val = sum(val_ngrams.values())
    return overlap / total_val

for n in (5, 10, 15):
    rate = overlap_memorize(train_ids_stream, val_ids_stream, n)
    print(f"{n}-gram overlap: {rate:.4%}")


2851449
480708
5-gram overlap: 18.2913%
10-gram overlap: 0.1065%
15-gram overlap: 0.0015%


In [4]:
best_parameters = {'lr': 0.0008846917512346465, 'weight_decay': 0.00029475907910799643, 'dropout_percentage': 0.0010306170289606452, 'heads': 4, 'num_layers': 6, 'dim': 64, 
 'label_smoothing': 0.019019051540349757, 'pct_start': 0.35980068988151803, 'activation': 'gelu', 'ffn_internal': 4}

In [5]:
model = MLLM(
    vocab=len(token2id),
    dim=best_parameters["dim"],
    pad_idx=token2id["<pad>"],
    max_pos=512,
    QKV_dim=best_parameters["dim"] // best_parameters["heads"],
    heads=best_parameters["heads"],
    num_layers=best_parameters["num_layers"],
    dropout_percentage=best_parameters["dropout_percentage"],
    learning_rate=best_parameters["lr"],
    wd=best_parameters["weight_decay"],
    ls=best_parameters["label_smoothing"],pct_start = best_parameters['pct_start'], act=best_parameters['activation'], ffn_internal = best_parameters['ffn_internal'])


early_stop = EarlyStopping(monitor="Val_Loss", patience=5, verbose=True, mode="min")
checkpoint = ModelCheckpoint(
    monitor="Val_Loss",
    mode="min",
    save_top_k=1,
    filename="best")

trainer = pl.Trainer(
        max_epochs=100,
        accelerator="gpu",
        precision="16-mixed",
        gradient_clip_val=1.0,        
        callbacks=[checkpoint,early_stop])

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(model, train_dataset, val_dataset)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name       | Type             | Params | Mode 
--------------------------------------------------------
0 | act        | GELU             | 0      | train
1 | embedtoken | Embedding        | 1.3 M  | train
2 | embedpos   | Embedding        | 32.8 K | train
3 | dropout    | Dropout          | 0      | train
4 | layers     | ModuleList       | 299 K  | train
5 | model_head | Linear           | 1.3 M  | train
6 | loss       | CrossEntropyLoss | 0      | train
--------------------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.458     Total estimated model params size (MB)
67        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved. New best score: 13.732


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 6.721 >= min_delta = 0.0. New best score: 7.011


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.693 >= min_delta = 0.0. New best score: 6.318


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.331 >= min_delta = 0.0. New best score: 5.988


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.218 >= min_delta = 0.0. New best score: 5.770


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.166 >= min_delta = 0.0. New best score: 5.603


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.215 >= min_delta = 0.0. New best score: 5.388


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.138 >= min_delta = 0.0. New best score: 5.250


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.199 >= min_delta = 0.0. New best score: 5.051


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.147 >= min_delta = 0.0. New best score: 4.904


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.042 >= min_delta = 0.0. New best score: 4.863


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.110 >= min_delta = 0.0. New best score: 4.752


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.060 >= min_delta = 0.0. New best score: 4.693


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.021 >= min_delta = 0.0. New best score: 4.672


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.014 >= min_delta = 0.0. New best score: 4.658


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.001 >= min_delta = 0.0. New best score: 4.657


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.030 >= min_delta = 0.0. New best score: 4.628


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.006 >= min_delta = 0.0. New best score: 4.622


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.003 >= min_delta = 0.0. New best score: 4.619


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.095 >= min_delta = 0.0. New best score: 4.524


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.003 >= min_delta = 0.0. New best score: 4.521


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.017 >= min_delta = 0.0. New best score: 4.504


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.049 >= min_delta = 0.0. New best score: 4.455


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.032 >= min_delta = 0.0. New best score: 4.423


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.030 >= min_delta = 0.0. New best score: 4.393


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.036 >= min_delta = 0.0. New best score: 4.357


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.000 >= min_delta = 0.0. New best score: 4.357


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.015 >= min_delta = 0.0. New best score: 4.342


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.017 >= min_delta = 0.0. New best score: 4.326


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.028 >= min_delta = 0.0. New best score: 4.298


Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.041 >= min_delta = 0.0. New best score: 4.258


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.027 >= min_delta = 0.0. New best score: 4.230


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.018 >= min_delta = 0.0. New best score: 4.213


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.020 >= min_delta = 0.0. New best score: 4.193


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.000 >= min_delta = 0.0. New best score: 4.192


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.011 >= min_delta = 0.0. New best score: 4.181


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric Val_Loss improved by 0.008 >= min_delta = 0.0. New best score: 4.172


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric Val_Loss did not improve in the last 5 records. Best score: 4.172. Signaling Trainer to stop.
