In [27]:
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, InputExample, losses, models
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import functional as F

from datasets import load_dataset, Dataset

from torchinfo import summary

### My pooling method

In [44]:
class learnable_gaussian(nn.Module):
    '''
    I can weight something by a gaussian and let the network learn how wide the gaussian should be
    '''
    def __init__(self, center, var = 128):
        super().__init__()
        self.var = nn.Parameter(data =  torch.Tensor([var]))
        self.center = center
    def forward(self, x):
        return torch.exp(-(x - self.center)**2/(2 * self.var))

class SwiGLU(nn.Module):
    '''
    Basically SwiGLU from ChatGPT, looks correct
    '''
    def __init__(self, input_dim, hidden_dim, output_dim, use_layernorm = False):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(input_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        self.use_layernorm = use_layernorm
        if self.use_layernorm:
            self.ln = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        output = self.output_layer(F.silu(x1) * x2)
        if self.use_layernorm:
            return self.ln(output)
        return output

class SwiGLUNetwork(nn.Module):
    '''
    Another piece of GPT code with minor modifications
    Since we just want a one sentence embedding I am just setting output_dim = 1
    '''
    def __init__(self, input_dim, mid_dim = 64, use_layernorm = True):
        super(SwiGLUNetwork, self).__init__()
        self.layers = nn.ModuleList()
        depth_dim = np.arange(start = 0, stop = input_dim+1, step = 128)
        depth_dim[0], depth_dim[-1] = mid_dim, input_dim
        depth_dim = depth_dim[::-1]

        # Add subsequent layers
        for i in range(len(depth_dim) - 1):
            self.layers.append(SwiGLU(depth_dim[i], depth_dim[i], depth_dim[i+1]))
        self.layers.append(SwiGLU(mid_dim, mid_dim, 1, use_layernorm = False))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class my_pooling(nn.Module):
    def __init__(self):
        super().__init__()
        self.ffn_512 = SwiGLUNetwork(512)
        self.ffn_384 = SwiGLUNetwork(384)
        self.ffn_256 = SwiGLUNetwork(256)
        self.ffn_128 = SwiGLUNetwork(128)
        
        self.gaussian_512 = learnable_gaussian(512)
        self.gaussian_384 = learnable_gaussian(384)
        self.gaussian_256 = learnable_gaussian(256)
        self.gaussian_128 = learnable_gaussian(128)

        self.unconventional_pooling_weight = nn.Parameter(data =  torch.Tensor([0.01]))
    
    @staticmethod
    def linear_interpolate_sentence(masked_embeddings_tranposed, embed_length):
        embedding_512 = torch.empty(0)
        embedding_384 = torch.empty(0)
        embedding_256 = torch.empty(0)
        embedding_128 = torch.empty(0)
        
        for sentence, length in zip(masked_embeddings_tranposed, embed_length):
            sentence_512 = F.interpolate(sentence[:, 0:length].unsqueeze(0), size = 512, mode = 'linear')
            sentence_384 = F.interpolate(sentence[:, 0:length].unsqueeze(0), size = 384, mode = 'linear')
            sentence_256 = F.interpolate(sentence[:, 0:length].unsqueeze(0), size = 256, mode = 'linear')
            sentence_128 = F.interpolate(sentence[:, 0:length].unsqueeze(0), size = 128, mode = 'linear')

            embedding_512 = torch.cat((embedding_512, sentence_512), dim = 0)
            embedding_384 = torch.cat((embedding_384, sentence_384), dim = 0)
            embedding_256 = torch.cat((embedding_256, sentence_256), dim = 0)
            embedding_128 = torch.cat((embedding_128, sentence_128), dim = 0)
        
        return embedding_512, embedding_384, embedding_256, embedding_128


    def forward(self, features):
        # every sentence get's a mask that masks the rest of the length
        token_embeddings = features["token_embeddings"]
        attention_mask = features["attention_mask"]
        embed_length = attention_mask.sum(1) # number from 0 to 512 which is embed length
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
            )
        masked_embeddings = token_embeddings * input_mask_expanded # this is the real sentence
        mean_embeddings = torch.sum(masked_embeddings, 1)/torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        mean_embeddings = F.normalize(mean_embeddings, p=2, dim=1)
        fnn_raw_embeddings = torch.transpose(token_embeddings, -2, -1) # put 512 in last dim for linear layer
        fnn_512_embeddings, fnn_384_embeddings, fnn_256_embeddings, fnn_128_embeddings = self.linear_interpolate_sentence(fnn_raw_embeddings, embed_length)
        
        sum_unconv_embed = (
            self.ffn_512(fnn_512_embeddings) * self.gaussian_512(embed_length)[:, None, None] +
            self.ffn_384(fnn_384_embeddings) * self.gaussian_384(embed_length)[:, None, None] +
            self.ffn_256(fnn_256_embeddings) * self.gaussian_256(embed_length)[:, None, None] + 
            self.ffn_128(fnn_128_embeddings) * self.gaussian_128(embed_length)[:, None, None])

        output_mod = (sum_unconv_embed * self.unconventional_pooling_weight).squeeze() 
        output = F.normalize(mean_embeddings + output_mod, p=2, dim=1)

        features.update({"sentence_embedding": output})
        return features

