In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from transformer import LlamaLMHeadModel
from transformer import Config

config = Config()
llm = LlamaLMHeadModel(config)
llm = llm.to(config.device)
input_ids = torch.randint(0, config.vocab_size, (32, 128), device=config.device)
loss, logits = llm(input_ids, input_ids.clone())
print(f'Loss: {loss.item()}')


Loss: 673.03564453125


In [None]:
# The core training begins from here.

import sentencepiece as spm

sp = spm.SentencePieceProcessor()
sp.load("nanollama_tokenizer.model")
text = "The awful thing is that beauty is mysterious as well as terrible. God and the devil are fighting there and the battlefield is the heart of man. - Fyodor Dostoevsky, The Brothers Karamazov"
ids = sp.encode(text)
tokens = sp.encode(text, out_type=str)
print(ids)
print(tokens)

[75, 14926, 721, 8206, 96, 92, 10799, 96, 8591, 11, 411, 72, 696, 72, 883, 399, 41, 15929, 2701, 33, 8, 2069, 57, 146, 3833, 533, 33, 8, 7177, 96, 8, 5111, 21, 834, 15929, 15906, 1, 157, 15924, 125, 26, 177, 212, 15912, 2961, 11765, 15927, 75, 9272, 187, 15913, 8189, 80, 1265, 466]
['▁The', '▁aw', 'ful', '▁thing', '▁is', '▁that', '▁beauty', '▁is', '▁myst', 'er', 'ious', '▁as', '▁well', '▁as', '▁ter', 'rib', 'le', '.', '▁God', '▁and', '▁the', '▁dev', 'il', '▁are', '▁fighting', '▁there', '▁and', '▁the', '▁battlefield', '▁is', '▁the', '▁heart', '▁of', '▁man', '.', '▁', '―', '▁F', 'y', 'od', 'or', '▁D', 'ost', 'o', 'ev', 'sky', ',', '▁The', '▁Bro', 'ther', 's', '▁Kar', 'am', 'az', 'ov']


In [3]:
from typing import Union, List, Optional
import sentencepiece as spm
import torch

class NanoLlamaTokenizer:
    def __init__(self, model_path: str = "nanollama_tokenizer.model",
                 bos_id: int = 2, eos_id: int = 3, pad_id: int = 4):
        self.tokenizer = spm.SentencePieceProcessor()
        self.tokenizer.load(model_path)
        self.bos_id = bos_id
        self.eos_id = eos_id
        self.pad_id = pad_id

    def encode(self, text: Union[str, List[str]], add_bos: bool = True, add_eos: bool = True,
               max_length: Optional[int] = None, pad: bool = True, return_tensor: bool = False
               ) -> Union[List[int], List[List[int]], torch.Tensor]:
        def _encode_single(s: str) -> List[int]:
            ids = self.tokenizer.encode(s)
            if add_bos:
                ids = [self.bos_id] + ids
            if add_eos:
                ids = ids + [self.eos_id]
            if max_length:
                ids = ids[:max_length]
            if pad and max_length:
                ids += [self.pad_id] * (max_length - len(ids))
            return ids

        if isinstance(text, str):
            encoded = _encode_single(text)
        elif isinstance(text, list):
            encoded = [_encode_single(s) for s in text]
        else:
            raise ValueError("Input must be a string or a list of strings.")

        if return_tensor:
            if isinstance(encoded[0], list):
                return torch.tensor(encoded)
            else:
                return torch.tensor([encoded])
        return encoded

    def decode(self, ids: Union[List[int], List[List[int]], torch.Tensor]) -> Union[str, List[str]]:
        if isinstance(ids, torch.Tensor):
            ids = ids.tolist()
        if isinstance(ids[0], list):
            return [self.tokenizer.decode(id_seq) for id_seq in ids]
        else:
            return self.tokenizer.decode(ids)

    def __call__(self, text: Union[str, List[str]], **kwargs):
        return self.encode(text, **kwargs)


In [4]:
from torch.utils.data import Dataset, DataLoader
import torch

class TokenizedTextDataset(Dataset):
    def __init__(self, file_path: str, tokenizer: NanoLlamaTokenizer, block_size: int = 128):
        with open(file_path, 'r', encoding='utf-8') as f:
            self.text = f.read()
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.tokens = self.tokenizer.encode(self.text, add_bos=False, add_eos=False, pad=False)

    def __len__(self):
        return len(self.tokens) // self.block_size

    def __getitem__(self, idx):
        i = idx * self.block_size
        input_ids = self.tokens[i: i + self.block_size]
        target_ids = input_ids[1:] + [self.tokenizer.pad_id]  # Shifted for language modeling
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(target_ids, dtype=torch.long),
        }


def collate_fn(batch, pad_id=4):
    input_ids = [item['input_ids'] for item in batch]
    labels = [item['labels'] for item in batch]
    max_len = max(len(x) for x in input_ids)

    input_ids_padded = torch.stack([torch.cat([x, torch.full((max_len - len(x),), pad_id)]) for x in input_ids])
    labels_padded = torch.stack([torch.cat([x, torch.full((max_len - len(x),), -100)]) for x in labels])  # -100 is ignored in loss

    attention_mask = (input_ids_padded != pad_id).long()

    return {
        "input_ids": input_ids_padded,
        "labels": labels_padded,
        "attention_mask": attention_mask
        }

tokenizer = NanoLlamaTokenizer("nanollama_tokenizer.model")
dataset = TokenizedTextDataset("nanollama_training_corpus.txt", tokenizer, block_size=128)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

total_batches = len(dataloader)
print("Total number of batches:", total_batches)
# Get one batch from the dataloader
example_batch = next(iter(dataloader))
# Check shapes of the batch tensors
print("Batch input_ids shape:", example_batch["input_ids"].shape)
print("Batch labels shape:", example_batch["labels"].shape)
print("Batch attention_mask shape:", example_batch["attention_mask"].shape)


Total number of batches: 3064
Batch input_ids shape: torch.Size([8, 128])
Batch labels shape: torch.Size([8, 128])
Batch attention_mask shape: torch.Size([8, 128])


In [8]:
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import os

def train(model, dataloader, optimizer, device, epochs=1, save_dir="checkpoints"):
    os.makedirs(save_dir, exist_ok=True)

    model.to(device)
    model.train()

    loss_fn = nn.CrossEntropyLoss()
    scaler = GradScaler()

    global_step = 0

    for epoch in range(epochs):
        total_loss = 0

        for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}")):
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()

            with autocast():
                logits = model(input_ids)  # (B, T, vocab_size)
                loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            global_step += 1

            # Log every 10 batches
            if global_step % 10 == 0:
                print(f"Step {global_step} | Loss: {loss.item():.4f}")

            # Save checkpoint every 500 batches
            if global_step % 500 == 0:
                checkpoint_path = os.path.join(save_dir, f"checkpoint_step{global_step}.pt")
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Saved checkpoint at step {global_step} to {checkpoint_path}")

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} complete | Avg Loss: {avg_loss:.4f}")

        # Save model at end of epoch
        epoch_path = os.path.join(save_dir, f"model_epoch{epoch+1}.pt")
        torch.save(model.state_dict(), epoch_path)
        print(f"Saved model at end of epoch {epoch+1} to {epoch_path}")


