In [1]:
import os
import gc
from dataclasses import dataclass
from typing import Optional, Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoConfig, AutoTokenizer, LlamaModel
from peft import LoraConfig, get_peft_model, TaskType


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig

model_id = "meta-llama/Llama-3.1-8B" 

hf_config = AutoConfig.from_pretrained(model_id)

# hidden state 뽑기 좋은 AutoModel 사용
backbone = AutoModel.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",)

        # 1) backbone freeze
for p in backbone.parameters():
    p.requires_grad = False

tokenizer = AutoTokenizer.from_pretrained(model_id)

backbone.eval()

tokenizer.pad_token = tokenizer.eos_token
# padding='max_length', max_length = 4800

print(backbone)

Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.36s/it]
Some parameters are on the meta device because they were offloaded to the cpu.


LlamaModel(
  (embed_tokens): Embedding(128256, 4096)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
        (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
    )
  )
  (norm): LlamaRMSNorm((4096,), eps=1e-05)
  (rotary_emb): LlamaRotaryEmbedding()
)


In [5]:
@dataclass
class DownstreamConfig:
    model_id: str = "meta-llama/Meta-Llama-3.1-8B"  # base로 쓰는 걸 추천(인스트럭트도 가능)
    num_labels: int = 2
    task: str = "classification"  # "classification" or "regression"
    pooling: str = "last"         # "last" or "mean"
    dora_r: int = 8
    dora_alpha: int = 16
    dora_dropout: float = 0.0
    target_modules: tuple = ("q_proj", "v_proj")  # attention에만
    torch_dtype: torch.dtype = torch.bfloat16

args = DownstreamConfig()

In [None]:
TaskType.

<enum 'TaskType'>

In [None]:
# PEFT 설정
peft_args = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    r=args.dora_r,
    lora_alpha=args.dora_alpha,
    lora_dropout=args.dora_dropout,
    target_modules=list(args.target_modules),
    bias="none",
    use_dora=True,)

In [7]:
peft_backbone = get_peft_model(backbone, peft_args)

