In [1]:
from typing import List, Dict, Tuple, Callable, Union, Optional, Any
import logging

import os
import sys
sys.path.append("..")

import json

import torch
import torch.nn as nn

from sklearn.decomposition import TruncatedSVD


logger = logging.getLogger(__name__)


class BatchPCRemoval(nn.Module):
    
    def __init__(self, npc: int = 1):
        super(BatchPCRemoval, self).__init__()
        self.config_keys = ["npc"]
        self.npc = npc

    def forward(self, features: Dict[str, Any]):
        
        embeddings = features["sentence_embedding"]
        embeddings_detached = embeddings.detach()

        _, _, pcs = torch.linalg.svd(embeddings_detached, full_matrices=False)
        pcs = pcs[:self.npc]
        proj = torch.matmul(embeddings_detached, pcs.transpose(1, 0))

        features.update({"sentence_embedding": embeddings - torch.matmul(proj, pcs)})
        return features
        
    # For IO //////////////////////////////////////////////////////////////////////
    def get_config_dict(self):
        return {key: self.__dict__[key] for key in self.config_keys}

    def save(self, output_path):
        with open(os.path.join(output_path, "config.json"), "w") as f:
            json.dump(self.get_config_dict(), f, indent=2)

    @classmethod
    def load(cls, input_path):
        with open(os.path.join(input_path, "config.json")) as f:
            config = json.load(f)
        return cls(**config)


class DirectionRemover(nn.Module):
    
    def __init__(self, direction: List[float]):
        super(DirectionRemover, self).__init__()
        self.direction = nn.Parameter(torch.FloatTensor(direction), requires_grad=False)
        
    def forward(self, features: Dict[str, torch.Tensor]):
        """Performs direction removal by first project the sentence embedding onto the
            direction specified by `self.direction` and then subtracting the projection.`"""
        
        sentence_embedding = features["sentence_embedding"]
        sentence_embedding = sentence_embedding - torch.sum(sentence_embedding * self.direction, dim=1, keepdim=True) * self.direction
        features.update({"sentence_embedding": sentence_embedding})
        return features

    
def compute_pc(X, npc=1):
    """Compute the principal components. DO NOT MAKE THE DATA ZERO MEAN!
    :param X: 
        The sentence embedding matrix; `X[i,:]` is a data point
    :param npc: 
        Number of principal components to remove
    :return: 
        `component_[i,:]` is the i-th pc
    """

    svd = TruncatedSVD(n_components=npc, n_iter=7, random_state=0)
    svd.fit(X)
    return svd.components_

In [28]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.models import Transformer, Pooling

from src.models import SIFPooling

from datasets import load_dataset, get_dataset_config_names


model_card = "bert-base-cased"

data_stsb = load_dataset("mteb/stsbenchmark-sts")
eval_stsb = EmbeddingSimilarityEvaluator(sentences1=data_stsb["test"]["sentence1"], 
                                         sentences2=data_stsb["test"]["sentence2"], 
                                         scores=data_stsb["test"]["score"], 
                                         name="stsb-dev",
                                         write_csv=False,
                                         batch_size=512)

data_sick = load_dataset("sick", split="test")
eval_sick = EmbeddingSimilarityEvaluator(sentences1=data_sick["sentence_A"],
                                         sentences2=data_sick["sentence_B"],
                                         scores=data_sick["relatedness_score"],
                                         name="sick-dev",
                                         write_csv=False,
                                         batch_size=512)

# Load wiki-text-2 to estimate the word frequencies
corpus = [s.strip() for s in load_dataset("wikitext", "wikitext-2-v1", split="train")["text"]
            if s.strip() != "" and not s.strip().startswith("=")]

Using custom data configuration mteb--stsbenchmark-sts-998a21523b45a16a
Found cached dataset json (/home/dogdog/.cache/huggingface/datasets/mteb___json/mteb--stsbenchmark-sts-998a21523b45a16a/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset sick (/home/dogdog/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
Found cached dataset wikitext (/home/dogdog/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


In [29]:
embedding_layer = Transformer(model_card)
normal_pooling_layer = Pooling(embedding_layer.get_word_embedding_dimension(), pooling_mode="mean")
weighted_pooling_layer = SIFPooling.from_corpus_hf(model_card, corpus)
batch_pc_removal_layer = BatchPCRemoval(npc=1)

model_b = SentenceTransformer(modules=[embedding_layer, normal_pooling_layer])
model_w = SentenceTransformer(modules=[embedding_layer, weighted_pooling_layer])
model_r = SentenceTransformer(modules=[embedding_layer, weighted_pooling_layer, batch_pc_removal_layer])

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Token indices sequence length is longer than the specified maximum sequence length for this model (726 > 512). Running this sequence through the model will r

In [33]:
eval_sick(model_b, output_path=".")

0.5910148950535766

In [34]:
eval_sick(model_w, output_path=".")

0.6030681625701952

In [35]:
eval_sick(model_r, output_path=".")

0.6409534386135679