In [None]:
!pip install -U transformers

In [None]:
import json
import torch
import pandas as pd
from transformers import BertModel, BertTokenizer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [None]:
MODEL_NAME = 'bert-base-uncased'
BATCH_SIZE = 8
NUM_EPOCHS = 50
OUTPUT_DIM = 896
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
HYPERPARAMS = {
        'lr': 0.01,
        'weight_decay': 0.01,
        'betas': (0.9, 0.9)
    }

In [None]:
with open('/kaggle/input/diploma-two-vectors/training_results.json', 'r') as f:
    data = pd.DataFrame(json.load(f))
data.head(5)

In [None]:
class Model(torch.nn.Module):
    def __init__(self, model_name, output_dim, freeze_bert=True):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.regressor = torch.nn.Linear(self.bert.config.hidden_size, output_dim * 2)
        if freeze_bert:
            for name, param in self.bert.named_parameters():
                if not (('layer.9' in name) or ('layer.10' in name) or ('layer.11' in name)):
                    param.requires_grad = False

    def forward(self, input_ids, attention_mask=None):
        out = self.bert(input_ids, attention_mask)
        cls = out.last_hidden_state[:, 0, :]
        vecs = self.regressor(cls)
        e, m = torch.split(vecs, vecs.size(1) // 2, dim=1)
        return e, m        

class TextDataset(Dataset):
    def __init__(self, texts, vectors):
        self.texts = texts
        self.vectors = vectors
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return self.texts[idx], self.vectors[idx]

def collate_fn(batch, tokenizer, device):
    texts = [item[0] for item in batch]
    vectors = [torch.tensor(item[1]) for item in batch]
    input_ids = [tokenizer(text, add_special_tokens=True, return_tensors='pt')['input_ids'].reshape(-1) for text in texts]
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(device)
    attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)
    vectors = torch.stack(vectors).to(device)
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'vectors': vectors
    }

In [None]:
model = Model(MODEL_NAME, OUTPUT_DIM, freeze_bert=False).to(DEVICE)
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
text_dataset = TextDataset(list(data['instruction']), list(data['best_vectors']))
text_dataloader = DataLoader(text_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda x: collate_fn(x, tokenizer, DEVICE))

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=HYPERPARAMS['lr'], betas=HYPERPARAMS['betas'], weight_decay=HYPERPARAMS['weight_decay'])
loss_fn = torch.nn.MSELoss()
for i in range(NUM_EPOCHS):
    total_loss = 0.0    
    for batch in text_dataloader:
        optimizer.zero_grad()
        
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        e_target = batch['vectors'][:, 0, :]
        m_target = batch['vectors'][:, 1, :]
        
        e_pred, m_pred = model(input_ids, attention_mask)
        loss = loss_fn(e_pred, e_target) + loss_fn(m_pred, m_target)
        total_loss += loss.item() * BATCH_SIZE
    
        loss.backward()
        optimizer.step()

    print(f'Epoch: {i + 1}; Loss: {total_loss / 400}')