In [1]:
import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torchmetrics.functional import pairwise_cosine_similarity
from datasets import Dataset
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.trainer import SentenceTransformerTrainer, SentenceTransformerTrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def sample_triples(train_data, label_column, id_column, topk=10):
    
    # get X and y data (in aligned order)
    X = list(train_data['embeddings'])
    y = list(train_data[label_column])
    ids = list(train_data[id_column])
    assert set(y) == {0, 1}

    # store embeddings of positive and negative examples and their ids
    pos_vecs, neg_vecs, pos_data_ids, neg_data_ids = [], [], [], []
    for i, y in enumerate(y):
        if y == 1:
            pos_vecs.append(X[i])
            pos_data_ids.append(ids[i])
        elif y == 0:
            neg_vecs.append(X[i])
            neg_data_ids.append(ids[i])
    
    # compute pairwise cosine similarity matrix between positive and negative embeddings
    posneg_pairwise_sim = pairwise_cosine_similarity(torch.stack(pos_vecs), torch.stack(neg_vecs))
    # map data ids to ranked ids of negative examples (based on cosine similarity)
    posid2sortednegids = {pos_data_ids[p_id]: [neg_data_ids[x] for x in np.argsort(posneg_pairwise_sim[p_id])] for p_id in range(len(pos_vecs))}
    # compute pairwise cosine similarity matrix between positive embeddings
    pospos_pairwise_sim = pairwise_cosine_similarity(torch.stack(pos_vecs), torch.stack(pos_vecs))
    # map data ids to ranked ids of positive examples (based on cosine similarity)
    posid2sortedposids = {pos_data_ids[p_id]: [pos_data_ids[x] for x in np.argsort(pospos_pairwise_sim[p_id])] for p_id in range(len(pos_vecs))}

    # collect a set of triples for every positive embedding as anchor embedding
    sample_size = min([topk, len(pos_vecs), len(neg_vecs)])
    triples = []
    for anchor_id in pos_data_ids:
        # triple positive: top k positive embeddings that are most similar to anchor embedding
        for i in range(sample_size):
            pos_id = posid2sortedposids[anchor_id][-(i+1)]
            if anchor_id != pos_id:
                for j in range(sample_size):
                    # triple negative: top k negative embeddings that are most similar to anchor embedding
                    neg_id = posid2sortednegids[anchor_id][-(j+1)]
                    triples.append((anchor_id, pos_id, neg_id))

    # returns list of triples as tuples of data ids
    print('Resulting number of triples:', len(triples))
    return triples

In [3]:
def train_contrastive_model(train_data, model_dir, pretrained_model_name, batch_size=16, epochs=3, triplet_margin=1):
    
    train_dataset = Dataset.from_dict({"anchor": [t[0] for t in train_data], "positive": [t[1] for t in train_data], "negative": [t[2] for t in train_data]})
    model = SentenceTransformer(pretrained_model_name).cpu() # if device is mps, because that doesn't work
    train_loss = losses.TripletLoss(model=model, triplet_margin=triplet_margin)
    args = SentenceTransformerTrainingArguments(output_dir=model_dir, per_device_train_batch_size=batch_size, num_train_epochs=epochs)
    trainer = SentenceTransformerTrainer(model=model, args=args, train_dataset=train_dataset, loss=train_loss) #evaluator=evaluator
    trainer.train()
    model.save_pretrained(model_dir)

In [4]:
# info on datafile and pre-trained model
input_path = 'HateWiC_T5Defs_MajorityLabels.csv'
id_column = 'id'
sentence_column = 'T5generated_definition'
label_column = 'majority_binary_annotation'

pretrained_model_name = 'sentence-transformers/all-mpnet-base-v2'
trained_model_dir = 'CL-model/'

In [5]:
# load data
data = pd.read_csv(input_path, sep=';')
model = SentenceTransformer(pretrained_model_name).cpu() # device='mps' gives error

print('Encoding sentences with Sentence Transformer...')
data['embeddings'] = list(model.encode(data[sentence_column], convert_to_tensor=True, show_progress_bar=True))

train_data, dev_test_data = train_test_split(data, train_size=0.8, random_state=12)
dev_data, test_data = train_test_split(dev_test_data, train_size=0.5, random_state=12)
#print(train_data.head())



Encoding sentences with Sentence Transformer...


