In [None]:
device = "cuda"

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
import torch
from torch import nn


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        # Linear projections for Q, K, V
        self.queries = nn.LazyLinear(out_features=embed_dim)
        self.keys = nn.LazyLinear(out_features=embed_dim)
        self.values = nn.LazyLinear(out_features=embed_dim)

        # Multi-head attention block (batch_first=True means inputs shape [batch, seq, embed_dim])
        self.att_block = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

        # Feedforward network
        self.feedforward = nn.Sequential(
            nn.LazyLinear(out_features=embed_dim * 4),
            nn.LeakyReLU(),
            nn.LazyLinear(out_features=embed_dim)
        )

        # Layer normalization layers
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # Dropout layer for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask=None):
        # Self-attention sub-layer with residual connection and normalization.
        residual = x

        # Linear projections for Q, K, V.
        q, k, v = self.queries(x), self.keys(x), self.values(x)

        seq_len = x.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).bool()
        attn_out, _ = self.att_block(
            q, k, v,
            attn_mask=~causal_mask,  # mask future tokens
            key_padding_mask=key_padding_mask  # mask pads
        )

        # Apply multi-head attention.
        att_output, _ = self.att_block(q, k, v)
        att_output = self.dropout(att_output)

        # Add residual and normalize.
        x = self.norm1(residual + att_output)

        # Feedforward sub-layer with residual connection and normalization.
        residual2 = x
        ff_output = self.feedforward(x)
        ff_output = self.dropout(ff_output)
        x = self.norm2(residual2 + ff_output)

        return x


In [None]:
import torch
from torch import nn


class Decoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.LazyLinear(out_features=vocab_size)
        )

    def forward(self, attended_sequence):
        outputs = self.decoder(attended_sequence)
        return outputs

In [None]:
import torch
from torch import nn


class Model(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_transformer_blocks, max_sequence_length):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(max_sequence_length, embed_dim)

        # Create a ModuleList of transformer blocks.
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim=embed_dim, num_heads=num_heads)
            for _ in range(num_transformer_blocks)
        ])

        self.decoder = Decoder(vocab_size=vocab_size)

    def forward(self, input_ids, attention_mask):
        batch_size, seq_length = input_ids.shape

        # Embed tokens.
        embedded_sequence = self.embedding(input_ids)  # Shape: [batch_size, seq_length, embed_dim]

        # Create positional embeddings.
        positions = torch.arange(seq_length, device=embedded_sequence.device).unsqueeze(0).expand(batch_size,
                                                                                                  seq_length)
        pos_embeds = self.positional_embedding(positions)

        # Sum token and positional embeddings.
        x = embedded_sequence + pos_embeds

        # Pass through the stacked transformer blocks.
        for block in self.transformer_blocks:
            x = block(x, key_padding_mask=(attention_mask == 0))

        # Decoder projection to vocabulary logits.
        predictions = self.decoder(x)
        return predictions


In [None]:
import torch

model_path = "models/vocabsize_(50257,)_embeddim_(768,)_numheads_(12,)_numtransformerblocks_(12,)_maxseqlen_(1024,)_timestamp_20250405_185858.pt"

model = Model(
    vocab_size=50257,
    embed_dim=768,
    num_heads=12,
    num_transformer_blocks=12,
    max_sequence_length=1024,
).to(device)

state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
# print(model.positional_embedding)

In [None]:
import torch
import torch.nn.functional as F


def generate_and_stream(
        model,
        tokenizer,
        base_text: str,
        max_gen_length: int = 100,
        temperature: float = 1.0,
        top_k: int = 50,
        top_p: float = 0.9,
        repetition_penalty: float = 1.2,
        device: str = "cuda"
):
    model.eval()
    generated = tokenizer.encode(base_text)  # list of token ids

    # Figure out max context from your model's positional embeddings
    max_context = model.positional_embedding.num_embeddings

    # Print the prompt
    print(base_text, end="", flush=True)

    with torch.no_grad():
        for _ in range(max_gen_length):
            # 1) Prepare inputs, truncate to last `max_context` tokens
            context = generated[-max_context:]
            input_ids = torch.tensor([context], device=device)
            attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)

            # 2) Forward pass
            logits = model(input_ids, attention_mask)  # [1, seq_len, vocab_size]
            next_logits = logits[:, -1, :]  # [1, vocab_size]

            # 3) Repetition penalty
            for token_id in set(generated):
                if next_logits[0, token_id] < 0:
                    next_logits[0, token_id] *= repetition_penalty
                else:
                    next_logits[0, token_id] /= repetition_penalty

            # 4) Temperature
            next_logits = next_logits / temperature

            # 5) Top‑k
            if top_k > 0:
                vals, _ = torch.topk(next_logits, top_k)
                threshold = vals[:, -1].unsqueeze(1)
                next_logits = torch.where(
                    next_logits < threshold,
                    torch.full_like(next_logits, -float("Inf")),
                    next_logits
                )

            # 6) Top‑p (nucleus)
            if top_p < 1.0:
                sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
                cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                # mask tokens with cumulative prob above top_p
                remove_mask = cum_probs > top_p
                # always keep at least one
                remove_mask[:, 1:] = remove_mask[:, :-1].clone()
                remove_mask[:, 0] = False

                # scatter back to original ordering
                to_remove = remove_mask.scatter(1, sorted_idx, remove_mask)
                next_logits = next_logits.masked_fill(to_remove, -float("Inf"))

            # 7) Sample
            probs = F.softmax(next_logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1).item()

            # 8) Stream the token
            token_str = tokenizer.decode([next_id], skip_special_tokens=False)
            print(token_str, end="", flush=True)

            generated.append(next_id)
            if next_id == tokenizer.eos_token_id:
                break

    print()  # newline at end
    return tokenizer.decode(generated, skip_special_tokens=True)

In [None]:
BASE_TEXT = "compassionately explores the seemingly irreconcilable situation between conservative christian parents and their estranged gay and lesbian children"

full_text = generate_and_stream(
    model,
    tokenizer,
    base_text=BASE_TEXT,
    max_gen_length=200,
    temperature=0.0001,
    top_k=50,
    top_p=0.95,
    repetition_penalty=1.1,
    device=device
)
