In [1]:
import os
import random
import numpy as np
import torch
from transformers import AutoTokenizer
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
import json
import time
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
# import  evaluate as evaluate_model

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

In [2]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = (
        True  # only applies to CUDA convolution operations
    )
    torch.backends.cudnn.benchmark = False


## Model


class MLP(nn.Module):
    def __init__(self, n_embd, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Dropout(p=dropout),
            nn.Linear(4 * n_embd, n_embd),
        )

    def forward(self, x):
        return self.net(x)


class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head, seq_length, dropout=0.1):
        super().__init__()

        self.n_embd = n_embd
        self.n_head = n_head
        self.head_dim = (
            n_embd // n_head
        )  # Dimension of each head's key, query, and value
        assert (
            self.head_dim * n_head == self.n_embd
        ), "n_embd must be divisible by n_head"
        self.seq_length = seq_length
        self.drop = nn.Dropout(p=dropout)

        self.query = nn.Linear(n_embd, n_embd, bias=False)
        self.key = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(n_embd, n_embd, bias=False)
        self.out = nn.Linear(
            n_embd, n_embd, bias=False
        )  # multi-head combining weight matrix

    def split_heads(self, x):
        B, S, D = x.size()
        # split dimension into n_head * head_dim, then transpose the sequence length w/ n_head
        # output: [B, n_head, S, head_dim]
        return x.view(B, S, self.n_head, self.head_dim).transpose(1, 2)

    def combine_heads(self, x):
        B, _, S, head_dim = x.size()  # _ is n_head which we will merge
        # output: [B, S, n_embd]
        return x.transpose(1, 2).contiguous().view(B, S, self.n_embd)

    def scaled_dot_product(self, q, k, v, dropout, mask=None):
        # q,k,v are [B, n_head, S, head_dim]
        # q @ k.T(-2, -1) sets up batch multiplication s.t. wei = [B, n_head, S, S]
        wei = q @ k.transpose(-2, -1) / np.sqrt(self.head_dim)
        # mask = [B, 1, S, S], so it is simply broadcasted across each head 
        if mask is not None:
            wei = wei.masked_fill(mask, float("-inf"))
        wei = dropout(F.softmax(wei, dim=-1))
        out = wei @ v
        return out

    def forward(self, x, mask=None):
        # x: (B, S, n_embd)
        # Step 1 and 2: Project full query, key, value, then split via reshaping
        q = self.split_heads(self.query(x))
        k = self.split_heads(self.key(x))
        v = self.split_heads(self.value(x))

        # Step 3: Compute scaled dot-product attention with causal mask
        attn = self.scaled_dot_product(q, k, v, self.drop, mask)

        # Step 4 and 5: Concatenate attention scores, return projected output matrix
        out = self.out(self.combine_heads(attn))  # (B, S, n_embd)
        return out


class Block(nn.Module):
    def __init__(self, n_embd, n_head, seq_length, dropout=0.1):
        super().__init__()
        self.sa = MultiHeadAttention(n_embd, n_head, seq_length, dropout)
        self.mlp = MLP(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        # experimentally, apply layer norm before attention/MLP
        self.drop = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        # residual connection (stream)
        x = x + self.drop(self.sa(self.ln1(x), mask))
        x = x + self.drop(self.mlp(self.ln2(x)))
        return x


class PositionalEncoding(nn.Module):
    """
    Formula taken from the original Transformer paper:
    PE(pos, 2i (even)) = sin(pos/(10000^{2i/d_model}))
    PE(pos, 2i+1 (odd)) = cos(pos/(10000^{2i/d_model}))

    See reference for more details:
    https://kikaben.com/transformers-positional-encoding/
    """

    def __init__(self, d_model, max_len):
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)  # [max_len, 1]
        divisor = torch.exp(
            torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)
        )  # [d_model / 2, half for each of sin and cos]
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * divisor)
        pe[:, 1::2] = torch.cos(position * divisor)
        self.register_buffer("pe", pe)
        # result: self.pe = [max_len, d_model], mapping each token index to a vector of length d_model as desired

    def forward(self, x):
        # index self.pe for the first seq_length mappings
        # output = (seq_length, d_model=n_embd)
        return self.pe[: x.size(0)]


