In [1]:
import sys
sys.path.append("..")

import torch

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.models import Transformer, Pooling

from datasets import load_dataset, get_dataset_config_names

from src.models import SIFPooling, BatchPCRemoval


data_stsb = load_dataset("mteb/stsbenchmark-sts", split="test")
eval_stsb = EmbeddingSimilarityEvaluator(sentences1=data_stsb["sentence1"], 
                                         sentences2=data_stsb["sentence2"], 
                                         scores=data_stsb["score"], 
                                         name="stsb",
                                         write_csv=False,
                                         batch_size=512)

data_sick = load_dataset("sick", split="test")
eval_sick = EmbeddingSimilarityEvaluator(sentences1=data_sick["sentence_A"],
                                         sentences2=data_sick["sentence_B"],
                                         scores=data_sick["relatedness_score"],
                                         name="sick",
                                         write_csv=False,
                                         batch_size=512)

# Load wiki-text-2 to estimate the word frequencies
corpus = load_dataset("wikitext", "wikitext-2-v1")
corpus = [s.strip() for p in corpus.values() for s in p["text"]
            if s.strip() != "" and not s.strip().startswith("=")]

Using custom data configuration mteb--stsbenchmark-sts-998a21523b45a16a
Found cached dataset json (/home/dogdog/.cache/huggingface/datasets/mteb___json/mteb--stsbenchmark-sts-998a21523b45a16a/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)
Found cached dataset sick (/home/dogdog/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
Found cached dataset wikitext (/home/dogdog/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
model_card = "distilbert-base-uncased"

embedding_layer = Transformer(model_card)
normal_pooling_layer = Pooling(embedding_layer.get_word_embedding_dimension(), pooling_mode="mean")
weighted_pooling_layer = SIFPooling.from_corpus_hf(model_card, corpus)
batch_pc_removal_layer = BatchPCRemoval(n_components=1)

model_b = SentenceTransformer(modules=[embedding_layer, normal_pooling_layer])
model_w = SentenceTransformer(modules=[embedding_layer, weighted_pooling_layer])
model_r = SentenceTransformer(modules=[embedding_layer, weighted_pooling_layer, batch_pc_removal_layer])

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
eval_sick(model_b, output_path=".")

0.6368024616834773

In [4]:
eval_sick(model_w, output_path=".")

0.6614225828406428

In [5]:
eval_sick(model_r, output_path=".")

0.6755694201801812