In [2]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn

In [4]:
def causal_attention_mask(sequence_length: int, dtype=torch.float32):
    mask = torch.tril(torch.ones((1, 1, sequence_length, sequence_length), dtype=dtype))
    mask = mask.masked_fill(mask == 0, float("-inf")) 
    return mask

causal_attention_mask(5, dtype=torch.bfloat16)

tensor([[[[1., -inf, -inf, -inf, -inf],
          [1., 1., -inf, -inf, -inf],
          [1., 1., 1., -inf, -inf],
          [1., 1., 1., 1., -inf],
          [1., 1., 1., 1., 1.]]]], dtype=torch.bfloat16)

In [None]:
a = torch.ones((512, 1024))

x = a.split(2, dim=-1)

In [None]:
non_kv = torch.load("k.tensor")
kv = torch.load("k_kv.tensor")

In [None]:
kv[0, 0, 5, 0], non_kv[0, 0, 5, 0]

In [2]:
from omni.preprocessing.tokenizer import AutoTokenizer

tokenizer = AutoTokenizer.create("EleutherAI/gpt-neo-125m")
tokenizer.add_special_tokens({"pad_token": "<pad>"})

  from .autonotebook import tqdm as notebook_tqdm


1

In [3]:
text = "This is a test"

tokenizer(text, max_length = 20, padding='max_length')

{'input_ids': [1212, 318, 257, 1332, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257], 'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}

In [3]:
def causal_attention_mask(sequence_length: int, dtype=torch.float32):
    mask = torch.triu(torch.ones((1, 1, sequence_length, sequence_length), dtype=dtype), diagonal=1)
    mask = mask.masked_fill(mask == 1, float("-inf")) 
    return mask

mask = causal_attention_mask(5)

In [4]:
mask.to(torch.bfloat16)

tensor([[[[0., -inf, -inf, -inf, -inf],
          [0., 0., -inf, -inf, -inf],
          [0., 0., 0., -inf, -inf],
          [0., 0., 0., 0., -inf],
          [0., 0., 0., 0., 0.]]]], dtype=torch.bfloat16)

In [None]:
import torch

from omni.architectures.llama import LlamaConfig
from omni.modules.transformer import Transformer
from omni.preprocessing.tokenizer import AutoTokenizer
from omni.utils.system import auto_device

tokenizer = AutoTokenizer.create("EleutherAI/gpt-neo-125m")
tokenizer.add_special_tokens({"pad_token": "<pad>"})

llama_config = LlamaConfig(
    vocab_size=50258,
    seq_len=512,
    d_model=256,
    num_heads=8,
    num_kv_heads=8,
    num_layers=4,
    rope_theta=0.1,
    norm_eps=1e-6,
    activation_fn="silu",
    mlp_bias=False,
    mlp_dropout=0.0,
    attention_bias=False,
    attention_dropout=0.0,
    pos_encoding_type="rope",
    mlp="mlp_swiglu",
    normalization="rmsnorm",
    attention="gqa",
)

model = Transformer(llama_config)

## create KV cache

checkpoint = torch.load("checkpoints/llama-30M_20250123_104138/init.ckpt", weights_only=True)
model.load_state_dict(checkpoint["model"])

device="mps"
model = model.to(device)
model.eval()

In [None]:
device="mps"
prompt = "Once upon a time"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

In [None]:
outputs = model(input_ids)
next_token_logits = outputs[:, -1, :]

In [None]:
top_k_values, top_k_indices = torch.topk(logits, self.top_k, dim=-1)

In [None]:
top_k_values = torch.tensor([float("-inf"), 1.0, 2.0])

top_k_values
nn.functional.softmax(top_k_values, dim=-1)

In [None]:
float("-inf")

In [None]:
sorted_logits, sorted_indices = torch.sort(
    top_k_values, dim=-1, descending=True
)
cumulative_probs = torch.cumsum(
    nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)

In [None]:
next_token_logits

In [None]:
visualize_logit_distribution(next_token_logits[0], top_k=500)

In [None]:
import torch
import matplotlib.pyplot as plt

def visualize_logit_distribution(logits, top_k=500):
    """
    Visualizes the logit distribution by focusing on the top-k logits.

    Args:
        logits (torch.Tensor): Logits tensor of shape (vocab_size,).
        top_k (int): Number of top logits to visualize. Defaults to 500.
    """
    if logits.dim() != 1:
        raise ValueError("Logits tensor must be 1-dimensional (vocab_size,).")

    # Convert logits to probabilities and take log
    log_probs = torch.log_softmax(logits, dim=0)

    # Get the top-k log probabilities and their indices
    top_log_probs, top_indices = torch.topk(log_probs, k=top_k)

    # Sort the top-k log probabilities for visualization
    sorted_log_probs, sorted_indices = torch.sort(top_log_probs, descending=True)

    # Plot the distribution
    plt.figure(figsize=(10, 6))
    plt.plot(sorted_log_probs.detach().cpu().numpy(), marker="o", linestyle="-")
    plt.title(f"Log Probability Distribution (Top {top_k})", fontsize=14)
    plt.xlabel("Rank", fontsize=12)
    plt.ylabel("Log Probability Value", fontsize=12)
    plt.grid(alpha=0.5)
    plt.show()


In [None]:
import tokenizers.processors as processors

def _add_bos_token(tokenizer: AutoTokenizer):
    tokenizer._tokenizer.post_processor = processors.Sequence(
        [
            processors.ByteLevel(trim_offsets=False),
            processors.TemplateProcessing(
                single=f"{tokenizer.bos_token}:0 $A:0",
                pair=f"{tokenizer.bos_token}:0 $A:0 {tokenizer.bos_token}:1 $B:1",
                special_tokens=[
                    (tokenizer.bos_token, tokenizer.bos_token_id),
                ],
            ),
        ]
    )
    return tokenizer

In [None]:
tokenizer2 = _add_bos_token(tokenizer) 

In [None]:
torch.log(torch.tensor(50000))

In [None]:
test = "thjis is a test, this is a test"

output = tokenizer2(test, padding="max_length", max_length=3, truncation=True, return_overflowing_tokens=True)

In [None]:
d_model=128
hidden_dim = 4 * int(2 * d_model / 3)
hidden_dim = 4 * (
    (hidden_dim + 4 - 1) // 4
)

hidden_dim

In [None]:
import torch

logits = torch.tensor([1,5,10], dtype=torch.float32)
torch.softmax(logits, dim=0)

In [None]:
tokenizer.special_tokens_map

In [None]:
tokenizer.encode(x, padding="max_length", max_length=max_length, truncation=True, return_overflowing_tokens=True)

In [None]:
tokenizer(test, )

In [None]:
output

In [None]:
from omni.modules.pos_embeddings import precompute_freqs_cis_real

pos_embeddings = precompute_freqs_cis_real(64, 512)

In [None]:
base = torch.pow(2.0, torch.tensor(-8/16))

In [None]:
def _tokenize(dataset, tokenizer: AutoTokenizer, num_proc):
    print("Tokenizing dataset...")
    dataset = dataset.map(
        lambda x: tokenizer(
            x["text"],
            truncation=True,
            max_length=512,
            padding="max_length",
            return_overflowing_tokens=True,
        ),
        batched=True,
        num_proc=num_proc,
        remove_columns=dataset.column_names,
    )
    dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
    return dataset

In [None]:
from datasets import load_dataset

ds = load_dataset("roneneldan/TinyStories", split="train", num_proc=1)

In [None]:
_tokenize(ds, tokenizer, 2)