In [None]:
import numpy as np
import pandas as pd
import tiktoken as tk
import torch
import torch.nn as nn
import torch.nn.functional as F

### `build_ids(data, add_eot=True) -> torch.Tensor`
- **Input:** string or list of strings  
- **Output:** 1D tensor of token IDs `[N]`  
- **Process:** encodes text to tokens, optionally appends `<|endoftext|>`

### `batch_loader(raw_dataset, T=64, B=8, device="cuda") -> (x, y)`
- **Input:** raw text dataset  
- **Output:** `(x, y)` tensors of shape `[B, T]`  
- **Process:** samples random slices of length `T` from tokenized data for training batches  


In [None]:
encoder = tk.get_encoding("gpt2")
EOT= encoder.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})  ### [50256]

enc = encoder.encode("Hello world", allowed_special={"<|endoftext|>"})

dec = encoder.decode(enc)

print(enc)
print(dec)

In [None]:
def enc(s: str) -> list[int]:
    return encoder.encode(s, allowed_special={"<|endoftext|>"})
def dec(ids: list[int]) -> str:
    return encoder.decode(ids)


def build_ids(data, add_eot: bool = True) -> torch.Tensor:
        
    if isinstance(data, str):
        txts = [data]
    else:
        txts = list(data)

    buf = []
    for s in txts:
        buf.extend(enc(s))
        if add_eot:
            buf.extend(EOT)
    return torch.tensor(buf, dtype=torch.long)





@torch.no_grad() ## Saves memory
def batch_loader(raw_dataset, T: int = 64, B: int = 8, device: str = "cuda"):
    
    ## Encodes the dataset
    ids = build_ids(raw_dataset, add_eot = True)

    
    ###Check if token sequence is too small
    N = ids.size(0)
    if N <= T + 1:
        raise ValueError(f"Need more tokens (got {N}) than T+1 ({T+1}).")

    
    # sample B starting positions
    i = torch.randint(0, N - T - 1, (B,))
    
    
    # gather slices (CPU) then move once (faster than so many tiny transfers)
    x = torch.stack([ids[j:j+T]     for j in i], dim=0)
    y = torch.stack([ids[j+1:j+T+1] for j in i], dim=0)
    return x.to(device, non_blocking=True), y.to(device, non_blocking=True)
