In [None]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torchmetrics.functional import pairwise_cosine_similarity
from datasets import Dataset
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.trainer import SentenceTransformerTrainer, SentenceTransformerTrainingArguments

In [None]:
def sample_triples(train_data, label_column, id_column, topk=10):
    
    # get X and y data (in aligned order)
    X = list(train_data['embeddings'])
    y = list(train_data[label_column])
    ids = list(train_data[id_column])
    assert set(y) == {0, 1}

    # store embeddings of positive and negative examples and their ids
    pos_vecs, neg_vecs, pos_data_ids, neg_data_ids = [], [], [], []
    for i, y in enumerate(y):
        if y == 1:
            pos_vecs.append(X[i])
            pos_data_ids.append(ids[i])
        elif y == 0:
            neg_vecs.append(X[i])
            neg_data_ids.append(ids[i])
    
    # compute pairwise cosine similarity matrix between positive and negative embeddings
    posneg_pairwise_sim = pairwise_cosine_similarity(torch.stack(pos_vecs), torch.stack(neg_vecs))
    # map data ids to ranked ids of negative examples (based on cosine similarity)
    posid2sortednegids = {pos_data_ids[p_id]: [neg_data_ids[x] for x in np.argsort(posneg_pairwise_sim[p_id])] for p_id in range(len(pos_vecs))}
    # compute pairwise cosine similarity matrix between positive embeddings
    pospos_pairwise_sim = pairwise_cosine_similarity(torch.stack(pos_vecs), torch.stack(pos_vecs))
    # map data ids to ranked ids of positive examples (based on cosine similarity)
    posid2sortedposids = {pos_data_ids[p_id]: [pos_data_ids[x] for x in np.argsort(pospos_pairwise_sim[p_id])] for p_id in range(len(pos_vecs))}

    # collect a set of triples for every positive embedding as anchor embedding
    sample_size = min([topk, len(pos_vecs), len(neg_vecs)])
    triples = []
    for anchor_id in pos_data_ids:
        # triple positive: top k positive embeddings that are most similar to anchor embedding
        for i in range(sample_size):
            pos_id = posid2sortedposids[anchor_id][-(i+1)]
            if anchor_id != pos_id:
                for j in range(sample_size):
                    # triple negative: top k negative embeddings that are most similar to anchor embedding
                    neg_id = posid2sortednegids[anchor_id][-(j+1)]
                    triples.append((anchor_id, pos_id, neg_id))

    # returns list of triples as tuples of data ids
    print('Resulting number of triples:', len(triples))
    return triples

In [None]:
def train_contrastive_model(train_data, model_dir, pretrained_model_name, batch_size=16, epochs=3, triplet_margin=1):
    
    train_dataset = Dataset.from_dict({"anchor": [t[0] for t in train_data], "positive": [t[1] for t in train_data], "negative": [t[2] for t in train_data]})
    model = SentenceTransformer(pretrained_model_name).cpu() # if device is mps, because that doesn't work
    train_loss = losses.TripletLoss(model=model, triplet_margin=triplet_margin)
    args = SentenceTransformerTrainingArguments(output_dir=model_dir, per_device_train_batch_size=batch_size, num_train_epochs=epochs)
    trainer = SentenceTransformerTrainer(model=model, args=args, train_dataset=train_dataset, loss=train_loss) #evaluator=evaluator
    trainer.train()
    model.save_pretrained(model_dir)

In [None]:
# info on datafile and pre-trained model
input_path = 'HateWiC_T5Defs_MajorityLabels.csv'
id_column = 'id'
sentence_column = 'T5generated_definition'
label_column = 'majority_binary_annotation'

pretrained_model_name = 'sentence-transformers/all-mpnet-base-v2'
trained_model_dir = 'CL-model/'

In [None]:
# load data
data = pd.read_csv(input_path, sep=';')
model = SentenceTransformer(pretrained_model_name).cpu() # device='mps' gives error

print('Encoding sentences with Sentence Transformer...')
data['embeddings'] = list(model.encode(data[sentence_column], convert_to_tensor=True, show_progress_bar=True))

train_data, dev_test_data = train_test_split(data, train_size=0.8, random_state=12)
dev_data, test_data = train_test_split(dev_test_data, train_size=0.5, random_state=12)
#print(train_data.head())

In [None]:
# sample training triples and train model with contrastive learning
id_triples = sample_triples(train_data, label_column, id_column)
id2sentence = {data_id: sent.lower() for data_id, sent in zip(data[id_column], data[sentence_column])}
sentence_triples = [[id2sentence[id1], id2sentence[id2], id2sentence[id3]] for (id1, id2, id3) in id_triples]
train_contrastive_model(sentence_triples, trained_model_dir, pretrained_model_name)