In [1]:
!pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.9.0


In [1]:
import jax
import jax.numpy as jnp
from flax import nnx
import numpy as np
from functools import lru_cache
from dataclasses import dataclass
from typing import Any
from jax import checkpoint

In [2]:
@dataclass
class ModelArgs:
    vocab_size: int
    context_length: int
    embedding_dim: int
    n_heads: int
    n_layers: int
    hidden_dim: int
    n_kv_groups: int
    rope_base: float
    rope_freq: dict
    param_dtype: jnp.dtype = jnp.float32
    dtype: jnp.dtype = jnp.bfloat16

## Model Definition

In [3]:
class FeedForward(nnx.Module):
    def __init__(self, args: ModelArgs, rngs: nnx.Rngs):
        super().__init__()
        self.fc1 = nnx.Linear(args.embedding_dim, args.hidden_dim, rngs=rngs, use_bias=False, dtype=args.dtype)
        self.fc2 = nnx.Linear(args.embedding_dim, args.hidden_dim, rngs=rngs, use_bias=False, dtype=args.dtype)
        self.fc3 = nnx.Linear(args.hidden_dim, args.embedding_dim, rngs=rngs, use_bias=False, dtype=args.dtype)

    def __call__(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = nnx.silu(x_fc1) * x_fc2
        return self.fc3(x)

def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):
    assert head_dim % 2 == 0, "Embedding dimension must be even"

    # Compute inverse frequencies
    inv_freq = 1.0 / (theta_base ** (jnp.arange(0, head_dim, 2) / head_dim))

    # Frequency adjustments (optional)
    if freq_config is not None:
      low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
      high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]

      wavelen = 2 * jnp.pi / inv_freq

      inv_freq_llama = jnp.where(
          wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
      )

      smooth_factor = (
          (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"])
          / (freq_config["high_freq_factor"] - freq_config["low_freq_factor"])
      )

      smoothed_inv_freq = (
          (1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
      )

      is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
      inv_freq_llama = jnp.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)

      inv_freq = inv_freq_llama

    # Position indices
    positions = jnp.arange(context_length)

    # Compute rotary angles
    angles = positions[:, None] * inv_freq[None, :]  # (context_len, head_dim // 2)

    # Duplicate for interleaved dimensions
    angles = jnp.concatenate([angles, angles], axis=-1)  # (context_len, head_dim)

    # Precompute sin/cos
    cos = jnp.cos(angles)
    sin = jnp.sin(angles)

    return cos, sin

def compute_rope(x, cos, sin):
    # x: (batch, heads, seq_len, head_dim)
    batch, heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, "Head dim must be even"

    x1 = x[..., :head_dim // 2]
    x2 = x[..., head_dim // 2:]

    # Broadcast cos/sin: (1, 1, seq_len, head_dim)
    cos = cos[:seq_len, :][None, None, :, :]
    sin = sin[:seq_len, :][None, None, :, :]

    # Rotary transformation
    rotated = jnp.concatenate([-x2, x1], axis=-1)
    return (x * cos) + (rotated * sin)


class GroupedQueryAttention(nnx.Module):
    def __init__(
        self, d_in, d_out, context_length, num_heads, num_kv_groups,
        rope_base=10_000, rope_config=None, dtype=jnp.float32, rngs=nnx.Rngs
    ):
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

        self.d_out = d_out
        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.head_dim = d_out // num_heads
        self.group_size = num_heads // num_kv_groups

        self.W_query = nnx.Linear(d_in, d_out, use_bias=False, dtype=dtype, rngs=rngs)
        self.W_key = nnx.Linear(d_in, num_kv_groups * self.head_dim, use_bias=False, dtype=dtype, rngs=rngs)
        self.W_value = nnx.Linear(d_in, num_kv_groups * self.head_dim, use_bias=False, dtype=dtype, rngs=rngs)
        self.out_proj = nnx.Linear(d_out, d_out, use_bias=False, dtype=dtype, rngs=rngs)

        # Direct JAX computation without caching
        cos, sin = precompute_rope_params(
            head_dim=self.head_dim,
            theta_base=rope_base,
            context_length=context_length,
            freq_config=rope_config
        )
        self.cos = cos.astype(dtype)
        self.sin = sin.astype(dtype)
        self.rope_base = rope_base
        self.freq_config = rope_config
        self.context_length = context_length

    def __call__(self, x: jax.Array) -> jax.Array:
        b, seq_len, _ = x.shape

        q = self.W_query(x).reshape(b, seq_len, self.num_heads, self.head_dim)
        k = self.W_key(x).reshape(b, seq_len, self.num_kv_groups, self.head_dim)
        v = self.W_value(x).reshape(b, seq_len, self.num_kv_groups, self.head_dim)

        # Transpose for attention shape: (b, heads, seq, dim)
        q = jnp.transpose(q, (0, 2, 1, 3))
        k = jnp.transpose(k, (0, 2, 1, 3))
        v = jnp.transpose(v, (0, 2, 1, 3))

        # cos, sin = precompute_rope_params(
        #     head_dim=self.head_dim,
        #     theta_base=self.rope_base,
        #     context_length=seq_len,  # not the full context
        #     freq_config=self.freq_config
        # )
        # cos = cos.astype(x.dtype)
        # sin = sin.astype(x.dtype)
        cos = self.cos[:seq_len]
        sin = self.sin[:seq_len]

        # Apply RoPE
        q = compute_rope(q, cos, sin)
        k = compute_rope(k, cos, sin)

        # Expand k and v from kv_groups to full head count
        k = jnp.repeat(k, self.group_size, axis=1)
        v = jnp.repeat(v, self.group_size, axis=1)

        # Attention scores
        attn_scores = jnp.einsum("bhqd,bhkd->bhqk", q, k)  # (b, heads, query_len, key_len)

        # Apply causal mask
        mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))
        # attn_scores = jnp.where(mask, -jnp.inf, attn_scores)
        attn_scores = jnp.where(mask, attn_scores, -jnp.inf)


        # Softmax
        attn_weights = jax.nn.softmax(attn_scores / jnp.sqrt(self.head_dim), axis=-1)

        # Attention output
        context = jnp.einsum("bhqk,bhkd->bhqd", attn_weights, v)
        context = jnp.transpose(context, (0, 2, 1, 3))  # (b, seq, heads, dim)
        context = context.reshape(b, seq_len, self.d_out)

        return self.out_proj(context)

class TransformerBlock(nnx.Module):
    def __init__(self, args: ModelArgs, rngs: nnx.Rngs):
        super().__init__()


        self.att = GroupedQueryAttention(
            d_in=args.embedding_dim,
            d_out=args.embedding_dim,
            context_length=args.context_length,
            num_heads=args.n_heads,
            num_kv_groups=args.n_kv_groups,
            rope_base=args.rope_base,
            rope_config=args.rope_freq,
            dtype=args.dtype,
            rngs=rngs
        )
        self.ff = FeedForward(args, rngs=rngs)
        self.norm1 = nnx.RMSNorm(args.embedding_dim, epsilon=1e-5, rngs=rngs)
        self.norm2 = nnx.RMSNorm(args.embedding_dim, epsilon=1e-5, rngs=rngs)

    def __call__(self, x, cos=None, sin=None, cache=None, cache_index=None):
        shortcut = x
        x = self.norm1(x)

        x = checkpoint(self.att)(x.astype(jnp.bfloat16))
        x = x + shortcut

        shortcut = x
        x = self.norm2(x)
        x = checkpoint(self.ff)(x.astype(jnp.bfloat16))
        x = x + shortcut
        return x


class Llama3Model(nnx.Module):
    def __init__(self, args: ModelArgs, rngs: nnx.Rngs):
        super().__init__()

        self.tok_emb = nnx.Embed(
            num_embeddings=args.vocab_size,
            features=args.embedding_dim,
            dtype=args.dtype,
            param_dtype=args.param_dtype,
            rngs=rngs
        )

        self.trf_blocks = [
            TransformerBlock(args, rngs=rngs)
            for _ in range(args.n_layers)
        ]

        self.final_norm = nnx.RMSNorm(
            args.embedding_dim,
            epsilon=1e-5,
            dtype=args.dtype,
            param_dtype=args.param_dtype,
            rngs=rngs
        )

        self.out_head = nnx.Linear(
            args.embedding_dim,
            args.vocab_size,
            use_bias=False,
            dtype=args.dtype,
            param_dtype=args.param_dtype,
            rngs=rngs
        )

    def __call__(self, in_idx: jax.Array):
        x = self.tok_emb(in_idx)
        for block in self.trf_blocks:
            x = block(x)
        x = self.final_norm(x)
        logits = self.out_head(x.astype(jnp.bfloat16))
        return logits

In [24]:
# Define model arguments using LLAMA32_CONFIG
LLAMA32_CONFIG = {
    "vocab_size": 128_256,
    "context_length": 131_072,
    "embedding_dim": 2048,
    "n_heads": 32,
    "n_layers": 16,
    "hidden_dim": 8192,
    "n_kv_groups": 8,
    "rope_base": 500_000.0,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,

    "rope_freq": {
        "factor": 32.0,
        "low_freq_factor": 1.0,
        "high_freq_factor": 4.0,
        "original_context_length": 8192,
    }
}


In [None]:
# Save the old context length
old_context_length = LLAMA32_CONFIG["context_length"]

# Set new context length
LLAMA32_CONFIG["context_length"] = 8192

# Rescale theta (RoPE base)
def rescale_theta(theta_old, context_length_old, context_length_new):
    scaling_factor = context_length_new / context_length_old
    return theta_old * scaling_factor

old_context_length = LLAMA32_CONFIG["context_length"]

# Set new context length
LLAMA32_CONFIG["context_length"] = 8192

LLAMA32_CONFIG["rope_base"] = rescale_theta(
    LLAMA32_CONFIG["rope_base"],
    old_context_length,
    LLAMA32_CONFIG["context_length"]
)

print("New RoPE theta:", LLAMA32_CONFIG["rope_base"])


New RoPE theta: 31250.0


### Miscellaneous Checks

In [6]:
args = ModelArgs(**LLAMA32_CONFIG)

rngs = nnx.Rngs(0)

# Instantiate the model
model = Llama3Model(args=args, rngs=rngs)
rng = jax.random.PRNGKey(0)


# # Example dummy input (batch_size=1, seq_len=128)
# dummy_input = jnp.ones((1, 128), dtype=jnp.int32)

# # Forward pass (no cache in this case)
# logits = model(dummy_input)
# print("Logits shape:", logits.shape)  # Should be (1, 128, vocab_size)

In [8]:
# Example dummy input (batch_size=1, seq_len=128)
dummy_input = jnp.ones((1, 128), dtype=jnp.int32)

# Forward pass (no cache in this case)
logits = model(dummy_input)
print("Logits shape:", logits.shape)  # Should be (1, 128, vocab_size)

Logits shape: (1, 128, 128256)


In [15]:
type(model.tok_emb.embedding.value)

jaxlib.xla_extension.ArrayImpl

In [7]:
init_weights = nnx.state(model)

init_weights

State({
  'final_norm': {
    'scale': VariableState( # 2,048 (4.1 KB)
      type=Param,
      value=Array([1, 1, 1, ..., 1, 1, 1], dtype=bfloat16)
    )
  },
  'out_head': {
    'kernel': VariableState( # 262,668,288 (525.3 MB)
      type=Param,
      value=Array([[-0.0336914, 0.0167236, -0.013855, ..., 0.0314941, 0.00521851,
              -0.00765991],
             [0.00135803, -0.000492096, 0.00958252, ..., 0.00570679,
              -0.00282288, 0.000492096],
             [-0.0197754, 0.0402832, -0.0303955, ..., 0.0117188, 0.00196838,
              -0.0167236],
             ...,
             [0.0159912, -0.0427246, 0.00723267, ..., 0.0142212, -0.013855,
              -0.0314941],
             [-0.0253906, -0.0197754, -0.0303955, ..., -0.0117188, -0.0197754,
              0.0117188],
             [0.0117188, -0.0148315, -0.0349121, ..., 0.0032196, -0.0159912,
              -0.0214844]], dtype=bfloat16)
    )
  },
  'tok_emb': {
    'embedding': VariableState( # 262,668,288 (525.3 MB)

In [8]:
loaded_weights = nnx.state(model)

loaded_weights

State({
  'final_norm': {
    'scale': VariableState( # 2,048 (4.1 KB)
      type=Param,
      value=Array([2.45312, 2.25, 1.53906, ..., 2.51562, 2.40625, 2.5], dtype=bfloat16)
    )
  },
  'out_head': {
    'kernel': VariableState( # 262,668,288 (525.3 MB)
      type=Param,
      value=Array([[-0.0336914, 0.0167236, -0.013855, ..., 0.0314941, 0.00521851,
              -0.00765991],
             [0.00135803, -0.000492096, 0.00958252, ..., 0.00570679,
              -0.00282288, 0.000492096],
             [-0.0197754, 0.0402832, -0.0303955, ..., 0.0117188, 0.00196838,
              -0.0167236],
             ...,
             [0.0159912, -0.0427246, 0.00723267, ..., 0.0142212, -0.013855,
              -0.0314941],
             [-0.0253906, -0.0197754, -0.0303955, ..., -0.0117188, -0.0197754,
              0.0117188],
             [0.0117188, -0.0148315, -0.0349121, ..., 0.0032196, -0.0159912,
              -0.0214844]], dtype=bfloat16)
    )
  },
  'tok_emb': {
    'embedding': VariableSt

In [None]:
from jax import tree_util
import jax.numpy as jnp

def check_for_nan_inf(tree):
    def test(x):
        return jnp.isnan(x).any() or jnp.isinf(x).any()
    flagged = tree_util.tree_map(test, tree)
    return flagged

nan_flags = check_for_nan_inf(loaded_weights)
print(nan_flags)


[38;2;79;201;177mState[0m[38;2;255;213;3m({[0m[38;2;105;105;105m[0m
  [38;2;156;220;254m'final_norm'[0m[38;2;212;212;212m: [0m[38;2;255;213;3m{[0m[38;2;105;105;105m[0m
    [38;2;156;220;254m'scale'[0m[38;2;212;212;212m: [0m[38;2;79;201;177mVariableState[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (1 B)[0m
      [38;2;156;220;254mtype[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m,
      [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray(False, dtype=bool)
    [38;2;255;213;3m)[0m
  [38;2;255;213;3m}[0m,
  [38;2;156;220;254m'out_head'[0m[38;2;212;212;212m: [0m[38;2;255;213;3m{[0m[38;2;105;105;105m[0m
    [38;2;156;220;254m'kernel'[0m[38;2;212;212;212m: [0m[38;2;79;201;177mVariableState[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (1 B)[0m
      [38;2;156;220;254mtype[0m[38;2;212;212;212m=[0m[38;2;79;201;177mParam[0m,
      [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray(False, dtype=bool)
    [38;2;255;213;3

## Loading Pretrained Weights

In [16]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: fineGrained).
The token `llama-3` has been saved to /root/.cache/huggingface/stored_tokens
Your token has been saved to /root/.cache/huggingface/token
Login successful.
The current active token is: `llama-3`


In [17]:
from huggingface_hub import hf_hub_download

weights_file = hf_hub_download(
        repo_id=f"meta-llama/Llama-3.2-1B-Instruct",
        filename="model.safetensors",
        local_dir=f"Llama-3.2-1B-Instruct"
    )

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

In [18]:
from huggingface_hub import hf_hub_download

tokenizer_file = hf_hub_download(
    repo_id=f"meta-llama/Llama-3.2-1B-Instruct",
    filename="original/tokenizer.model",  # Use the correct filename for the tokenizer
    local_dir=f"Llama-3.2-1B-Instruct/original",  # Specify the subfolder where the tokenizer is located (if any)
)

tokenizer.model:   0%|          | 0.00/2.18M [00:00<?, ?B/s]

### Checking model layer names

In [27]:
from safetensors.torch import load_file

# Path to the safetensors file
safetensor_path = "/content/Llama-3.2-1B-Instruct/model.safetensors"

# Load the tensors from the safetensor file
pt_state_dict = load_file(safetensor_path)

# Print out the names and shapes of the tensors
for name, tensor in pt_state_dict.items():
    print(name, tensor.shape)


#do not run this always as this loads the entire model in memory, run it once only to see the layer names

model.embed_tokens.weight torch.Size([128256, 2048])
model.layers.0.input_layernorm.weight torch.Size([2048])
model.layers.0.mlp.down_proj.weight torch.Size([2048, 8192])
model.layers.0.mlp.gate_proj.weight torch.Size([8192, 2048])
model.layers.0.mlp.up_proj.weight torch.Size([8192, 2048])
model.layers.0.post_attention_layernorm.weight torch.Size([2048])
model.layers.0.self_attn.k_proj.weight torch.Size([512, 2048])
model.layers.0.self_attn.o_proj.weight torch.Size([2048, 2048])
model.layers.0.self_attn.q_proj.weight torch.Size([2048, 2048])
model.layers.0.self_attn.v_proj.weight torch.Size([512, 2048])
model.layers.1.input_layernorm.weight torch.Size([2048])
model.layers.1.mlp.down_proj.weight torch.Size([2048, 8192])
model.layers.1.mlp.gate_proj.weight torch.Size([8192, 2048])
model.layers.1.mlp.up_proj.weight torch.Size([8192, 2048])
model.layers.1.post_attention_layernorm.weight torch.Size([2048])
model.layers.1.self_attn.k_proj.weight torch.Size([512, 2048])
model.layers.1.self_at

In [28]:
(pt_state_dict['model.embed_tokens.weight']).shape

torch.Size([128256, 2048])

In [29]:
(model.tok_emb.embedding.value).shape

(128256, 2048)

### Function for weight key mapping and loading

In [6]:
from safetensors.numpy import safe_open
import jax.numpy as jnp
import gc
from tqdm import tqdm

def assign(dest, src, name):
    if dest.shape != src.shape:
        raise ValueError(
            f"[Shape Mismatch] → {name}\n"
            f"  Expected: {dest.shape}\n"
            f"  Received: {src.shape}\n"
        )
    # if dest.dtype != src.dtype:
    #     print(
    #         f"[Warning: Dtype Mismatch] → {name}\n"
    #         f"  Expected: {dest.dtype}\n"
    #         f"  Received: {src.dtype}\n"
    #     )
    return src

def update_model_from_safetensors(model, safetensors_path):
    with safe_open(safetensors_path, framework="np") as f:
        for key in tqdm(f.keys(), desc="Loading Weights"):
            tensor = jnp.array(f.get_tensor(key))

            if key == 'model.embed_tokens.weight':
                model.tok_emb.embedding.value = assign(model.tok_emb.embedding.value, tensor, key)

            elif key.startswith('model.layers.'):
                parts = key.split('.')
                layer_id = int(parts[2])
                subkey = '.'.join(parts[3:])
                block = model.trf_blocks[layer_id]
                # print(subkey)

                if subkey == 'self_attn.q_proj.weight':
                    block.att.W_query.kernel.value = assign(block.att.W_query.kernel.value, tensor.T, key)
                elif subkey == 'self_attn.k_proj.weight':
                    block.att.W_key.kernel.value = assign(block.att.W_key.kernel.value, tensor.T, key)
                elif subkey == 'self_attn.v_proj.weight':
                    block.att.W_value.kernel.value = assign(block.att.W_value.kernel.value, tensor.T, key)
                elif subkey == 'self_attn.o_proj.weight':
                    block.att.out_proj.kernel.value = assign(block.att.out_proj.kernel.value, tensor.T, key)
                elif subkey == 'input_layernorm.weight':
                    block.norm1.scale.value = assign(block.norm1.scale.value, tensor, key)
                elif subkey == 'post_attention_layernorm.weight':
                    block.norm2.scale.value = assign(block.norm2.scale.value, tensor, key)
                elif subkey == 'mlp.gate_proj.weight':
                    block.ff.fc1.kernel.value = assign(block.ff.fc1.kernel.value, tensor.T, key)
                elif subkey == 'mlp.up_proj.weight':
                    block.ff.fc2.kernel.value = assign(block.ff.fc2.kernel.value, tensor.T, key)
                elif subkey == 'mlp.down_proj.weight':
                    block.ff.fc3.kernel.value = assign(block.ff.fc3.kernel.value, tensor.T, key)

            elif key == 'model.norm.weight':
                model.final_norm.scale.value = assign(model.final_norm.scale.value, tensor, key)

            elif key == 'lm_head.weight':
                model.out_head.kernel.value = assign(model.out_head.kernel.value, tensor.T, key)

            del tensor
            gc.collect()


In [18]:
# LLAMA32_CONFIG = {
#         "vocab_size": 128_256,
#         "context_length": 131_072,
#         "embedding_dim": 2048,
#         "n_heads": 32,
#         "n_layers": 16,
#         "hidden_dim": 8192,
#         "n_kv_groups": 8,
#         "rope_base": 500_000.0,
#         "dtype": jnp.bfloat16,
#         "param_dtype": jnp.bfloat16,

#         "rope_freq": {
#             "factor": 32.0,
#             "low_freq_factor": 1.0,
#             "high_freq_factor": 4.0,
#             "original_context_length": 8192,
#         }
#     }

args = ModelArgs(**LLAMA32_CONFIG)

# Step 2: Instantiate model (params not allocated yet)
model = Llama3Model(args=args, rngs=nnx.Rngs(0))

# Step 3: Load weights from .safetensors file
update_model_from_safetensors(model, "/content/Llama-3.2-1B-Instruct/model.safetensors")


Loading Weights: 100%|██████████| 146/146 [01:11<00:00,  2.04it/s]


In [None]:
update_model_from_safetensors(model, "/content/Llama-3.2-1B-Instruct/model.safetensors")

In [42]:
model.tok_emb.embedding.shape

(128256, 2048)

In [19]:
model.out_head.kernel = nnx.Param(model.tok_emb.embedding.value.T)

In [20]:
def get_all_params(module):
    params = []

    if isinstance(module, nnx.Param):
        params.append(module)
    elif isinstance(module, nnx.Module):
        for attr in vars(module).values():
            params.extend(get_all_params(attr))
    elif isinstance(module, (list, tuple)):
        for item in module:
            params.extend(get_all_params(item))
    elif isinstance(module, dict):
        for item in module.values():
            params.extend(get_all_params(item))

    return params

def count_params_and_size(model):
    params = get_all_params(model)
    total_params = sum(p.value.size for p in params)
    total_bytes = sum(p.value.size * p.value.dtype.itemsize for p in params)
    return total_params, total_bytes

def pretty_size(num_bytes):
    if num_bytes < 1024**2:
        return f"{num_bytes / 1024:.2f} KB"
    elif num_bytes < 1024**3:
        return f"{num_bytes / (1024**2):.2f} MB"
    else:
        return f"{num_bytes / (1024**3):.2f} GB"

# Run it
total_params, total_bytes = count_params_and_size(model)
print(f"Total Parameters: {total_params:,}")
print(f"Model Size: {pretty_size(total_bytes)}")


Total Parameters: 1,498,482,688
Model Size: 2.79 GB


## Chat Inference

In [10]:
!pip install blobfile



In [None]:
from tiktoken.load import load_tiktoken_bpe

# Path to your tokenizer.model file
tokenizer_path = "/content/Llama-3.2-1B-Instruct/original/tokenizer.model" # update the path if necessary

# Load the tokenizer
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)

# Access the vocabulary (tokens and IDs)
vocabulary = mergeable_ranks.items()

# Print or inspect the vocabulary
for token, token_id in vocabulary:
    print(f"Token: {token}, Token ID: {token_id}")

In [21]:
from huggingface_hub import hf_hub_download
from tiktoken.load import load_tiktoken_bpe

import os
from pathlib import Path

import tiktoken


class Tokenizer:
    def __init__(self, model_path):
        assert os.path.isfile(model_path), f"Model file {model_path} not found"
        mergeable_ranks = load_tiktoken_bpe(model_path)

        self.special_tokens = {
            "<|begin_of_text|>": 128000,
            "<|end_of_text|>": 128001,
            "<|start_header_id|>": 128006,
            "<|end_header_id|>": 128007,
            "<|eot_id|>": 128009,
        }
        self.special_tokens.update({
            f"<|reserved_{i}|>": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()
        })

        self.model = tiktoken.Encoding(
            name=Path(model_path).name,
            pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
            mergeable_ranks=mergeable_ranks,
            special_tokens=self.special_tokens
        )


    def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):
        if bos:
            tokens = [self.special_tokens["<|begin_of_text|>"]]
        else:
            tokens = []

        tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)

        if eos:
            tokens.append(self.special_tokens["<|end_of_text|>"])
        return tokens

    def decode(self, tokens):
        return self.model.decode(tokens)


class ChatFormat:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def encode_header(self, message):
        tokens = []
        tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
        tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
        tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
        tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
        return tokens

    def encode(self, text):
        message = {
            "role": "user",
            "content": text
        }

        tokens = self.encode_header(message)
        tokens.extend(
            self.tokenizer.encode(message["content"].strip(), bos=False, eos=False)
        )
        tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
        return tokens

    def decode(self, token_ids):
        return self.tokenizer.decode(token_ids)

# Assuming LLAMA_SIZE_STR is defined (e.g., "1B")
LLAMA_SIZE_STR = "1B"

tokenizer_file_path = hf_hub_download(
    repo_id=f"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct",
    filename="original/tokenizer.model",
    local_dir=f"Llama-3.2-{LLAMA_SIZE_STR}-Instruct"
)

tokenizer = Tokenizer(tokenizer_file_path)
chat_tokenizer = ChatFormat(tokenizer)

### Miscellaneous Tests

In [37]:
input_ids = tokenizer.encode("What do llamas eat?")
input_ids

[3923, 656, 9507, 29189, 8343, 30]

In [38]:
input_ids = tokenizer.encode("<|start_header_id|> What do llamas eat?")
input_ids

[27, 91, 2527, 8932, 851, 91, 29, 3639, 656, 9507, 29189, 8343, 30]

In [39]:
tokenizer.decode(input_ids)

'<|start_header_id|> What do llamas eat?'

In [40]:
input = jnp.array(input_ids).reshape(len(input_ids),1)

In [47]:
output = model(input)
output.shape

(13, 1, 128256)

In [23]:
jnp.argmax(output[-1])

Array(53751, dtype=int32)

In [18]:
import torch
tokenizer.decode([53751])

' Improvement'

## Inference

In [22]:
import jax
import jax.numpy as jnp

def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text)
    encoded_tensor = jnp.expand_dims(jnp.array(encoded),axis=0)  # add batch dimension
    return encoded_tensor


def token_ids_to_text(token_ids, tokenizer):
    flat = token_ids.squeeze(0)  # remove batch dimension
    return tokenizer.decode(flat.tolist())

def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
    """Generates text using the given model and parameters."""

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        logits = model(idx_cond)  # Assuming model is already JAX-compatible
        logits = logits[:, -1, :]

        # Top-k sampling
        if top_k is not None:
            top_logits, _ = jax.lax.top_k(logits, top_k)
            min_val = top_logits[:, -1]
            logits = jnp.where(logits < min_val, -jnp.inf, logits)

        # Temperature scaling
        if temperature > 0.0:
            logits /= temperature
            probs = jax.nn.softmax(logits, axis=-1)
            idx_next = jax.random.categorical(jax.random.PRNGKey(0), probs, axis=-1, shape=(logits.shape[0], 1))
        else:
            idx_next = jnp.argmax(logits, axis=-1, keepdims=True)

        # Early stopping with eos_id
        if eos_id is not None and jnp.all(idx_next == eos_id):
            break

        idx = jnp.concatenate([idx, idx_next], axis=1)

    return idx

In [25]:
PROMPT = "What do llamas eat?"

token_ids = generate(
    model=model,
    idx=text_to_token_ids(PROMPT, chat_tokenizer),
    max_new_tokens=50,
    context_size=LLAMA32_CONFIG["context_length"],
    top_k=1,
    temperature=0.
)

output_text = token_ids_to_text(token_ids, tokenizer)

In [24]:
#5 tokens
output_text

'<|start_header_id|>user<|end_header_id|>\n\nWhat do llamas eat?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nL'

In [26]:
# 50 tokens
output_text

'<|start_header_id|>user<|end_header_id|>\n\nWhat do llamas eat?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nLlamas are herbivores, which means they primarily eat plants. Their diet consists mainly of:\n\n1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and grass'