# CSM checkpoint conversion script to Fish format

Before running this, set `HUGGINGFACE_TOKEN` in a .env file in this repo somewhere. Make sure you're approved for the official checkpoint.

In [5]:
import torch
from safetensors.torch import save_file
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download, snapshot_download
from pathlib import Path
import os

load_dotenv()
model_path = hf_hub_download(
    repo_id="sesame/csm-1b",
    filename="ckpt.pt",
    token=os.getenv("HUGGINGFACE_TOKEN")
)
config_dir = snapshot_download(
    repo_id="unsloth/Llama-3.2-1B",
    ignore_patterns=["model.safetensors"]
)
config_path = Path(config_dir)

Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 95948.13it/s]


In [6]:
state_dict = torch.load(model_path, map_location="cpu")
for name, param in state_dict.items():
    print(f"{name}: {param.shape}")


backbone.layers.0.attn.q_proj.weight: torch.Size([2048, 2048])
backbone.layers.0.attn.k_proj.weight: torch.Size([512, 2048])
backbone.layers.0.attn.v_proj.weight: torch.Size([512, 2048])
backbone.layers.0.attn.output_proj.weight: torch.Size([2048, 2048])
backbone.layers.0.mlp.w1.weight: torch.Size([8192, 2048])
backbone.layers.0.mlp.w2.weight: torch.Size([2048, 8192])
backbone.layers.0.mlp.w3.weight: torch.Size([8192, 2048])
backbone.layers.0.sa_norm.scale: torch.Size([2048])
backbone.layers.0.mlp_norm.scale: torch.Size([2048])
backbone.layers.1.attn.q_proj.weight: torch.Size([2048, 2048])
backbone.layers.1.attn.k_proj.weight: torch.Size([512, 2048])
backbone.layers.1.attn.v_proj.weight: torch.Size([512, 2048])
backbone.layers.1.attn.output_proj.weight: torch.Size([2048, 2048])
backbone.layers.1.mlp.w1.weight: torch.Size([8192, 2048])
backbone.layers.1.mlp.w2.weight: torch.Size([2048, 8192])
backbone.layers.1.mlp.w3.weight: torch.Size([8192, 2048])
backbone.layers.1.sa_norm.scale: torc

In [13]:
renamed_tensors = {
    key.replace('backbone.', '')
        .replace('decoder.layers', 'fast_layers')
        .replace('decoder.norm.scale', 'fast_norm.weight')
        .replace('attn', 'attention')       
        .replace('k_proj', 'wk')
        .replace('q_proj', 'wq')
        .replace('v_proj', 'wv')
        .replace('o_proj', 'wo')
        .replace('sa_norm.scale', 'attention_norm.weight')
        .replace('mlp.', 'feed_forward.')
        .replace('mlp_norm.scale', 'ffn_norm.weight')
        .replace('norm.scale', 'norm.weight')
        .replace('audio_embeddings', 'codebook_embeddings')
        .replace('text_embeddings', 'embeddings')
        .replace('audio_head', 'fast_output')
        .replace('projection', 'fast_project_in')
    : tensor

    for key, tensor in state_dict.items()
}
list(renamed_tensors.keys())

['layers.0.attention.wq.weight',
 'layers.0.attention.wk.weight',
 'layers.0.attention.wv.weight',
 'layers.0.attention.output_proj.weight',
 'layers.0.feed_forward.w1.weight',
 'layers.0.feed_forward.w2.weight',
 'layers.0.feed_forward.w3.weight',
 'layers.0.attention_norm.weight',
 'layers.0.ffn_norm.weight',
 'layers.1.attention.wq.weight',
 'layers.1.attention.wk.weight',
 'layers.1.attention.wv.weight',
 'layers.1.attention.output_proj.weight',
 'layers.1.feed_forward.w1.weight',
 'layers.1.feed_forward.w2.weight',
 'layers.1.feed_forward.w3.weight',
 'layers.1.attention_norm.weight',
 'layers.1.ffn_norm.weight',
 'layers.2.attention.wq.weight',
 'layers.2.attention.wk.weight',
 'layers.2.attention.wv.weight',
 'layers.2.attention.output_proj.weight',
 'layers.2.feed_forward.w1.weight',
 'layers.2.feed_forward.w2.weight',
 'layers.2.feed_forward.w3.weight',
 'layers.2.attention_norm.weight',
 'layers.2.ffn_norm.weight',
 'layers.3.attention.wq.weight',
 'layers.3.attention.wk.weig

In [18]:
import json
with open(config_path / 'config.json') as f:
    hf_config = json.load(f)

print(json.dumps(hf_config, indent=2))

{
  "_name_or_path": "meta-llama/Llama-3.2-1B",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 16,
  "num_key_value_heads": 8,
  "pad_token_id": 128004,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
  },
  "rope_theta": 500000.0,
  "tie_word_embeddings": true,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.48.1",
  "unsloth_fixed": true,
  "use_cache": true,
  "vocab_size": 128256
}


In [None]:
CODEBOOK_SIZE=2048

config = {
    "attention_qkv_bias": False,
    "codebook_size": CODEBOOK_SIZE,
    "dim": 2048,
    'dropout': 0.0,
    'fast_attention_qkv_bias': False,
    'fast_dim': 1024,
    'fast_intermediate_size': 8192,
    'fast_head_dim': 64,
    'fast_n_head': 8,
    'fast_n_local_heads': 2,
    'n_fast_layer': 4,
    'head_dim': hf_config['head_dim'],
    "initializer_range": hf_config['initializer_range'],
    'intermediate_size': hf_config['intermediate_size'],
    "is_reward_model": False,
    "max_seq_len": 2048,
    "model_type": "csm",
    "n_head": hf_config['num_attention_heads'],
    "norm_eps": hf_config['rms_norm_eps'],
    "n_layer": hf_config['num_hidden_layers'],
    "n_local_heads": hf_config['num_key_value_heads'],
    "num_codebooks": 32,
    "rope_base": 500_000,
    "scale_codebook_embeddings": False,
    "share_codebook_embeddings": True,
    "use_gradient_checkpointing": False,
    "vocab_size": hf_config['vocab_size']
}