# CSM checkpoint conversion script to Fish format

Before running this, set `HUGGINGFACE_TOKEN` in a .env file in this repo somewhere. Make sure you're approved for the official checkpoint.

In [2]:
import torch
from safetensors.torch import save_file
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download, snapshot_download
from pathlib import Path
import os

load_dotenv()
model_path = hf_hub_download(
    repo_id="sesame/csm-1b",
    filename="ckpt.pt",
    token=os.getenv("HUGGINGFACE_TOKEN")
)
config_dir = snapshot_download(
    repo_id="unsloth/Llama-3.2-1B",
    ignore_patterns=["model.safetensors"]
)
config_path = Path(config_dir)

  from .autonotebook import tqdm as notebook_tqdm
Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 111212.61it/s]


In [3]:
state_dict = torch.load(model_path, map_location="cpu")
for name, param in state_dict.items():
    print(f"{name}: {param.shape}")


backbone.layers.0.attn.q_proj.weight: torch.Size([2048, 2048])
backbone.layers.0.attn.k_proj.weight: torch.Size([512, 2048])
backbone.layers.0.attn.v_proj.weight: torch.Size([512, 2048])
backbone.layers.0.attn.output_proj.weight: torch.Size([2048, 2048])
backbone.layers.0.mlp.w1.weight: torch.Size([8192, 2048])
backbone.layers.0.mlp.w2.weight: torch.Size([2048, 8192])
backbone.layers.0.mlp.w3.weight: torch.Size([8192, 2048])
backbone.layers.0.sa_norm.scale: torch.Size([2048])
backbone.layers.0.mlp_norm.scale: torch.Size([2048])
backbone.layers.1.attn.q_proj.weight: torch.Size([2048, 2048])
backbone.layers.1.attn.k_proj.weight: torch.Size([512, 2048])
backbone.layers.1.attn.v_proj.weight: torch.Size([512, 2048])
backbone.layers.1.attn.output_proj.weight: torch.Size([2048, 2048])
backbone.layers.1.mlp.w1.weight: torch.Size([8192, 2048])
backbone.layers.1.mlp.w2.weight: torch.Size([2048, 8192])
backbone.layers.1.mlp.w3.weight: torch.Size([8192, 2048])
backbone.layers.1.sa_norm.scale: torc

In [4]:
renamed_tensors = {
    key.replace('backbone.', '')
        .replace('decoder.layers', 'fast_layers')
        .replace('attn', 'attention')       
        .replace('k_proj', 'wk')
        .replace('q_proj', 'wq')
        .replace('v_proj', 'wv')
        .replace('output_proj', 'wo')
        .replace('sa_norm.scale', 'attention_norm.weight')
        .replace('mlp.', 'feed_forward.')
        .replace('mlp_norm.scale', 'ffn_norm.weight')
        .replace('norm.scale', 'norm.weight')
        .replace('decoder.norm.weight', 'fast_norm.weight')
        .replace('audio_embeddings', 'codebook_embeddings')
        .replace('text_embeddings', 'embeddings')
        .replace('projection', 'fast_project_in')
    : tensor

    for key, tensor in state_dict.items()
}
list(renamed_tensors.keys())