### Let's test and see that it runs

In [45]:
sentence_trans = SentenceTransformer('sentence-transformers/multi-qa-distilbert-cos-v1')
sentence_trans.eval();
# model._first_module()
# sentence_trans

In [46]:
unconv_pooling = my_pooling()
my_pooling_model = SentenceTransformer(modules=[sentence_trans[0], unconv_pooling])
my_pooling_model.eval();

In [41]:
dataset = load_dataset("mteb/cqadupstack-physics", "corpus")

In [42]:
example_sentence = dataset['corpus']['text'][0:5]
example_sentence

["Let's discuss about $SU(3)$. I understand that the most important representations (relevant to physics) are the defining and the adjoint. In the defining representation of $SU(3)$; namely $\\mathbf{3}$, the Gell-Mann matrices are used to represent the generators $$ \\left[T^{A}\\right]_{ij} = \\dfrac{1}{2}\\lambda^{A}, $$ where $T^A$ are the generators and $\\lambda^A$ the Gell-Mann matrices. In adjoint representation, on the other hand, an $\\mathbf{8}$, the generators are represented by matrices according to $$ \\left[ T_{i} \\right]_{jk} = -if_{ijk}, $$ where $f_{ijk}$ are the structure constants. My question is this, how can one represent the generators in the $\\mathbf{10}$ of $SU(3)$, which corresponds to a symmetric tensor with 3 upper or lower indices (or for that matter how to represent the $\\mathbf{6}$ with two symmetric indices). What is the general procedure to represent the generators in an arbitrary representation?",
 'So in the context of a set of notes I am reading a

In [32]:
with torch.no_grad():    
    tokenized = sentence_trans[0].tokenize(example_sentence)
    my_tokenized = my_pooling_model[0].tokenize(example_sentence)
    built_in_out = sentence_trans(tokenized)
    my_out = my_pooling_model(my_tokenized)

Ignore above, use below for encode

In [47]:
with torch.no_grad():
    built_in_embedding = sentence_trans.encode(example_sentence)
    my_embedding = my_pooling_model.encode(example_sentence)


we see that it is almost the same, indeed, it starts out as a small modification, but I assume we can train and the modifcation will be large

In [48]:
built_in_embedding - my_embedding

array([[9.9465251e-06, 9.9614263e-06, 9.9707395e-06, ..., 9.8515302e-06,
        9.8710880e-06, 9.8878518e-06],
       [2.4335459e-05, 2.4091452e-05, 2.3666769e-05, ..., 2.3588538e-05,
        2.4154317e-05, 2.4106121e-05],
       [1.6091019e-04, 1.6538799e-04, 1.6475096e-04, ..., 1.6450509e-04,
        1.6423408e-04, 1.6508624e-04],
       [2.4216250e-05, 2.4129637e-05, 2.3532659e-05, ..., 2.4538487e-05,
        2.3966655e-05, 2.4194364e-05],
       [5.8986247e-05, 5.8498234e-05, 5.8483332e-05, ..., 5.8760867e-05,
        5.8815815e-05, 5.8639795e-05]], dtype=float32)

my pooling roughly adds 2 million parameters

In [28]:
summary(my_pooling_model)

Layer (type:depth-idx)                                       Param #
SentenceTransformer                                          --
├─Transformer: 1-1                                           --
│    └─DistilBertModel: 2-1                                  --
│    │    └─Embeddings: 3-1                                  23,835,648
│    │    └─Transformer: 3-2                                 42,527,232
├─my_pooling: 1-2                                            1
│    └─SwiGLUNetwork: 2-2                                    --
│    │    └─ModuleList: 3-3                                  1,330,689
│    └─SwiGLUNetwork: 2-3                                    --
│    │    └─ModuleList: 3-4                                  608,385
│    └─SwiGLUNetwork: 2-4                                    --
│    │    └─ModuleList: 3-5                                  214,145
│    └─SwiGLUNetwork: 2-5                                    --
│    │    └─ModuleList: 3-6                                  49,665