# <center> Compressed Bert Model

This model is a reduced BERT model that is trained differently. Where the first reduced BERT model is given a reduction head that is learned from the data, the resulting model had two different problems that we hope to solve here:

1. The reduced embeddings did not cluster well. Where a set of full-sized embeddings originally clustered into around 250 different clusters, the reduced embeddings were clustered into only 3 groups. Adding some dimensionality reduction to the already reduced embeddings resulted in more clusters, but the clustering did not appear to be meaningful.
2. After the model was fully trained, only the final reduced embedding size could be used. This means only the embeddings of size 48 could be used, and not embeddings that matched intermediate reduction layer sizes (512,256,128,64). 

In order to solve these problems, we first add a contrastive learning objective to our model, which should help the embeddings cluster better. The idea here is to push items that are similar together, and push items that are different apart from each other in the embedding space. While this would be a potentially difficult problem with unlabeled text, we have the advantage of having the original BERT embeddings to compare against. This means that we can use the original BERT embeddings to determine how similar certain examples are, and we can push the reduced embeddings towards similar examples or away from different examples to be similar to the original embeddings.

we can also add a set of decompression layers to the end of the model. This way, we can add loss terms to model to ensure that the full-size embeddings can be extracted from the reduced size embeddings. This allows us to potentially get reduced embeddings that are more meaningful and contain all of the information of the original embeddings.

The second problem is a bit more difficult to solve, but I see two potential solutions:
1) Train each compression/decompression layer separately, freezing the trained layers before adding the next compression/decompression layer of the model. This means that for any choice of intermediate layer size, the compression and decompression layers should work. However, because we have to freeze the layers each time we add another, this reduces the potential gains of the layers working together to compress the embeddings.
2) Find a way to get the loss terms from one intermediate compression layer to the correponding decompression layer. This way, we can train the entire model at once, and the intermediate layers should be able to be used. It could potentially be done by running through the intermediate composite parts of the compression/decompression layers multiple times, once for each reduction layer. For example, run the input through compression1->decompression1 for the first loss term, then compression1->compression2->decompression2->decompression1 for the next loss term, ..., and compression1->...->compressionN->decompressionN->...->decompression1 for the last loss term. This would increase the computation time significantly, but would potentially increase the ability of the model to learn the intermediate layers together while still allowing the intermediate layers to be used independently. We could also weight the different loss terms to give more weight to the first compression/decompression set at first, then slowly reweight in favor of the lower compression/decompression sets as the model trains. Note that when not training, the model should only run through all the compression layers first, then all the decompression layers, so that the intermediate embeddings do not need to be saved.

In [1]:
# TODO: Test the intermediate reduction layer embeddings to see how they perform on GLUE tasks compared to the fully reduced model.
# TODO: Determine how to train such that clusters are better
#    - Maybe try using self-supervised contrastive learning, where the full-size embeddings are used as a baseline. (https://encord.com/blog/guide-to-contrastive-learning/#:~:text=NLP%20deals%20with%20the%20processing,semantic%20information%20and%20contextual%20relationships.)
#    - Check out the papers here (https://github.com/ryanzhumich/Contrastive-Learning-NLP-Papers?tab=readme-ov-file#4-contrastive-learning-for-nlp)
# TODO: Check if its best to train the first reduction layer first, then the second, etc., freezing the previous layers as you go.

# TODO: Fix defaults from BertReducedConfig in the BertReducedForPreTraining class (and potentially other model classes).

In [2]:
from torch import nn

from reduced_encoders.modeling_reduced import DimReduceLayer

class Decoder(nn.Sequential):
    """
    Module that includes a sequence of layers that are used to undo the reduction of the input
    embeddings. This module is used in the BertCompressedForPreTraining class to train the 
    autoencoder part of that model.

    Args:
        config (PretrainedConfig): Configuration for the base model. Should include the hidden_size
            and reduction_sizes parameters for the dimensions of each layer of this module.
        modules (OrderedDict): An optional ordered dictionary of modules to load the reduction
            layers from. If not specified, the reduction layers will be randomly initialized, 
            using the sizes from the reduction_sizes parameter in the configuration.
    """
    def __init__(self, config, modules=None):
        input_size = config.reduced_size
        self.decoder_sizes = config.reduction_sizes.reverse()
        
        if modules is None:
            modules = OrderedDict()
            for i, decoded_size in enumerate(self.decoder_sizes):   
                modules[str(i)] = DimReduceLayer(input_size, decoded_size, config)
                input_size = decoded_size
        elif not isinstance(modules, OrderedDict):
            modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
    
        super().__init__(modules)

In [3]:
import torch
import torch.nn.functional as F

from transformers import BertModel
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput

from reduced_encoders import BertReducedPreTrainedModel, DimReduce, BertReducedConfig
from reduced_encoders.models.bert_reduced.modeling_bert_reduced import BertReducedPreTrainingHeads

