In [None]:
# Î™®Îç∏.py

In [1]:
import random
import math
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd

from datasets import Dataset
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from transformers import AutoTokenizer, AutoModel
from transformers import get_linear_schedule_with_warmup
from transformers import get_cosine_schedule_with_warmup
from peft import LoraConfig, get_peft_model, TaskType #, AdapterConfig

In [2]:
def set_seed(seed: int = 42):
    random.seed(seed)            # Í∏∞Î≥∏ Python random Í≥†Ï†ï
    np.random.seed(seed)         # NumPy ÎûúÎç§ Í≥†Ï†ï
    torch.manual_seed(seed)      # CPU Ïó∞ÏÇ∞ ÎûúÎç§ Í≥†Ï†ï
    torch.cuda.manual_seed(seed) # GPU Î™®Îì† ÎîîÎ∞îÏù¥Ïä§ ÎûúÎç§ Í≥†Ï†ï
    torch.cuda.manual_seed_all(seed)  # Î©ÄÌã∞ GPUÏùº Îïå

    # Ïó∞ÏÇ∞ Ïû¨ÌòÑÏÑ±
    torch.backends.cudnn.deterministic = True  # cuDNN Ïó∞ÏÇ∞ÏùÑ determinisitcÏúºÎ°ú Í∞ïÏ†ú
    torch.backends.cudnn.benchmark = False     # CUDA ÏÑ±Îä• ÏûêÎèô ÌäúÎãù Í∏∞Îä• ÎÅî ‚Üí ÏôÑÏ†Ñ Ïû¨ÌòÑ Í∞ÄÎä•

set_seed(42)

In [3]:
EPOCHS = 20
WARMUP_RATIO = 0.1
LEARNING_RATE = 1e-3
BATCH_SIZE = 128
TEMPERATURE = 0.05
NEG_RATIO = 0.2

In [6]:
class E5LoRABackbone(nn.Module):
    def __init__(self, model_name: str, lora_cfg: dict):
        super().__init__()
        
        base_model = AutoModel.from_pretrained(model_name)

        # Linear(d ‚Üí d) -‚Üí Linear(d ‚Üí d) + LoRA(d ‚Üí d)
        lora_config = LoraConfig( 
            task_type=TaskType.FEATURE_EXTRACTION, # ÏûÑÎ≤†Îî© fine-tuning
            # LoRAÍ∞Ä Î∂ÑÎ•òÍ∏∞ÏôÄ Í∞ôÏùÄ output headÏóê Ï†ÅÏö©ÎêòÎäî Í≤ÉÏù¥ ÏïÑÎãàÎùº
            # Î™®Îç∏Ïùò Transformer Î∏îÎ°ù(encoder)ÏóêÎßå Ï†ÅÏö©ÎêòÎèÑÎ°ù
            r=lora_cfg["r"],    # LoRA rank
            lora_alpha=lora_cfg["alpha"],
            lora_dropout=lora_cfg["dropout"],
            bias="none"
        )

        self.encoder = get_peft_model(base_model, lora_config)
        self.config = self.encoder.config # hidden_size Í∞ôÏùÄÍ±∞ headÏóêÏÑú ÏïåÏïÑÏïºÌï®
    
    def forward(self, input_ids, attention_mask, **kwargs):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
        # return outputs.last_hidden_state[:, 0] single vector ÌÖåÏä§Ìä∏Ìï†Îïê ÏùºÏºÄ ÌñàÎäîÎîî..
        return outputs.last_hidden_state # headÏóê ÎÑ£ÏùÑÍ±∞Îùº CLS pooling ÏïàÌï®

