In [None]:
import os, sys
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install sentence_transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence_transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[K     |████████████████████████████████| 85 kB 4.6 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 77.7 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 74.6 MB/s 
[?25hCollecting huggingface-hub>=0.4.0
  Downloading huggingface_hub-0.11.0-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 87.7 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 64.7 MB/s 
Building wheels for collected 

In [None]:
from torch.utils.data import DataLoader
import math
from sentence_transformers import models, losses
from sentence_transformers import LoggingHandler, SentenceTransformer, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
import logging
from datetime import datetime
import sys
import os
import gzip
import csv
import random
import torch
from torch import nn, Tensor
from typing import Iterable, Dict

In [None]:

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

#Check if dataset exsist. If not, download and extract  it
nli_dataset_path = 'data/AllNLI.tsv.gz'
sts_dataset_path = 'data/stsbenchmark.tsv.gz'

if not os.path.exists(nli_dataset_path):
    util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path)

# if not os.path.exists(sts_dataset_path):
#     util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)

  0%|          | 0.00/40.8M [00:00<?, ?B/s]

  0%|          | 0.00/392k [00:00<?, ?B/s]

In [None]:
# Read the AllNLI.tsv.gz file and create the training dataset
logging.info("Read AllNLI train dataset")

label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
train_samples = []

with gzip.open(nli_dataset_path, 'rt', encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
        if row['split'] == 'train':
            label_id = label2int[row['label']]
            if label_id == 1 or label_id == 0:
              train_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=label_id))


random.shuffle(train_samples)


In [None]:
logging.info("Read AllNLI dev and test datasets")

label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
dev_samples = []
test_samples = []
with gzip.open(nli_dataset_path, 'rt', encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
      if label2int[row['label']] == 1 or label2int[row['label']] == 0: 
        if row['split'] == 'dev':
            dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=label2int[row['label']]))
        elif row['split'] == 'test':
            test_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=label2int[row['label']]))

random.shuffle(dev_samples)
random.shuffle(test_samples)

In [None]:
print('Number of training samples: ', len(train_samples))
print('Number of dev samples: ', len(dev_samples))
print('Number of test samples: ', len(test_samples))

Number of training samples:  628405
Number of positive samples:  314315
Number of dev samples:  13299
Number of test samples:  13308


In [None]:
import pickle

In [None]:
with open('train_samples.pickle', 'wb') as handle:
    pickle.dump(train_samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('positive_samples.pickle', 'wb') as handle:
    pickle.dump(positive_samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

Load Datafiles

In [None]:
with open('positive_samples.pickle', 'rb') as handle:
    b = pickle.load(handle)

In [None]:
print(b == positive_samples)

False


In [None]:
def lorentz_dist(u, v, beta = 1):
  u0 = torch.sqrt(torch.pow(u,2).sum(-1, keepdim=True) + beta)
  v0 = -torch.sqrt(torch.pow(v,2).sum(-1, keepdim=True) + beta)
  u = torch.cat((u,u0),-1)
  v = torch.cat((v,v0),-1)
  result = - 2 * beta - 2 *torch.sum(u * v, dim=-1)
  return result

In [None]:

class customLoss(nn.Module):
  
  def __init__(self, model: SentenceTransformer, distance_metric = lorentz_dist, align_alpha = 2, unif_t = 5, w_align = 0.9, w_unif = 0.1):
        super(customLoss, self).__init__()
        self.distance_metric = distance_metric
        self.model = model
        self.align_alpha = align_alpha
        self.unif_t = unif_t
        self.w_align = w_align
        self.w_unif = w_unif
        
  
  def align_loss(self, distance, alpha):
    return distance.norm(p=2, dim=1).pow(alpha).mean()


  def uniform_loss(self, x, t):
    return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()

  
  def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
        label_list = labels.tolist()
        reps_pos = []

        pos_label = (labels == 1)
        indices = pos_label.nonzero().squeeze(1)
        pos_rep_anchor, pos_rep_other = reps[0][indices], reps[1][indices]
        rep_anchor, rep_other = reps
        distances = self.distance_metric(pos_rep_anchor, pos_rep_other)
        a_loss = self.align_loss(distances.unsqueeze(dim = 1), self.align_alpha)
        u_loss = (self.uniform_loss(rep_anchor, self.unif_t) + self.uniform_loss(rep_other, self.unif_t)) / 2
        t_loss =  self.w_align * a_loss + self.w_unif * u_loss
        return t_loss

       

In [None]:
train_batch_size = 64

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, batch_size=train_batch_size, name='nli-dev')
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, batch_size=train_batch_size, name='nli-test')

In [None]:
word_embedding_model = models.Transformer('bert-base-uncased')
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

train_loss = customLoss(model = model)
num_epochs = 10
warmup_steps = math.ceil(len(train_dataloader) * num_epochs  * 0.1) #10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))

model_save_path = 'output/training_nli_custom-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator = dev_evaluator,
          epochs = num_epochs,
          warmup_steps = warmup_steps,
          output_path = model_save_path
          )




Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/9819 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
model = SentenceTransformer(model_save_path)
model.evaluate(test_evaluator)