['layers.0.attention.wq.weight',
 'layers.0.attention.wk.weight',
 'layers.0.attention.wv.weight',
 'layers.0.attention.wo.weight',
 'layers.0.feed_forward.w1.weight',
 'layers.0.feed_forward.w2.weight',
 'layers.0.feed_forward.w3.weight',
 'layers.0.attention_norm.weight',
 'layers.0.ffn_norm.weight',
 'layers.1.attention.wq.weight',
 'layers.1.attention.wk.weight',
 'layers.1.attention.wv.weight',
 'layers.1.attention.wo.weight',
 'layers.1.feed_forward.w1.weight',
 'layers.1.feed_forward.w2.weight',
 'layers.1.feed_forward.w3.weight',
 'layers.1.attention_norm.weight',
 'layers.1.ffn_norm.weight',
 'layers.2.attention.wq.weight',
 'layers.2.attention.wk.weight',
 'layers.2.attention.wv.weight',
 'layers.2.attention.wo.weight',
 'layers.2.feed_forward.w1.weight',
 'layers.2.feed_forward.w2.weight',
 'layers.2.feed_forward.w3.weight',
 'layers.2.attention_norm.weight',
 'layers.2.ffn_norm.weight',
 'layers.3.attention.wq.weight',
 'layers.3.attention.wk.weight',
 'layers.3.attention.w

In [5]:
import re

def merge_attention_weights(state_dict):
    # This regex captures any key ending with .attention.{wq|wk|wv}.weight
    pattern = re.compile(r"^(.*\.attention)\.(wq|wk|wv)\.weight$")
    
    # New state dict we'll build up
    new_state_dict = {}
    
    # Group keys by their common prefix (like "layers.0.attention" or "fast_layers.1.attention")
    attention_groups = {}
    
    for key, value in state_dict.items():
        match = pattern.match(key)
        if match:
            base = match.group(1)    # e.g., "layers.0.attention"
            sub_key = match.group(2) # one of "wq", "wk", or "wv"
            if base not in attention_groups:
                attention_groups[base] = {}
            attention_groups[base][sub_key] = value
        else:
            # All keys that are not attention weights are copied as-is.
            new_state_dict[key] = value

    # For each group, if we have all three weights, merge them
    for base, sub_dict in attention_groups.items():
        if all(k in sub_dict for k in ["wq", "wk", "wv"]):
            # Concatenate along dimension 0 (this assumes the weights are shaped like [out_features, in_features])
            merged_weight = torch.cat([sub_dict["wq"], sub_dict["wk"], sub_dict["wv"]], dim=0)
            new_key = base + ".wqkv.weight"
            new_state_dict[new_key] = merged_weight
        else:
            # If one or more keys are missing, copy what we have.
            for sub_key, value in sub_dict.items():
                new_state_dict[f"{base}.{sub_key}.weight"] = value

    return new_state_dict

wqkv_dict = merge_attention_weights(renamed_tensors)
list(wqkv_dict.keys())

['layers.0.attention.wo.weight',
 'layers.0.feed_forward.w1.weight',
 'layers.0.feed_forward.w2.weight',
 'layers.0.feed_forward.w3.weight',
 'layers.0.attention_norm.weight',
 'layers.0.ffn_norm.weight',
 'layers.1.attention.wo.weight',
 'layers.1.feed_forward.w1.weight',
 'layers.1.feed_forward.w2.weight',
 'layers.1.feed_forward.w3.weight',
 'layers.1.attention_norm.weight',
 'layers.1.ffn_norm.weight',
 'layers.2.attention.wo.weight',
 'layers.2.feed_forward.w1.weight',
 'layers.2.feed_forward.w2.weight',
 'layers.2.feed_forward.w3.weight',
 'layers.2.attention_norm.weight',
 'layers.2.ffn_norm.weight',
 'layers.3.attention.wo.weight',
 'layers.3.feed_forward.w1.weight',
 'layers.3.feed_forward.w2.weight',
 'layers.3.feed_forward.w3.weight',
 'layers.3.attention_norm.weight',
 'layers.3.ffn_norm.weight',
 'layers.4.attention.wo.weight',
 'layers.4.feed_forward.w1.weight',
 'layers.4.feed_forward.w2.weight',
 'layers.4.feed_forward.w3.weight',
 'layers.4.attention_norm.weight',
 'la

In [6]:
import json
with open(config_path / 'config.json') as f:
    hf_config = json.load(f)

print(json.dumps(hf_config, indent=2))

{
  "_name_or_path": "meta-llama/Llama-3.2-1B",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 16,
  "num_key_value_heads": 8,
  "pad_token_id": 128004,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.48.1",
  "unsloth_fixed": true,
  "use_cache": true,
  "vocab_size": 128256
}


In [None]:
# TODO this could cause problems later
CODEBOOK_SIZE=2051

config = {
    "attention_qkv_bias": False,
    "codebook_size": CODEBOOK_SIZE,
    "dim": 2048,
    'dropout': 0.0,
    'fast_attention_qkv_bias': False,
    'fast_dim': 1024,
    'fast_intermediate_size': 8192,
    'fast_head_dim': 128,
    'fast_n_head': 8,
    'fast_n_local_heads': 2,
    'n_fast_layer': 4,
    'head_dim': hf_config['head_dim'],
    "initializer_range": hf_config['initializer_range'],
    'intermediate_size': hf_config['intermediate_size'],
    "is_reward_model": False,
    "max_seq_len": 2048,
    "model_type": "csm",
    "n_head": hf_config['num_attention_heads'],
    "norm_eps": hf_config['rms_norm_eps'],
    "n_layer": hf_config['num_hidden_layers'],
    "n_local_heads": hf_config['num_key_value_heads'],
    "num_codebooks": 32,
    "rope_base": 500_000,
    "scale_codebook_embeddings": False,
    "share_codebook_embeddings": True,
    "use_gradient_checkpointing": False,
    "vocab_size": hf_config['vocab_size'],
    "rope_scaling": hf_config['rope_scaling']
}

out_dir = Path("../../inits/csm_1b")
os.makedirs(out_dir, exist_ok=True)
# Save config
with open(out_dir / "config.json", "w") as f:
    json.dump(config, f, indent=2)

In [9]:
save_file(wqkv_dict, str(out_dir / "model.safetensors"))

## Tokenizer shenanigans

In [None]:
from transformers import AutoTokenizer
from tokenizers.processors import TemplateProcessing

# Copying from official repo `generator.py`
tokenizer_name = "unsloth/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token
eos = tokenizer.eos_token
tokenizer._tokenizer.post_processor = TemplateProcessing(
    single=f"{bos}:0 $A:0 {eos}:0",
    pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
    special_tokens=[
        (f"{bos}", tokenizer.bos_token_id),
        (f"{eos}", tokenizer.eos_token_id),
    ],
)
tokenizer.save_pretrained(out_dir)

('../../inits/csm_1b/tokenizer_config.json',
 '../../inits/csm_1b/special_tokens_map.json',
 '../../inits/csm_1b/tokenizer.json')