# Training pair of models

In [46]:
import os
import torch
import glob
from sklearn.model_selection import train_test_split
from torch import nn, amp
from torch.utils.data import Dataset, DataLoader
import json
from transformers import (
    AutoModel,
    AutoTokenizer,
    get_scheduler,
)
from tqdm import tqdm
import torch.nn.functional as F

In [3]:
os.chdir("avitotech_data\\avitotech_data")

In [4]:
def build_triplets(data):
    triplets = []

    for anchor, info in data.items():
        pair = info['pair']
        is_double = info['is_double']

        positives = [p for p, d in zip(pair, is_double) if d == 1]
        negatives = [p for p, d in zip(pair, is_double) if d == 0]

        assert len(negatives) >= 6 * len(positives), f"Not enough negatives for anchor {anchor}"

        for i, pos in enumerate(positives):
            start = i * 6
            end = start + 6
            negs_for_pos = negatives[start:end]
            for neg in negs_for_pos:
                triplets.append((anchor, pos, neg))

    return triplets

In [5]:
def tokenize_texts_in_dict(data: dict, tokenizer, max_length: int = 384):
    for key, text in tqdm(data.items()):
        if 'input_ids' in text:
            continue
        tokens = tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )

        tokens.pop("token_type_ids", None)
        # Убираем batch размерность (1, seq_len) -> (seq_len,)
        data[key] = {
            k: v.squeeze(0) for k, v in tokens.items()
        }
    return data

In [None]:
def extract_numbers(filename):
    # Получаем числа из строки
    parts = filename.replace(".pt", "").split("_")
    return int(parts[2]), int(parts[4])

In [9]:
class AttnPooling(nn.Module):
    """
    Learnable attention pooling: весит каждый токен по-разному.
    """
    def __init__(self, hidden_size: int = 312):
        super().__init__()
        # "Запрос" q, который обучается вместе с остальной сеткой
        self.q = nn.Parameter(torch.randn(hidden_size))

    def forward(self,
                hidden_states: torch.Tensor,
                attention_mask: torch.Tensor) -> torch.Tensor:
        """
        hidden_states : (B, T, 312)
        attention_mask: (B, T)
        """
        # (B,T) — скалярное произведение q и каждого токена
        scores = (hidden_states @ self.q) / hidden_states.size(-1) ** 0.5
        scores = scores.masked_fill(attention_mask == 0, -1e4)  # exclude pads
        weights = F.softmax(scores, dim=1).unsqueeze(-1)        # (B,T,1)

        pooled = (weights * hidden_states).sum(dim=1)           # (B,312)
        return pooled

In [10]:
class PairEncoder(nn.Module):
    def __init__(self, d=512):
        super().__init__()
        self.proj_img = nn.Sequential(nn.Linear(768, d), nn.GELU(), nn.LayerNorm(d))
        self.proj_txt = nn.Sequential(nn.Linear(312, d), nn.GELU(), nn.LayerNorm(d))
        self.fuse     = nn.Sequential(
            nn.Linear(2*d, 2*d), nn.GELU(),
            nn.Linear(2*d, d),   nn.LayerNorm(d)
        )

    def forward(self, v_img, v_txt):
        z_i = F.normalize(self.proj_img(v_img), dim=-1)
        z_t = F.normalize(self.proj_txt(v_txt), dim=-1)
        pair = torch.cat([z_i, z_t], dim=-1)
        return F.normalize(self.fuse(pair), dim=-1)   # (B,d)

In [11]:
def get_embeddings(model, attn_pool, encoder, input_ids, attention_mask, image_embed):
    out = model(input_ids=input_ids, attention_mask=attention_mask)
    pooled = attn_pool(out.last_hidden_state.float(), attention_mask)
    return encoder(image_embed, pooled)

