In [None]:
import os
from datasets import Dataset
import torch
from torch.utils.data import DataLoader
from torch import amp
from tqdm import tqdm
from transformers import (
    AutoModel,
    AutoTokenizer,
    get_scheduler,
)
from peft import get_peft_model, LoraConfig
import json
import glob
from torch import nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

In [None]:
def build_triplets(data):
    triplets = []

    for anchor, info in data.items():
        pair = info['pair']
        is_double = info['is_double']

        positives = [p for p, d in zip(pair, is_double) if d == 1]
        negatives = [p for p, d in zip(pair, is_double) if d == 0]

        assert len(negatives) >= 6 * len(positives), f"Not enough negatives for anchor {anchor}"

        for i, pos in enumerate(positives):
            start = i * 6
            end = start + 6
            negs_for_pos = negatives[start:end]
            for neg in negs_for_pos:
                triplets.append((anchor, pos, neg))

    return triplets

In [None]:
class TripletDataset(Dataset):
    def __init__(self, triplets, tokenized_data):
        self.triplets = triplets
        self.tokenized_data = tokenized_data

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        a, p, n = self.triplets[idx]
        da = self.tokenized_data[a]
        dp = self.tokenized_data[p]
        dn = self.tokenized_data[n]
        return {
            'anchor_input_ids':    da['input_ids'].squeeze(0),
            'anchor_attention_mask': da['attention_mask'].squeeze(0),
            'positive_input_ids':  dp['input_ids'].squeeze(0),
            'positive_attention_mask': dp['attention_mask'].squeeze(0),
            'negative_input_ids':  dn['input_ids'].squeeze(0),
            'negative_attention_mask': dn['attention_mask'].squeeze(0),
        }

In [None]:
def extract_numbers(filename):
    # Получаем числа из строки
    parts = filename.replace(".pt", "").split("_")
    return int(parts[2]), int(parts[4])

In [None]:
def tokenize_texts_in_dict(data: dict, tokenizer, max_length: int = 384):
    for key, text in tqdm(data.items()):
        if 'input_ids' in text:
            continue
        tokens = tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )

        tokens.pop("token_type_ids", None)
        # Убираем batch размерность (1, seq_len) -> (seq_len,)
        data[key] = {
            k: v.squeeze(0) for k, v in tokens.items()
        }
    return data

In [None]:
class AttnPooling(nn.Module):
    """
    Learnable attention pooling: весит каждый токен по-разному.
    """
    def __init__(self, hidden_size: int = 312):
        super().__init__()
        # "Запрос" q, который обучается вместе с остальной сеткой
        self.q = nn.Parameter(torch.randn(hidden_size))

    def forward(self,
                hidden_states: torch.Tensor,
                attention_mask: torch.Tensor) -> torch.Tensor:
        """
        hidden_states : (B, T, 312)
        attention_mask: (B, T)
        """
        # (B,T) — скалярное произведение q и каждого токена
        scores = (hidden_states @ self.q) / hidden_states.size(-1) ** 0.5
        scores = scores.masked_fill(attention_mask == 0, -1e4)  # exclude pads
        weights = F.softmax(scores, dim=1).unsqueeze(-1)        # (B,T,1)

        pooled = (weights * hidden_states).sum(dim=1)           # (B,312)
        return pooled

In [None]:
def get_embeddings(model, attn_pool, input_ids, attention_mask):
    out = model(input_ids=input_ids, attention_mask=attention_mask)
    pooled = attn_pool(out.last_hidden_state.float(), attention_mask) # (B, 312)
    return pooled

In [None]:
os.chdir('avitotech_data/avitotech_data')

In [None]:
with open('to_undergo.json', 'r') as f:
    ids_trunc = json.load(f)

with open("cards_train.json", "r") as file:
    cards_train = json.load(file)

In [None]:
triplets = build_triplets(ids_trunc)
train_triplets, val_triplets = train_test_split(triplets, test_size=0.1, random_state=42)

tokenized_data_text = tokenize_texts_in_dict(cards_train, tokenizer)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = "cointegrated/rubert-tiny2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

lora_config = LoraConfig(
    r=8,  # размер low-rank матриц A и B
    lora_alpha=16,
    target_modules=["query", "value"],  # модули attention, которые дообучаем
    lora_dropout=0.1,
    bias="none",
    task_type="FEATURE_EXTRACTION"
)

model_lora = get_peft_model(model, lora_config)

In [None]:
train_dataset = TripletDataset(train_triplets, tokenized_data_text)
val_dataset   = TripletDataset(val_triplets, tokenized_data_text)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=8, shuffle=False)

attn_pool = AttnPooling(312)
attn_pool = attn_pool.to(device)

In [None]:
optimizer = torch.optim.AdamW(
    list(attn_pool.parameters()) +
    list(model_lora.parameters()),
    lr=2e-5
)
criterion = nn.TripletMarginLoss(margin=2, p=2)

epochs = 3
gradient_accumulation_steps = 5
num_training_steps = epochs * len(train_loader)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)
scaler = amp.GradScaler()

In [None]:
os.makedirs("trained_text_models", exist_ok=True)

os.chdir("trained_text_models")

In [None]:
models = glob.glob('lora_triplet_*_step_*.pt')

if models:
    last_chpt = max(models, key=extract_numbers)
    checkpoint = torch.load(last_chpt, map_location=device)
    
    model_lora.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
model_lora = model_lora.to(device)

In [None]:
for epoch in range(1, epochs + 1):
    model_lora.train()
    total_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}")

    cnt = len(train_loader)
    intern_loss = 0

    optimizer.zero_grad()
    for step, batch in enumerate(progress_bar, start=1):
        
        with amp.autocast(device_type='cuda', dtype=torch.float16):
            anchor_emb   = get_embeddings(model_lora, attn_pool, batch['anchor_input_ids'].to(device),   batch['anchor_attention_mask'].to(device))
            positive_emb = get_embeddings(model_lora, attn_pool, batch['positive_input_ids'].to(device), batch['positive_attention_mask'].to(device))
            negative_emb = get_embeddings(model_lora, attn_pool, batch['negative_input_ids'].to(device), batch['negative_attention_mask'].to(device))

            loss = criterion(anchor_emb, positive_emb, negative_emb) / gradient_accumulation_steps

        intern_loss += loss
        scaler.scale(loss).backward()
        total_loss += loss.item()

        if step % gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            lr_scheduler.step()

            progress_bar.set_description(
                f"Epoch {epoch+1} | Loss: {intern_loss.item():.5f}"
            )
            intern_loss = 0
            optimizer.zero_grad()

        if (step) % (cnt // 100) == 0:
            torch.save({
                'model_state_dict': model_lora.state_dict(),
                'attn_state_dict': attn_pool.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
            }, f'lora_triplet_{epoch + 1}_step_{step}.pt')

    model_lora.eval()
    val_loss = 0.0
    
    with torch.no_grad():
        for batch in val_loader:
            with amp.autocast(device_type='cuda', dtype=torch.float16):
                anchor_emb   = get_embeddings(model_lora, attn_pool, batch['anchor_input_ids'].to(device), batch['anchor_attention_mask'].to(device))
                positive_emb = get_embeddings(model_lora, attn_pool, batch['positive_input_ids'].to(device), batch['positive_attention_mask'].to(device))
                negative_emb = get_embeddings(model_lora, attn_pool, batch['negative_input_ids'].to(device), batch['negative_attention_mask'].to(device))
    
                loss = criterion(anchor_emb, positive_emb, negative_emb)
                val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch + 1} completed — Avg Loss: {avg_loss:.4f}")