Batches: 100%|██████████| 121/121 [00:18<00:00,  6.47it/s]


In [6]:
# sample training triples and train model with contrastive learning
id_triples = sample_triples(train_data, label_column, id_column)
id2sentence = {data_id: sent.lower() for data_id, sent in zip(data[id_column], data[sentence_column])}
sentence_triples = [[id2sentence[id1], id2sentence[id2], id2sentence[id3]] for (id1, id2, id3) in id_triples]
train_contrastive_model(sentence_triples, trained_model_dir, pretrained_model_name)

Resulting number of triples: 130910


  2%|▏         | 500/24546 [03:16<3:00:56,  2.21it/s]

{'loss': 0.6612, 'grad_norm': 5.451143741607666, 'learning_rate': 4.898150411472338e-05, 'epoch': 0.06}


  4%|▍         | 1000/24546 [06:35<2:23:48,  2.73it/s]

{'loss': 0.4004, 'grad_norm': 3.5781755447387695, 'learning_rate': 4.7963008229446755e-05, 'epoch': 0.12}


  6%|▌         | 1500/24546 [09:56<2:23:45,  2.67it/s]

{'loss': 0.3047, 'grad_norm': 3.08046817779541, 'learning_rate': 4.694451234417013e-05, 'epoch': 0.18}


  8%|▊         | 2000/24546 [13:17<2:21:05,  2.66it/s] 

{'loss': 0.2727, 'grad_norm': 2.0618505477905273, 'learning_rate': 4.592601645889351e-05, 'epoch': 0.24}


 10%|█         | 2500/24546 [16:38<2:20:04,  2.62it/s] 

{'loss': 0.241, 'grad_norm': 7.711935043334961, 'learning_rate': 4.490752057361688e-05, 'epoch': 0.31}


 12%|█▏        | 3000/24546 [20:02<2:20:06,  2.56it/s] 

{'loss': 0.2497, 'grad_norm': 2.2996647357940674, 'learning_rate': 4.3889024688340266e-05, 'epoch': 0.37}


 14%|█▍        | 3500/24546 [23:25<2:19:16,  2.52it/s] 

{'loss': 0.217, 'grad_norm': 0.5952454209327698, 'learning_rate': 4.2870528803063635e-05, 'epoch': 0.43}


 16%|█▋        | 4000/24546 [26:49<2:18:54,  2.47it/s] 

{'loss': 0.2225, 'grad_norm': 1.1201153993606567, 'learning_rate': 4.185203291778702e-05, 'epoch': 0.49}


 18%|█▊        | 4500/24546 [30:10<2:11:17,  2.54it/s] 

{'loss': 0.206, 'grad_norm': 5.716087818145752, 'learning_rate': 4.083353703251039e-05, 'epoch': 0.55}


 20%|██        | 5000/24546 [33:38<2:05:37,  2.59it/s] 

{'loss': 0.2003, 'grad_norm': 0.33098939061164856, 'learning_rate': 3.981504114723377e-05, 'epoch': 0.61}


 22%|██▏       | 5500/24546 [37:09<2:13:47,  2.37it/s] 

{'loss': 0.2056, 'grad_norm': 0.17984157800674438, 'learning_rate': 3.8796545261957146e-05, 'epoch': 0.67}


 24%|██▍       | 6000/24546 [40:43<2:08:14,  2.41it/s] 

{'loss': 0.1996, 'grad_norm': 0.31419193744659424, 'learning_rate': 3.7778049376680516e-05, 'epoch': 0.73}


 26%|██▋       | 6500/24546 [44:11<2:00:42,  2.49it/s] 

{'loss': 0.203, 'grad_norm': 12.842952728271484, 'learning_rate': 3.67595534914039e-05, 'epoch': 0.79}


 29%|██▊       | 7000/24546 [47:45<1:47:52,  2.71it/s] 

{'loss': 0.2022, 'grad_norm': 0.33287352323532104, 'learning_rate': 3.574105760612727e-05, 'epoch': 0.86}


 31%|███       | 7500/24546 [51:19<1:48:29,  2.62it/s]

{'loss': 0.1868, 'grad_norm': 0.4142511487007141, 'learning_rate': 3.472256172085065e-05, 'epoch': 0.92}


 33%|███▎      | 8000/24546 [54:46<1:37:38,  2.82it/s]