In [14]:
class TripletDataset(Dataset):
    def __init__(self, triplets, tokenized_data_text, tokenized_data_img):
        self.triplets = triplets
        self.tokenized_data_img = tokenized_data_img
        self.tokenized_data_text = tokenized_data_text

        self.keys_img = set(tokenized_data_img.keys())
        self.keys_text = set(tokenized_data_text.keys())
        self.text_empty_vector = torch.zeros(512)
        self.img_empty_vector = torch.zeros(768)

    def get_img(self, key):
        return self.tokenized_data_img.get(key, self.text_empty_vector)

    def get_text(self, key):
        return self.tokenized_data_text.get(key, {
            'input_ids': self.text_empty_vector,
            'attention_mask': self.text_empty_vector
        })
        
    def __getitem__(self, idx):
        a, p, n = self.triplets[idx]
        img_a = self.get_img(a)
        img_p = self.get_img(p)
        img_n = self.get_img(n)

        text_a = self.get_text(a)
        text_p = self.get_text(p)
        text_n = self.get_text(n)
        return {
            'anchor_input_ids':        text_a['input_ids'].squeeze(0),
            'anchor_attention_mask':   text_a['attention_mask'].squeeze(0),
            'anchor_encoded_image':    img_a.squeeze(0),
            'positive_input_ids':      text_p['input_ids'].squeeze(0),
            'positive_attention_mask': text_p['attention_mask'].squeeze(0),
            'positive_encoded_image':  img_p.squeeze(0),
            'negative_input_ids':      text_n['input_ids'].squeeze(0),
            'negative_attention_mask': text_n['attention_mask'].squeeze(0),
            'negative_encoded_image':  img_n.squeeze(0),
        }

    def __len__(self):
        return len(self.triplets)

## Loading image emb

In [None]:
PATH = 'unzipped\\train_images_embeddings_merged.pt'
tokenized_data_img = torch.load(PATH, map_location="cpu")

## Loading trained text model 

In [23]:
models = glob.glob('trained_text_models\\lora_triplet_*_step_*.pt')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = "cointegrated/rubert-tiny2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

lora_config = LoraConfig(
    r=8,  # размер low-rank матриц A и B
    lora_alpha=16,
    target_modules=["query", "value"],  # модули attention, которые дообучаем
    lora_dropout=0.1,
    bias="none",
    task_type="FEATURE_EXTRACTION"
)

model_lora = get_peft_model(model, lora_config)
attn_pool = AttnPooling(312)

In [None]:
last_chpt = max(models, key=extract_numbers)
checkpoint = torch.load(last_chpt, map_location=device)

model_lora.load_state_dict(checkpoint['model_state_dict'])
attn_pool.load_state_dict(checkpoint['attn_state_dict'])
model_lora = model_lora.to(device)
attn_pool = attn_pool.to(device)
model_lora.eval()
attn_pool.eval()

### Freeze model

In [None]:
for param in model_lora.parameters():
    param.requires_grad = False

for param in attn_pool.parameters():
    param.requires_grad = False

## Loading train data

In [None]:
with open('to_undergo.json', 'r') as f:
    ids_trunc = json.load(f)

with open("cards_train.json", "r") as file:
    cards_train = json.load(file)

In [None]:
triplets = build_triplets(ids_trunc)
train_triplets, val_triplets = train_test_split(triplets, test_size=0.1, random_state=42)

tokenized_data_text = tokenize_texts_in_dict(cards_train, tokenizer)

In [None]:
train_dataset = TripletDataset(train_triplets, tokenized_data_text, tokenized_data_img)
val_dataset   = TripletDataset(val_triplets, tokenized_data_text, tokenized_data_img)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=8, shuffle=False)

## Settings

In [None]:
encoder = PairEncoder(512)
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, encoder.parameters()),
    lr=2e-5
)
criterion = nn.TripletMarginLoss(margin=2, p=2)

epochs = 3
gradient_accumulation_steps = 5
num_training_steps = epochs * len(train_loader)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)
scaler = amp.GradScaler()

In [None]:
os.makedirs("trained_encoder", exist_ok=True)

os.chdir("trained_encoder")

In [None]:
models = glob.glob('encoder_*_step_*.pt')

if models:
    last_chpt = max(models, key=extract_numbers)
    checkpoint = torch.load(last_chpt, map_location=device)
    
    encoder.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])

encoder = encoder.to(device)