In [7]:
class SimpleMultiVectorHead(nn.Module):
    def __init__(self, num_vectors=3,  input_dim=384):
        super().__init__()

        # kÍ∞úÏùò ÌïôÏäµ Í∞ÄÎä•Ìïú ÏøºÎ¶¨ ÌÜ†ÌÅ∞ ÏÉùÏÑ±
        self.query_tokens = nn.Parameter( # ÏùºÎã® ÌÖêÏÑúÎûë Îã§Î•¥Í≤å ÌïôÏäµ Í∞ÄÎä•Ìï®
            torch.randn(1, num_vectors, input_dim) # (1, K, D) Ï∞®ÏõêÏùò ÎûúÎç§Í∞í
        )
        # nn.init.normal_(self.query_tokens, std=0.02)
        # nn.init.orthogonal_(self.query_tokens)
        self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=8, batch_first=True)
        self.norm = nn.LayerNorm(input_dim)
    
    def forward(self, seq_out, attn_mask):
        """
        # shapes
        query = queries  # (B, K, D)
        key = seq_out    # (B, L, D)
        value = seq_out  # (B, L, D)
        """
        batch_size = seq_out.shape[0] # (B)
        queries = self.query_tokens.repeat(batch_size, 1, 1)  # Query ÌôïÏû• (1, K, D) -> (B, K, D)
        key_padding_mask = ~attn_mask.bool()

        vectors, _ = self.attention( # (B, K, D)
            query=queries,
            key=seq_out,
            value=seq_out,
            key_padding_mask=key_padding_mask
        )

        # 2. üî• [Pro Tip] ÏûîÏ∞® Ïó∞Í≤∞ (Residual) + LayerNorm
        # "ÏÉàÎ°ú Î∞∞Ïö¥ Ï†ïÎ≥¥(vectors)Ïóê ÏõêÎûò ÎÇ¥ ÏûêÏïÑ(queries)Î•º ÏÑûÎäîÎã§"
        # Ïù¥Î†áÍ≤å ÌïòÎ©¥ ÌïôÏäµÏù¥ Ìõ®Ïî¨ ÏïàÏ†ïÏ†ÅÏúºÎ°ú Î≥ÄÌï©ÎãàÎã§.
        # vectors = self.norm(queries + vectors)
        return vectors

In [8]:
class BookEmbeddingModel(nn.Module):
    def __init__(self, model_name: str, lora_config: dict):
        super().__init__()
        self.backbone = E5LoRABackbone(model_name, lora_config)
        self.head = SimpleMultiVectorHead(num_vectors=2, input_dim=self.backbone.config.hidden_size)
    
    def forward(self, input_ids, attention_mask, **kargs):
        sequence_output = self.backbone(input_ids, attention_mask) # (B, L, D)
        embeddings = self.head(sequence_output, attention_mask) # (B, k, D)
        return F.normalize(embeddings, p=2, dim=2) # contrastive loss Í≥ÑÏÇ∞ÌïòÎ†§Î©¥ ÌïÑÏàò

In [10]:
model_name = "intfloat/e5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)

teacher_model = AutoModel.from_pretrained(model_name)
teacher_model.eval()
for param in teacher_model.parameters():
    param.requires_grad = False 

lora = {'r': 16, 'alpha': 32, 'dropout': 0.1}
model = BookEmbeddingModel(model_name, lora)
device = "cuda" if torch.cuda.is_available() else "cpu"
teacher_model.to(device)
model.to(device)

