In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import triton
import triton.language as tl
import copy
from typing import Optional, Tuple, List, Dict, Any, Union
from transformers.models.llama.modeling_llama import (
    LlamaConfig,
    LlamaModel, 
    LlamaForCausalLM,
    LlamaDecoderLayer,
    LlamaAttention,
    LlamaMLP,
    LlamaRMSNorm
)
from tqdm.notebook import tqdm

In [None]:
llama_model_path   = "meta-llama/Llama-3.1-8B-Instruct"
llama_deepseek_dir = "./llama_deepseek_8B_mla_moe"

# Convert

In [None]:
from llama_deepseek_convert import *

In [None]:
new_model, new_config =  convert_llama_to_deepseek(
    llama_model_path = llama_model_path, 
    output_dir = llama_deepseek_dir,
)

In [None]:
def count_parameters(model):
  """Counts the total number of trainable parameters in a PyTorch model.

  Args:
    model: A PyTorch model instance.

  Returns:
    The total number of trainable parameters in the model.
  """
  total_params = sum(p.numel() for p in model.parameters())
  trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  return total_params, trainable_params

In [None]:
total_params, trainable_params = count_parameters(new_model)
total_params/1024/1024/1024, trainable_params # 12B --> 9B trainable

In [None]:
print(new_config)

# Reload

In [None]:
from llama_deepseek_model_test import *

In [None]:
new_model = "./llama_deepseek_8B_mla_moe"
inspect_only = False
prompt = "Tại sao bác Hồ được yêu quý ?"
max_new_tokens = 512

In [None]:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader, DistributedSampler
import transformers
from transformers import (
    LlamaTokenizer, 
    LlamaForCausalLM, 
    Trainer, 
    TrainingArguments,
    DataCollatorForLanguageModeling,
    AutoTokenizer,
    AutoModel
)

import os
os.environ["CUDA_VISIBLE_DEVICES"]="2,3"

from tqdm.notebook import tqdm


In [None]:
try:
    # Load the model and tokenizer
    model, tokenizer = load_model(new_model, "auto")
    
    # Inspect model structure
    logger.info("\nInspecting model structure...")
    inspection_results = inspect_model_structure(model)
    
    logger.info("\nModel Structure Inspection Results:")
    for key, value in inspection_results.items():
        logger.info(f"{key}: {value}")
    
except Exception as e:
    logger.error(f"Error loading or testing model: {e}")
    import traceback
    logger.error(traceback.format_exc())
    sys.exit(1)



In [None]:
inspection_results

In [None]:
total_params, trainable_params = count_parameters(model)
total_params/1024/1024/1024, trainable_params # 12B --> 9B trainable

In [None]:
model.config

In [None]:
model