# Imports and Setup

In [1]:
!git clone https://github.com/Hari31416/transformer_from_scratch.git
!cp -r ./transformer_from_scratch/* ./
!pip install -q evaluate=='0.4.1'

Cloning into 'transformer_from_scratch'...
remote: Enumerating objects: 74, done.[K
remote: Counting objects: 100% (74/74), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 74 (delta 40), reused 58 (delta 24), pack-reused 0 (from 0)[K
Receiving objects: 100% (74/74), 4.28 MiB | 22.71 MiB/s, done.
Resolving deltas: 100% (40/40), done.


In [2]:
import torch
import torch.nn as nn
import evaluate

from tqdm.auto import tqdm

T = torch.Tensor
M = nn.Module


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from src.train_utils import *
from src.transformer import *

For logging to wandb. Update the API key.

In [3]:
LOG_TO_WANDB = True
if LOG_TO_WANDB:
    from kaggle_secrets import UserSecretsClient
    import wandb

    user_secrets = UserSecretsClient()
    WANDB_API_KEY = user_secrets.get_secret("WANDB_API_KEY")

    text = f"""machine api.wandb.ai
        login user
        password {WANDB_API_KEY}
        """
    # wandb saves credentials at /root/.netrc
    with open("/root/.netrc", "w") as f:
        f.write(text)

    wandb.init(project="Transformer_From_Scratch", name="Run 8")
else:
    wandb = None

[34m[1mwandb[0m: Currently logged in as: [33mhari31416[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.18.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.17.7
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240921_133322-avz5jaov[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mRun 13[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/hari31416/Transformer_From_Scratch[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/hari31416/Transformer_From_Scratch/runs/avz5jaov[0m


# Configs

We start by creating the configuration for the dataset and then loading the dataset. See the notebook [lyrics](https://github.com/Hari31416/transformer_from_scratch/blob/main/notebooks/lyrics.ipynb) for a detail on the dataset creation.

In [4]:
dataset_config = GenerationDatasetConfig(
    **{
        "dataset_path": "data/songs_section_wise.json",
        "tokenizer_path": "data/tokenizer_eng_lyrics.json",
        "max_len": 80,  # amounts to 90%+of the data
        "device": device,
    }
)
dataset: GenerationDataset = dataset_config.load_object(GenerationDataset)
dataloader = dataset.get_dataloader(32, shuffle=True)

Since we want a model that generates the next word in a sequence, we will create a decoder only transformer model. Note that the same can also be achieved by using a full encoder-decoder model, however, we will use a decoder only model for simplicity and because the performance gain is not significant.

In [5]:
model = DecoderOnlyTransformer(
    target_vocab_size=dataset.tokenizer.get_vocab_size(),
    N = 6,
)
model.param_count()/1e6

27.618603

Next, we create the configuration for the trainer `TransformerTrainerForGeneration`.

In [6]:
tokenizer = dataset.tokenizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [7]:
trainer_config = TransformerTrainerForGenerationConfig(
    **{
        "model": model,
        "optimizer": optimizer,
        "criterion": bce_crit,
        "tokenizer": tokenizer,
        "max_len": dataset_config.max_len,
        "device": device,
        "wandb": wandb,
        "wandb_log_freq": 100,
        "scheduler": None,
    }
)
trainer: TransformerTrainerForGeneration = trainer_config.load_object(TransformerTrainerForGeneration)

# Training and Inference

## Training

Now, it is easy to train the model.

In [8]:
losses, train_states = trainer.train(
    data_loader=dataloader,
    num_epochs=60,
    log_freq=100,
    eval_loader=None,
)

torch.save(model.state_dict(), "trained_transformer_for_generation.pth")

  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000413, Tokens / Sec: 384.216980, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000174, Tokens / Sec: 11000.087891, Learning Rate: 0.0001
Epoch: 0, Loss: 0.000002


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000150, Tokens / Sec: 22109.179688, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000128, Tokens / Sec: 10810.259766, Learning Rate: 0.0001
Epoch: 1, Loss: 0.000002


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000136, Tokens / Sec: 21875.916016, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000148, Tokens / Sec: 11001.439453, Learning Rate: 0.0001
Epoch: 2, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000123, Tokens / Sec: 22374.937500, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000118, Tokens / Sec: 11792.797852, Learning Rate: 0.0001
Epoch: 3, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000144, Tokens / Sec: 19969.238281, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000103, Tokens / Sec: 11646.263672, Learning Rate: 0.0001
Epoch: 4, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000126, Tokens / Sec: 22302.226562, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000113, Tokens / Sec: 14282.322266, Learning Rate: 0.0001
Epoch: 5, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000122, Tokens / Sec: 19037.113281, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000102, Tokens / Sec: 14211.031250, Learning Rate: 0.0001
Epoch: 6, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000117, Tokens / Sec: 20326.312500, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000108, Tokens / Sec: 13609.169922, Learning Rate: 0.0001
Epoch: 7, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000092, Tokens / Sec: 21859.085938, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000094, Tokens / Sec: 12741.222656, Learning Rate: 0.0001
Epoch: 8, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000097, Tokens / Sec: 23316.660156, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000074, Tokens / Sec: 14838.389648, Learning Rate: 0.0001
Epoch: 9, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000086, Tokens / Sec: 22601.355469, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000093, Tokens / Sec: 15749.674805, Learning Rate: 0.0001
Epoch: 10, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000081, Tokens / Sec: 22315.916016, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000095, Tokens / Sec: 14603.101562, Learning Rate: 0.0001
Epoch: 11, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000068, Tokens / Sec: 25108.597656, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000084, Tokens / Sec: 14410.793945, Learning Rate: 0.0001
Epoch: 12, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000071, Tokens / Sec: 22219.585938, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000065, Tokens / Sec: 13234.324219, Learning Rate: 0.0001
Epoch: 13, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000081, Tokens / Sec: 22387.507812, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000060, Tokens / Sec: 14513.944336, Learning Rate: 0.0001
Epoch: 14, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000067, Tokens / Sec: 22540.964844, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000070, Tokens / Sec: 13599.527344, Learning Rate: 0.0001
Epoch: 15, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000054, Tokens / Sec: 21472.183594, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000066, Tokens / Sec: 14390.550781, Learning Rate: 0.0001
Epoch: 16, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000052, Tokens / Sec: 21628.625000, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000047, Tokens / Sec: 13667.739258, Learning Rate: 0.0001
Epoch: 17, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000056, Tokens / Sec: 20808.033203, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000059, Tokens / Sec: 14773.893555, Learning Rate: 0.0001
Epoch: 18, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000060, Tokens / Sec: 21610.896484, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000049, Tokens / Sec: 13023.195312, Learning Rate: 0.0001
Epoch: 19, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000059, Tokens / Sec: 20346.105469, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000050, Tokens / Sec: 14186.909180, Learning Rate: 0.0001
Epoch: 20, Loss: 0.000001


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000054, Tokens / Sec: 21200.613281, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000043, Tokens / Sec: 14035.930664, Learning Rate: 0.0001
Epoch: 21, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000044, Tokens / Sec: 23892.876953, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000043, Tokens / Sec: 13297.981445, Learning Rate: 0.0001
Epoch: 22, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000052, Tokens / Sec: 20476.173828, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000032, Tokens / Sec: 14531.953125, Learning Rate: 0.0001
Epoch: 23, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000036, Tokens / Sec: 22709.546875, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000038, Tokens / Sec: 14862.273438, Learning Rate: 0.0001
Epoch: 24, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000041, Tokens / Sec: 21870.134766, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000044, Tokens / Sec: 13893.931641, Learning Rate: 0.0001
Epoch: 25, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000035, Tokens / Sec: 21591.443359, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000032, Tokens / Sec: 13682.052734, Learning Rate: 0.0001
Epoch: 26, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000038, Tokens / Sec: 22718.361328, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000028, Tokens / Sec: 15037.572266, Learning Rate: 0.0001
Epoch: 27, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000048, Tokens / Sec: 20128.986328, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000027, Tokens / Sec: 14794.699219, Learning Rate: 0.0001
Epoch: 28, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000030, Tokens / Sec: 22654.865234, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000029, Tokens / Sec: 14403.745117, Learning Rate: 0.0001
Epoch: 29, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000035, Tokens / Sec: 23658.253906, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000028, Tokens / Sec: 14552.604492, Learning Rate: 0.0001
Epoch: 30, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000032, Tokens / Sec: 22295.126953, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000027, Tokens / Sec: 14590.297852, Learning Rate: 0.0001
Epoch: 31, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000032, Tokens / Sec: 23036.716797, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000024, Tokens / Sec: 15152.138672, Learning Rate: 0.0001
Epoch: 32, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000023, Tokens / Sec: 21043.539062, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000026, Tokens / Sec: 14711.955078, Learning Rate: 0.0001
Epoch: 33, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000025, Tokens / Sec: 21302.412109, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000023, Tokens / Sec: 14288.700195, Learning Rate: 0.0001
Epoch: 34, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000027, Tokens / Sec: 20591.160156, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000023, Tokens / Sec: 15561.219727, Learning Rate: 0.0001
Epoch: 35, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000016, Tokens / Sec: 21485.078125, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000028, Tokens / Sec: 14236.280273, Learning Rate: 0.0001
Epoch: 36, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000017, Tokens / Sec: 21136.886719, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000022, Tokens / Sec: 14581.454102, Learning Rate: 0.0001
Epoch: 37, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000016, Tokens / Sec: 22660.441406, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000022, Tokens / Sec: 15050.761719, Learning Rate: 0.0001
Epoch: 38, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000021, Tokens / Sec: 17770.914062, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000023, Tokens / Sec: 15311.076172, Learning Rate: 0.0001
Epoch: 39, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000013, Tokens / Sec: 20806.029297, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000014, Tokens / Sec: 14534.612305, Learning Rate: 0.0001
Epoch: 40, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000013, Tokens / Sec: 23598.939453, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000020, Tokens / Sec: 14412.848633, Learning Rate: 0.0001
Epoch: 41, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000016, Tokens / Sec: 21944.558594, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000012, Tokens / Sec: 15157.783203, Learning Rate: 0.0001
Epoch: 42, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000012, Tokens / Sec: 21837.107422, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000016, Tokens / Sec: 14433.512695, Learning Rate: 0.0001
Epoch: 43, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000013, Tokens / Sec: 22890.539062, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000016, Tokens / Sec: 14842.159180, Learning Rate: 0.0001
Epoch: 44, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000012, Tokens / Sec: 23004.132812, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000015, Tokens / Sec: 14564.815430, Learning Rate: 0.0001
Epoch: 45, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000010, Tokens / Sec: 20546.445312, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000014, Tokens / Sec: 14807.321289, Learning Rate: 0.0001
Epoch: 46, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000018, Tokens / Sec: 21279.101562, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000011, Tokens / Sec: 15064.448242, Learning Rate: 0.0001
Epoch: 47, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000011, Tokens / Sec: 24614.132812, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000011, Tokens / Sec: 14865.102539, Learning Rate: 0.0001
Epoch: 48, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000014, Tokens / Sec: 22375.927734, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000013, Tokens / Sec: 15009.714844, Learning Rate: 0.0001
Epoch: 49, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000009, Tokens / Sec: 24621.667969, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000008, Tokens / Sec: 15165.880859, Learning Rate: 0.0001
Epoch: 50, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000009, Tokens / Sec: 23066.228516, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000013, Tokens / Sec: 14917.144531, Learning Rate: 0.0001
Epoch: 51, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000014, Tokens / Sec: 19475.230469, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000011, Tokens / Sec: 15147.902344, Learning Rate: 0.0001
Epoch: 52, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000010, Tokens / Sec: 22095.896484, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000016, Tokens / Sec: 13978.357422, Learning Rate: 0.0001
Epoch: 53, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000009, Tokens / Sec: 22434.189453, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000013, Tokens / Sec: 14377.298828, Learning Rate: 0.0001
Epoch: 54, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000011, Tokens / Sec: 22190.962891, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000008, Tokens / Sec: 14383.401367, Learning Rate: 0.0001
Epoch: 55, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000008, Tokens / Sec: 20900.460938, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000011, Tokens / Sec: 14221.770508, Learning Rate: 0.0001
Epoch: 56, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000011, Tokens / Sec: 18870.968750, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000011, Tokens / Sec: 14890.936523, Learning Rate: 0.0001
Epoch: 57, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000008, Tokens / Sec: 21913.277344, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000011, Tokens / Sec: 14864.454102, Learning Rate: 0.0001
Epoch: 58, Loss: 0.000000


  0%|          | 0/111 [00:00<?, ?it/s]

Batch Step: 1, Loss*100: 0.000009, Tokens / Sec: 21415.460938, Learning Rate: 0.0001
Batch Step: 101, Loss*100: 0.000008, Tokens / Sec: 15222.160156, Learning Rate: 0.0001
Epoch: 59, Loss: 0.000000


## Inference

Let us see how the model performs when we generate text from a seed text. We have used the following texts:

```python
[
    "salt air and the rust on your door",
    "no other sadness in the world would do",
    "One slip and falling back into the hedge maze\nOh what a way to die",
    "It's me, hi, I'm the problem, it's me\nAt tea time, everybody agrees",
    "there i was again tonight forcing laughter faking smiles same old tired lonely place",
    "i said remember this moment in the back of my mind",
    "I was supposed to be sent away\nBut they forgot to come and get me",
    "we could leave the christmas lights up 'til january",
    "life was a willow and it bent right to your wind",
]
```

Some of these texts are in the training dataset while some are not.

In [9]:
trainer.log_sample_to_wandb(None)

{'generations': ['salt air and the rust on your door i never needed anything '
                 'more whispers of " are you sure? " " " never have i ever '
                 'before "',
                 'no other sadness in the world would do',
                 'one slip and falling back into the hedge maze oh what a way '
                 'to die',
                 "it ' s me, hi, i ' m the problem, it ' s me at tea time, "
                 'everybody agrees',
                 'there i was again tonight forcing laughter faking smiles '
                 'same old tired lonely place walls of insincerity, shifting '
                 'eyes and vacancy vanished when i saw your face all i can say '
                 'is, it was enchanting to meet you',
                 'i said remember this moment in the back of my mind the night '
                 'we stood with our shaking hands tied and i had the crowds in '
                 'stands went wild we were the kings and the queens you traded '
 

We can easily figure out which texts were in the training dataset and which were not. The last text `life was a willow and it bent right to your wind` was in the training set and the model is able to complete the text. For texts such as `I was supposed to be sent away\nBut they forgot to come and get me` were not in the training set and the model is able to complete it.