{'loss': 0.1887, 'grad_norm': 9.452204704284668, 'learning_rate': 3.370406583557403e-05, 'epoch': 0.98}


 35%|███▍      | 8500/24546 [58:21<1:44:31,  2.56it/s]

{'loss': 0.1852, 'grad_norm': 0.21046246588230133, 'learning_rate': 3.26855699502974e-05, 'epoch': 1.04}


 37%|███▋      | 9000/24546 [1:01:47<1:50:59,  2.33it/s]

{'loss': 0.1871, 'grad_norm': 0.4188470244407654, 'learning_rate': 3.166707406502078e-05, 'epoch': 1.1}


 39%|███▊      | 9500/24546 [1:05:04<1:28:40,  2.83it/s]

{'loss': 0.1845, 'grad_norm': 0.3806311786174774, 'learning_rate': 3.0648578179744155e-05, 'epoch': 1.16}


 41%|████      | 10000/24546 [1:08:25<1:47:18,  2.26it/s]

{'loss': 0.1883, 'grad_norm': 1.6476167440414429, 'learning_rate': 2.963008229446753e-05, 'epoch': 1.22}


 43%|████▎     | 10500/24546 [1:48:03<1:23:50,  2.79it/s]    

{'loss': 0.1912, 'grad_norm': 0.5160092711448669, 'learning_rate': 2.861158640919091e-05, 'epoch': 1.28}


 45%|████▍     | 11000/24546 [1:51:16<1:22:44,  2.73it/s]

{'loss': 0.1717, 'grad_norm': 0.2540348768234253, 'learning_rate': 2.7593090523914284e-05, 'epoch': 1.34}


 47%|████▋     | 11500/24546 [1:54:30<1:24:18,  2.58it/s]

{'loss': 0.1676, 'grad_norm': 0.36715543270111084, 'learning_rate': 2.6574594638637663e-05, 'epoch': 1.41}


 49%|████▉     | 12000/24546 [1:57:45<1:19:42,  2.62it/s]

{'loss': 0.1687, 'grad_norm': 0.22027261555194855, 'learning_rate': 2.5556098753361036e-05, 'epoch': 1.47}


 51%|█████     | 12500/24546 [2:00:57<1:13:10,  2.74it/s]

{'loss': 0.1742, 'grad_norm': 29.696685791015625, 'learning_rate': 2.4537602868084415e-05, 'epoch': 1.53}


 53%|█████▎    | 13000/24546 [2:04:09<1:11:39,  2.69it/s]

{'loss': 0.1588, 'grad_norm': 0.34169772267341614, 'learning_rate': 2.3519106982807788e-05, 'epoch': 1.59}


 55%|█████▍    | 13500/24546 [2:07:19<1:05:16,  2.82it/s]

{'loss': 0.1587, 'grad_norm': 0.5442103743553162, 'learning_rate': 2.2500611097531164e-05, 'epoch': 1.65}


 57%|█████▋    | 14000/24546 [2:10:30<1:07:06,  2.62it/s]

{'loss': 0.1812, 'grad_norm': 0.38373732566833496, 'learning_rate': 2.1482115212254544e-05, 'epoch': 1.71}


 59%|█████▉    | 14500/24546 [2:13:43<1:01:33,  2.72it/s]

{'loss': 0.17, 'grad_norm': 0.25167903304100037, 'learning_rate': 2.046361932697792e-05, 'epoch': 1.77}


 61%|██████    | 15000/24546 [2:16:50<58:09,  2.74it/s]  

{'loss': 0.1668, 'grad_norm': 0.3091393709182739, 'learning_rate': 1.9445123441701296e-05, 'epoch': 1.83}


 63%|██████▎   | 15500/24546 [2:20:00<58:36,  2.57it/s]  

{'loss': 0.1686, 'grad_norm': 0.12956099212169647, 'learning_rate': 1.8426627556424672e-05, 'epoch': 1.89}


 65%|██████▌   | 16000/24546 [2:23:09<50:27,  2.82it/s]  

{'loss': 0.1676, 'grad_norm': 0.23934434354305267, 'learning_rate': 1.740813167114805e-05, 'epoch': 1.96}


 67%|██████▋   | 16500/24546 [2:26:23<49:28,  2.71it/s]  