BookEmbeddingModel(
  (backbone): E5LoRABackbone(
    (encoder): PeftModelForFeatureExtraction(
      (base_model): LoraModel(
        (model): BertModel(
          (embeddings): BertEmbeddings(
            (word_embeddings): Embedding(30522, 384, padding_idx=0)
            (position_embeddings): Embedding(512, 384)
            (token_type_embeddings): Embedding(2, 384)
            (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder): BertEncoder(
            (layer): ModuleList(
              (0-11): 12 x BertLayer(
                (attention): BertAttention(
                  (self): BertSdpaSelfAttention(
                    (query): lora.Linear(
                      (base_layer): Linear(in_features=384, out_features=384, bias=True)
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.1, inplace=False)
                      )
               

In [11]:
# book_path = './data/e5_book_meta.parquet'
book_path = './test/book_meta.parquet'
books = pd.read_parquet(book_path)

In [12]:
def build_text(row): # ÏûÖÎ†• ÌÖçÏä§Ìä∏ ÏÉùÏÑ± (ÌÉÄÏù¥ÌãÄ + ÏÑ§Î™Ö + Ï†ÄÏûê Îì± Í≤∞Ìï©)
    parts = [
        f"Title: {row['title']} |",
        # f"Category: {row['category']} |", # oracle
        f"Description: {row['description']}"
    ]
    return " ".join( # Î¶¨Ïä§Ìä∏Ïùò Î¨∏ÏûêÏó¥Îì§ÏùÑ Í≥µÎ∞±ÏúºÎ°ú Ïó∞Í≤∞Ìï†Í±¥Îç∞.....
        [p for p in parts if isinstance(p, str)] # NaNÏù¥ÎÇò NoneÏù¥ ÏûàÏúºÎ©¥ Ï†úÏô∏Ìï®
    ) # ÏµúÏ¢ÖÏ†ÅÏúºÎ°ú ÌïòÎÇòÏùò Î¨∏Ïû• ÌòïÌÉúÎ°ú Î∞òÌôòÌïúÎã§Í≥† Ìï®!! "Title: ... Category: ... Description: ..."

books["text"] = books.apply(build_text, axis=1) # ÏÉà Ïª¨Îüº textÏóê ÎåÄÌï¥ÏÑú.... Î¨∏Ïû• ÎßåÎì¶

# 100Í∞ú ÎØ∏ÎßåÏù∏ Ïπ¥ÌÖåÍ≥†Î¶¨Îäî ÎÖ∏Ïù¥Ï¶àÎ°ú Í∞ÑÏ£ºÌïòÍ≥† ÏÇ≠Ï†ú
counts = books['category'].value_counts()
valid_categories = counts[counts > 100].index
books = books[books['category'].isin(valid_categories)]

In [13]:
dataset = Dataset.from_pandas(books)

le = LabelEncoder()
le.fit(dataset['category'])   # Ï†ÑÏ≤¥ Îç∞Ïù¥ÌÑ∞Î°ú ÌïôÏäµ

def encode_label(x):
    return {"label": le.transform([x["category"]])[0]}

dataset = dataset.map(encode_label)

num_classes = len(le.classes_)

Map:   0%|          | 0/81845 [00:00<?, ? examples/s]

collate_fnÏùÄ raw textÏôÄ labelÏùÑ ÌÖêÏÑúÎ°ú Î¨∂Ïñ¥ Î™®Îç∏Ïù¥ ÌïôÏäµÌï† Ïàò ÏûàÎäî ÌòïÌÉúÎ°ú ÎßåÎì§Ïñ¥Ï§å
DataLoaderÎäî Ïù¥ Ìï®ÏàòÎ°ú ÎØ∏Î¶¨ Ï†ÑÏ≤òÎ¶¨Ìïú batchÎ•º Î™®Îç∏Ïóê Í≥µÍ∏âÌïòÎäî Ïó≠Ìï†ÏùÑ Ìï®
```
Dataset row(dict)
     ‚Üì (DataLoader)
batch = [row1, row2, ...] (list)
     ‚Üì (collate_fn)
ÌÖçÏä§Ìä∏ Î¶¨Ïä§Ìä∏ + ÎùºÎ≤® Î¶¨Ïä§Ìä∏
     ‚Üì (tokenizer)
input_ids, attention_mask (tensor)
     ‚Üì
(inputs, labels)
     ‚Üì
model(**inputs)
```

In [14]:
# Transformer Î™®Îç∏ÏùÄ Ïù¥Îü∞ raw ÌÖçÏä§Ìä∏Î•º Î∞îÎ°ú Ï≤òÎ¶¨ Î™ª ÌïòÍ≥†
# ÌÜ†ÌÅ¨ÎÇòÏù¥Ï†ÄÎ•º Í±∞Ï≥ê tensor(batch_input_ids, batch_attention_mask) ÌòïÌÉúÍ∞Ä ÌïÑÏöîÌï®.
def collate_fn(batch): # DataLoaderÍ∞Ä batchÎßàÎã§ Ìò∏Ï∂ú
    # texts = [f"passage: {x['text']}" for x in batch]
    texts = [f"query: {x['text']}" for x in batch]
    labels = torch.tensor([x['label'] for x in batch])  # ÎùºÎ≤®ÏùÑ int Î¶¨Ïä§Ìä∏ ‚Üí torch.tensor Î°ú Î≥ÄÌôò

    """
    ÌÜ†ÌÅ¨ÎÇòÏù¥Ï†Ä:
    ÌÖçÏä§Ìä∏Î•º token idÎ°ú Î≥ÄÌôò (input_ids), attention_mask ÏÉùÏÑ±,
    batchÏùò ÏµúÎåÄ lengthÏóê ÎßûÏ∂∞ Ìå®Îî©, Ï∂úÎ†• ÌÉÄÏûÖÏùÄ PyTorch tensor

    { 'input_ids': tensor([[101,  ... , 102], ...]),
      'attention_mask': tensor([[1,1,1,0,0...], ...) }
    """
    inputs = tokenizer(
      texts, padding=True, truncation=True, max_length=256, return_tensors="pt")

    return inputs, labels

In [15]:
total_len = len(dataset)
train_len = int(total_len * 0.8)
valid_len = total_len - train_len

train_dataset, valid_dataset = random_split(dataset, [train_len, valid_len])

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn
)
valid_loader = DataLoader(
    valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn
)

ÏßÄÍ∏àÏùò Íµ¨Ï°∞Îäî contrastive loss, genre/content KD, dynamic Œ±, GradNorm, multi-vector head Îì±Ïù¥ ÏÑúÎ°ú ÏñΩÌòÄ ÏûàÏñ¥ LR Î≥ÄÌôîÏóê Îß§Ïö∞ ÎØºÍ∞êÌï®

- contrastive(genre) ÎπÑÏ§ëÏù¥ ÌÅ∞ Œ±=0.8 ÏÉÅÌÉú
- genre/content ÏñëÎ∞©Ìñ• KD ÏïïÎ†•
- multi-vector head Ï†ïÎ†¨ Ï†Ñ
- GradNorm ÎπÑÏú®Ïù¥ ÏïÑÏßÅ ÏàòÎ†¥ Ïïà Îê®
- Œ± decayÍ∞Ä ÏïÑÏßÅ Ï†ÅÏö© Ï†Ñ(Ïû•Î•¥ ÎπÑÏ§ë Í≥ºÎã§)

warmupÏùÑ ÎÑàÎ¨¥ ÏßßÍ≤å(0\~2%) Ï£ºÎ©¥ learning rateÍ∞Ä ÏßÄÎÇòÏπòÍ≤å Îπ®Î¶¨ ÏÉÅÏäπÌï¥ Ï¥àÍ∏∞ Îã®Í≥ÑÏóêÏÑú gradient explosionÏù¥ÎÇò embedding Î∂ïÍ¥¥Í∞Ä ÏùºÏñ¥ÎÇòÍ≥†, Î∞òÎåÄÎ°ú ÎÑàÎ¨¥ Í∏∏Í≤å(10\~20%) Ï£ºÎ©¥ LRÏù¥ Ï≤úÏ≤úÌûà Ïò¨ÎùºÍ∞ÄÎäî ÎèôÏïà contrastive Ï™Ω ÌëúÌòÑ ÌïôÏäµÏù¥ ÏßÄÏó∞ÎêòÍ≥† embedding collapse ÏúÑÌóòÏù¥ Ïª§Ï†∏ ÏµúÏ¢Ö ÏÑ±Îä•(MRR, top-1)ÍπåÏßÄ Îñ®Ïñ¥Ïßê

Ïä§ÏºÄÏ•¥Îü¨ ÏûêÏ≤¥ÎèÑ Î∞îÍø®ÎäîÎç∞, Linear decayÎäî warmup Ïù¥ÌõÑÏóê LRÏù¥ ÏßÅÏÑ†ÏúºÎ°ú Í∏âÍ≤©Ìûà Îñ®Ïñ¥Ïßê. Ïù¥Îäî ÌïôÏäµ ÌõÑÎ∞òÎ∂Ä KD alignmentÏôÄ contrastive alignmentÏùò ÎØ∏ÏÑ∏ Ï°∞Ï†ïÏùÑ ÎßâÏïÑ ÌïôÏäµÏù¥ ÏÇ¨Ïã§ÏÉÅ Î©àÏ∂îÎäî Î¨∏Ï†úÍ∞Ä ÏûàÏùå. Î∞òÎ©¥ Cosine Ïä§ÏºÄÏ§ÑÎü¨Îäî warmup Ïù¥ÌõÑ LRÏùÑ ÏôÑÎßåÌïú Í≥°ÏÑ† ÌòïÌÉúÎ°ú Í∞êÏÜåÏãúÌÇ§Í∏∞ ÎïåÎ¨∏Ïóê, Ï¥àÎ∞òÏóêÎäî ÏïàÏ†ïÏÑ±ÏùÑ Ï†úÍ≥µÌïòÍ≥†, Ï§ëÎ∞òÏóêÎäî Ï∂©Î∂ÑÌïú LRÏùÑ Ïú†ÏßÄÌïòÎ©∞, ÌõÑÎ∞òÏóêÎèÑ ÏûëÏùÄ Ìè≠Ïù¥ÏßÄÎßå ÏùòÎØ∏ ÏûàÎäî ÏóÖÎç∞Ïù¥Ìä∏Î•º Ïù¥Ïñ¥Í∞à Ïàò ÏûàÏùÑ Í≤ÉÏûÑ

In [16]:
total_steps = len(train_loader) * EPOCHS

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
# AdamW ÏòµÌã∞ÎßàÏù¥Ï†ÄÎ°ú LoRA ÌååÎùºÎØ∏ÌÑ∞Îßå ÌïôÏäµ
# LoRA ÎçïÎ∂ÑÏóê Ïã§Ï†ú ÏóÖÎç∞Ïù¥Ìä∏ÎêòÎäî ÌååÎùºÎØ∏ÌÑ∞Îäî Ï†ÑÏ≤¥Ïùò 1% Ï†ïÎèÑ

# scheduler = get_linear_schedule_with_warmup(
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(total_steps * 0.1), # WARMUP_RATIO), ÏõêÎûò linear ÏùºÎïå ÏçªÎçò Í∞íÏùÄÎç∞ 0.1Ïù¥ÏóàÍ±∞Îì†? 0.05Î°ú Ï§ÑÏù¥Îûò
    num_training_steps=total_steps,
)

In [17]:
import torch
import torch.nn.functional as F

def compute_retrieval_accuracy(model, dataloader, device, tokenizer, k=10):
    """
    Multi-Vector Î™®Îç∏Ïö© Í≤ÄÏ¶ù Ìï®Ïàò
    * Vector 0 (Genre Vector)Îßå ÏÇ¨Ïö©ÌïòÏó¨ Ïπ¥ÌÖåÍ≥†Î¶¨ Í≤ÄÏÉâ ÏÑ±Îä•ÏùÑ Ï∏°Ï†ïÌï©ÎãàÎã§.
    """
    embeddings_list = []
    labels_list = []

    model.eval() # ÌèâÍ∞Ä Î™®Îìú ÌïÑÏàò!

    with torch.no_grad():
        for batch_inputs, labels in dataloader:
            # 1. ÏûÖÎ†• Îç∞Ïù¥ÌÑ∞ Ï§ÄÎπÑ
            # input_ids = batch_inputs.to(device)
            batch_inputs = {k: v.to(device) for k, v in batch_inputs.items()}
            labels = labels.to(device)

            # 2. Î™®Îç∏ Forward (Multi-Vector Ï∂úÎ†•)
            # output shape: (Batch, K, 768)
            multi_vectors = model(**batch_inputs)

            # 3. ‚òÖÌïµÏã¨‚òÖ: Vector 0 (Ïû•Î•¥ Î≤°ÌÑ∞)Îßå Ï∂îÏ∂ú
            # Ïö∞Î¶¨Í∞Ä Í≤ÄÏ¶ùÌïòÎ†§Îäî Í±¥ "Ïπ¥ÌÖåÍ≥†Î¶¨(Ïû•Î•¥)Î•º Ïûò ÎßûÏ∂îÎäîÍ∞Ä?" Ïù¥ÎØÄÎ°ú 0Î≤àÎßå ÏîÄ
            genre_embeddings = multi_vectors[:, 0, :]
            genre_embeddings = F.normalize(genre_embeddings, p=2, dim=1)

            embeddings_list.append(genre_embeddings)
            labels_list.append(labels)

    # Ï†ÑÏ≤¥ Îç∞Ïù¥ÌÑ∞ Ìï©ÏπòÍ∏∞
    all_embeddings = torch.cat(embeddings_list, dim=0)
    all_labels = torch.cat(labels_list, dim=0)
    similarity = torch.matmul(all_embeddings, all_embeddings.T) # (N, N) ÌñâÎ†¨ Í≥ÑÏÇ∞
    similarity.fill_diagonal_(-1e9) # ÏûêÍ∏∞ ÏûêÏã† Ï†úÏô∏ (-Î¨¥ÌïúÎåÄ ÎßàÏä§ÌÇπ)

    _, topk_idx = similarity.topk(k, dim=1)  # top-k neighbor Ïù∏Îç±Ïä§
    nn_labels_topk = all_labels[topk_idx] # Ïù¥ÏõÉÎì§Ïùò ÎùºÎ≤® Í∞ÄÏ†∏Ïò§Í∏∞ (N, k)

    # Ï†ïÎãµ Ïó¨Î∂Ä ÌñâÎ†¨ (True/False)
    # ÎÇ¥ ÎùºÎ≤®(all_labels)Í≥º Ïù¥ÏõÉ ÎùºÎ≤®(nn_labels_topk)Ïù¥ Í∞ôÏùÄÏßÄ ÎπÑÍµê
    # unsqueeze(1)Î°ú (N, 1) vs (N, k) Î∏åÎ°úÎìúÏ∫êÏä§ÌåÖ ÎπÑÍµê
    hits = (nn_labels_topk == all_labels.unsqueeze(1)) # (N, k) Bool Tensor

    # -----------------------------
    # top-1: Í∞ÄÏû• Í∞ÄÍπåÏö¥ neighbor Ï†ïÎãµ Ïó¨Î∂Ä
    # top-k: kÍ∞ú ÏïàÏóê Ï†ïÎãµÏù¥ ÌïòÎÇòÎùºÎèÑ ÏûàÏúºÎ©¥
    # precision@k: kÍ∞ú neighbor Ï§ë Ï†ïÎãµ ÎπÑÏú® ÌèâÍ∑†
    # MRR: Ï†ïÎãµÏù¥ rank Î™á Î≤àÏß∏Ïù∏ÏßÄÏóê Îî∞Î•∏ ÌèâÍ∑† Ïó≠Ïàò ‚Üí rank 1Ïù¥Î©¥ 1.0, rank 2Ïù¥Î©¥ 0.5
    # -----------------------------
    # Top-1 accuracy
    top1_labels = nn_labels_topk[:, 0]
    top1_acc = (top1_labels == all_labels).float().mean().item()

    # Top-k accuracy (at least 1 match)
    topk_match = (nn_labels_topk == all_labels.unsqueeze(1)).any(dim=1).float()
    topk_acc = topk_match.mean().item()

    # Precision@k
    precision_at_k = (nn_labels_topk == all_labels.unsqueeze(1)).float().mean(dim=1).mean().item()

    # MRR (Mean Reciprocal Rank)
    ranks = (nn_labels_topk == all_labels.unsqueeze(1)).float() # Ï†ïÎãµ label ÏúÑÏπò Ï∞æÍ∏∞
    reciprocal_rank = []
    for i in range(ranks.size(0)):
        pos_positions = torch.nonzero(ranks[i]).flatten()
        if len(pos_positions) == 0: # positive ÏóÜÏúºÎ©¥ reciprocal rank = 0
            reciprocal_rank.append(0.0)
        else:
            reciprocal_rank.append(1.0 / (pos_positions[0].item() + 1))
    mrr = sum(reciprocal_rank) / len(reciprocal_rank)

    metrics = {
        "top1_acc": top1_acc,
        "topk_acc": topk_acc,
        "precision@k": precision_at_k,
        "mrr": mrr
    }

    return metrics

In [18]:
def hard_negative(embeddings, labels, neg_ratio=0.1):
    batch_size = embeddings.size(0)
    k = max(3, int(batch_size * neg_ratio))
    similarity = torch.matmul(embeddings, embeddings.T)

    pos_mask = labels.unsqueeze(1) == labels.unsqueeze(0)

    sim_for_neg = similarity.clone()
    sim_for_neg.masked_fill_(pos_mask, -1e9)
    hard_neg_sims, _ = sim_for_neg.topk(k, dim=1) # (N, k)

    return hard_neg_sims

In [19]:
def noisy_contrastive_loss(embeddings, labels):

    similarity = torch.matmul(embeddings, embeddings.T)

    labels_eq = labels.unsqueeze(1) == labels.unsqueeze(0)
    identity_mask = torch.eye(len(labels), device=labels.device).bool() # ÏûêÍ∏∞ ÏûêÏã† Ï†úÍ±∞ mask
    labels_eq = labels_eq & (~identity_mask)

    pos_mask = labels_eq.float()   # (N, N)

    pos_sim = similarity * pos_mask
    neg_sim = hard_negative(embeddings, labels, neg_ratio=NEG_RATIO)

    pos_sim = pos_sim / TEMPERATURE
    neg_sim = neg_sim / TEMPERATURE

    loss = -torch.log(
        torch.exp(pos_sim).sum(dim=1) /
        (torch.exp(pos_sim).sum(dim=1) + torch.exp(neg_sim).sum(dim=1))
    ).mean()

    return loss

In [21]:
def calc_grad_norm(loss, model_layer):
    """
    ÌäπÏ†ï LossÍ∞Ä ÌäπÏ†ï Î†àÏù¥Ïñ¥(model_layer)Ïùò ÌååÎùºÎØ∏ÌÑ∞Ïóê Í∞ÄÌïòÎäî
    GradientÏùò Ï¥ùÎüâ(Norm)ÏùÑ Í≥ÑÏÇ∞Ìï©ÎãàÎã§.
    """
    # 1. Ìï¥Îãπ Î†àÏù¥Ïñ¥Ïùò ÌååÎùºÎØ∏ÌÑ∞Îßå Í∞ÄÏ†∏Ïò¥ (requires_grad=TrueÏù∏ Í≤ÉÎßå)
    params = [p for p in model_layer.parameters() if p.requires_grad]

    if not params:
        return 0.0

    # 2. Gradient Í≥ÑÏÇ∞ (create_graph=False, retain_graph=True ÌïÑÏàò!)
    # retain_graph=True: Îí§Ïóê ÏßÑÏßú backward()Î•º Îòê Ìï¥Ïïº ÌïòÎØÄÎ°ú Í∑∏ÎûòÌîÑÎ•º ÎÇ†Î¶¨Î©¥ Ïïà Îê®
    grads = torch.autograd.grad(
        loss,
        params,
        retain_graph=True,
        allow_unused=True
    )

    # 3. Norm(ÌÅ¨Í∏∞) Ìï©ÏÇ∞ (L2 Norm)
    total_norm = 0.0
    for g in grads:
        if g is not None:
            total_norm += g.pow(2).sum().item()

    return total_norm ** 0.5

In [22]:
t = 200  # Ï∂©Îèå ÏãúÏûë ÏßÄÏ†ê
base_alpha = 0.8   # Ï¥àÎ∞ò: Ïû•Î•¥(Genre) Ï†ïÎ≥¥Î•º ÌôïÏã§Ìûà Ïû°Ïùå
target_alpha = 0.2 # ÌõÑÎ∞ò: Ï∂©Îèå ÌöåÌîºÎ•º ÏúÑÌï¥ Ïû•Î•¥ ÎπÑÏ§ëÏùÑ ÎÇÆÏ∂§ (Î≥∏Î¨∏ ÏßëÏ§ë)
steps_per_epoch = len(train_loader)
running_ratio = 1.0
beta = 0.95  # Í¥ÄÏÑ± Í≥ÑÏàò (ÌÅ¥ÏàòÎ°ù Î≥ÄÌôîÍ∞Ä Î∂ÄÎìúÎü¨ÏõÄ)

for epoch in range(EPOCHS):
    model.train()
    total_train_loss = 0

    for step, (batch_inputs, labels) in enumerate(tqdm(train_loader, desc = f"Epoch: {epoch+1}")):
        global_step = epoch * steps_per_epoch + step

        if global_step < t:
            alpha = base_alpha
        else:
            x = 5 * (global_step - t) / t
            sigmoid_x = 1 / (1 + math.exp(-x))

            decay_ratio = (sigmoid_x - 0.5) * 2
            if decay_ratio > 1.0: decay_ratio = 1.0

            alpha = base_alpha - (base_alpha - target_alpha) * decay_ratio

        batch_inputs = {k: v.to(device) for k, v in batch_inputs.items()}
        labels = labels.to(device)

        student_vectors = model(**batch_inputs)
        
        genre_vector = student_vectors[:, 0, :]
        content_vectors = student_vectors[:, 1, :]
        
        loss_cont = noisy_contrastive_loss(genre_vector, labels)
    
        with torch.no_grad():
            teacher_outputs = teacher_model(**batch_inputs)
            hidden = teacher_outputs.last_hidden_state       # (B, L, D)
            mask = batch_inputs['attention_mask'].unsqueeze(-1)  # (B, L, 1)
            teacher_embeddings = (hidden * mask).sum(dim=1) / mask.sum(dim=1)
            teacher_norm = F.normalize(teacher_embeddings, p=2, dim=1)

        genre_norm = F.normalize(genre_vector, p=2, dim=1)
        loss_kd_genre = F.mse_loss(genre_norm, teacher_norm)
        
        content_norm = F.normalize(content_vector, p=2, dim=1)
        loss_kd_content = F.mse_loss(content_norm, teacher_norm)
        
        loss_kd_combined = alpha * loss_kd_genre + (1 - alpha) * loss_kd_content

        norm_main = calc_grad_norm(loss_cont, model.head.attention)
        norm_sub = calc_grad_norm(loss_kd_combined, model.head.attention)
        target_scale = 0.6
        current_ratio = norm_main / (norm_sub + 1e-8) * target_scale
        if current_ratio > 1000.0: current_ratio = 1000.0
        running_ratio = beta * running_ratio + (1 - beta) * current_ratio

        total_loss = loss_cont + running_ratio * loss_kd_combined

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        scheduler.step()

        total_train_loss += total_loss.item()

    train_loss = total_train_loss / len(train_loader)

    model.eval()
    metrics = compute_retrieval_accuracy(model, valid_loader, device, tokenizer)
    print(f"[Epoch {epoch + 1}] Train Loss: {train_loss:.4f}")
    print(f"Top-1 Accuracy : {metrics['top1_acc']:.4f} | Top-10 Accuracy : {metrics['topk_acc']:.4f}")
    print(f"Precision@10   : {metrics['precision@k']:.4f} | MRR              : {metrics['mrr']:.4f}")

Epoch: 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [02:24<00:00,  3.56it/s]


[Epoch 1] Train Loss: 2.5856
Top-1 Accuracy : 0.5741 | Top-10 Accuracy : 0.8661
Precision@10   : 0.5402 | MRR              : 0.6755


Epoch: 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [02:23<00:00,  3.57it/s]


[Epoch 2] Train Loss: 2.4089
Top-1 Accuracy : 0.5984 | Top-10 Accuracy : 0.8666
Precision@10   : 0.5643 | MRR              : 0.6918


Epoch: 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [02:23<00:00,  3.57it/s]


[Epoch 3] Train Loss: 2.2554
Top-1 Accuracy : 0.6047 | Top-10 Accuracy : 0.8655
Precision@10   : 0.5744 | MRR              : 0.6960


Epoch: 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [02:23<00:00,  3.57it/s]


[Epoch 4] Train Loss: 2.1418
Top-1 Accuracy : 0.6098 | Top-10 Accuracy : 0.8625
Precision@10   : 0.5796 | MRR              : 0.6983


Epoch: 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [02:23<00:00,  3.57it/s]


[Epoch 5] Train Loss: 2.0630
Top-1 Accuracy : 0.6127 | Top-10 Accuracy : 0.8585
Precision@10   : 0.5899 | MRR              : 0.7004


Epoch: 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [02:23<00:00,  3.57it/s]


[Epoch 6] Train Loss: 2.0043
Top-1 Accuracy : 0.6127 | Top-10 Accuracy : 0.8570
Precision@10   : 0.5891 | MRR              : 0.7000


Epoch: 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 512/512 [02:23<00:00,  3.56it/s]


[Epoch 7] Train Loss: 1.9436
Top-1 Accuracy : 0.6137 | Top-10 Accuracy : 0.8553
Precision@10   : 0.5962 | MRR              : 0.6998


Epoch: 8:   9%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà                                                             | 46/512 [00:13<02:13,  3.48it/s]


KeyboardInterrupt: 

In [None]:
import os
save_path = f"./old_multi_vec_{LAMBDA}_20ep.pth"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(model.state_dict(), save_path)