In [8]:
peft_backbone

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): lora.Linear(
              (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
              (lora_dropout): ModuleDict(
                (default): Identity()
              )
              (lora_A): ModuleDict(
                (default): Linear(in_features=4096, out_features=8, bias=False)
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=8, out_features=4096, bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
              (lora_magnitude_vector): ModuleDict(
                (default): lora.dora.DoraLinearLayer()
              )
            )
            (k_proj): Linear(in_features=4096, ou

In [9]:
class FeatureTransform(nn.Module):
    """
    DoRA/LoRA의 lora_A 출력 (low-rank vector)에 적용할 변환.
    예: L2 normalization
    """
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        # z: (B, L, r)
        norm = z.norm(p=2, dim=-1, keepdim=True).clamp_min(self.eps)
        return z / norm


class AWithTransform(nn.Module):
    """
    기존 lora_A(Linear)를 감싸서 A(x) -> transform(A(x)) 반환
    """
    def __init__(self, a_linear: nn.Module, transform: nn.Module):
        super().__init__()
        self.a_linear = a_linear
        self.transform = transform

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.a_linear(x)
        z = self.transform(z)
        return z

def wrap_all_lora_A_modules_inplace(peft_model: nn.Module, transform: nn.Module) -> int:
    """
    PEFT가 삽입한 LoRA/DoRA 레이어들 중 lora_A 모듈을 찾아 wrapper로 교체.
    (PEFT 내부 구현이 버전별로 조금 달라서 '탐색 기반'으로 최대한 견고하게 작성)

    Returns:
        교체한 lora_A 개수
    """
    replaced = 0

    for module in peft_model.modules():
        # PEFT LoRA layer들은 보통 lora_A, lora_B 같은 attribute를 가짐(버전/타겟 레이어 타입별로 약간 다름)
        if hasattr(module, "lora_A"):
            lora_A = getattr(module, "lora_A")

            # lora_A가 adapter_name -> nn.Module 형태로 담긴 dict/ModuleDict인 경우가 흔함
            if isinstance(lora_A, (nn.ModuleDict, dict)):
                for adapter_name, a_mod in list(lora_A.items()):
                    # 이미 wrapper면 skip
                    if isinstance(a_mod, AWithTransform):
                        continue
                    lora_A[adapter_name] = AWithTransform(a_mod, transform)
                    replaced += 1

            # 일부 구현에서는 단일 모듈일 수도 있으니 방어적으로 처리
            elif isinstance(lora_A, nn.Module) and not isinstance(lora_A, AWithTransform):
                setattr(module, "lora_A", AWithTransform(lora_A, transform))
                replaced += 1

    return replaced

In [16]:
for module in peft_backbone.modules():
    print(module)
    print('-------')

PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): lora.Linear(
              (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
              (lora_dropout): ModuleDict(
                (default): Identity()
              )
              (lora_A): ModuleDict(
                (default): AWithTransform(
                  (a_linear): Linear(in_features=4096, out_features=8, bias=False)
                  (transform): FeatureTransform()
                )
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=8, out_features=4096, bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
              (lora_magnitude_vector): ModuleDict(
                (def

In [None]:
class LoraAAdaIN(nn.Module):
    """
    lora_A 출력 z를 AdaIN처럼 변환:
      z_hat = (z - mu_c) / sigma_c * sigma_t + mu_t
    - mu_c, sigma_c: 현재 배치(또는 sample)에서 계산
    - mu_t, sigma_t: Dataset A에서 수집한 target stats
    """
    def __init__(
        self,
        a_linear: nn.Module,
        target_payload: Dict[str, Any],
        key: str,
        style_mode: Literal["aggregate", "distribution"] = "aggregate",
        selection: Literal["mean_of_dist", "random", "cycle"] = "mean_of_dist",
        seed: int = 0,
        eps: float = 1e-6,
        instance_wise: bool = True,  # AdaIN 느낌(샘플별)로 할지, 배치 전체로 할지
    ):
        super().__init__()
        self.a_linear = a_linear
        self.key = key
        self.style_mode = style_mode
        self.selection = selection
        self.eps = eps
        self.instance_wise = instance_wise

        self._rng = random.Random(seed)
        self._cycle_idx = 0

        # payload 구조:
        # target_payload = {"mode":..., "r":..., "data":...}
        self._data = target_payload["data"]
        self._payload_mode = target_payload["mode"]

        if self.style_mode != self._payload_mode:
            raise ValueError(
                f"style_mode({self.style_mode})와 payload mode({self._payload_mode})가 다릅니다."
            )

    @torch.no_grad()
    def _get_target_stats(self, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.style_mode == "aggregate":
            entry: LayerStatsAggregate = self._data[self.key]
            mu_t = entry.mean.to(device)
            var_t = entry.var.to(device)
            return mu_t, var_t

        # distribution
        entry: LayerStatsDistribution = self._data[self.key]
        means, vars_ = entry.means, entry.vars

        if len(means) == 0:
            raise RuntimeError(f"[{self.key}] distribution stats가 비어 있습니다.")

        if self.selection == "mean_of_dist":
            mu_t = torch.stack([m.to(device) for m in means], dim=0).mean(dim=0)
            var_t = torch.stack([v.to(device) for v in vars_], dim=0).mean(dim=0)
            return mu_t, var_t

        if self.selection == "random":
            i = self._rng.randrange(len(means))
            return means[i].to(device), vars_[i].to(device)

        # cycle
        i = self._cycle_idx % len(means)
        self._cycle_idx += 1
        return means[i].to(device), vars_[i].to(device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.a_linear(x)  # (B,L,r) 또는 (B,r)
        if z.dim() == 2:
            z = z.unsqueeze(1)  # (B,1,r)

        device = z.device
        mu_t, var_t = self._get_target_stats(device)
        sigma_t = torch.sqrt(var_t + self.eps)  # (r,)

        # content stats (현재 입력에서)
        if self.instance_wise:
            # 샘플별로 L축 평균/분산: (B,1,r)
            mu_c = z.mean(dim=1, keepdim=True)
            var_c = z.var(dim=1, keepdim=True, unbiased=False).clamp_min(1e-12)
        else:
            # 배치 전체(B*L) 기준 채널별 stats: (1,1,r)
            z2 = z.reshape(-1, z.size(-1))
            mu_c = z2.mean(dim=0).view(1, 1, -1)
            var_c = z2.var(dim=0, unbiased=False).view(1, 1, -1).clamp_min(1e-12)

        sigma_c = torch.sqrt(var_c + self.eps)

        # broadcast: (r,) -> (1,1,r)
        mu_t = mu_t.view(1, 1, -1)
        sigma_t = sigma_t.view(1, 1, -1)

        z_hat = (z - mu_c) / sigma_c * sigma_t + mu_t

        # 원래가 (B,r)였으면 다시 squeeze
        if z_hat.size(1) == 1 and x.dim() == 2:
            z_hat = z_hat.squeeze(1)
        return z_hat

In [13]:
# 3) A 출력(low-rank vector) 변형 삽입
n = wrap_all_lora_A_modules_inplace(peft_backbone, FeatureTransform())
print(f"[INFO] Wrapped lora_A modules: {n}")
print(peft_backbone)

[INFO] Wrapped lora_A modules: 0
PeftModelForFeatureExtraction(
  (base_model): LoraModel(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): lora.Linear(
              (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
              (lora_dropout): ModuleDict(
                (default): Identity()
              )
              (lora_A): ModuleDict(
                (default): AWithTransform(
                  (a_linear): Linear(in_features=4096, out_features=8, bias=False)
                  (transform): FeatureTransform()
                )
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=8, out_features=4096, bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
              (lora_magnitude_vector):