In [None]:
import sys
sys.path.insert(0, "../lib")

from STSDataReaderBinary import STSDataReaderBinary
from STSDataReaderBinaryPositives import STSDataReaderBinaryPositives
from BSCLoss import BSCLoss, ComboBSCLoss
from BSCShuffler import ShuffledSentencesDataset, ShuffledSentenceTransformer
from BSCShuffler import BSCShuffler, ModelBSCShuffler, ModelExampleBasedShuffler

from torch.utils.data import DataLoader
import math
import os
from sentence_transformers import models, losses
from sentence_transformers import SentencesDataset, LoggingHandler, SentenceTransformer, evaluation
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, TripletEvaluator, SimilarityFunction
from evaluator import BinaryClassificationEvaluator
from sentence_transformers.readers import *
import pandas as pd
import logging
import csv

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

num_runs = 5

In [None]:
with open('intervals_estimates/metrics-MRPC.txt', 'a+') as f:
    f.write('\n')
    f.write('combo, shuffled 4-same, 2e-5, bias False, norm 1, tau_lr 0.1, mu 0.1, 6 epochs\n')
for _ in range(num_runs):
    word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=90)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                   pooling_mode_mean_tokens=True,
                                   pooling_mode_cls_token=False,
                                   pooling_mode_max_tokens=False)

    model = ShuffledSentenceTransformer(modules=[word_embedding_model, pooling_model])

    sts_reader_pos = STSDataReaderBinaryPositives('datasets/MRPC', quoting=csv.QUOTE_NONE, 
                                                  s1_col_idx=3, s2_col_idx=4, score_col_idx=0,
                                                  normalize_scores=False, thr=0.6, get_positives=False)

    sts_reader = STSDataReader('datasets/MRPC', s1_col_idx=3, s2_col_idx=4, score_col_idx=0, normalize_scores=False)

    train_batch_size = 30
    num_epochs = 6

    train_data_bsc = ShuffledSentencesDataset(sts_reader_pos.get_examples('train.tsv'), model)
    train_dataloader_bsc = DataLoader(train_data_bsc, shuffle=False, batch_size=train_batch_size)
    train_loss_bsc = CombobscLoss(model=model, norm_dim=1, mu=0.1, tau=0.1)

    train_data = SentencesDataset(sts_reader.get_examples('train.tsv'), model)
    train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
    train_loss = losses.CosineSimilarityLoss(model=model)
    dev_sentences1 = []
    dev_sentences2 = []
    dev_labels = []
    with open(os.path.join('datasets/MRPC', "dev.tsv"), encoding='utf8') as fIn:
        for row in fIn:
            row = row.split('\t')
            dev_sentences1.append(row[3])
            dev_sentences2.append(row[4])
            dev_labels.append(int(row[0]))
    binary_acc_evaluator = BinaryClassificationEvaluator(dev_sentences1, dev_sentences2, dev_labels)
    binary_acc_evaluator.main_similarity = SimilarityFunction.COSINE

    warmup_steps = math.ceil(len(train_data)*num_epochs/train_batch_size*0.1)
    model_save_path = 'checkpoints/bsc_mrpc'

    shuffler = ModelExampleBasedShuffler(group_size=3, allow_same=True)

    get_ipython().system("rm -rf 'checkpoints/bsc_mrpc'")

    model.fit(train_objectives=[(train_dataloader_bsc, train_loss_bsc)],
              evaluator=binary_acc_evaluator,
              epochs=num_epochs,
              evaluation_steps=1000,
              warmup_steps=warmup_steps,
              optimizer_params={'alpha_lr': 0.1, 'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False},
              output_path=model_save_path,
              output_path_ignore_not_empty=True,
              shuffler=shuffler,
              shuffle_idxs=[0]
             )

    model = SentenceTransformer('checkpoints/bsc_mrpc')
    metric = model.evaluate(binary_acc_evaluator)
    with open('intervals_estimates/metrics-MRPC.txt', 'a+') as f:
        f.write(str(metric) + '\n')