In [1]:
from typing import List, Dict, Tuple, Callable, Union, Optional, Any
import logging

import os
import sys
sys.path.append("..")

import json

import torch
import torch.nn as nn

from sklearn.decomposition import TruncatedSVD


logger = logging.getLogger(__name__)


class DirectionRemover(nn.Module):
    
    def __init__(self, direction: List[float]):
        super(DirectionRemover, self).__init__()
        self.direction = nn.Parameter(torch.FloatTensor(direction), requires_grad=False)
        
    def forward(self, features: Dict[str, torch.Tensor]):
        """Performs direction removal by first project the sentence embedding onto the
            direction specified by `self.direction` and then subtracting the projection.`"""
        
        sentence_embedding = features["sentence_embedding"]
        sentence_embedding = sentence_embedding - torch.sum(sentence_embedding * self.direction, dim=1, keepdim=True) * self.direction
        features.update({"sentence_embedding": sentence_embedding})
        return features

    
def compute_pc(X, npc=1):
    """Compute the principal components. DO NOT MAKE THE DATA ZERO MEAN!
    :param X: 
        The sentence embedding matrix; `X[i,:]` is a data point
    :param npc: 
        Number of principal components to remove
    :return: 
        `component_[i,:]` is the i-th pc
    """

    svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0)
    svd.fit(X)
    return svd.components_

In [2]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.models import Transformer, Pooling

from src.models import SIFPooling

from datasets import load_dataset, get_dataset_config_names


model_card = "bert-base-uncased"

data_stsb = load_dataset("mteb/stsbenchmark-sts")
eval_stsb = EmbeddingSimilarityEvaluator(sentences1=data_stsb["test"]["sentence1"], 
                                         sentences2=data_stsb["test"]["sentence2"], 
                                         scores=data_stsb["test"]["score"], 
                                         name="stsb-dev",
                                         write_csv=False)

data_sick = load_dataset("sick", split="test")
eval_sick = EmbeddingSimilarityEvaluator(sentences1=data_sick["sentence_A"],
                                         sentences2=data_sick["sentence_B"],
                                         scores=data_sick["relatedness_score"],
                                         name="sick-dev",
                                         write_csv=False)

corpus = [s.encode("utf-8").decode("utf-8") for s in 
            data_stsb["train"]["sentence1"] + data_stsb["train"]["sentence2"]]

Using custom data configuration mteb--stsbenchmark-sts-998a21523b45a16a
Found cached dataset json (/home/dogdog/.cache/huggingface/datasets/mteb___json/mteb--stsbenchmark-sts-998a21523b45a16a/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset sick (/home/dogdog/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)


In [3]:
embedding_layer = Transformer(model_card)
normal_pooling_layer = Pooling(embedding_layer.get_word_embedding_dimension(), pooling_mode="mean")
weighted_pooling_layer = SIFPooling.from_corpus_hf(model_card, corpus)

model_b = SentenceTransformer(modules=[embedding_layer, normal_pooling_layer])
model_w = SentenceTransformer(modules=[embedding_layer, weighted_pooling_layer])

# normal_dr_layer = DirectionRemover(compute_pc(model_b.encode(corpus)))
# weighted_dr_layer = DirectionRemover(compute_pc(model_w.encode(corpus)))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
model_w.save("../models/sif-bert-base-uncased")

In [5]:
model_w.load("../models/sif-bert-base-uncased")

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): SIFPooling(
    (emb_layer): Embedding(30522, 1)
  )
)

In [18]:
eval_stsb(model_b, output_path=".")

0.48862744348863496

In [58]:
eval_stsb(model_w, output_path=".")

0.6453034678942327

In [24]:
model_br = SentenceTransformer(modules=[embedding_layer, normal_pooling_layer, normal_dr_layer])
eval_stsb(model_br, output_path=".")

0.5503251654706054

In [25]:
model_wr = SentenceTransformer(modules=[embedding_layer, weighted_pooling_layer, weighted_dr_layer])
eval_stsb(model_wr, output_path=".")

0.6581081183448381

In [33]:
embedding_layer = models.Transformer(model_card)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [40]:
cross_layer_pooling = models.WeightedLayerPooling(embedding_layer.get_word_embedding_dimension(),
                                                  num_hidden_layers=12,
                                                  layer_start=11)

model_t = SentenceTransformer(modules=[embedding_layer, cross_layer_pooling, weighted_pooling_layer, weighted_dr_layer])

In [41]:
eval_stsb(model_t, output_path=".")

0.6619439511647357