In [None]:
%pip install scikit-learn sentence-transformers torch datasets wandb einops

In [None]:
import wandb
import tqdm as notebook_tqdm
from sentence_transformers import SentenceTransformer, losses, InputExample, util
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import torch
from datasets import Dataset
import os
import pandas as pd

In [None]:
data = pd.read_json("---")

In [None]:
def load_and_split(data, validation_size=0.2, random_state=42):
    """
    Loads triplet data from JSON files using pandas, splits it into training and validation sets.

    Args:
        json_file_paths (list): List of paths to JSON files.
        validation_size (float): Proportion of data to use for validation.
        random_state (int): Random state for reproducible splitting.

    Returns:
        tuple: (train_triplets, val_triplets)
    """

    combined_df = pd.concat([data], ignore_index=True)

    train_df, val_df = train_test_split(combined_df, test_size=validation_size, random_state=random_state)

    def create_triplets_from_dataframe(df):
        anchors = df['anchor'].tolist()
        positives = df['positive'].tolist()
        negatives = df['negative'].tolist()
        return anchors, positives, negatives

    train_triplets = create_triplets_from_dataframe(train_df)
    val_triplets = create_triplets_from_dataframe(val_df)

    return train_triplets, val_triplets

train_triplets, val_triplets = load_and_split(data)

train_anchors, train_positives, train_negatives = train_triplets
val_anchors, val_positives, val_negatives = val_triplets

print(f"Number of training triplets: {len(train_anchors)}")
print(f"Number of validation triplets: {len(val_anchors)}")

print(train_anchors[0])
print(train_positives[0])
print(train_negatives[0])

In [None]:
wandb.login()

In [None]:
def evaluate_triplets(model, anchors, positives, negatives):
    """Evaluates triplet data. Prints the mean cosine similarity of the positive and negative pairs."""
    positive_similarities = []
    negative_similarities = []

    for anchor, positive, negative in zip(anchors, positives, negatives):
        anchor_emb = model.encode(anchor)
        positive_emb = model.encode(positive)
        negative_emb = model.encode(negative)

        positive_similarities.append(util.cos_sim(anchor_emb, positive_emb).item())
        negative_similarities.append(util.cos_sim(anchor_emb, negative_emb).item())

    print(f"Mean positive similarity: {sum(positive_similarities) / len(positive_similarities)}")
    print(f"Mean negative similarity: {sum(negative_similarities) / len(negative_similarities)}")

def finetune_triplet_model(model_name, train_triplets, val_triplets=None,
                           epochs=1, batch_size=16, save_path='fine-tuned'):
    """
    Args:
        model_name (str): Name of the pre-trained model.
        train_triplets (tuple): Tuple of (anchors, positives, negatives) for training.
        val_triplets (tuple, optional): Tuple of (anchors, positives, negatives) for validation.
        epochs (int): Number of training epochs.
        batch_size (int): Batch size for training.
        save_path (str): Path to save the fine-tuned model.
    """

    # Determine the starting epoch
    start_epoch = 0
    if os.path.exists(save_path):
        # Check for existing checkpoints
        checkpoint_files = [f for f in os.listdir(save_path) if f.startswith('checkpoint_epoch_')]
        if checkpoint_files:
            # Extract epoch numbers and find the latest
            epoch_numbers = [int(f.split('_')[-1]) for f in checkpoint_files]
            start_epoch = max(epoch_numbers)
            print(f"Resuming training from epoch {start_epoch + 1}")
    
    model = SentenceTransformer(save_path if start_epoch > 0 else model_name)
    model.to('cuda')
    train_anchors, train_positives, train_negatives = train_triplets

    train_data = list(zip(train_anchors, train_positives, train_negatives))
    train_examples = [InputExample(texts=list(triplet)) for triplet in train_data]

    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
    train_loss = losses.TripletLoss(model=model)
    
    for epoch in range(start_epoch, epochs):
        model.fit(
            train_objectives=[(train_dataloader, train_loss)],
            epochs=1,  # Train for only 1 epoch at a time
            warmup_steps=len(train_dataloader) // 10,  # 10% of data for warm-up
        )

        # Save checkpoint after each epoch
        checkpoint_path = os.path.join(save_path, f'checkpoint_epoch_{epoch + 1}')
        model.save(checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    # Save the final model
    model.save(save_path)
    print(f"Fine-tuned model saved to {save_path}")

    if val_triplets:
        val_anchors, val_positives, val_negatives = val_triplets
        evaluate_triplets(model, val_anchors, val_positives, val_negatives)

In [None]:
finetune_triplet_model('sentence-transformers/all-MiniLM-L6-v2', train_triplets, val_triplets, epochs=1, batch_size=4, save_path='miniLM_triplets')

In [None]:
evaluate_triplets(SentenceTransformer('mpnet_triplets'), val_anchors, val_positives, val_negatives)

In [None]:
evaluate_triplets(SentenceTransformer('mpnet_triplets'), val_anchors, val_positives, val_negatives)