In [1]:
import os, sys
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install sentence_transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence_transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[K     |████████████████████████████████| 85 kB 1.5 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 41.3 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 67.3 MB/s 
[?25hCollecting huggingface-hub>=0.4.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 76.4 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 62.7 MB/s 
Building wheels for collected pa

TESTING LORENTZ DISTANCE

In [None]:
# model = SentenceTransformer('all-MiniLM-L6-v2')
# train_examples = [
#     InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
#     InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]
# train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
# train_loss = losses.ContrastiveLoss(model=model, distance_metric= lorentz_dist)
# model.fit([(train_dataloader, train_loss)], show_progress_bar=True)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import math
from sentence_transformers import models, losses, util
from sentence_transformers import LoggingHandler, SentenceTransformer
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, BinaryClassificationEvaluator
from sentence_transformers import SentencesDataset
from sentence_transformers.readers import *
import logging
from datetime import datetime
import os
import csv
import gzip

In [4]:
def lorentz_dist(u, v, beta = 0.1):
  u0 = torch.sqrt(torch.pow(u,2).sum(-1, keepdim=True) + beta)
  v0 = -torch.sqrt(torch.pow(v,2).sum(-1, keepdim=True) + beta)
  u = torch.cat((u,u0),-1)
  v = torch.cat((v,v0),-1)
  result = - 2 * beta - 2 *torch.sum(u * v, dim=-1)
  return result


In [5]:
class customLoss(nn.Module):
  
  def __init__(self, model: SentenceTransformer, distance_metric=lorentz_dist):
        super(ContrastiveLoss, self).__init__()
        self.distance_metric = distance_metric
        self.model = model
  
  def forward(self, sentence_features, labels):
        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        assert len(reps) == 2
        rep_anchor, rep_other = reps
        distances = self.distance_metric(rep_anchor, rep_other)
        return losses.mean() 

In [6]:
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

# Read the dataset
batch_size = 32
model_save_path = 'output/training_sts-hyperboloid-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")


#Check if dataset exsist. If not, download and extract  it
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'

if not os.path.exists(sts_dataset_path):
    util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)

logging.info("Read STSbenchmark train dataset")

train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
        score = float(row['score']) / 5.0  # Normalize score to range 0 ... 1
        inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)

        if row['split'] == 'dev':
            dev_samples.append(inp_example)
        elif row['split'] == 'test':
            test_samples.append(inp_example)
        else:
            train_samples.append(inp_example)


# Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset")
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size)


logging.info("Read STSbenchmark dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')


  0%|          | 0.00/392k [00:00<?, ?B/s]

In [7]:

''' insert lorentz here'''

word_embedding_model = models.Transformer('bert-base-uncased')

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)


model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
# model = SentenceTransformer('all-MiniLM-L6-v2')
train_loss = losses.ContrastiveLoss(model=model, distance_metric= lorentz_dist)

# Configure the training
num_epochs = 10
warmup_steps = math.ceil(len(train_dataloader) * num_epochs  * 0.1) #10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))

# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps = int(len(train_dataloader)*0.1),
          warmup_steps=warmup_steps,
          output_path=model_save_path
          )



##############################################################################
#
# Load the stored model and evaluate its performance on STS benchmark dataset
#
##############################################################################

model = SentenceTransformer(model_save_path)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name='sts-test')
model.evaluate(test_evaluator)

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/180 [00:00<?, ?it/s]

Iteration:   0%|          | 0/180 [00:00<?, ?it/s]

Iteration:   0%|          | 0/180 [00:00<?, ?it/s]

Iteration:   0%|          | 0/180 [00:00<?, ?it/s]

Iteration:   0%|          | 0/180 [00:00<?, ?it/s]

Iteration:   0%|          | 0/180 [00:00<?, ?it/s]

Iteration:   0%|          | 0/180 [00:00<?, ?it/s]

Iteration:   0%|          | 0/180 [00:00<?, ?it/s]

Iteration:   0%|          | 0/180 [00:00<?, ?it/s]

Iteration:   0%|          | 0/180 [00:00<?, ?it/s]

0.7788653879618485

Testing on NLI

In [8]:
#Check if dataset exsist. If not, download and extract  it
nli_dataset_path = 'data/AllNLI.tsv.gz'
if not os.path.exists(nli_dataset_path):
    util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path)

  0%|          | 0.00/40.8M [00:00<?, ?B/s]

In [9]:
# Read the AllNLI.tsv.gz file and create the training dataset
logging.info("Read AllNLI train dataset")

label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
test_s1 = []
test_s2 = []
test_labels = []

with gzip.open(nli_dataset_path, 'rt', encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
        if row['split'] == 'test':
            label_id = label2int[row['label']]
            if label_id == 1 or label_id == 0:
              test_s1.append(row['sentence1'])
              test_s2.append(row['sentence2'])
              test_labels.append(label_id)

In [10]:
model_test_path = 'output/'

In [11]:
test_evaluator = BinaryClassificationEvaluator(test_s1, test_s2, test_labels)
test_evaluator(model, output_path=model_test_path)

0.7216995171413687

In [None]:
!zip -r /content/bert_custdistfn.zip /content/output/training_sts-hyperboloid-2022-12-09_16-35-04

  adding: content/output/training_sts-hyperboloid-2022-12-09_16-35-04/ (stored 0%)
  adding: content/output/training_sts-hyperboloid-2022-12-09_16-35-04/config_sentence_transformers.json (deflated 26%)
  adding: content/output/training_sts-hyperboloid-2022-12-09_16-35-04/pytorch_model.bin (deflated 7%)
  adding: content/output/training_sts-hyperboloid-2022-12-09_16-35-04/README.md (deflated 58%)
  adding: content/output/training_sts-hyperboloid-2022-12-09_16-35-04/special_tokens_map.json (deflated 42%)
  adding: content/output/training_sts-hyperboloid-2022-12-09_16-35-04/vocab.txt (deflated 53%)
  adding: content/output/training_sts-hyperboloid-2022-12-09_16-35-04/tokenizer.json (deflated 71%)
  adding: content/output/training_sts-hyperboloid-2022-12-09_16-35-04/1_Pooling/ (stored 0%)
  adding: content/output/training_sts-hyperboloid-2022-12-09_16-35-04/1_Pooling/config.json (deflated 49%)
  adding: content/output/training_sts-hyperboloid-2022-12-09_16-35-04/config.json (deflated 48%)
