In [None]:
!rm -rf ~/.cache/huggingface/datasets/databricks__databricks-dolly-15k

In [None]:
!pip install datasets --upgrade

Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
  Attempting uninstall: datasets
    Found existing installation: datasets 2.14.4
    Uninstalling datasets-2.14.4:
      Successfully uninstalled datasets-2.14.4
[31mERROR: pip's dependency r

In [None]:
import math
from typing import Tuple, Dict, Any, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm

# ────────────────────────────────────────────────────────────────────────────────
# 기본 설정
# ────────────────────────────────────────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "gpt2"

# ────────────────────────────────────────────────────────────────────────────────
# Tokenizer & 모델 로드
# ────────────────────────────────────────────────────────────────────────────────

tokenizer = AutoTokenizer.from_pretrained(model_name)
# GPT‑2 는 pad_token 이 없으므로 eos_token 을 사용
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

teacher_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
teacher_model.eval()  # teacher 는 고정, gradient 계산 X
student_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

# ────────────────────────────────────────────────────────────────────────────────
# Soft Prompt 모듈
# ────────────────────────────────────────────────────────────────────────────────

class SoftPrompt(nn.Module):
    """왼쪽에 붙는 learnable soft‑prompt (length, hidden)"""

    def __init__(self, length: int, hidden_size: int):
        super().__init__()
        self.length = length
        self.prompt = nn.Parameter(torch.randn(length, hidden_size))

    def forward(self, input_embeds: torch.Tensor) -> torch.Tensor:
        batch_size = input_embeds.size(0)
        prompt_expanded = self.prompt.unsqueeze(0).expand(batch_size, -1, -1)
        return torch.cat([prompt_expanded, input_embeds], dim=1)

# ────────────────────────────────────────────────────────────────────────────────
# 데이터셋 로드 & 전처리
# ────────────────────────────────────────────────────────────────────────────────

raw_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
raw_dataset = raw_dataset.select(range(int(len(raw_dataset) * 0.02)))  # 2% 만 사용

def preprocess(example: Dict[str, str]) -> Dict[str, str]:
    x = f"{example['instruction']}\n{example['context']}"
    y = example["response"]
    return {"x": x, "y": y}

dataset = raw_dataset.map(preprocess)

# Collate Fn --------------------------------------------------------------------

def collate_fn(batch: Any) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
    inputs = tokenizer([item["x"] for item in batch],
                       return_tensors="pt",
                       padding=True,
                       truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer([item["y"] for item in batch],
                           return_tensors="pt",
                           padding=True,
                           truncation=True).input_ids
    inputs = {k: v.to(device) for k, v in inputs.items()}
    labels = labels.to(device)
    return inputs, labels

# ────────────────────────────────────────────────────────────────────────────────
# 학습 파라미터 & 옵티마이저
# ────────────────────────────────────────────────────────────────────────────────

soft_prompt_len = 7
hidden_size = student_model.config.hidden_size
soft_prompt = SoftPrompt(length=soft_prompt_len, hidden_size=hidden_size).to(device)

optimizer = torch.optim.AdamW(
    list(student_model.parameters()) + list(soft_prompt.parameters()),
    lr=5e-5,
)

max_pos = student_model.config.n_positions        # GPT‑2 기본 1024
trim_len = max_pos - soft_prompt_len              # 실제 본문이 가질 수 있는 최대 길이

# ────────────────────────────────────────────────────────────────────────────────
# 학습 루프
# ────────────────────────────────────────────────────────────────────────────────

loader = DataLoader(dataset,
                    batch_size=4,
                    shuffle=True,
                    collate_fn=collate_fn)

for epoch in range(1):
    progress = tqdm(loader, desc=f"Epoch {epoch}")
    for inputs, _ in progress:
        # ── 시퀀스 길이 잘라내기 ────────────────────────────────────────────
        input_ids: torch.Tensor = inputs["input_ids"]
        attention_mask: torch.Tensor = inputs["attention_mask"]

        if input_ids.size(1) > trim_len:
            input_ids = input_ids[:, :trim_len]
            attention_mask = attention_mask[:, :trim_len]

        # teacher 입력도 동일하게 잘라낸 것으로 교체
        trimmed_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

        # ── Teacher Forward (No‑grad) ────────────────────────────────────────
        with torch.no_grad():
            teacher_outputs = teacher_model(**trimmed_inputs)
            teacher_probs = F.softmax(teacher_outputs.logits, dim=-1)

        # ── Student Forward ─────────────────────────────────────────────────
        input_embeds = student_model.transformer.wte(input_ids)
        input_embeds = soft_prompt(input_embeds)  # (B, prompt+seq, H)

        # attention mask: prompt 영역 1 로 padding
        prompt_mask = torch.ones(input_ids.size(0), soft_prompt_len,
                                 dtype=attention_mask.dtype,
                                 device=device)
        extended_attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)

        # position ids: 0 ~ L‑1
        seq_len = input_embeds.size(1)
        position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(input_ids.size(0), -1)

        student_outputs = student_model(inputs_embeds=input_embeds,
                                        attention_mask=extended_attention_mask,
                                        position_ids=position_ids)
        student_log_probs = F.log_softmax(student_outputs.logits, dim=-1)

        # ── KL Loss (student_probs exclude prompt tokens) ────────────────────
        loss = F.kl_div(student_log_probs[:, soft_prompt_len:, :],
                        teacher_probs,
                        reduction="batchmean")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        progress.set_postfix({"loss": loss.item()})

    print(f"[Epoch {epoch}] Loss: {loss.item():.4f}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/8.20k [00:00<?, ?B/s]

databricks-dolly-15k.jsonl:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15011 [00:00<?, ? examples/s]

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

Epoch 0: 100%|██████████| 75/75 [00:25<00:00,  2.95it/s, loss=14.8]

[Epoch 0] Loss: 14.8452





## Loss Weighting

In [None]:
#entropy = -(teacher_probs * teacher_probs.log()).sum(dim=-1)
#weight = 1 - entropy / math.log(vocab_size)
#kl = F.kl_div(...).sum(dim=-1)
#weighted_kl = (weight * kl).mean()
def weighted_loss(kl_loss: torch.Tensor, ce_loss: Optional[torch.Tensor] = None, *, alpha: float = 1.0, beta: float = 0.0) -> torch.Tensor:
    """KL + CrossEntropy 가중합. beta=0 → pure KL
    Args:
        kl_loss : scalar
        ce_loss : scalar or None
        alpha   : weight for KL
        beta    : weight for CE
    """
    if ce_loss is None:
        return kl_loss * alpha
    return alpha * kl_loss + beta * ce_loss

## Soft Blending

In [None]:
#blend = alpha * teacher_probs + (1 - alpha) * student_probs
#F.kl_div(student_log_probs, blend, reduction='batchmean')
def soft_blend(teacher_probs: torch.Tensor, student_log_probs: torch.Tensor, gamma: float = 0.5) -> torch.Tensor:
    """Teacher 와 Student 예측을 γ 비율로 혼합해 새로운 타깃 확률 생성
    Args:
        teacher_probs: (B, L, V) softmax
        student_log_probs: (B, L, V) log‑softmax
    Returns:
        blended_probs: (B, L, V)
    """
    student_probs = student_log_probs.exp().detach()  # stop grad
    blended = gamma * teacher_probs + (1.0 - gamma) * student_probs
    return blended / blended.sum(dim=-1, keepdim=True)

##Teacher Assistant Filterting

In [None]:
def teacher_assistant_filter(probs: torch.Tensor, *, top_k: Optional[int] = None, top_p: Optional[float] = None) -> torch.Tensor:
    """Top‑k 또는 Top‑p nucleus 필터링 후 확률 재정규화
    Args:
        probs: (B, L, V)
    Returns:
        filtered_probs: (B, L, V)   (gradient ✗)  – no grad assumed (teacher output)
    """
    with torch.no_grad():
        if top_k is not None:
            top_k = max(top_k, 1)
            vals, idx = torch.topk(probs, top_k)
            mask = torch.zeros_like(probs).scatter_(2, idx, 1.0)
            probs = probs * mask
        if top_p is not None:
            sorted_probs, sorted_idx = torch.sort(probs, dim=-1, descending=True)
            cumulative = sorted_probs.cumsum(dim=-1)
            mask = (cumulative - sorted_probs) < top_p  # first token that makes cum>p 포함
            new_mask = torch.zeros_like(probs).scatter_(2, sorted_idx, mask)
            probs = probs * new_mask
        # re‑normalize
        probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-9)
    return probs

In [None]:
# 학습 파라미터

soft_len = 7
hidden = student_model.config.hidden_size
soft_prompt = SoftPrompt(soft_len, hidden).to(device)

optimizer = torch.optim.AdamW(list(student_model.parameters()) + list(soft_prompt.parameters()), lr=5e-5)

max_pos = student_model.config.n_positions  # 1024
trim_len = max_pos - soft_len

top_k_assist = 50       # teacher filtering
blend_gamma = 0.7        # soft blend ratio (teacher dominant)
alpha_kl = 1.0
beta_ce = 0.0            # if >0, CE(y)도 사용

# 학습 루프(각 메소드별로 나눠서 아래와 같은 루프 3개 돌려보기)

loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

for epoch in range(1):
    pbar = tqdm(loader, desc=f"Epoch {epoch}")
    for inputs, labels in pbar:
        # ── trim ────────────────────────────────────────────────────────────
        input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
        if input_ids.size(1) > trim_len:
            input_ids = input_ids[:, :trim_len]
            attention_mask = attention_mask[:, :trim_len]
        trimmed_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

        # ── Teacher ─────────────────────────────────────────────────────────
        with torch.no_grad():
            t_logits = teacher_model(**trimmed_inputs).logits
            t_probs = F.softmax(t_logits, dim=-1)
            t_probs = teacher_assistant_filter(t_probs, top_k=top_k_assist)

        # ── Student forward ────────────────────────────────────────────────
        embeds = student_model.transformer.wte(input_ids)
        embeds = soft_prompt(embeds)

        prompt_mask = torch.ones(input_ids.size(0), soft_len, dtype=attention_mask.dtype, device=device)
        ext_mask = torch.cat([prompt_mask, attention_mask], dim=1)
        seq_len = embeds.size(1)
        pos_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(input_ids.size(0), -1)

        s_logits = student_model(inputs_embeds=embeds, attention_mask=ext_mask, position_ids=pos_ids).logits
        s_log_probs = F.log_softmax(s_logits, dim=-1)

        # ── Soft Blending ─────────────────────────────────────────────────
        target_probs = soft_blend(t_probs, s_log_probs[:, soft_len:, :], gamma=blend_gamma)

        # ── Loss 계산 ──────────────────────────────────────────────────────
        kl_loss = F.kl_div(s_log_probs[:, soft_len:, :], target_probs, reduction="batchmean")
        # Cross‑entropy(teacher vs GT)가 필요하면 labels 사용
        ce_loss = None
        if beta_ce > 0:
            ce_loss = F.cross_entropy(t_logits.view(-1, t_logits.size(-1)), labels.view(-1), ignore_index=tokenizer.pad_token_id)

        loss = weighted_loss(kl_loss, ce_loss, alpha=alpha_kl, beta=beta_ce)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    print(f"[Epoch {epoch}] Loss: {loss.item():.4f}")


Epoch 0: 100%|██████████| 75/75 [00:25<00:00,  2.89it/s, loss=53.5835]

[Epoch 0] Loss: 53.5835



