In [None]:
!pip install -U transformers

In [None]:
import json
import torch
import pandas as pd
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [None]:
MODEL_NAME = 'FacebookAI/xlm-roberta-base'
DECODER_MODEL_NAME = 'Qwen/Qwen2.5-0.5B'
BATCH_SIZE = 8
NUM_EPOCHS = 20
OUTPUT_DIM = 896
RANDOM_STATE = 42
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
HYPERPARAMS = {
        'lr': 0.0005,
        'weight_decay': 0.01,
        'betas': (0.9, 0.9)
    }

In [None]:
with open('/kaggle/input/diploma-two-vectors/training_results.json', 'r') as f:
    data = pd.DataFrame(json.load(f))
data.head(5)

In [None]:
class Model(torch.nn.Module):
    def __init__(self, model_name, output_dim, freeze_bert=True):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.e_proj = torch.nn.Linear(self.bert.config.hidden_size, output_dim)
        self.m_proj = torch.nn.Linear(self.bert.config.hidden_size, output_dim)
        self.mu = torch.nn.Linear(self.bert.config.hidden_size, 1)
        self.std = torch.nn.Linear(self.bert.config.hidden_size, 1)
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask=None):
        out = self.bert(input_ids, attention_mask)
        cls = out.last_hidden_state[:, 0, :]
        e = self.e_proj(cls)
        m = self.m_proj(cls)
        mu = self.mu(cls)
        std = self.std(cls)
        return e, m, mu, std        

class TextDataset(Dataset):
    def __init__(self, texts, vectors):
        self.texts = texts
        self.vectors = vectors
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx], self.vectors[idx]

def collate_fn(batch, tokenizer, decoder_tokenizer, device):
    texts = [item[0] for item in batch]
    vectors = [torch.tensor(item[1]) for item in batch]
    input_ids = [tokenizer(text, add_special_tokens=True, return_tensors='pt')['input_ids'].reshape(-1) for text in texts]
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
    attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
    vectors = torch.stack(vectors).to(device)
    lengths = torch.tensor([len(decoder_tokenizer(text, return_tensors='pt')['input_ids'].reshape(-1)) for text in texts]).to(device)
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'vectors': vectors,
        'lengths': lengths
    }

def info_nce_loss(e_pred, m_pred, e_target, m_target, temperature=0.07):
    e_pred = F.normalize(e_pred, dim=-1)
    m_pred = F.normalize(m_pred, dim=-1)
    e_target = F.normalize(e_target, dim=-1)
    m_target = F.normalize(m_target, dim=-1)

    e_sim = torch.matmul(e_pred, e_target.T) / temperature
    m_sim = torch.matmul(m_pred, m_target.T) / temperature

    e_positive = torch.arange(e_pred.size(0), device=e_pred.device)
    m_positive = torch.arange(m_pred.size(0), device=m_pred.device)

    e_loss = F.cross_entropy(e_sim, e_positive)
    m_loss = F.cross_entropy(m_sim, m_positive)

    loss = 0.5 * (e_loss + m_loss)
    return loss

def gaussian_loss(target, mu, std):
    return (0.5 * (std + ((target - mu) ** 2) / torch.exp(std))).mean()

In [None]:
model = Model(MODEL_NAME, OUTPUT_DIM).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
decoder_tokenizer = AutoTokenizer.from_pretrained(DECODER_MODEL_NAME)

In [None]:
X = data['instruction'].to_list()
y = data['best_vectors'].to_list()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_STATE)
train_dataset = TextDataset(X_train, y_train)
test_dataset = TextDataset(X_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                              collate_fn=lambda x: collate_fn(x, tokenizer, decoder_tokenizer, DEVICE))
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                             collate_fn=lambda x: collate_fn(x, tokenizer, decoder_tokenizer, DEVICE))

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=HYPERPARAMS['lr'], betas=HYPERPARAMS['betas'], weight_decay=HYPERPARAMS['weight_decay'])
for i in range(NUM_EPOCHS):
    train_loss = 0.0    
    for batch in train_dataloader:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        lengths = batch['lengths']
        e_target = batch['vectors'][:, 0, :]
        m_target = batch['vectors'][:, 1, :]
        
        e_pred, m_pred, mu_pred, std_pred = model(input_ids, attention_mask)
        loss = info_nce_loss(e_pred, m_pred, e_target, m_target) + gaussian_loss(lengths, mu_pred, std_pred)
        train_loss += loss.item()
    
        loss.backward()
        optimizer.step()

    val_loss = 0.0
    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            lengths = batch['lengths']
            e_target = batch['vectors'][:, 0, :]
            m_target = batch['vectors'][:, 1, :]

            e_pred, m_pred, mu_pred, std_pred = model(input_ids, attention_mask)
            loss = info_nce_loss(e_pred, m_pred, e_target, m_target) + gaussian_loss(lengths, mu_pred, std_pred)
            val_loss += loss.item()      
        
    print(f'Epoch: {i + 1}; Train Loss: {train_loss}; Eval Loss: {val_loss}')