class BetterTransformer(nn.Module):
    def __init__(
        self,
        vocab_size,
        seq_length,
        n_embd,
        n_head,
        n_layer,
        pad_idx,
        eos_idx,
        device,
        dropout=0.1,
    ):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd, padding_idx=pad_idx)
        self.position_embedding = PositionalEncoding(n_embd, seq_length)
        self.blocks = nn.Sequential(
            *[Block(n_embd, n_head, seq_length, dropout) for _ in range(n_layer)]
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.drop = nn.Dropout(dropout)
        self.seq_length = seq_length
        self.pad_idx = pad_idx
        self.eos_idx = eos_idx
        self.device = device
        self.init_params()

    # optional weight initialization (Xavier uniform)
    def init_params(self, default_initialization=False):
        if not default_initialization:
            for name, p in self.named_parameters():
                if p.dim() > 1:
                    nn.init.xavier_uniform_(p)


    def get_causal_mask(self, x):
        """
        Generates causal mask for decoding
        """
        seq_len = x.size(-1)  # x = (batch_size x seq_len)
        attn_shape = (1, seq_len, seq_len)
        # k = 1 shifts the diagonal, so that the main diagonal is set to 0
        subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8")
        return (torch.from_numpy(subsequent_mask) == 0).to(
            self.device
        )  # (1, seq_len x seq_len)
        # returns: True along main diagonal + below, False elsewhere

    def get_pad_mask(self, x, pad_idx):
        """
        Generates padding mask
        """
        return (x != pad_idx).unsqueeze(1).unsqueeze(-2).to(self.device)
        # (B x 1 x 1 x seq_len)

    def forward(self, x, targets=None):

        # explicit cast in case
        x = x.to(torch.int64)
        B, S = x.shape
        
        # get mask
        mask = self.get_pad_mask(x, self.pad_idx) & self.get_causal_mask(x).to(
            self.device
        )
        # mask = (B x 1 x seq_len x seq_len)

        tok_emb = self.token_embedding(x)
        pos_emb = self.position_embedding(torch.arange(S).to(x.device))
        x = self.drop(tok_emb + pos_emb)
        # (B, S, n_embd)
        for block in self.blocks:
            x = block(x, ~mask)  # (B, seq_length, n_embd)
        # negate mask to fill originally False values with -inf later
        logits = self.lm_head(x)  # (B, seq_length, vocab_size)

        # Teacher forcing——for each text of seq length S we have S autoregressive predictions,
        # thus we have B*S logits and B*S targets
        if targets is None:
            loss = None
        else:
            B, S, C = logits.shape  # C = vocab_size
            logits = logits.view(B * S, C)
            targets = targets.view(B * S)
            loss = F.cross_entropy(logits, targets, ignore_index=self.pad_idx)

        return logits, loss

    def generate(
        self,
        input_ids,
        method="multinomial",
        max_new_tokens=1000,
        temp=None,
        num_beams=None,
        p_nucleus=None,
        k=None,
    ):

        # input_ids begins as (batch_size, seq_length)

        self.eval()

        for _ in range(max_new_tokens):
            # for future compatibility, if method == beam, may take a different approach
            if method in ["multinomial", "temperature", "greedy", "nucleus", "top-k"]:
                # i) Truncate to the most recent `max length` tokens
                text_cond = input_ids[:, -self.seq_length :]
                # ii) Retrieve predictions
                with torch.no_grad():
                    logits, _ = self(text_cond)
                # model output: (batch_size, seq_length, vocab_size)
                # iii) Find last token logits of each
                logits = logits[:, -1, :]  # (batch_size, vocab_size)

                # aside: if temperature sampling, divide logits by temp before applying softmax
                if method == "temperature":
                    logits = logits / temp

                # iv) Take softmax along each
                probs = F.softmax(logits, dim=-1)

                # v) Sample next token depending on method
                if method == "greedy":
                    next_idx = probs.argmax(dim=-1).unsqueeze(-1)

                elif method in ["multinomial", "temperature", "nucleus", "top-k"]:
                    if method == "nucleus":
                        assert (
                            p_nucleus is not None
                            and (0 < p_nucleus)
                            and (p_nucleus <= 1)
                        )

                        sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
                        prob_cumsum = sorted_probs.cumsum(dim=-1)
                        idx_remove = prob_cumsum > p_nucleus
                        # shift one right to ensure the first token is above the threshold
                        idx_remove[..., 1:] = idx_remove[..., :-1].clone()
                        idx_remove[..., 0] = False
                        # retrieve original indices by reverse-sorting
                        remove_mask = idx_remove.gather(
                            dim=-1, index=sorted_idx.argsort(dim=-1)
                        )
                        
                        probs[remove_mask] = 0

                    if method == "top-k":
                        remove_mask = (
                            probs < torch.topk(probs, k).values[..., -1, None]
                        )  # the topk returns (B, 1), leaving only the
                        # kth largest probs (i.e. the cutoff value for each). Then mask is same size as probs (B, vocab_size)
                        probs[remove_mask] = 0

                    # Sample probabilistically via scores
                    next_idx = torch.multinomial(
                        probs, num_samples=1
                    )  # (batch_size, 1)

                # vi) Autoregressively append to input_text
                input_ids = torch.cat((input_ids, next_idx), dim=-1)
                # end prematurely if <EOS> generated
                if (next_idx == self.eos_idx).any():
                    break
                # now input_text = (batch_size, seq_length + 1)

        return input_ids


In [3]:
def load_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
    initial_vocab_size = len(tokenizer)  

    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    final_vocab_size = len(tokenizer)  

    global VOCAB_SIZE 
    VOCAB_SIZE = final_vocab_size

    return tokenizer


In [4]:
data = load_dataset("terru3/tokenized_tinystories384")
data = data.with_format("torch")
val_data = data["validation"]

def eval_collate_wrapper(batch):
    inputs_list = [torch.tensor(item["input_ids"]) for item in batch] # list of tensors
    targets = torch.stack(inputs_list) # stack to get the targets tensor
    return inputs_list, targets 

subset_size = 10000
val_data_subset = val_data.select(range(subset_size))
val_dataloader = DataLoader(
    val_data_subset, shuffle=False, batch_size=32, collate_fn=eval_collate_wrapper
)

README.md:   0%|          | 0.00/489 [00:00<?, ?B/s]

train-00000-of-00012.parquet:   0%|          | 0.00/147M [00:00<?, ?B/s]

train-00001-of-00012.parquet:   0%|          | 0.00/147M [00:00<?, ?B/s]

train-00002-of-00012.parquet:   0%|          | 0.00/147M [00:00<?, ?B/s]

train-00003-of-00012.parquet:   0%|          | 0.00/147M [00:00<?, ?B/s]

train-00004-of-00012.parquet:   0%|          | 0.00/145M [00:00<?, ?B/s]

train-00005-of-00012.parquet:   0%|          | 0.00/147M [00:00<?, ?B/s]

train-00006-of-00012.parquet:   0%|          | 0.00/144M [00:00<?, ?B/s]

train-00007-of-00012.parquet:   0%|          | 0.00/147M [00:00<?, ?B/s]

train-00008-of-00012.parquet:   0%|          | 0.00/144M [00:00<?, ?B/s]

train-00009-of-00012.parquet:   0%|          | 0.00/147M [00:00<?, ?B/s]

train-00010-of-00012.parquet:   0%|          | 0.00/146M [00:00<?, ?B/s]

train-00011-of-00012.parquet:   0%|          | 0.00/147M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/17.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

In [5]:
def generate_unconditional(model, tokenizer, empty_tokens, epoch):
    """
    Generates model output for unconditional prompts using various decoding methods.

    Args:
        model (torch.nn.Module): The trained language model.
        tokenizer (AutoTokenizer): The tokenizer used for the model.
        empty_tokens (torch.Tensor): Tensor of bos_token_id for unconditional generation.
        epoch (int): The current training epoch.

    Returns:
        str: Formatted string containing the unconditional generation results.
    """
    set_seed(42)

    uncond_res1 = tokenizer.batch_decode(
        model.generate(empty_tokens, method="top-k", k=5, max_new_tokens=150)
    )[0]
    uncond_res2 = tokenizer.batch_decode(
        model.generate(empty_tokens, method="greedy", max_new_tokens=150)
    )[0]
    uncond_res3 = tokenizer.batch_decode(
        model.generate(
            empty_tokens, method="nucleus", p_nucleus=0.5, max_new_tokens=150
        )
    )[0]
    uncond_res4 = tokenizer.batch_decode(
        model.generate(empty_tokens, method="multinomial", max_new_tokens=150)
    )[0]
    uncond_res5 = tokenizer.batch_decode(
        model.generate(empty_tokens, method="top-k", k=5, max_new_tokens=250)
    )[0]
    uncond_res6 = tokenizer.batch_decode(
        model.generate(
            empty_tokens, method="nucleus", p_nucleus=0.5, max_new_tokens=250
        )
    )[0]

    uncond_generation_text = f"""UNCONDITIONAL GENERATION:

    Top-k (5) (150 max_tokens):
    {uncond_res1}

    Greedy (150 max_tokens):
    {uncond_res2}

    Nucleus (0.5) (150 max_tokens):
    {uncond_res3}

    Multinomial (150 max_tokens):
    {uncond_res4}

    Top-k (5) (250 max_tokens):
    {uncond_res5}

    Nucleus (0.5) (250 max_tokens):
    {uncond_res6}
    """
    return uncond_generation_text


def generate_conditional(model, tokenizer, cond_token_list, device):
    """
    Generates model output based on a list of conditional prompts using top-k sampling.

    Args:
        model (torch.nn.Module): The trained language model.
        tokenizer (AutoTokenizer): The tokenizer used for the model.
        cond_token_list (list of torch.Tensor): List of tokenized conditional prompts.
        device (torch.device): The device to run generation on (cpu or cuda).

    Returns:
        str: Formatted string containing the conditional generation results.
    """
    set_seed(42)
    cond_res_list = []
    for prompt in cond_token_list:
        gen_tokens = model.generate(
            torch.tensor(prompt).unsqueeze(0).long().to(device),
            method="top-k",
            k=5,
            max_new_tokens=250,
        )[0]

        # delimiter to indicate where prompt ends
        gen_prep = torch.zeros(
            len(gen_tokens) + 2
        ).long()  # make space for two more tokens for delimiter
        gen_prep -= 1  # set all ids to -1 to avoid clashing with token ids
        # fill in prompt and generated tokens
        gen_prep[: len(prompt)] = gen_tokens[: len(prompt)]
        gen_prep[-(len(gen_tokens) - len(prompt)) :] = gen_tokens[
            -(len(gen_tokens) - len(prompt)) :
        ]
        # insert tokens for || in the remaining indices between
        gen_prep[gen_prep == -1] = torch.tensor(tokenizer.encode(" || "))

        cond_res = tokenizer.decode(gen_prep)
        cond_res_list.append(cond_res)

    cond_res_list_str = "\n\n".join(cond_res_list)

    conditional_generation_text = f"""CONDITIONAL GENERATION (Top-k (5), 250 max_tokens):
    {cond_res_list_str}
    -----------------------------------------------------
    """
    return conditional_generation_text


def generate_train(
    model, tokenizer, generation_file_path, empty_tokens, cond_token_list, epoch
):
    """
    Generates model output for unconditional and conditional prompts and writes to file.

    Args:
        model (torch.nn.Module): The trained language model.
        tokenizer (AutoTokenizer): The tokenizer used for the model.
        generation_file_path (str): Path to the file to save the generated text.
        empty_tokens (torch.Tensor): Tensor of bos_token_id for unconditional generation.
        cond_token_list (list of torch.Tensor): List of tokenized conditional prompts.
        epoch (int): The current training epoch.
    """
    unconditional_output = generate_unconditional(model, tokenizer, empty_tokens, epoch)
    conditional_output = generate_conditional(model, tokenizer, cond_token_list, model.device)

    generation_text = f""" Output @Epoch {epoch}
    {unconditional_output}

    #####################################################
    {conditional_output}
    """
    with open(generation_file_path, "a") as file:
        file.write(generation_text)
    print(generation_text)

In [6]:
def load_big_model(tokenizer, device, n_head=16, n_layer=8, n_embd=768, vocab_size=50258, seq_length=384, data_pct=70, path="/kaggle/input/transformer-2"):

    set_seed(42) 

    if os.path.isfile(path):
        load_path = path

    model = BetterTransformer(
        vocab_size,
        seq_length,
        n_embd,
        n_head,
        n_layer,
        tokenizer.pad_token_id,
        tokenizer.eos_token_id,
        device=device,
    )
    model.init_params()

    try:
        checkpoint = torch.load(load_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint["model_state_dict"])
        else:
            model.load_state_dict(checkpoint)
        model.to(device)
        print(f"Loaded model from: {load_path}")
    except FileNotFoundError:
        print(f"Error: Model file not found at: {load_path}")
    except Exception as e:
        print(f"Error loading model: {e}")

    return model


In [14]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model_path = "/kaggle/input/big-model-10-epochs/transformer-2/model/bt_model_epoch_8.pt"  # ADD MODEL PATH HERE
model_params = {
        'n_head': 4,
        'n_layer': 4,
        'n_embd': 256,
        'vocab_size': 50258,
        'seq_length': 384,
        'data_pct' : 70
    }

tokenizer = load_tokenizer()
epoch = 0

empty_tokens = torch.full((1, 1), tokenizer.bos_token_id, dtype=torch.long).to(device)

# ADD STARTING PROMPTS HERE
cond_prompts = [
    "Timmy wanted to buy toys",
    "Once there was a bird named parrot"
]
cond_token_list = tokenizer(cond_prompts, padding=True, return_tensors='pt').input_ids.to(device)

model = load_big_model(tokenizer, device, n_head=model_params['n_head'], n_layer=model_params['n_layer'], n_embd=model_params['n_embd'], vocab_size=50258, seq_length=384, data_pct=model_params['data_pct'], path=model_path)

model.to(device)
generation_file_path = "output.txt"
generate_train(model, tokenizer, generation_file_path, empty_tokens, cond_token_list, epoch)



  checkpoint = torch.load(load_path, map_location=device)


Loaded model from: /kaggle/input/big-model-10-epochs/transformer-2/model/bt_model_epoch_8.pt


  torch.tensor(prompt).unsqueeze(0).long().to(device),


 Output @Epoch 0
    UNCONDITIONAL GENERATION:

    Top-k (5) (150 max_tokens):
    <|endoftext|>One day, a little boy named Tim found a big, red ball. It was so shiny and pretty. Tim was very happy. He wanted to show his friends, Sam, how much he loved the ball.

Tim said to Sam, "Let's test how to play with the red ball. I bet I can run and jump and catch it!" Sam agreed. They played with the ball for a while, and then Tim's friends came.

They played with the red ball all day long. They had lots of fun. When it was time for lunch, Tim and Sam said bye. They played with the red ball again, and they both had a great day.<|endoftext|>

    Greedy (150 max_tokens):
    <|endoftext|>Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, scary monster in the sky. She was scared and ran away.

Her mommy came outside and asked, "What's wrong, Lily?"

"I saw a big monster in the sky," Lily said.

"Don't worry, Lily. It's just 