In [1]:
import numpy as np
import pandas as pd

from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    pairwise,
)
from torch.utils.data import DataLoader

from datasets import load_dataset
from sentence_transformers import (
    InputExample,
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import BinaryClassificationEvaluator, TripletEvaluator
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

from utils.data_preprocessing import preprocess_data, split_data
from models.topic_matcher import TopicMatcher
from models.comment_classifier import OpinionClassifier
from models.summarizer import Summarizer

%load_ext autoreload
%autoreload 2

  from tqdm.autonotebook import tqdm, trange


In [2]:
topics, opinions, conclusions = preprocess_data()
train_data, test_data = split_data(opinions)

In [3]:
topic_matcher = TopicMatcher()
train_dataset = topic_matcher.prepare_data(topics, train_data)
test_dataset = topic_matcher.prepare_data(topics, test_data)



In [4]:
# test_evaluator = TripletEvaluator(
#     anchors=test_dataset["anchor"],
#     positives=test_dataset["positive"],
#     negatives=test_dataset["negative"],
# )
test_evaluator = BinaryClassificationEvaluator(
    sentences1=test_dataset["text1"],
    sentences2=test_dataset["text2"],
    labels=test_dataset["label"]
)
test_evaluator(topic_matcher.model)

{'cosine_accuracy': np.float64(0.8205318951087968),
 'cosine_accuracy_threshold': np.float32(0.25139463),
 'cosine_f1': np.float64(0.8251797922400782),
 'cosine_f1_threshold': np.float32(0.20716247),
 'cosine_precision': 0.7895005096839959,
 'cosine_recall': np.float64(0.864236563139297),
 'cosine_ap': np.float64(0.8827102750892164),
 'dot_accuracy': np.float64(0.8205318951087968),
 'dot_accuracy_threshold': np.float32(0.25139463),
 'dot_f1': np.float64(0.8251797922400782),
 'dot_f1_threshold': np.float32(0.20716244),
 'dot_precision': 0.7895005096839959,
 'dot_recall': np.float64(0.864236563139297),
 'dot_ap': np.float64(0.8827102869788535),
 'manhattan_accuracy': np.float64(0.8186721220011158),
 'manhattan_accuracy_threshold': np.float32(18.940704),
 'manhattan_f1': np.float64(0.8243171806167401),
 'manhattan_f1_threshold': np.float32(19.638052),
 'manhattan_precision': 0.783191026284949,
 'manhattan_recall': np.float64(0.8700018597731077),
 'manhattan_ap': np.float64(0.8823113250813

In [5]:
# train_loss = losses.MultipleNegativesRankingLoss(topic_matcher.model)
train_loss = losses.ContrastiveLoss(topic_matcher.model)

In [6]:
training_args = SentenceTransformerTrainingArguments(
    output_dir="./sentence_transformer_output",
    num_train_epochs=1,
    per_device_train_batch_size=32,
    warmup_steps=100,
    eval_strategy="no",
    save_strategy="no",
    learning_rate=4e-5,
)

trainer = SentenceTransformerTrainer(
    model=topic_matcher.model,
    args=training_args,
    train_dataset=train_dataset,
    loss=train_loss
    
)

trainer.train()

Step,Training Loss
500,0.0197
1000,0.012


TrainOutput(global_step=1336, training_loss=0.014717698097229004, metrics={'train_runtime': 280.2448, 'train_samples_per_second': 152.481, 'train_steps_per_second': 4.767, 'total_flos': 0.0, 'train_loss': 0.014717698097229004, 'epoch': 1.0})

In [7]:
test_evaluator = BinaryClassificationEvaluator(
    sentences1=test_dataset["text1"],
    sentences2=test_dataset["text2"],
    labels=test_dataset["label"]
)
test_evaluator(topic_matcher.model)

{'cosine_accuracy': np.float64(0.8908313185791333),
 'cosine_accuracy_threshold': np.float32(0.7126949),
 'cosine_f1': np.float64(0.8956145768993204),
 'cosine_f1_threshold': np.float32(0.6957381),
 'cosine_precision': 0.8520819341840161,
 'cosine_recall': np.float64(0.943834852148038),
 'cosine_ap': np.float64(0.9197200674022441),
 'dot_accuracy': np.float64(0.8908313185791333),
 'dot_accuracy_threshold': np.float32(0.7126949),
 'dot_f1': np.float64(0.8956145768993204),
 'dot_f1_threshold': np.float32(0.69573814),
 'dot_precision': 0.8520819341840161,
 'dot_recall': np.float64(0.943834852148038),
 'dot_ap': np.float64(0.9197200467455794),
 'manhattan_accuracy': np.float64(0.8911102845452855),
 'manhattan_accuracy_threshold': np.float32(11.935891),
 'manhattan_f1': np.float64(0.8962155455109456),
 'manhattan_f1_threshold': np.float32(11.935891),
 'manhattan_precision': 0.8560785641720284,
 'manhattan_recall': np.float64(0.9403012832434443),
 'manhattan_ap': np.float64(0.919456180525393

In [8]:
topic_matcher.model.save_pretrained("saved_models/grouping/binary")

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [14]:
test_evaluator = TripletEvaluator(
    anchors=test_dataset["anchor"],
    positives=test_dataset["positive"],
    negatives=test_dataset["negative"],
)
test_evaluator(topic_matcher.model)

{'cosine_accuracy': 0.9443927840803422,
 'dot_accuracy': 0.0556072159196578,
 'manhattan_accuracy': 0.941975079040357,
 'euclidean_accuracy': 0.9443927840803422,
 'max_accuracy': 0.9443927840803422}