In [9]:
model = LlamaLMHeadModel(config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device in usage: {device}')
tokenizer = NanoLlamaTokenizer("nanollama_tokenizer.model")
dataset = TokenizedTextDataset("nanollama_training_corpus.txt", tokenizer, block_size=128)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
optimizer = torch.optim.AdamW(model.parameters(), 
                              lr=3e-4,           # Learning rate
                              betas=(0.9, 0.95), # Momentum terms
                              eps=1e-8,          # Epsilon for numerical stability
                              weight_decay=0.1)  # Helps generalization


train(model=model,
      device=device,
      dataloader=dataloader, optimizer=optimizer, epochs=1)

Device in usage: cuda


  scaler = GradScaler()
  with autocast():
Epoch 1:   0%|          | 10/3064 [00:39<3:21:24,  3.96s/it]

Step 10 | Loss: 108.7156


Epoch 1:   1%|          | 20/3064 [01:19<3:23:23,  4.01s/it]

Step 20 | Loss: 81.2143


Epoch 1:   1%|          | 30/3064 [01:59<3:18:01,  3.92s/it]

Step 30 | Loss: 62.2528


Epoch 1:   1%|▏         | 40/3064 [02:38<3:16:13,  3.89s/it]

Step 40 | Loss: 56.6998


Epoch 1:   2%|▏         | 50/3064 [03:18<3:25:31,  4.09s/it]

Step 50 | Loss: 51.9858


Epoch 1:   2%|▏         | 60/3064 [03:59<3:22:35,  4.05s/it]

Step 60 | Loss: 44.3808


Epoch 1:   2%|▏         | 70/3064 [04:36<3:00:20,  3.61s/it]

Step 70 | Loss: 42.6090


Epoch 1:   3%|▎         | 80/3064 [05:15<3:18:48,  4.00s/it]

Step 80 | Loss: 41.3978


Epoch 1:   3%|▎         | 90/3064 [05:55<3:17:43,  3.99s/it]

Step 90 | Loss: 38.1525


Epoch 1:   3%|▎         | 100/3064 [06:36<3:19:47,  4.04s/it]

Step 100 | Loss: 36.1135


Epoch 1:   4%|▎         | 110/3064 [07:16<3:19:11,  4.05s/it]

Step 110 | Loss: 38.1725


Epoch 1:   4%|▍         | 120/3064 [07:57<3:20:33,  4.09s/it]

Step 120 | Loss: 36.2848


Epoch 1:   4%|▍         | 130/3064 [08:37<3:17:13,  4.03s/it]

Step 130 | Loss: 35.2759


Epoch 1:   5%|▍         | 140/3064 [09:18<3:14:57,  4.00s/it]

Step 140 | Loss: 31.8651


Epoch 1:   5%|▍         | 150/3064 [09:57<3:12:17,  3.96s/it]

Step 150 | Loss: 32.6662


Epoch 1:   5%|▌         | 160/3064 [10:37<3:11:22,  3.95s/it]

Step 160 | Loss: 31.7683


Epoch 1:   6%|▌         | 170/3064 [11:16<3:10:53,  3.96s/it]

Step 170 | Loss: 32.4497


Epoch 1:   6%|▌         | 180/3064 [11:56<3:10:18,  3.96s/it]

Step 180 | Loss: 35.2662


Epoch 1:   6%|▌         | 190/3064 [12:35<3:08:25,  3.93s/it]

Step 190 | Loss: 41.7816


Epoch 1:   7%|▋         | 200/3064 [13:15<3:06:20,  3.90s/it]

Step 200 | Loss: 56.4755


Epoch 1:   7%|▋         | 210/3064 [13:52<2:58:21,  3.75s/it]

Step 210 | Loss: 47.5517


Epoch 1:   7%|▋         | 220/3064 [14:29<2:55:57,  3.71s/it]

Step 220 | Loss: 46.0107


Epoch 1:   8%|▊         | 230/3064 [15:07<2:56:29,  3.74s/it]

Step 230 | Loss: 38.3056


Epoch 1:   8%|▊         | 240/3064 [15:44<2:56:39,  3.75s/it]

Step 240 | Loss: 34.4156


Epoch 1:   8%|▊         | 250/3064 [16:24<3:07:40,  4.00s/it]

Step 250 | Loss: 31.2988


Epoch 1:   8%|▊         | 260/3064 [17:04<3:07:11,  4.01s/it]

Step 260 | Loss: 33.4305


Epoch 1:   9%|▉         | 270/3064 [17:44<3:05:13,  3.98s/it]

Step 270 | Loss: 29.8799


Epoch 1:   9%|▉         | 280/3064 [18:24<3:02:54,  3.94s/it]

Step 280 | Loss: 28.6977


Epoch 1:   9%|▉         | 290/3064 [19:03<3:00:51,  3.91s/it]

Step 290 | Loss: 27.1424


Epoch 1:  10%|▉         | 300/3064 [19:42<3:03:02,  3.97s/it]

Step 300 | Loss: 25.6652


Epoch 1:  10%|█         | 310/3064 [20:21<2:55:55,  3.83s/it]

Step 310 | Loss: 26.2152


Epoch 1:  10%|█         | 320/3064 [20:58<2:49:04,  3.70s/it]

Step 320 | Loss: 25.8543


Epoch 1:  11%|█         | 330/3064 [21:37<2:58:24,  3.92s/it]

Step 330 | Loss: 24.0223


Epoch 1:  11%|█         | 340/3064 [22:18<3:03:07,  4.03s/it]

Step 340 | Loss: 22.9338


Epoch 1:  11%|█▏        | 350/3064 [22:58<3:01:50,  4.02s/it]

Step 350 | Loss: 21.6648


Epoch 1:  12%|█▏        | 360/3064 [23:38<2:58:16,  3.96s/it]

Step 360 | Loss: 21.3272


Epoch 1:  12%|█▏        | 370/3064 [24:17<2:56:21,  3.93s/it]

Step 370 | Loss: 19.4314


Epoch 1:  12%|█▏        | 380/3064 [24:58<3:03:19,  4.10s/it]

Step 380 | Loss: 21.2947


Epoch 1:  13%|█▎        | 390/3064 [25:38<2:58:09,  4.00s/it]

Step 390 | Loss: 19.7870


Epoch 1:  13%|█▎        | 400/3064 [26:18<2:57:27,  4.00s/it]

Step 400 | Loss: 18.9394


Epoch 1:  13%|█▎        | 410/3064 [26:57<2:53:29,  3.92s/it]

Step 410 | Loss: 21.1292


Epoch 1:  14%|█▎        | 420/3064 [27:37<2:54:19,  3.96s/it]

Step 420 | Loss: 18.0310


Epoch 1:  14%|█▍        | 430/3064 [28:17<2:56:30,  4.02s/it]

Step 430 | Loss: 17.7071


Epoch 1:  14%|█▍        | 440/3064 [28:57<2:57:15,  4.05s/it]

Step 440 | Loss: 16.7239


Epoch 1:  15%|█▍        | 450/3064 [29:37<2:53:18,  3.98s/it]

Step 450 | Loss: 16.3932


Epoch 1:  15%|█▌        | 460/3064 [30:18<2:57:40,  4.09s/it]

Step 460 | Loss: 15.9107


Epoch 1:  15%|█▌        | 470/3064 [30:58<2:53:50,  4.02s/it]

Step 470 | Loss: 16.1501


Epoch 1:  16%|█▌        | 480/3064 [31:37<2:46:34,  3.87s/it]

Step 480 | Loss: 14.8370


Epoch 1:  16%|█▌        | 490/3064 [32:17<2:47:48,  3.91s/it]

Step 490 | Loss: 13.9316


Epoch 1:  16%|█▋        | 499/3064 [32:53<2:50:45,  3.99s/it]

Step 500 | Loss: 14.2983


Epoch 1:  16%|█▋        | 500/3064 [32:58<3:10:38,  4.46s/it]

Saved checkpoint at step 500 to checkpoints\checkpoint_step500.pt


Epoch 1:  17%|█▋        | 510/3064 [33:38<2:47:40,  3.94s/it]

Step 510 | Loss: 14.3664


Epoch 1:  17%|█▋        | 520/3064 [34:18<2:50:38,  4.02s/it]

Step 520 | Loss: 14.3546


Epoch 1:  17%|█▋        | 530/3064 [34:58<2:47:55,  3.98s/it]

Step 530 | Loss: 14.2652


Epoch 1:  18%|█▊        | 540/3064 [35:38<2:47:38,  3.99s/it]

Step 540 | Loss: 13.5440


Epoch 1:  18%|█▊        | 550/3064 [36:17<2:45:27,  3.95s/it]

Step 550 | Loss: 12.2626


Epoch 1:  18%|█▊        | 560/3064 [36:56<2:43:20,  3.91s/it]

Step 560 | Loss: 12.2340


Epoch 1:  19%|█▊        | 570/3064 [37:36<2:42:54,  3.92s/it]

Step 570 | Loss: 14.5083


Epoch 1:  19%|█▉        | 580/3064 [38:16<2:45:34,  4.00s/it]

Step 580 | Loss: 12.8803


Epoch 1:  19%|█▉        | 590/3064 [38:55<2:43:05,  3.96s/it]

Step 590 | Loss: 12.7288


Epoch 1:  20%|█▉        | 600/3064 [39:35<2:44:23,  4.00s/it]

Step 600 | Loss: 12.0829


Epoch 1:  20%|█▉        | 610/3064 [40:15<2:41:30,  3.95s/it]

Step 610 | Loss: 13.1426


Epoch 1:  20%|██        | 620/3064 [40:54<2:41:15,  3.96s/it]

Step 620 | Loss: 12.4486


Epoch 1:  21%|██        | 630/3064 [41:34<2:41:36,  3.98s/it]

Step 630 | Loss: 12.4094


Epoch 1:  21%|██        | 640/3064 [42:14<2:42:06,  4.01s/it]

Step 640 | Loss: 11.5269


Epoch 1:  21%|██        | 650/3064 [42:53<2:36:26,  3.89s/it]

Step 650 | Loss: 11.0071


Epoch 1:  22%|██▏       | 660/3064 [43:32<2:37:18,  3.93s/it]

Step 660 | Loss: 10.8692


Epoch 1:  22%|██▏       | 670/3064 [44:12<2:36:52,  3.93s/it]

Step 670 | Loss: 11.7278


Epoch 1:  22%|██▏       | 680/3064 [44:52<2:39:36,  4.02s/it]

Step 680 | Loss: 11.1604


Epoch 1:  23%|██▎       | 690/3064 [45:32<2:39:49,  4.04s/it]

Step 690 | Loss: 10.7901


Epoch 1:  23%|██▎       | 700/3064 [46:12<2:38:32,  4.02s/it]

Step 700 | Loss: 10.4691


Epoch 1:  23%|██▎       | 710/3064 [46:52<2:37:59,  4.03s/it]

Step 710 | Loss: 9.9160


Epoch 1:  23%|██▎       | 720/3064 [47:33<2:37:37,  4.03s/it]

Step 720 | Loss: 10.8105


Epoch 1:  24%|██▍       | 730/3064 [48:13<2:36:39,  4.03s/it]

Step 730 | Loss: 9.7232


Epoch 1:  24%|██▍       | 740/3064 [48:53<2:35:09,  4.01s/it]

Step 740 | Loss: 10.1064


Epoch 1:  24%|██▍       | 750/3064 [49:33<2:33:30,  3.98s/it]

Step 750 | Loss: 10.2796


Epoch 1:  25%|██▍       | 760/3064 [50:14<2:35:49,  4.06s/it]

Step 760 | Loss: 9.7446


Epoch 1:  25%|██▌       | 770/3064 [50:54<2:34:13,  4.03s/it]

Step 770 | Loss: 9.6297


Epoch 1:  25%|██▌       | 780/3064 [51:34<2:30:56,  3.97s/it]

Step 780 | Loss: 10.7741


Epoch 1:  26%|██▌       | 790/3064 [52:13<2:29:49,  3.95s/it]

Step 790 | Loss: 8.9917


Epoch 1:  26%|██▌       | 800/3064 [52:53<2:30:52,  4.00s/it]

Step 800 | Loss: 10.5022


Epoch 1:  26%|██▋       | 810/3064 [53:33<2:28:12,  3.95s/it]

Step 810 | Loss: 10.3257


Epoch 1:  27%|██▋       | 820/3064 [54:13<2:29:56,  4.01s/it]

Step 820 | Loss: 9.7587


Epoch 1:  27%|██▋       | 830/3064 [54:52<2:24:38,  3.88s/it]

Step 830 | Loss: 9.3172


Epoch 1:  27%|██▋       | 840/3064 [55:31<2:27:14,  3.97s/it]

Step 840 | Loss: 9.3008


Epoch 1:  28%|██▊       | 850/3064 [56:12<2:28:58,  4.04s/it]

Step 850 | Loss: 9.2904


Epoch 1:  28%|██▊       | 860/3064 [56:51<2:25:50,  3.97s/it]

Step 860 | Loss: 9.2061


Epoch 1:  28%|██▊       | 870/3064 [57:31<2:24:45,  3.96s/it]

Step 870 | Loss: 9.1939


Epoch 1:  29%|██▊       | 880/3064 [58:11<2:24:20,  3.97s/it]

Step 880 | Loss: 8.9093


Epoch 1:  29%|██▉       | 890/3064 [58:51<2:26:11,  4.03s/it]

Step 890 | Loss: 8.7405


Epoch 1:  29%|██▉       | 900/3064 [59:31<2:26:10,  4.05s/it]

Step 900 | Loss: 8.6250


Epoch 1:  30%|██▉       | 910/3064 [1:00:12<2:25:35,  4.06s/it]

Step 910 | Loss: 9.0457


Epoch 1:  30%|███       | 920/3064 [1:00:52<2:22:54,  4.00s/it]

Step 920 | Loss: 8.4804


Epoch 1:  30%|███       | 930/3064 [1:01:32<2:20:11,  3.94s/it]

Step 930 | Loss: 8.6237


Epoch 1:  31%|███       | 940/3064 [1:02:12<2:21:19,  3.99s/it]

Step 940 | Loss: 8.3545


Epoch 1:  31%|███       | 950/3064 [1:02:51<2:19:15,  3.95s/it]

Step 950 | Loss: 8.5394


Epoch 1:  31%|███▏      | 960/3064 [1:03:31<2:18:59,  3.96s/it]

Step 960 | Loss: 9.4905


Epoch 1:  32%|███▏      | 970/3064 [1:04:11<2:18:22,  3.96s/it]

Step 970 | Loss: 8.9855


Epoch 1:  32%|███▏      | 980/3064 [1:04:51<2:19:22,  4.01s/it]

Step 980 | Loss: 8.4387


Epoch 1:  32%|███▏      | 990/3064 [1:05:31<2:16:12,  3.94s/it]

Step 990 | Loss: 8.7020


Epoch 1:  33%|███▎      | 999/3064 [1:06:06<2:16:17,  3.96s/it]

Step 1000 | Loss: 8.5645


Epoch 1:  33%|███▎      | 1000/3064 [1:06:11<2:24:44,  4.21s/it]

Saved checkpoint at step 1000 to checkpoints\checkpoint_step1000.pt


Epoch 1:  33%|███▎      | 1010/3064 [1:06:51<2:16:27,  3.99s/it]

Step 1010 | Loss: 8.8646


Epoch 1:  33%|███▎      | 1020/3064 [1:07:31<2:15:45,  3.99s/it]

Step 1020 | Loss: 8.1267


Epoch 1:  34%|███▎      | 1030/3064 [1:08:10<2:14:49,  3.98s/it]

Step 1030 | Loss: 8.7793


Epoch 1:  34%|███▍      | 1040/3064 [1:08:50<2:12:10,  3.92s/it]

Step 1040 | Loss: 8.1898


Epoch 1:  34%|███▍      | 1050/3064 [1:09:29<2:13:14,  3.97s/it]

Step 1050 | Loss: 8.0832


Epoch 1:  35%|███▍      | 1060/3064 [1:10:09<2:14:12,  4.02s/it]

Step 1060 | Loss: 8.8684


Epoch 1:  35%|███▍      | 1070/3064 [1:10:48<2:11:10,  3.95s/it]

Step 1070 | Loss: 8.4862


Epoch 1:  35%|███▌      | 1080/3064 [1:11:28<2:12:13,  4.00s/it]

Step 1080 | Loss: 8.9549


Epoch 1:  36%|███▌      | 1090/3064 [1:12:08<2:10:57,  3.98s/it]

Step 1090 | Loss: 7.8695


Epoch 1:  36%|███▌      | 1100/3064 [1:12:47<2:06:36,  3.87s/it]

Step 1100 | Loss: 8.1988


Epoch 1:  36%|███▌      | 1110/3064 [1:13:26<2:08:16,  3.94s/it]

Step 1110 | Loss: 7.8141


Epoch 1:  37%|███▋      | 1120/3064 [1:14:06<2:08:14,  3.96s/it]

Step 1120 | Loss: 8.1913


Epoch 1:  37%|███▋      | 1130/3064 [1:14:46<2:09:30,  4.02s/it]

Step 1130 | Loss: 8.0931


Epoch 1:  37%|███▋      | 1140/3064 [1:15:26<2:07:10,  3.97s/it]

Step 1140 | Loss: 8.3402


Epoch 1:  38%|███▊      | 1150/3064 [1:16:06<2:08:45,  4.04s/it]

Step 1150 | Loss: 7.5239


Epoch 1:  38%|███▊      | 1160/3064 [1:16:46<2:05:32,  3.96s/it]

Step 1160 | Loss: 8.3984


Epoch 1:  38%|███▊      | 1170/3064 [1:17:25<2:05:50,  3.99s/it]

Step 1170 | Loss: 8.6982


Epoch 1:  39%|███▊      | 1180/3064 [1:18:04<2:03:11,  3.92s/it]

Step 1180 | Loss: 7.7060


Epoch 1:  39%|███▉      | 1190/3064 [1:18:44<2:02:55,  3.94s/it]

Step 1190 | Loss: 8.0481


Epoch 1:  39%|███▉      | 1200/3064 [1:19:23<2:01:22,  3.91s/it]

Step 1200 | Loss: 7.6809


Epoch 1:  39%|███▉      | 1210/3064 [1:20:02<2:03:19,  3.99s/it]

Step 1210 | Loss: 7.9577


Epoch 1:  40%|███▉      | 1220/3064 [1:20:42<2:00:08,  3.91s/it]

Step 1220 | Loss: 8.5304


Epoch 1:  40%|████      | 1230/3064 [1:21:22<2:02:02,  3.99s/it]

Step 1230 | Loss: 7.8113


Epoch 1:  40%|████      | 1240/3064 [1:22:02<2:01:17,  3.99s/it]

Step 1240 | Loss: 7.3457


Epoch 1:  41%|████      | 1250/3064 [1:22:42<2:00:16,  3.98s/it]

Step 1250 | Loss: 8.1462


Epoch 1:  41%|████      | 1260/3064 [1:23:22<2:02:10,  4.06s/it]

Step 1260 | Loss: 7.8605


Epoch 1:  41%|████▏     | 1270/3064 [1:24:02<1:59:15,  3.99s/it]

Step 1270 | Loss: 7.6542


Epoch 1:  42%|████▏     | 1280/3064 [1:24:41<1:55:47,  3.89s/it]

Step 1280 | Loss: 7.5151


Epoch 1:  42%|████▏     | 1290/3064 [1:25:21<1:58:02,  3.99s/it]

Step 1290 | Loss: 7.6724


Epoch 1:  42%|████▏     | 1300/3064 [1:26:01<1:55:38,  3.93s/it]

Step 1300 | Loss: 7.4819


Epoch 1:  43%|████▎     | 1310/3064 [1:26:40<1:54:38,  3.92s/it]

Step 1310 | Loss: 7.5954


Epoch 1:  43%|████▎     | 1320/3064 [1:27:19<1:53:30,  3.91s/it]

Step 1320 | Loss: 7.3919


Epoch 1:  43%|████▎     | 1330/3064 [1:27:59<1:52:34,  3.90s/it]

Step 1330 | Loss: 7.6396


Epoch 1:  44%|████▎     | 1340/3064 [1:28:38<1:53:24,  3.95s/it]

Step 1340 | Loss: 7.6972


Epoch 1:  44%|████▍     | 1350/3064 [1:29:18<1:54:47,  4.02s/it]

Step 1350 | Loss: 7.6899


Epoch 1:  44%|████▍     | 1360/3064 [1:29:58<1:53:44,  4.01s/it]

Step 1360 | Loss: 7.5126


Epoch 1:  45%|████▍     | 1370/3064 [1:30:39<1:54:00,  4.04s/it]

Step 1370 | Loss: 7.6991


Epoch 1:  45%|████▌     | 1380/3064 [1:31:19<1:52:22,  4.00s/it]

Step 1380 | Loss: 7.3269


Epoch 1:  45%|████▌     | 1390/3064 [1:31:59<1:52:01,  4.02s/it]

Step 1390 | Loss: 7.6952


Epoch 1:  46%|████▌     | 1400/3064 [1:32:39<1:50:26,  3.98s/it]

Step 1400 | Loss: 7.7168


Epoch 1:  46%|████▌     | 1410/3064 [1:33:18<1:48:00,  3.92s/it]

Step 1410 | Loss: 7.4265


Epoch 1:  46%|████▋     | 1420/3064 [1:33:58<1:50:14,  4.02s/it]

Step 1420 | Loss: 7.8007


Epoch 1:  47%|████▋     | 1430/3064 [1:34:38<1:48:38,  3.99s/it]

Step 1430 | Loss: 7.3482


Epoch 1:  47%|████▋     | 1440/3064 [1:35:18<1:49:22,  4.04s/it]

Step 1440 | Loss: 7.6020


Epoch 1:  47%|████▋     | 1450/3064 [1:35:59<1:48:41,  4.04s/it]

Step 1450 | Loss: 7.1015


Epoch 1:  48%|████▊     | 1460/3064 [1:36:38<1:45:20,  3.94s/it]

Step 1460 | Loss: 7.1860


Epoch 1:  48%|████▊     | 1470/3064 [1:37:19<1:47:32,  4.05s/it]

Step 1470 | Loss: 7.4721


Epoch 1:  48%|████▊     | 1480/3064 [1:37:58<1:41:27,  3.84s/it]

Step 1480 | Loss: 7.5433


Epoch 1:  49%|████▊     | 1490/3064 [1:38:35<1:36:36,  3.68s/it]

Step 1490 | Loss: 7.3938


Epoch 1:  49%|████▉     | 1499/3064 [1:39:09<1:36:22,  3.70s/it]

Step 1500 | Loss: 6.9832


Epoch 1:  49%|████▉     | 1500/3064 [1:39:13<1:44:57,  4.03s/it]

Saved checkpoint at step 1500 to checkpoints\checkpoint_step1500.pt


Epoch 1:  49%|████▉     | 1510/3064 [1:39:50<1:36:03,  3.71s/it]

Step 1510 | Loss: 7.2346


Epoch 1:  50%|████▉     | 1520/3064 [1:40:27<1:34:32,  3.67s/it]

Step 1520 | Loss: 7.3264


Epoch 1:  50%|████▉     | 1530/3064 [1:42:09<4:27:37, 10.47s/it]

Step 1530 | Loss: 7.2673


Epoch 1:  50%|█████     | 1540/3064 [1:42:49<1:46:23,  4.19s/it]

Step 1540 | Loss: 7.4743


Epoch 1:  51%|█████     | 1550/3064 [1:43:28<1:39:50,  3.96s/it]

Step 1550 | Loss: 7.5079


Epoch 1:  51%|█████     | 1560/3064 [1:44:08<1:38:40,  3.94s/it]

Step 1560 | Loss: 7.2950


Epoch 1:  51%|█████     | 1570/3064 [1:44:47<1:38:03,  3.94s/it]

Step 1570 | Loss: 6.9988


Epoch 1:  52%|█████▏    | 1580/3064 [1:45:27<1:37:00,  3.92s/it]

Step 1580 | Loss: 7.6163


Epoch 1:  52%|█████▏    | 1590/3064 [1:46:06<1:37:43,  3.98s/it]

Step 1590 | Loss: 7.1000


Epoch 1:  52%|█████▏    | 1600/3064 [1:46:46<1:37:07,  3.98s/it]

Step 1600 | Loss: 7.4431


Epoch 1:  53%|█████▎    | 1610/3064 [1:47:26<1:36:27,  3.98s/it]

Step 1610 | Loss: 7.3124


Epoch 1:  53%|█████▎    | 1620/3064 [1:48:05<1:33:55,  3.90s/it]

Step 1620 | Loss: 7.3745


Epoch 1:  53%|█████▎    | 1630/3064 [1:48:45<1:36:41,  4.05s/it]

Step 1630 | Loss: 7.2413


Epoch 1:  54%|█████▎    | 1640/3064 [1:49:25<1:35:19,  4.02s/it]

Step 1640 | Loss: 7.6381


Epoch 1:  54%|█████▍    | 1650/3064 [1:50:05<1:33:29,  3.97s/it]

Step 1650 | Loss: 7.5202


Epoch 1:  54%|█████▍    | 1660/3064 [1:50:45<1:32:58,  3.97s/it]

Step 1660 | Loss: 7.1256


Epoch 1:  55%|█████▍    | 1670/3064 [1:51:24<1:31:45,  3.95s/it]

Step 1670 | Loss: 7.1618


Epoch 1:  55%|█████▍    | 1680/3064 [1:52:05<1:33:26,  4.05s/it]

Step 1680 | Loss: 7.5799


Epoch 1:  55%|█████▌    | 1690/3064 [1:52:45<1:32:33,  4.04s/it]

Step 1690 | Loss: 7.0638


Epoch 1:  55%|█████▌    | 1700/3064 [1:53:25<1:31:11,  4.01s/it]

Step 1700 | Loss: 7.1949


Epoch 1:  56%|█████▌    | 1710/3064 [1:54:05<1:30:31,  4.01s/it]

Step 1710 | Loss: 7.0703


Epoch 1:  56%|█████▌    | 1720/3064 [1:54:45<1:30:01,  4.02s/it]

Step 1720 | Loss: 7.4250


Epoch 1:  56%|█████▋    | 1730/3064 [1:55:24<1:26:27,  3.89s/it]

Step 1730 | Loss: 7.1787


Epoch 1:  57%|█████▋    | 1740/3064 [1:56:04<1:24:57,  3.85s/it]

Step 1740 | Loss: 7.2663


Epoch 1:  57%|█████▋    | 1750/3064 [1:56:42<1:24:44,  3.87s/it]

Step 1750 | Loss: 7.1353


Epoch 1:  57%|█████▋    | 1760/3064 [1:57:20<1:21:05,  3.73s/it]

Step 1760 | Loss: 7.2555


Epoch 1:  58%|█████▊    | 1770/3064 [1:57:59<1:23:41,  3.88s/it]

Step 1770 | Loss: 7.0368


Epoch 1:  58%|█████▊    | 1780/3064 [1:58:38<1:23:51,  3.92s/it]

Step 1780 | Loss: 7.2614


Epoch 1:  58%|█████▊    | 1790/3064 [1:59:17<1:22:42,  3.90s/it]

Step 1790 | Loss: 6.8321


Epoch 1:  59%|█████▊    | 1800/3064 [1:59:56<1:21:46,  3.88s/it]

Step 1800 | Loss: 7.1141


Epoch 1:  59%|█████▉    | 1810/3064 [2:00:35<1:20:25,  3.85s/it]

Step 1810 | Loss: 7.1938


Epoch 1:  59%|█████▉    | 1820/3064 [2:01:14<1:20:22,  3.88s/it]

Step 1820 | Loss: 7.2435


Epoch 1:  60%|█████▉    | 1830/3064 [2:01:53<1:21:22,  3.96s/it]

Step 1830 | Loss: 6.8577


Epoch 1:  60%|██████    | 1840/3064 [2:02:33<1:21:05,  3.98s/it]

Step 1840 | Loss: 7.1604


Epoch 1:  60%|██████    | 1850/3064 [2:03:12<1:19:01,  3.91s/it]

Step 1850 | Loss: 7.0240


Epoch 1:  61%|██████    | 1860/3064 [2:03:51<1:19:36,  3.97s/it]

Step 1860 | Loss: 6.9987


Epoch 1:  61%|██████    | 1870/3064 [2:04:31<1:19:02,  3.97s/it]

Step 1870 | Loss: 7.4055


Epoch 1:  61%|██████▏   | 1880/3064 [2:05:10<1:16:59,  3.90s/it]

Step 1880 | Loss: 7.1225


Epoch 1:  62%|██████▏   | 1890/3064 [2:05:50<1:17:22,  3.95s/it]

Step 1890 | Loss: 7.0019


Epoch 1:  62%|██████▏   | 1900/3064 [2:06:27<1:12:36,  3.74s/it]

Step 1900 | Loss: 6.9823


Epoch 1:  62%|██████▏   | 1910/3064 [2:07:04<1:11:31,  3.72s/it]

Step 1910 | Loss: 6.9294


Epoch 1:  63%|██████▎   | 1920/3064 [2:07:42<1:12:28,  3.80s/it]

Step 1920 | Loss: 6.8901


Epoch 1:  63%|██████▎   | 1930/3064 [2:09:08<1:25:21,  4.52s/it]

Step 1930 | Loss: 7.1839


Epoch 1:  63%|██████▎   | 1940/3064 [2:09:47<1:14:07,  3.96s/it]

Step 1940 | Loss: 7.0447


Epoch 1:  64%|██████▎   | 1950/3064 [2:10:26<1:11:52,  3.87s/it]

Step 1950 | Loss: 6.9677


Epoch 1:  64%|██████▍   | 1960/3064 [2:11:06<1:14:15,  4.04s/it]

Step 1960 | Loss: 7.1734


Epoch 1:  64%|██████▍   | 1970/3064 [2:11:47<1:13:19,  4.02s/it]

Step 1970 | Loss: 6.9977


Epoch 1:  65%|██████▍   | 1980/3064 [2:12:27<1:11:47,  3.97s/it]

Step 1980 | Loss: 6.9903


Epoch 1:  65%|██████▍   | 1990/3064 [2:13:07<1:11:41,  4.01s/it]

Step 1990 | Loss: 6.9306


Epoch 1:  65%|██████▌   | 1999/3064 [2:13:43<1:11:48,  4.05s/it]

Step 2000 | Loss: 6.8639


Epoch 1:  65%|██████▌   | 2000/3064 [2:13:48<1:17:06,  4.35s/it]

Saved checkpoint at step 2000 to checkpoints\checkpoint_step2000.pt


Epoch 1:  66%|██████▌   | 2010/3064 [2:14:29<1:10:15,  4.00s/it]

Step 2010 | Loss: 6.9828


Epoch 1:  66%|██████▌   | 2020/3064 [2:15:08<1:08:18,  3.93s/it]

Step 2020 | Loss: 6.7955


Epoch 1:  66%|██████▋   | 2030/3064 [2:15:47<1:07:34,  3.92s/it]

Step 2030 | Loss: 6.9240


Epoch 1:  67%|██████▋   | 2040/3064 [2:16:27<1:07:15,  3.94s/it]

Step 2040 | Loss: 7.1710


Epoch 1:  67%|██████▋   | 2050/3064 [2:17:06<1:07:06,  3.97s/it]

Step 2050 | Loss: 6.9247


Epoch 1:  67%|██████▋   | 2060/3064 [2:17:45<1:05:19,  3.90s/it]

Step 2060 | Loss: 6.9379


Epoch 1:  68%|██████▊   | 2070/3064 [2:18:25<1:04:53,  3.92s/it]

Step 2070 | Loss: 7.0991


Epoch 1:  68%|██████▊   | 2080/3064 [2:19:03<1:02:33,  3.81s/it]

Step 2080 | Loss: 6.8762


Epoch 1:  68%|██████▊   | 2090/3064 [2:19:42<1:04:03,  3.95s/it]

Step 2090 | Loss: 6.7362


Epoch 1:  69%|██████▊   | 2100/3064 [2:20:22<1:04:24,  4.01s/it]

Step 2100 | Loss: 6.9443


Epoch 1:  69%|██████▉   | 2110/3064 [2:21:02<1:02:37,  3.94s/it]

Step 2110 | Loss: 6.8705


Epoch 1:  69%|██████▉   | 2120/3064 [2:21:41<1:01:48,  3.93s/it]

Step 2120 | Loss: 6.7116


Epoch 1:  70%|██████▉   | 2130/3064 [2:22:20<1:01:37,  3.96s/it]

Step 2130 | Loss: 6.8341


Epoch 1:  70%|██████▉   | 2140/3064 [2:23:00<1:01:25,  3.99s/it]

Step 2140 | Loss: 6.9017


Epoch 1:  70%|███████   | 2150/3064 [2:23:39<1:00:21,  3.96s/it]

Step 2150 | Loss: 7.1808


Epoch 1:  70%|███████   | 2160/3064 [2:24:19<1:00:09,  3.99s/it]

Step 2160 | Loss: 6.7900


Epoch 1:  71%|███████   | 2170/3064 [2:24:59<58:58,  3.96s/it]  

Step 2170 | Loss: 7.1484


Epoch 1:  71%|███████   | 2180/3064 [2:25:38<57:48,  3.92s/it]

Step 2180 | Loss: 7.0465


Epoch 1:  71%|███████▏  | 2190/3064 [2:26:18<58:14,  4.00s/it]

Step 2190 | Loss: 7.0085


Epoch 1:  72%|███████▏  | 2200/3064 [2:26:58<57:32,  4.00s/it]

Step 2200 | Loss: 6.7521


Epoch 1:  72%|███████▏  | 2210/3064 [2:27:38<55:30,  3.90s/it]

Step 2210 | Loss: 7.0240


Epoch 1:  72%|███████▏  | 2220/3064 [2:28:17<56:00,  3.98s/it]

Step 2220 | Loss: 6.8689


Epoch 1:  73%|███████▎  | 2230/3064 [2:28:57<55:36,  4.00s/it]

Step 2230 | Loss: 7.0830


Epoch 1:  73%|███████▎  | 2240/3064 [2:29:37<55:24,  4.03s/it]

Step 2240 | Loss: 6.6976


Epoch 1:  73%|███████▎  | 2250/3064 [2:30:18<54:49,  4.04s/it]

Step 2250 | Loss: 6.6619


Epoch 1:  74%|███████▍  | 2260/3064 [2:30:58<53:46,  4.01s/it]

Step 2260 | Loss: 6.9001


Epoch 1:  74%|███████▍  | 2270/3064 [2:31:38<52:31,  3.97s/it]

Step 2270 | Loss: 6.7624


Epoch 1:  74%|███████▍  | 2280/3064 [2:32:18<52:12,  4.00s/it]

Step 2280 | Loss: 7.0919


Epoch 1:  75%|███████▍  | 2290/3064 [2:32:57<50:57,  3.95s/it]

Step 2290 | Loss: 6.6708


Epoch 1:  75%|███████▌  | 2300/3064 [2:33:36<49:13,  3.87s/it]

Step 2300 | Loss: 6.8664


Epoch 1:  75%|███████▌  | 2310/3064 [2:34:16<49:38,  3.95s/it]

Step 2310 | Loss: 6.7764


Epoch 1:  76%|███████▌  | 2320/3064 [2:34:56<49:22,  3.98s/it]

Step 2320 | Loss: 6.7028


Epoch 1:  76%|███████▌  | 2330/3064 [2:35:36<48:53,  4.00s/it]

Step 2330 | Loss: 7.0098


Epoch 1:  76%|███████▋  | 2340/3064 [2:36:16<48:21,  4.01s/it]

Step 2340 | Loss: 6.7581


Epoch 1:  77%|███████▋  | 2350/3064 [2:36:55<46:24,  3.90s/it]

Step 2350 | Loss: 6.9110


Epoch 1:  77%|███████▋  | 2360/3064 [2:37:34<44:53,  3.83s/it]

Step 2360 | Loss: 7.1373


Epoch 1:  77%|███████▋  | 2370/3064 [2:38:13<44:24,  3.84s/it]

Step 2370 | Loss: 6.9893


Epoch 1:  78%|███████▊  | 2380/3064 [2:38:53<45:36,  4.00s/it]

Step 2380 | Loss: 6.6471


Epoch 1:  78%|███████▊  | 2390/3064 [2:39:32<43:27,  3.87s/it]

Step 2390 | Loss: 6.8735


Epoch 1:  78%|███████▊  | 2400/3064 [2:40:10<41:49,  3.78s/it]

Step 2400 | Loss: 7.0051


Epoch 1:  79%|███████▊  | 2410/3064 [2:40:49<42:31,  3.90s/it]

Step 2410 | Loss: 6.8386


Epoch 1:  79%|███████▉  | 2420/3064 [2:41:29<43:09,  4.02s/it]

Step 2420 | Loss: 6.7546


Epoch 1:  79%|███████▉  | 2430/3064 [2:42:09<41:43,  3.95s/it]

Step 2430 | Loss: 6.8091


Epoch 1:  80%|███████▉  | 2440/3064 [2:42:48<40:31,  3.90s/it]

Step 2440 | Loss: 6.7990


Epoch 1:  80%|███████▉  | 2450/3064 [2:43:27<38:56,  3.81s/it]

Step 2450 | Loss: 6.6950


Epoch 1:  80%|████████  | 2460/3064 [2:44:07<40:35,  4.03s/it]

Step 2460 | Loss: 6.9773


Epoch 1:  81%|████████  | 2470/3064 [2:44:47<39:35,  4.00s/it]

Step 2470 | Loss: 6.6666


Epoch 1:  81%|████████  | 2480/3064 [2:45:27<38:53,  4.00s/it]

Step 2480 | Loss: 6.8946


Epoch 1:  81%|████████▏ | 2490/3064 [2:46:07<38:08,  3.99s/it]

Step 2490 | Loss: 6.8015


Epoch 1:  82%|████████▏ | 2499/3064 [2:46:43<37:35,  3.99s/it]

Step 2500 | Loss: 6.5987


Epoch 1:  82%|████████▏ | 2500/3064 [2:46:48<39:53,  4.24s/it]

Saved checkpoint at step 2500 to checkpoints\checkpoint_step2500.pt


Epoch 1:  82%|████████▏ | 2510/3064 [2:47:27<36:58,  4.00s/it]

Step 2510 | Loss: 6.6285


Epoch 1:  82%|████████▏ | 2520/3064 [2:48:07<36:07,  3.98s/it]

Step 2520 | Loss: 6.8079


Epoch 1:  83%|████████▎ | 2530/3064 [2:48:47<35:25,  3.98s/it]

Step 2530 | Loss: 6.5108


Epoch 1:  83%|████████▎ | 2540/3064 [2:49:27<34:39,  3.97s/it]

Step 2540 | Loss: 6.7026


Epoch 1:  83%|████████▎ | 2550/3064 [2:50:06<33:30,  3.91s/it]

Step 2550 | Loss: 6.5410


Epoch 1:  84%|████████▎ | 2560/3064 [2:50:46<33:36,  4.00s/it]

Step 2560 | Loss: 6.5251


Epoch 1:  84%|████████▍ | 2570/3064 [2:51:26<32:55,  4.00s/it]

Step 2570 | Loss: 6.8613


Epoch 1:  84%|████████▍ | 2580/3064 [2:52:05<30:36,  3.79s/it]

Step 2580 | Loss: 6.3567


Epoch 1:  85%|████████▍ | 2590/3064 [2:52:43<31:08,  3.94s/it]

Step 2590 | Loss: 6.7738


Epoch 1:  85%|████████▍ | 2600/3064 [2:53:22<30:17,  3.92s/it]

Step 2600 | Loss: 6.8552


Epoch 1:  85%|████████▌ | 2610/3064 [2:54:01<29:33,  3.91s/it]

Step 2610 | Loss: 6.6231


Epoch 1:  86%|████████▌ | 2620/3064 [2:54:40<29:00,  3.92s/it]

Step 2620 | Loss: 6.6077


Epoch 1:  86%|████████▌ | 2630/3064 [2:55:19<27:44,  3.84s/it]

Step 2630 | Loss: 6.7037


Epoch 1:  86%|████████▌ | 2640/3064 [2:55:58<27:51,  3.94s/it]

Step 2640 | Loss: 6.7132


Epoch 1:  86%|████████▋ | 2650/3064 [2:56:37<26:04,  3.78s/it]

Step 2650 | Loss: 6.8017


Epoch 1:  87%|████████▋ | 2660/3064 [2:57:15<26:16,  3.90s/it]

Step 2660 | Loss: 6.6917


Epoch 1:  87%|████████▋ | 2670/3064 [2:57:54<25:29,  3.88s/it]

Step 2670 | Loss: 6.2368


Epoch 1:  87%|████████▋ | 2680/3064 [2:58:32<23:57,  3.74s/it]

Step 2680 | Loss: 6.8562


Epoch 1:  88%|████████▊ | 2690/3064 [2:59:10<23:58,  3.85s/it]

Step 2690 | Loss: 6.7689


Epoch 1:  88%|████████▊ | 2700/3064 [2:59:49<23:33,  3.88s/it]

Step 2700 | Loss: 6.8791


Epoch 1:  88%|████████▊ | 2710/3064 [3:00:28<22:52,  3.88s/it]

Step 2710 | Loss: 6.7700


Epoch 1:  89%|████████▉ | 2720/3064 [3:01:07<22:01,  3.84s/it]

Step 2720 | Loss: 6.9101


Epoch 1:  89%|████████▉ | 2730/3064 [3:01:47<21:40,  3.89s/it]

Step 2730 | Loss: 6.5708


Epoch 1:  89%|████████▉ | 2740/3064 [3:02:25<20:36,  3.82s/it]

Step 2740 | Loss: 6.4658


Epoch 1:  90%|████████▉ | 2750/3064 [3:03:04<19:45,  3.78s/it]

Step 2750 | Loss: 6.8030


Epoch 1:  90%|█████████ | 2760/3064 [3:03:42<19:06,  3.77s/it]

Step 2760 | Loss: 6.6775


Epoch 1:  90%|█████████ | 2770/3064 [3:04:22<19:24,  3.96s/it]

Step 2770 | Loss: 6.5502


Epoch 1:  91%|█████████ | 2780/3064 [3:05:01<18:48,  3.97s/it]

Step 2780 | Loss: 6.6682


Epoch 1:  91%|█████████ | 2790/3064 [3:05:41<18:09,  3.98s/it]

Step 2790 | Loss: 6.4375


Epoch 1:  91%|█████████▏| 2800/3064 [3:06:21<17:33,  3.99s/it]

Step 2800 | Loss: 6.5346


Epoch 1:  92%|█████████▏| 2810/3064 [3:07:01<16:48,  3.97s/it]

Step 2810 | Loss: 6.7375


Epoch 1:  92%|█████████▏| 2820/3064 [3:07:40<15:53,  3.91s/it]

Step 2820 | Loss: 6.3818


Epoch 1:  92%|█████████▏| 2830/3064 [3:08:19<15:19,  3.93s/it]

Step 2830 | Loss: 6.7206


Epoch 1:  93%|█████████▎| 2840/3064 [3:08:58<14:41,  3.93s/it]

Step 2840 | Loss: 6.7847


Epoch 1:  93%|█████████▎| 2850/3064 [3:09:36<13:21,  3.75s/it]

Step 2850 | Loss: 6.6212


Epoch 1:  93%|█████████▎| 2860/3064 [3:10:15<13:27,  3.96s/it]

Step 2860 | Loss: 6.8413


Epoch 1:  94%|█████████▎| 2870/3064 [3:10:55<12:57,  4.01s/it]

Step 2870 | Loss: 6.7283


Epoch 1:  94%|█████████▍| 2880/3064 [3:11:36<12:22,  4.03s/it]

Step 2880 | Loss: 7.1946


Epoch 1:  94%|█████████▍| 2890/3064 [3:12:16<11:41,  4.03s/it]

Step 2890 | Loss: 6.4891


Epoch 1:  95%|█████████▍| 2900/3064 [3:12:56<10:58,  4.02s/it]

Step 2900 | Loss: 6.6439


Epoch 1:  95%|█████████▍| 2910/3064 [3:13:36<10:21,  4.03s/it]

Step 2910 | Loss: 6.5029


Epoch 1:  95%|█████████▌| 2920/3064 [3:14:15<09:18,  3.88s/it]

Step 2920 | Loss: 6.7114


Epoch 1:  96%|█████████▌| 2930/3064 [3:14:54<08:44,  3.91s/it]

Step 2930 | Loss: 6.7245


Epoch 1:  96%|█████████▌| 2940/3064 [3:15:34<08:16,  4.00s/it]

Step 2940 | Loss: 6.6242


Epoch 1:  96%|█████████▋| 2950/3064 [3:16:14<07:36,  4.01s/it]

Step 2950 | Loss: 6.9582


Epoch 1:  97%|█████████▋| 2960/3064 [3:16:54<06:57,  4.01s/it]

Step 2960 | Loss: 6.8198


Epoch 1:  97%|█████████▋| 2970/3064 [3:17:34<06:16,  4.00s/it]

Step 2970 | Loss: 6.8696


Epoch 1:  97%|█████████▋| 2980/3064 [3:18:14<05:33,  3.97s/it]

Step 2980 | Loss: 6.4796


Epoch 1:  98%|█████████▊| 2990/3064 [3:18:52<04:40,  3.78s/it]

Step 2990 | Loss: 6.6068


Epoch 1:  98%|█████████▊| 2999/3064 [3:19:28<04:18,  3.98s/it]

Step 3000 | Loss: 6.8251


Epoch 1:  98%|█████████▊| 3000/3064 [3:19:32<04:34,  4.28s/it]

Saved checkpoint at step 3000 to checkpoints\checkpoint_step3000.pt


Epoch 1:  98%|█████████▊| 3010/3064 [3:20:13<03:37,  4.03s/it]

Step 3010 | Loss: 6.5430


Epoch 1:  99%|█████████▊| 3020/3064 [3:20:53<02:58,  4.05s/it]

Step 3020 | Loss: 6.5822


Epoch 1:  99%|█████████▉| 3030/3064 [3:21:32<02:08,  3.77s/it]

Step 3030 | Loss: 6.3616


Epoch 1:  99%|█████████▉| 3040/3064 [3:22:11<01:35,  3.97s/it]

Step 3040 | Loss: 6.5693


Epoch 1: 100%|█████████▉| 3050/3064 [3:22:49<00:54,  3.90s/it]

Step 3050 | Loss: 6.5972


Epoch 1: 100%|█████████▉| 3060/3064 [3:23:28<00:15,  3.86s/it]

Step 3060 | Loss: 6.8759


Epoch 1: 100%|██████████| 3064/3064 [3:23:41<00:00,  3.99s/it]


Epoch 1 complete | Avg Loss: 12.7187
Saved model at end of epoch 1 to checkpoints\model_epoch1.pt


In [10]:
torch.save(model.state_dict(), "nanollama_weights.pth")

In [15]:
import torch

def generate_autoregressive(
    model,
    tokenizer,
    prompt: str,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int = 50,
    eos_token_id: int = 3,
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
    model.eval()
    input_ids = tokenizer.encode(prompt, return_tensor=True).to(device)
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            outputs = model(input_ids)  # logits: [1, seq_len, vocab_size]
            logits = outputs[:, -1, :] / temperature  # take last token's logits
            if top_k is not None:
                top_k_logits, top_k_indices = torch.topk(logits, top_k)
                probs = torch.nn.functional.softmax(top_k_logits, dim=-1)
                next_token = top_k_indices[0, torch.multinomial(probs, 1)]
            else:
                probs = torch.nn.functional.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1)

            # Squeeze the next_token and concatenate with input_ids
            next_token = next_token.squeeze(1)  # Remove extra dimension
            input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=1)  # Concatenate

            if next_token.item() == eos_token_id:
                break

    return tokenizer.decode(input_ids[0].tolist())


In [16]:
generated = generate_autoregressive(
    model=model,
    tokenizer=tokenizer,
    prompt="The future of artificial intelligence is",
    max_new_tokens=100,
    top_k=40
)

print(generated)


<s> The future of artificial intelligence is</s> as the S-Du-croy-nas-Ata in the Middle East" of many Roman states, a matter of the natural result in the early as a third-in-Ap-Doni or Staci of a result of a century, the Sayian, the Bot, or others, the first century, social structure or more visible matter of some to the species of a star, a very similar of the social social structure of the general