{'loss': 0.1675, 'grad_norm': 0.36733344197273254, 'learning_rate': 1.6389635785871428e-05, 'epoch': 2.02}


 69%|██████▉   | 17000/24546 [2:29:33<48:16,  2.61it/s]  

{'loss': 0.161, 'grad_norm': 0.0974527895450592, 'learning_rate': 1.5371139900594804e-05, 'epoch': 2.08}


 71%|███████▏  | 17500/24546 [2:32:42<41:44,  2.81it/s]  

{'loss': 0.1676, 'grad_norm': 0.2582019865512848, 'learning_rate': 1.4352644015318178e-05, 'epoch': 2.14}


 73%|███████▎  | 18000/24546 [2:35:52<40:31,  2.69it/s]  

{'loss': 0.162, 'grad_norm': 0.3718836009502411, 'learning_rate': 1.3334148130041554e-05, 'epoch': 2.2}


 75%|███████▌  | 18500/24546 [2:39:04<38:45,  2.60it/s]  

{'loss': 0.1682, 'grad_norm': 0.5241735577583313, 'learning_rate': 1.2315652244764932e-05, 'epoch': 2.26}


 77%|███████▋  | 19000/24546 [2:42:12<33:40,  2.74it/s]  

{'loss': 0.1674, 'grad_norm': 0.23180615901947021, 'learning_rate': 1.1297156359488308e-05, 'epoch': 2.32}


 79%|███████▉  | 19500/24546 [2:45:21<30:30,  2.76it/s]  

{'loss': 0.156, 'grad_norm': 9.114452362060547, 'learning_rate': 1.0278660474211684e-05, 'epoch': 2.38}


 81%|████████▏ | 20000/24546 [2:48:30<30:18,  2.50it/s]  

{'loss': 0.1559, 'grad_norm': 0.9862040877342224, 'learning_rate': 9.26016458893506e-06, 'epoch': 2.44}


 84%|████████▎ | 20500/24546 [2:51:42<24:32,  2.75it/s]  

{'loss': 0.1523, 'grad_norm': 0.5887149572372437, 'learning_rate': 8.241668703658438e-06, 'epoch': 2.51}


 86%|████████▌ | 21000/24546 [2:54:52<23:23,  2.53it/s]  

{'loss': 0.1567, 'grad_norm': 0.19548410177230835, 'learning_rate': 7.223172818381814e-06, 'epoch': 2.57}


 88%|████████▊ | 21500/24546 [2:58:04<19:46,  2.57it/s]  

{'loss': 0.1516, 'grad_norm': 0.5045796632766724, 'learning_rate': 6.204676933105191e-06, 'epoch': 2.63}


 90%|████████▉ | 22000/24546 [3:01:13<15:42,  2.70it/s]  

{'loss': 0.1513, 'grad_norm': 0.2596152722835541, 'learning_rate': 5.186181047828568e-06, 'epoch': 2.69}


 92%|█████████▏| 22500/24546 [3:04:25<12:27,  2.74it/s]  

{'loss': 0.1551, 'grad_norm': 0.21571578085422516, 'learning_rate': 4.167685162551944e-06, 'epoch': 2.75}


 94%|█████████▎| 23000/24546 [3:07:36<09:18,  2.77it/s]  

{'loss': 0.1582, 'grad_norm': 0.30534952878952026, 'learning_rate': 3.14918927727532e-06, 'epoch': 2.81}


 96%|█████████▌| 23500/24546 [3:10:47<06:15,  2.79it/s]

{'loss': 0.1542, 'grad_norm': 0.2591329514980316, 'learning_rate': 2.1306933919986964e-06, 'epoch': 2.87}


 98%|█████████▊| 24000/24546 [3:13:56<03:19,  2.74it/s]

{'loss': 0.1495, 'grad_norm': 0.40440112352371216, 'learning_rate': 1.112197506722073e-06, 'epoch': 2.93}


100%|█████████▉| 24500/24546 [3:17:06<00:16,  2.81it/s]

{'loss': 0.1517, 'grad_norm': 0.4859429597854614, 'learning_rate': 9.370162144544936e-08, 'epoch': 2.99}


100%|██████████| 24546/24546 [3:17:30<00:00,  2.07it/s]


{'train_runtime': 11850.3092, 'train_samples_per_second': 33.141, 'train_steps_per_second': 2.071, 'train_loss': 0.19657148862636867, 'epoch': 3.0}
