In [None]:
from sentence_transformers import SentenceTransformer, InputExample, losses, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from torch.utils.data import DataLoader
import torch
import pandas as pd
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from torch.utils.data import DataLoader
from datasets import Dataset
import random
import pickle
import numpy as np  
import os
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
import torch
import logging
import os
import gc
import torcg
from tqdm import tqdm
from matplotlib import pyplot as plt
from IPython.display import clear_output
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

In [None]:
def run_with_params(params):
    model = SentenceTransformer('e5_large/', device='cuda:0')
    soft_negatives = pickle.load(open(params['SOFT_PATH'], 'rb'))
    hard_negatives = pickle.load(open(params['HARD_PATH'], 'rb'))
    if params['ALL_DATA']:
        all_data = hard_negatives + soft_negatives
    else:
        all_data = hard_negatives
    train_data = all_data
    random.shuffle(train_data)
    train_dataset = Dataset.from_dict({
        'anchor': [data[0] for data in train_data],
        'positive': [data[1] for data in train_data],
        'negative': [data[2] for data in train_data],
    })
    train_loss = losses.TripletLoss(
        model=model,
        distance_metric=losses.SiameseDistanceMetric.COSINE_DISTANCE,
        triplet_margin = params['MARGINE']
    )
    train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

    def prepare_batch_for_loss(batch):
        anchors, positives, negatives = batch['anchor'], batch['positive'], batch['negative']
        features = []
        for texts in [anchors, positives, negatives]:
            tokenized = model.tokenize(texts)
            tokenized = {key: value.to(model.device) for key, value in tokenized.items()}
            features.append({'input_ids': tokenized['input_ids'], 'attention_mask': tokenized['attention_mask']})
        return features
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

    num_epochs = params.get('NUM_EPOCHS', 1)
    log_interval = 10

    data_perfix = params['HARD_PATH'].split('/')[-1].split('.')[0]
    model_save_path = f'./finetuned_model_{data_perfix}_{params['MARGINE']}_{params['ALL_DATA']}'
    os.makedirs(model_save_path, exist_ok=True)

    epochs_losses = []
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        cur_losses = []
        for step, batch in enumerate(train_dataloader):
            loss_value = train_loss(prepare_batch_for_loss(batch), None)
            loss_value.backward()
            optimizer.step()
            optimizer.zero_grad()
            current_loss = loss_value.item()
            del loss_value
            cur_losses.append(current_loss)
            if step % log_interval == log_interval - 1:
                clear_output()
                print('at_epoch:', epoch, '\nat_step:', step, '/', len(train_dataloader), '\nloss:', current_loss)
                epochs_losses.append(np.mean(current_loss))
                cur_losses = []
                plt.plot(epochs_losses)  
                plt.show()
        model_save_path = f'./finetuned_model_{data_perfix}_{params['MARGINE']}_{params['ALL_DATA']}/epoch_{epoch}'     
        os.makedirs(model_save_path, exist_ok=True)
        model.save(model_save_path)
    del model
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
params = [{
    "SOFT_PATH": 'data/soft_negatives_single_83_001_1_75.pkl',
    "HARD_PATH": 'data/hard_negatives_single_83_001_1_75.pkl',
    "ALL_DATA": True,
    "MARGINE": 0.3,
    "NUM_EPOCHS": 1
},{
    "SOFT_PATH": 'data/soft_negatives_single_83_001_1_75.pkl',
    "HARD_PATH": 'data/hard_negatives_single_83_001_1_75.pkl',
    "ALL_DATA": False,
    "MARGINE": 0.3,
    "NUM_EPOCHS": 1
},{
    "SOFT_PATH": 'data/soft_negatives_single_83_001_1_75.pkl',
    "HARD_PATH": 'data/hard_negatives_single_83_001_1_75.pkl',
    "ALL_DATA": True,
    "MARGINE": 0.7,
    "NUM_EPOCHS": 1
},{
    "SOFT_PATH": 'data/soft_negatives_single_83_001_1_75.pkl',
    "HARD_PATH": 'data/hard_negatives_single_83_001_1_75.pkl',
    "ALL_DATA": False,
    "MARGINE": 0.7,
    "NUM_EPOCHS": 1
},{
    "SOFT_PATH": 'data/soft_negatives_single_83_001_1_75.pkl',
    "HARD_PATH": 'data/hard_negatives_single_83_001_1_75.pkl',
    "ALL_DATA": True,
    "MARGINE": 0.9,
    "NUM_EPOCHS": 1
},{
    "SOFT_PATH": 'data/soft_negatives_single_83_001_1_75.pkl',
    "HARD_PATH": 'data/hard_negatives_single_83_001_1_75.pkl',
    "ALL_DATA": False,
    "MARGINE": 0.9,
    "NUM_EPOCHS": 1
}]