In [1]:
# file: quick_moe_train.py
import torch
from transformers import AutoTokenizer
from changeMoE import ChangeMoE

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = torch.float16
MODEL_ID = "meta-llama/Llama-3.2-1B"

changer = ChangeMoE(
    model_id = MODEL_ID,
    num_experts = 4,
    top_k = 2,
    dtype = DTYPE,
    device = DEVICE,
)
model= changer.get_model()
tokenizer = changer.get_tokenizer()
tokenizer.pad_token= tokenizer.eos_token
tokenizer.padding_side = "left"

  from .autonotebook import tqdm as notebook_tqdm


48
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_em

In [2]:
texts = ["저녁 메뉴 추천좀.", "참치김밥 말고."]
batch = tokenizer(texts, return_tensors="pt", padding=True).to(DEVICE)
targets = batch["input_ids"] # GT 입력 그대로

In [3]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
texts = ["초밥이 좋아요? 국수가 좋아요?","저녁 메뉴 추천좀요."]
    
inputs = tokenizer(texts, return_tensors="pt",
                    padding=True, truncation=True).to(DEVICE)
model.eval()
with torch.no_grad():
    gen_ids = model.generate(
        input_ids       = inputs["input_ids"],
        attention_mask  = inputs["attention_mask"],
        do_sample       = False # Greedy
    )

decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
for i, output in enumerate(decoded):
    print(f"Model Input : {texts[i]} , Model output : {output}")

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Model Input : 초밥이 좋아요? 국수가 좋아요? , Model output : 초밥이 좋아요? 국수가 좋아요? 저는 초밥이 좋고 국수가 좋아요. 저는 초밥이 좋
Model Input : 저녁 메뉴 추천좀요. , Model output : 저녁 메뉴 추천좀요. 2018-12-12 16:00:00
저녁 메뉴 추천좀


In [4]:
model.train()
optim = torch.optim.AdamW(model.parameters(), lr=1e-5)
out = model(**batch, labels=targets)
loss = out.loss
loss.backward()
optim.step()
optim.zero_grad()

OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 0 has a total capacity of 23.54 GiB of which 30.44 MiB is free. Including non-PyTorch memory, this process has 22.92 GiB memory in use. Of the allocated memory 22.54 GiB is allocated by PyTorch, and 63.50 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# 가중치 진짜 변했나?
with torch.no_grad():
    w_after = model.model.layers[0].mlp.experts[0].gate_proj.weight
    changed = not torch.allclose(w_before, w_after)
    print(f"First-layer MoE gate_proj.weight changed? {changed}")
    print(f"Difference: {(w_before - w_after).abs().mean().item():.6f}")

In [None]:
# HF 저장 그래도 가능한지?
model.save_pretrained("moe_hf_checkpoint")
tokenizer.save_pretrained("moe_hf_checkpoint")

fresh = ChangeMoE(
    model_id    = MODEL_ID,
    num_experts = 4,
    top_k       = 2,
    dtype       = DTYPE,
    device      = DEVICE,
)
fresh_model = fresh.get_model()
fresh_model.load_state_dict(torch.load("moe_finetuned.pth"), strict=False)