In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import os
import torch.nn as nn

# `torchtune` Exploration

I would like to understand the usage of `torchtune`. 
To start things off, I would like to follow [sesame](https://github.com/SesameAILabs/csm/blob/main/models.py)'s architecture and work on overfitting a pair of conversation from [seamless interaction](https://github.com/facebookresearch/seamless_interaction).

In [None]:
import torchtune

from torchtune.models import llama3_2
import safetensors.torch

from transformers import AutoTokenizer

In [None]:
def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
    return llama3_2.llama3_2(
        vocab_size=128_256,
        num_layers=16,
        num_heads=32,
        num_kv_heads=8,
        embed_dim=2048,
        max_seq_len=2048,
        intermediate_dim=8192,
        attn_dropout=0.0,
        norm_eps=1e-5,
        rope_base=500_000,
        scale_factor=32,
    )


def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
    return llama3_2.llama3_2(
        vocab_size=128_256,
        num_layers=4,
        num_heads=8,
        num_kv_heads=2,
        embed_dim=1024,
        max_seq_len=2048,
        intermediate_dim=8192,
        attn_dropout=0.0,
        norm_eps=1e-5,
        rope_base=500_000,
        scale_factor=32,
    )


FLAVORS = {
    "llama-1B": llama3_2_1B,
    "llama-100M": llama3_2_100M,
}


def _prepare_transformer(model):
    embed_dim = model.tok_embeddings.embedding_dim
    # model.tok_embeddings = nn.Identity()
    model.output = nn.Identity()
    return model, embed_dim

In [None]:
llama, embed_dim = _prepare_transformer(FLAVORS["llama-100M"]())

In [None]:
tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_name,
    token=os.environ["HUGGINGFACE_HUB_TOKEN"],
)

In [None]:
def convert_hf_to_torchtune(hf_state_dict):
    """Convert HuggingFace Llama state dict keys to torchtune format."""
    torchtune_state_dict = {}

    key_mapping = {
        "model.embed_tokens.weight": "tok_embeddings.weight",
        "model.norm.weight": "norm.scale",
    }

    for i in range(16):  # 16 layers for Llama-3.2-1B
        key_mapping.update(
            {
                f"model.layers.{i}.self_attn.q_proj.weight": f"layers.{i}.attn.q_proj.weight",
                f"model.layers.{i}.self_attn.k_proj.weight": f"layers.{i}.attn.k_proj.weight",
                f"model.layers.{i}.self_attn.v_proj.weight": f"layers.{i}.attn.v_proj.weight",
                f"model.layers.{i}.self_attn.o_proj.weight": f"layers.{i}.attn.output_proj.weight",
                f"model.layers.{i}.mlp.gate_proj.weight": f"layers.{i}.mlp.w1.weight",
                f"model.layers.{i}.mlp.down_proj.weight": f"layers.{i}.mlp.w2.weight",
                f"model.layers.{i}.mlp.up_proj.weight": f"layers.{i}.mlp.w3.weight",
                f"model.layers.{i}.input_layernorm.weight": f"layers.{i}.sa_norm.scale",
                f"model.layers.{i}.post_attention_layernorm.weight": f"layers.{i}.mlp_norm.scale",
            }
        )

    for hf_key, value in hf_state_dict.items():
        if hf_key in key_mapping:
            torchtune_key = key_mapping[hf_key]
            torchtune_state_dict[torchtune_key] = value

    return torchtune_state_dict


llama = llama3_2_1B()

# Load the weights from the safetensors file
hf_state_dict = safetensors.torch.load_file(
    "/home/henry/model_weights/Llama-3.2-1B-Instruct/model.safetensors"
)

# Convert to torchtune format
torchtune_state_dict = convert_hf_to_torchtune(hf_state_dict)

# Load the converted state dict
llama.load_state_dict(torchtune_state_dict)

In [None]:
type(llama)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
llama = llama.to(device)

In [None]:
# Prepare the input text
prompt = "The capital of Spain is"
inputs = tokenizer(prompt, return_tensors="pt")

# Put model in eval mode
llama.eval()

# Perform forward pass
with torch.no_grad():
    # Get input ids
    input_ids = inputs["input_ids"].to(device)

    # Store the original prompt for display
    generated_text = prompt

    # Perform autoregressive decoding 5 times
    for i in range(10):
        # Forward pass through the model
        outputs = llama(input_ids)

        # Get the logits for the last token position
        logits = outputs.logits if hasattr(outputs, "logits") else outputs
        next_token_logits = logits[:, -1, :]

        # Get the most likely next token
        next_token_id = torch.argmax(next_token_logits, dim=-1)

        # Decode the prediction
        predicted_token = tokenizer.decode(next_token_id)

        # Add the predicted token to our generated text
        generated_text += predicted_token

        # Append the new token to input_ids for next iteration
        input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=1)

        print(f"Step {i+1}: Added '{predicted_token}' -> '{generated_text}'")

    print("\nFinal result:")
    print(f"Input: {prompt}")
    print(f"Generated: {generated_text}")