In [2]:
import os
import sys
sys.path.append('../../')
from configuration_tiny_mixtral import TinyMixtralConfig
from modeling_tiny_mixtral import TinyMixtralForCausalLM
from src.models.moe.config import ModelConfig
import torch
from dataclasses import asdict
from transformers import AutoTokenizer
import shutil

In [None]:
def convert_checkpoint_to_hf(
    checkpoint_path,
    output_dir,
    config_dict,
    tokenizer_name="gpt2"
):
    """Convert your checkpoint to HuggingFace format"""
    shutil.copy("modeling_tiny_mixtral.py", f"{output_dir}/modeling_tiny_mixtral.py")
    
    config = TinyMixtralConfig(**config_dict)
    
    hf_model = TinyMixtralForCausalLM(config)
    
    # Load your checkpoint
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    hf_model.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Save to HuggingFace format
    hf_model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    
    return hf_model, tokenizer

config_dict = asdict(ModelConfig())
config_dict.pop('top_k', None)  # Remove top_k if it exists
# Rename to the HF-expected name
config_dict['top_k_experts'] = ModelConfig.top_k

hf_model, tokenizer = convert_checkpoint_to_hf(
    checkpoint_path="last_epoch_moe_1.pt",
    output_dir="./hf-mixtral-active",
    config_dict=config_dict
)

  checkpoint = torch.load(checkpoint_path, map_location='cpu')


In [3]:
print(hf_model)

TinyMixtralForCausalLM(
  (model): tiny_mixtral(
    (tok_embedding): Embedding(50257, 768)
    (layers): ModuleList(
      (0-4): 5 x layer(
        (attention): SimpleMultiHeadAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=False)
          (c_proj): Linear(in_features=768, out_features=768, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ffn): SparseMOE(
          (experts): ModuleList(
            (0-7): 8 x SwiGLUFFN(
              (w_1): Linear(in_features=768, out_features=384, bias=True)
              (w_2): Linear(in_features=768, out_features=384, bias=True)
              (out): Linear(in_features=384, out_features=768, bias=True)
            )
          )
          (router): Linear(in_features=768, out_features=8, bias=True)
        )
        (attn_norm): RMSNorm()
        (ffn_norm): RMSNorm()
      )
    )
    (norm): RMSNorm()
    (output): Linear

In [None]:
from huggingface_hub import HfApi, create_repo

def push_to_hub(
    local_dir="./hf-mixtral-active",
    repo_name="Marmik/tiny-mixtral-5l-active",
    private=False
):
    
    # Create repository
    create_repo(
        repo_id=repo_name,
        private=False,
        exist_ok=True
    )
    
    # Initialize API
    api = HfApi()
    
    # Upload all files
    api.upload_folder(
        folder_path=local_dir,
        repo_id=repo_name,
        repo_type="model"
    )
    
    print(f"Model uploaded to: https://huggingface.co/{repo_name}")

# Push to hub
push_to_hub(
    local_dir="./hf-mixtral-active",
    repo_name="Marmik/tiny-mixtral-5l-active",
    private=False
)

model.safetensors:   0%|          | 0.00/498M [00:00<?, ?B/s]

Model uploaded to: https://huggingface.co/Marmik/tiny-mixtral-5l-total


In [None]:
state = torch.load("last_epoch_moe_total_1.pt", map_location="cpu", weights_only=False)
state_dict = state['model_state_dict']
cfg   = TinyMixtralConfig()
hf_model = TinyMixtralForCausalLM(cfg)
missing, unexpected = hf_model.model.load_state_dict(state_dict, strict=False)
print(missing)
print(unexpected)

In [None]:
# Generate text using the model
input_ids = torch.tensor([[68, 26, 1024, 38943, 500]])  # Example input tokens
with torch.no_grad():
    output = hf_model.generate(
        input_ids=input_ids,
        max_length=20,
        do_sample=True,
        temperature=0.7,
        pad_token_id=50256
    )
print("Generated output:", output)

In [None]:
# for name, module in hf_model.named_modules():
#     print(f"{name}: {type(module)}")