In [None]:
%pip install scikit-learn sentence-transformers torch datasets wandb einops

In [None]:
import wandb
import tqdm as notebook_tqdm
from sentence_transformers import SentenceTransformer, losses, InputExample, util
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torch
from datasets import Dataset
import os
import pandas as pd

In [None]:
data = pd.read_json("---")

In [None]:
def load_and_split(data, validation_size=0.2, random_state=42):
    """
    Loads anchor-positive pairs, splits into training and validation sets.

    Args:
        data (list): List of dictionaries with 'anchor' and 'positive' keys.
        validation_size (float): Proportion of data for validation.
        random_state (int): Random state for reproducible splitting.

    Returns:
        tuple: (train_pairs, val_pairs)
    """

    df = pd.DataFrame(data)  # Convert list of dicts to DataFrame

    train_df, val_df = train_test_split(df, test_size=validation_size, random_state=random_state)

    def create_pairs_from_dataframe(df):
        anchors = df['anchor'].tolist()
        positives = df['positive'].tolist()
        return anchors, positives

    train_pairs = create_pairs_from_dataframe(train_df)
    val_pairs = create_pairs_from_dataframe(val_df)

    return train_pairs, val_pairs

# Assuming 'data' is your list of dictionaries
train_pairs, val_pairs = load_and_split(data) #uncomment this line when you have data to load

train_anchors, train_positives = train_pairs
val_anchors, val_positives = val_pairs

print(f"Number of training pairs: {len(train_anchors)}")
print(f"Number of validation pairs: {len(val_anchors)}")

print(train_anchors[0])
print(train_positives[0])

In [None]:
wandb.login()

In [None]:
def evaluate_pairs(model, anchors, positives):
    """Evaluates anchor-positive pairs. Prints the mean cosine similarity."""
    similarities = []

    for anchor, positive in zip(anchors, positives):
        anchor_emb = model.encode(anchor)
        positive_emb = model.encode(positive)

        similarities.append(util.cos_sim(anchor_emb, positive_emb).item())

    print(f"Mean positive similarity: {sum(similarities) / len(similarities)}")

def finetune_pair_model(model_name, train_pairs, val_pairs=None,
                        epochs=1, batch_size=16, save_path='fine-tuned'):
    """
    Args:
        model_name (str): Name of the pre-trained model.
        train_pairs (tuple): Tuple of (anchors, positives) for training.
        val_pairs (tuple, optional): Tuple of (anchors, positives) for validation.
        epochs (int): Number of training epochs.
        batch_size (int): Batch size for training.
        save_path (str): Path to save the fine-tuned model.
    """

    # Determine the starting epoch
    start_epoch = 0
    if os.path.exists(save_path):
        # Check for existing checkpoints
        checkpoint_files = [f for f in os.listdir(save_path) if f.startswith('checkpoint_epoch_')]
        if checkpoint_files:
            # Extract epoch numbers and find the latest
            epoch_numbers = [int(f.split('_')[-1]) for f in checkpoint_files]
            start_epoch = max(epoch_numbers)
            print(f"Resuming training from epoch {start_epoch + 1}")

    model = SentenceTransformer(save_path if start_epoch > 0 else model_name)
    model.to('cuda')
    train_anchors, train_positives = train_pairs

    train_data = list(zip(train_anchors, train_positives))
    train_examples = [InputExample(texts=list(pair), label=1.0) for pair in train_data]

    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
    train_loss = losses.CosineSimilarityLoss(model=model)

    for epoch in range(start_epoch, epochs):
        model.fit(
            train_objectives=[(train_dataloader, train_loss)],
            epochs=1,
            warmup_steps=len(train_dataloader) // 10,
        )

        checkpoint_path = os.path.join(save_path, f'checkpoint_epoch_{epoch + 1}')
        model.save(checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    model.save(save_path)
    print(f"Fine-tuned model saved to {save_path}")

    if val_pairs:
        val_anchors, val_positives = val_pairs
        evaluate_pairs(model, val_anchors, val_positives)

In [None]:
finetune_pair_model('all-mpnet-base-v2', train_pairs, val_pairs, epochs=1, batch_size=6, save_path='mpnet_pair')