class BertCompressedForPretraining(BertReducedPreTrainedModel):
    """
    A reduced BERT model used during pretraining. The model has both an MLM and NSP head, and 
    uses a set of decompression layers during training to ensure that the reduced embeddings
    are able to capture all the information from the original embeddings.
    
    Args:
        config (BertReducedConfig): Configuration for the reduced BERT model. 
        base_model: The base BERT model to use. If not specified, a new BERT model will be
            initialized using the config.
        reduce_module: The dimensionality reduction module to use. If not specified, a new
            module will be initialized using the config.
    """
    def __init__(self, config=None, base_model=None, reduce_module=None):
        super().__init__(config)

        self.bert = base_model if base_model is not None else BertModel(self.config)
        self.reduce = reduce_module if reduce_module is not None else DimReduce(self.config) # TODO: Rewrite DimReduce to get intermediate layers
        self.decoder = Decoder() # The decoder module defined above
        self.cls = BertReducedPreTrainingHeads(self.config)

        self.post_init()

    def _get_similarities(self, embeddings):
        """Returned the flattened upper triangular cosine similarity matrix of the given embeddings."""
        cos_sim = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
        similarity_matrix = cos_sim(embeddings.unsqueeze(0), embeddings.unsqueeze(1))
        indices = torch.triu_indices(*similarity_matrix.shape, offset=1)
        return similarity_matrix[indices[0], indices[1]]
    
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, 
                labels=None, next_sentence_label=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output, pooled_output = outputs[:2]
        reduced_seq, reduced_pooled = self.reduce(sequence_output), self.reduce(pooled_output) # TODO: Reduce should output intermediate embeddings
        prediction_scores, seq_relationship_score = self.cls(reduced_seq, reduced_pooled)

        print("Pooled output shape:", pooled_output.shape)
        print("Reduced pooled output shape:", reduced_pooled.shape)

        # Cross-entropy loss
        total_loss = None
        if labels is not None and next_sentence_label is not None:
            loss_fct = nn.CrossEntropyLoss()
            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
            total_loss = masked_lm_loss + next_sentence_loss

        # Add contrastive loss (MSE between similarity scores of pairs of embeddings -- original vs reduced)
        # TODO: Decide whether to add loss for intermediate embeddings
        similarity = self._get_similarities(pooled_output)
        reduced_similarity = self._get_similarities(reduced_pooled)
        contrastive_loss = F.mse_loss(similarity, reduced_similarity)

        # Add reconstruction loss (MSE of BERT embeddings and decoded reduced embeddings)
        # TODO: Decide whether to add loss for intermediate embeddings
        # NOTE: NEW IDEA -- Don't run each intermediate layer again, just get MSE between intermediate reduced and decoded embeddings
        #                   This would be much faster, and it could be computed all in the same function call for each layer
        pooled_decoded = self.decoder(reduced_pooled)[:2]
        reconstruction_loss = F.mse_loss(pooled_output, pooled_decoded)


        if not return_dict:
            output = (prediction_scores, seq_relationship_score) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return BertForPreTrainingOutput(
            loss=total_loss,
            prediction_logits=prediction_scores,
            seq_relationship_logits=seq_relationship_score,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [4]:
import torch
from typing import Optional, Tuple
from dataclasses import dataclass
from transformers.utils import ModelOutput

@dataclass
class CompressedModelForPreTrainingOutput(ModelOutput):
    """
    Ouput type of ['MPNetCompressedForPretraining']

    Args:
        loss (torch.FloatTensor): Linear combination of the contrastive learning and the reconstruction loss. The contrastive 
            learning loss is the MSE between the cosine similarity of data pairs in the original and reduced embeddings. 
            The reconstruction loss is the MSE between the original and decoded reduced embeddings.
        hidden_states (tuple(torch.FloatTensor)): Hidden-states of the model at the output of each layer
        attentions (tuple(torch.FloatTensor)): Attentions weights after the attention softmax, used to compute the weighted 
            average in the self-attention heads.
    """
    loss: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

In [5]:
from torch.nn import functional as F

from transformers import MPNetModel
from reduced_encoders import MPNetReducedPreTrainedModel, DimReduce
from reduced_encoders.models.mpnet_reduced.modeling_sbert import SBertPooler

class MPNetCompressedForPretraining(MPNetReducedPreTrainedModel):
    def __init__(self, config=None, base_model=None, reduce_module=None, **kwargs):
        super().__init__(config)

        kwargs['add_pooling_layer'] = False     # We use our own pooling instead
        self.mpnet = base_model or MPNetModel(self.config, **kwargs)
        self.pooler = SBertPooler(self.config)
        self.reduce = reduce_module or DimReduce(self.config)

    def _get_similarities(self, embeddings):
        """Returned the flattened upper triangular cosine similarity matrix of the given embeddings."""
        cos_sim = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
        similarity_matrix = cos_sim(embeddings.unsqueeze(0), embeddings.unsqueeze(1))
        indices = torch.triu_indices(*similarity_matrix.shape, offset=1)
        return similarity_matrix[indices[0], indices[1]]
        
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, 
                inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.mpnet(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        sequence_output = outputs[0]
        pooled_output = self.pooler(sequence_output, attention_mask)  
        reduced_pooled = self.reduce(pooled_output)

        # TODO: There are ways to compute the loss at each layer of the reduction, is that something possible/something we want to do?

        # Compute contrastive loss
        full_similarity = self._get_similarities(pooled_output)
        reduced_similarity = self._get_similarities(reduced_pooled)
        contrastive_loss = F.mse_loss(full_similarity, reduced_similarity)
        print(full_similarity)
        print(reduced_similarity)

        # Compute reconstruction loss
        # TODO: Decide whether to implement this loss
        reconstruction_loss = 0     

        # Compute total loss
        loss = contrastive_loss + reconstruction_loss

        if not return_dict:
            return (embeddings, pooled_embeddings) + outputs[2:]

        return CompressedModelForPreTrainingOutput(
            loss=loss,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [6]:
from transformers import AutoTokenizer
from transformers import AutoModel

checkpoint = "sentence-transformers/all-mpnet-base-v2"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
mpnet = AutoModel.from_pretrained(checkpoint, add_pooling_layer=False)

In [7]:
from transformers import AutoConfig
from reduced_encoders import MPNetReducedConfig


base_config = AutoConfig.from_pretrained(checkpoint)
config = MPNetReducedConfig.from_config(base_config, reduction_sizes=[512,256,128,64,48])
config

MPNetReducedConfig {
  "_name_or_path": "sentence-transformers/all-mpnet-base-v2",
  "architectures": [
    "MPNetForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "mpnet_reduced",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "pooling_mode": "mean",
  "reduced_size": 48,
  "reduction_sizes": [
    512,
    256,
    128,
    64,
    48
  ],
  "relative_attention_num_buckets": 32,
  "transformers_version": "4.31.0.dev0",
  "vocab_size": 30527
}

In [8]:
compressed_model = MPNetCompressedForPretraining(config=config, base_model=mpnet)
compressed_model

MPNetCompressedForPretraining(
  (mpnet): MPNetModel(
    (embeddings): MPNetEmbeddings(
      (word_embeddings): Embedding(30527, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): MPNetEncoder(
      (layer): ModuleList(
        (0-11): 12 x MPNetLayer(
          (attention): MPNetAttention(
            (attn): MPNetSelfAttention(
              (q): Linear(in_features=768, out_features=768, bias=True)
              (k): Linear(in_features=768, out_features=768, bias=True)
              (v): Linear(in_features=768, out_features=768, bias=True)
              (o): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [17]:
from reduced_encoders.debug_utils import compare_weights

compare_weights(mpnet, compressed_model.mpnet)

True

In [9]:
text = ['This is a test sentence that is meant to determine whether I can run text through my new compressed SBERT model. Did it work?',
        'This is also a test sentence, but it is different from the first one. I hope this works!',
        'A feral cat walked down the street, hoping to find a place to rest for the night',
        'The last sentence was significantly different from the others to see where the embedding lands']

In [10]:
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")

In [11]:
with torch.no_grad():
    outputs = compressed_model(**inputs)

tensor([0.4352, 0.0172, 0.3200, 0.0596, 0.3839, 0.0697])
tensor([0.9052, 0.8000, 0.7335, 0.8493, 0.7925, 0.7880])


In [12]:
outputs

CompressedModelForPreTrainingOutput(loss=tensor(0.3852), hidden_states=None, attentions=None)

In [13]:
compressed_model.config.architectures = [compressed_model.__class__.__name__]
compressed_model.config

MPNetReducedConfig {
  "_name_or_path": "sentence-transformers/all-mpnet-base-v2",
  "architectures": [
    "MPNetCompressedForPretraining"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "mpnet_reduced",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "pooling_mode": "mean",
  "reduced_size": 48,
  "reduction_sizes": [
    512,
    256,
    128,
    64,
    48
  ],
  "relative_attention_num_buckets": 32,
  "transformers_version": "4.31.0.dev0",
  "vocab_size": 30527
}

In [15]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [16]:
compressed_checkpoint = "all-mpnet-base-v2-compressed"
compressed_model.push_to_hub(compressed_checkpoint)
tokenizer.push_to_hub(compressed_checkpoint)

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/cayjobla/all-mpnet-base-v2-compressed/commit/805dcba348f4af29642076a2c2c4266593448b4c', commit_message='Upload tokenizer', commit_description='', oid='805dcba348f4af29642076a2c2c4266593448b4c', pr_url=None, pr_revision=None, pr_num=None)