In [None]:
for epoch in range(1, epochs + 1):
    encoder.train()
    total_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}")

    cnt = len(train_loader)
    intern_loss = 0

    optimizer.zero_grad()
    for step, batch in enumerate(progress_bar, start=1):
        
        with amp.autocast(device_type='cuda', dtype=torch.float16):
            anchor_emb   = get_embeddings(model_lora, attn_pool, encoder, batch['anchor_input_ids'].to(device),   
                                          batch['anchor_attention_mask'].to(device), batch['anchor_encoded_image'].to(device))
            positive_emb = get_embeddings(model_lora, attn_pool, encoder, batch['positive_input_ids'].to(device),
                                          batch['positive_attention_mask'].to(device), batch['positive_encoded_image'].to(device))
            negative_emb = get_embeddings(model_lora, attn_pool, encoder, batch['negative_input_ids'].to(device),
                                          batch['negative_attention_mask'].to(device), batch['negative_encoded_image'].to(device))

            loss = criterion(anchor_emb, positive_emb, negative_emb) / gradient_accumulation_steps

        intern_loss += loss
        scaler.scale(loss).backward()
        total_loss += loss.item()

        if step % gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            lr_scheduler.step()

            progress_bar.set_description(
                f"Epoch {epoch+1} | Loss: {intern_loss.item():.5f}"
            )
            intern_loss = 0
            optimizer.zero_grad()

        if (step) % (cnt // 100) == 0:
            torch.save({
                'encoder_state_dict': encoder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
            }, f'encoder_{epoch + 1}_step_{step}.pt')

    model_lora.eval()
    val_loss = 0.0
    
    with torch.no_grad():
        for batch in val_loader:
            with amp.autocast(device_type='cuda', dtype=torch.float16):
                anchor_emb   = get_embeddings(model_lora, attn_pool, encoder, batch['anchor_input_ids'].to(device),   
                                              batch['anchor_attention_mask'].to(device), batch['anchor_encoded_image'].to(device))
                positive_emb = get_embeddings(model_lora, attn_pool, encoder, batch['positive_input_ids'].to(device),
                                              batch['positive_attention_mask'].to(device), batch['positive_encoded_image'].to(device))
                negative_emb = get_embeddings(model_lora, attn_pool, encoder, batch['negative_input_ids'].to(device),
                                              batch['negative_attention_mask'].to(device), batch['negative_encoded_image'].to(device))
    
                loss = criterion(anchor_emb, positive_emb, negative_emb)
                val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss:.4f}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1} completed — Avg Loss: {avg_loss:.4f}")

# Extracting embeddings

In [41]:
class ProvidedDataset(Dataset):
    def __init__(self, unqiue_ids, tokenized_data_text, tokenized_data_img):
        self.unqiue_ids = list(unqiue_ids)
        self.tokenized_data_img = tokenized_data_img
        self.tokenized_data_text = tokenized_data_text

        self.keys_img = set(tokenized_data_img.keys())
        self.keys_text = set(tokenized_data_text.keys())
        self.text_empty_vector = torch.zeros(512)
        self.img_empty_vector = torch.zeros(768)

    def get_img(self, key):
        return self.tokenized_data_img.get(key, self.text_empty_vector)

    def get_text(self, key):
        return self.tokenized_data_text.get(key, {
            'input_ids': self.text_empty_vector,
            'attention_mask': self.text_empty_vector
        })
        
    def __getitem__(self, idx):
        item_id = self.unqiue_ids[idx]
        img_emb = self.get_img(item_id)
        text_emb = self.get_text(item_id)
        
        return {
            'item_id':          item_id,
            'input_ids':        text_emb['input_ids'].squeeze(0),
            'attention_mask':   text_emb['attention_mask'].squeeze(0),
            'encoded_image':    img_emb.squeeze(0),
        }

    def __len__(self):
        return len(self.unqiue_ids)

In [42]:
unqiue_ids = set(e for triplet in triplets for e in triplet)
dataset = ProvidedDataset(unqiue_ids, tokenized_data_text, tokenized_data_img)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

In [39]:
for param in encoder.parameters():
    param.requires_grad = False

In [None]:
embedding_dict = {}

encoder.eval()
with torch.no_grad():
    for batch in dataloader:
        with amp.autocast(device_type='cuda', dtype=torch.float16):
            embs = get_embeddings(
                model_lora, attn_pool, encoder,
                batch['input_ids'].to(device),
                batch['attention_mask'].to(device),
                batch['encoded_image'].to(device)
            )

        for item_id, emb in zip(batch['item_id'], embs):
            embedding_dict[item_id] = emb.cpu()


In [None]:
os.chdir("..")

In [None]:
torch.save(embedding_dict, "train_merged_embed.pt")