In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer


class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, temperature=2.0, alpha=0.7, **kwargs):
        """
        Args:
            teacher_model: 학습된 원본 모델 (Teacher)
            temperature (float): 로짓을 부드럽게 만드는 온도 (T). 높을수록 정답 외의 확률 정보가 많이 전달됨.
            alpha (float): KD Loss와 일반 Task Loss 사이의 가중치. 1.0이면 KD Loss만 사용.
        """
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.temperature = temperature
        self.alpha = alpha
        
        # Teacher 모델은 학습하지 않으므로 VRAM 절약을 위해 Gradient 계산을 끈다.
        # 또한, Dropout 등이 작동하지 않도록 eval 모드로 설정.
        self.teacher_model.eval()
        self.teacher_model.requires_grad_(False)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Student 모델과 Teacher 모델의 출력을 비교하여 Loss를 계산합니다.
        """
        
        # 1. Student 모델의 Forward Pass
        student_outputs = model(**inputs)
        
        # HuggingFace Trainer 내부 로직 호환성 (Loss 유무 확인)
        if student_outputs.loss is not None:
            student_loss = student_outputs.loss
        else:
            # Labels가 있으면 CrossEntropy 수동 계산
            logits = student_outputs.logits
            labels = inputs.get("labels")
            loss_fct = nn.CrossEntropyLoss()
            student_loss = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))

        # 2. Teacher 모델의 Forward Pass (Gradient 계산 안 함)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
        
        # 3. Logits 추출
        student_logits = student_outputs.logits
        teacher_logits = teacher_outputs.logits

        # 4. KL-Divergence Loss (KD Loss) 계산
        # 핵심: Temperature(T)로 나누어 분포를 부드럽게 만듦 (Soft Target)
        # reduction='batchmean'이 수학적으로 올바른 KL Div 계산법.
        kd_loss = F.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=-1),
            F.softmax(teacher_logits / self.temperature, dim=-1),
            reduction="batchmean",
            log_target=False
        ) * (self.temperature ** 2) # Gradients 크기 보정을 위해 T^2를 곱함

        # 5. 최종 Loss 결합
        # Alpha가 클수록 Teacher를 더 많이 닮으려고 노력함
        total_loss = (self.alpha * kd_loss) + ((1 - self.alpha) * student_loss)

        return (total_loss, student_outputs) if return_outputs else total_loss

In [10]:
# Need latest bitsandbytes package for quantized teacher model load
!pip install -U -q bitsandbytes

[0m

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, DataCollatorForLanguageModeling, BitsAndBytesConfig
from datasets import load_dataset


# 1. Configuration
TEACHER_ID = "meta-llama/Llama-3.2-3B-Instruct"  # 원본 모델
STUDENT_PATH = "/mnt/workspace/quantization/llm_compressor/Llama-3.2-3B-Instruct2of4-sparse"           # Pruning 완료된 로컬 모델 경로
DATASET_ID = "HuggingFaceH4/ultrachat_200k"                     # 데이터셋
TRAIN_SPLIT = "train_sft"
TEST_SPLIT = "test_sft"
MAX_SEQUENCE_LENGTH = 2048

# KD 하이퍼파라미터
TEMPERATURE = 2.0 
ALPHA = 0.7


# 2. 모델 로드
print("Loading Teacher Model...")
# Teacher는 메모리를 아끼기 위해 4-bit 또는 8-bit로 로드하는 것을 강력 추천.
teacher_model = AutoModelForCausalLM.from_pretrained(
    TEACHER_ID,
    device_map="auto",
    quantization_config = BitsAndBytesConfig(load_in_4bit=True),
    dtype=torch.bfloat16
)

print("Loading Student Model...")
# Student는 실제로 Weight Update 해야 하므로 bf16 원본으로 로드
student_model = AutoModelForCausalLM.from_pretrained(
    STUDENT_PATH,
    device_map="auto",
    dtype=torch.bfloat16
)

# 토크나이저는 공유
tokenizer = AutoTokenizer.from_pretrained(TEACHER_ID)
tokenizer.pad_token = tokenizer.eos_token   # Llama 계열 필수


# 3. 훈련 준비
# 데이터셋 로드 (예시: HuggingFace Dataset)
def preprocess(example):
    """Preprocess dataset examples."""
    return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}

def tokenize(sample):
    """Tokenize dataset examples."""
    return tokenizer(
        sample["text"],
        padding=False,
        max_length=MAX_SEQUENCE_LENGTH,
        truncation=True,
        add_special_tokens=False,
    )

train_ds = load_dataset(
    DATASET_ID, 
    split=f"{TRAIN_SPLIT}[:8192]"
).shuffle(seed=47)
train_ds = train_ds.map(preprocess)
train_ds = train_ds.map(tokenize, remove_columns=train_ds.column_names)

test_ds = load_dataset(
    DATASET_ID, 
    split=f"{TRAIN_SPLIT}[:512]"
).shuffle(seed=47)
test_ds = test_ds.map(preprocess)
test_ds = test_ds.map(tokenize, remove_columns=test_ds.column_names)

# Training setting
training_args = TrainingArguments(
    output_dir="./distilled",
    num_train_epochs=3,
    per_device_train_batch_size=4, # VRAM에 맞춰 조절
    gradient_accumulation_steps=4,
    learning_rate=2e-5,            # Pruning 후 복구는 작은 LR 추천
    bf16=True,                     # 최신 GPU라면 bf16 필수
    logging_steps=16,
    save_strategy="epoch",
    remove_unused_columns=False,   # KD 할 때는 Input이 Teacher/Student 둘 다 들어가야 하므로 False 추천
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# 4. Trainer 실행
trainer = DistillationTrainer(
    teacher_model=teacher_model,
    temperature=TEMPERATURE,
    alpha=ALPHA,
    model=student_model,
    args=training_args,
    train_dataset=train_ds, # 준비된 데이터셋
    eval_dataset=test_ds,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

print("Starting Distillation...")
trainer.train()

print("Saving Distilled Student Model...")
trainer.save_model("./final_distilled_model")

Loading Teacher Model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

`run_compressed` is only supported for quantized_compressed models and not for sparsified models. Setting `run_compressed=False`


Loading Student Model...


Compressing model: 196it [00:00, 2314.55it/s]
Decompressing model: 196it [00:07, 27.34it/s]
  super().__init__(*args, **kwargs)
The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.


Starting Distillation...


OutOfMemoryError: CUDA out of memory. Tried to allocate 128.00 MiB. GPU 0 has a total capacity of 23.57 GiB of which 50.94 MiB is free. Process 1780711 has 23.48 GiB memory in use. Of the allocated memory 22.80 GiB is allocated by PyTorch, and 391.72 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)