In [23]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import numpy as np

class CrossEncoderCSR(nn.Module):
    def __init__(
        self,
        model_name="roberta-base",
    ):
        super().__init__()
        
        # Initialize encoders
        self.encoder = AutoModel.from_pretrained(model_name)
        self.encoder_config = self.encoder.config
        self.total_layers = len(self.encoder.encoder.layer)

        self.d_k = self.encoder_config.hidden_size

        self.W_pseduo_q = nn.Linear(self.d_k, self.d_k)

        self.W_pseduo_k = nn.Linear(self.d_k, self.d_k)

    def forward(self, s1_input_ids, s1_attention_mask, s2_input_ids, s2_attention_mask, c_input_ids, c_attention_mask):
        # Encode context
        c = self.encoder(c_input_ids, c_attention_mask, output_attentions=True)

        # [CLS] <- (1, d_c)
        cls = c[0][:, 0, :]

        q_cls = self.W_pseduo_q(cls)

        # print("---DEBUG---")
        # print("CLS: ", q_cls.shape)

        # Start with token embeddings
        s_1_h = self.encoder.embeddings(s1_input_ids)
        s_2_h = self.encoder.embeddings(s2_input_ids)

        # Process through first half of layers
        for i in range(self.total_layers // 2):
            layer = self.encoder.encoder.layer[i]
            s_1_h = layer(s_1_h, attention_mask=s1_attention_mask)[0]
            s_2_h = layer(s_2_h, attention_mask=s2_attention_mask)[0]
            
        k_s_1 = self.W_pseduo_k(s_1_h) 
        k_s_2 = self.W_pseduo_k(s_2_h)
        
        # Compute router weights
        w_s_1 = self.get_c_router_weights(q_cls, k_s_1)
        w_s_2 = self.get_c_router_weights(q_cls, k_s_2)

        # Apply weights
        s_1_h = (w_s_1.transpose(1, 2) + 1) * s_1_h
        s_2_h = (w_s_2.transpose(1, 2) + 1) * s_2_h

        # Process through second half of layers
        for i in range(self.total_layers // 2, self.total_layers):
            layer = self.encoder.encoder.layer[i]
            s_1_h = layer(s_1_h, attention_mask=s1_attention_mask)[0]
            s_2_h = layer(s_2_h, attention_mask=s2_attention_mask)[0]

        # return s_1_hidden_state, s_2_hidden_state

        average_pooling = nn.AdaptiveAvgPool1d(1)  

        rs_c_1 = average_pooling(s_1_h.squeeze(0)).squeeze(-1)  # Shape: (786,)
        rs_c_2 = average_pooling(s_2_h.squeeze(0)).squeeze(-1)  # Shape: (786,)      

        return rs_c_1, rs_c_2
    def get_c_router_weights(self, q_c, k_s):
        """
        Compute attention weights using the formula from the paper
        score = q_c * k_s / sqrt(d_k)
        w = softmax(score)
        """
        # Compute similarity scores
        scores = torch.matmul(q_c, k_s.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k).float())
        
        # Apply softmax to get weights
        weights = nn.functional.softmax(scores, dim=-1)
        return weights

def test_cross_encoder_csr():
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("roberta-base")
    
    # Create model
    model = CrossEncoderCSR(model_name="roberta-base")
    
    # Prepare test inputs
    sentences1 = ["The cat sits on the mat Machine learning is fascinating Hey Hey Hey"]
    sentences2 = ["A feline rests on a carpet AI is an exciting field Ha Ha Ha"]
    context = ["pets technology"]
    
    # Tokenize inputs
    s1_inputs = tokenizer(sentences1, padding='max_length', truncation=True, return_tensors="pt", max_length=512)
    s2_inputs = tokenizer(sentences2, padding='max_length', truncation=True, return_tensors="pt", max_length=512)
    c_inputs = tokenizer(context, padding='max_length', truncation=True, return_tensors="pt", max_length=512)
    
    # Forward pass
    with torch.no_grad():
        s1_hidden, s2_hidden = model(
            s1_input_ids=s1_inputs['input_ids'], 
            s1_attention_mask=s1_inputs['attention_mask'].to(dtype=torch.float),
            s2_input_ids=s2_inputs['input_ids'], 
            s2_attention_mask=s2_inputs['attention_mask'].to(dtype=torch.float),
            c_input_ids=c_inputs['input_ids'], 
            c_attention_mask=c_inputs['attention_mask'].to(dtype=torch.float)
        )
    
    print("S1 Hidden States Shape:", s1_hidden)
    print("S2 Hidden States Shape:", s2_hidden)
    
    # Optional: Compute cosine similarity
    def cosine_similarity(a, b):
        return torch.nn.functional.cosine_similarity(a, b, dim=-1)
    
    # Compute and print similarity
    for i in range(len(sentences1)):
        sim = cosine_similarity(s1_hidden[i], s2_hidden[i])
        print(f"Similarity between '{sentences1[i]}' and '{sentences2[i]}': {sim.item()}")

if __name__ == "__main__":
    test_cross_encoder_csr()

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


S1 Hidden States Shape: tensor([0.0170, 0.0162, 0.0150, 0.0162, 0.0155, 0.0169, 0.0149, 0.0171, 0.0175,
        0.0169, 0.0154, 0.0187, 0.0187, 0.0187, 0.0187, 0.0157, 0.0157, 0.0157,
        0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
        0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
        0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
        0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
        0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
        0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
        0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
        0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
        0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
        0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157, 0.0157,
        0.0157, 

## Fix

In [8]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import numpy as np

class CrossEncoderCSR(nn.Module):
    def __init__(self, model_name="roberta-base"):
        super().__init__()
        
        # Initialize encoders
        self.encoder = AutoModel.from_pretrained(model_name)
        self.encoder_config = self.encoder.config
        self.total_layers = len(self.encoder.encoder.layer)
        self.d_k = self.encoder_config.hidden_size

        # Linear projections for pseudo query and key
        self.W_pseduo_q = nn.Linear(self.d_k, self.d_k)
        self.W_pseduo_k = nn.Linear(self.d_k, self.d_k)

    def forward(self, s1_input_ids, s1_attention_mask, s2_input_ids, s2_attention_mask, c_input_ids, c_attention_mask):
        # Encode context and get [CLS] token embedding
        c = self.encoder(c_input_ids, c_attention_mask)
        cls = c.last_hidden_state[:, 0, :]  # Shape: (batch_size, hidden_dim)
        q_cls = self.W_pseduo_q(cls)

        # Embed sentences using embeddings
        s_1_h = self.encoder.embeddings(s1_input_ids)
        s_2_h = self.encoder.embeddings(s2_input_ids)

        # Process through first half of layers
        for i in range(self.total_layers // 2):
            layer = self.encoder.encoder.layer[i]
            s_1_h = layer(s_1_h, attention_mask=s1_attention_mask.unsqueeze(1).unsqueeze(2))[0]
            s_2_h = layer(s_2_h, attention_mask=s2_attention_mask.unsqueeze(1).unsqueeze(2))[0]
        
        # Project keys and compute router weights
        k_s_1 = self.W_pseduo_k(s_1_h)
        k_s_2 = self.W_pseduo_k(s_2_h)
        w_s_1 = self.get_c_router_weights(q_cls, k_s_1)
        w_s_2 = self.get_c_router_weights(q_cls, k_s_2)

        # Apply weights
        s_1_h = (w_s_1 + 1).transpose(1, 2) * s_1_h
        s_2_h = (w_s_2 + 1).transpose(1, 2) * s_2_h

        # Process through second half of layers
        for i in range(self.total_layers // 2, self.total_layers):
            layer = self.encoder.encoder.layer[i]
            s_1_h = layer(s_1_h, attention_mask=s1_attention_mask.unsqueeze(1).unsqueeze(2))[0]
            s_2_h = layer(s_2_h, attention_mask=s2_attention_mask.unsqueeze(1).unsqueeze(2))[0]

        # Apply average pooling
        rs_c_1 = s_1_h.mean(dim=1)  # Shape: (batch_size, hidden_dim)
        rs_c_2 = s_2_h.mean(dim=1)  # Shape: (batch_size, hidden_dim)

        return rs_c_1, rs_c_2

    def get_c_router_weights(self, q_c, k_s):
        """
        Compute attention weights using the formula:
        score = q_c * k_s / sqrt(d_k)
        w = softmax(score)
        """
        scores = torch.matmul(q_c.unsqueeze(1), k_s.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k).float().to(q_c.device))
        weights = torch.nn.functional.softmax(scores, dim=-1)
        return weights

def test_cross_encoder_csr():
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("roberta-base")
    
    # Create model
    model = CrossEncoderCSR(model_name="roberta-base")
    
    # Prepare test inputs
    sentences1 = ["Da Love"]
    sentences2 = ["Da Love"]
    context = ["What"]

    # Tokenize inputs
    s1_inputs = tokenizer(sentences1, padding="max_length", truncation=True, return_tensors="pt", max_length=128)
    s2_inputs = tokenizer(sentences2, padding="max_length", truncation=True, return_tensors="pt", max_length=128)
    c_inputs = tokenizer(context, padding="max_length", truncation=True, return_tensors="pt", max_length=128)

    # Forward pass
    with torch.no_grad():
        s1_hidden, s2_hidden = model(
            s1_input_ids=s1_inputs['input_ids'], 
            s1_attention_mask=s1_inputs['attention_mask'].to(dtype=torch.float),
            s2_input_ids=s2_inputs['input_ids'], 
            s2_attention_mask=s2_inputs['attention_mask'].to(dtype=torch.float),
            c_input_ids=c_inputs['input_ids'], 
            c_attention_mask=c_inputs['attention_mask'].to(dtype=torch.float)
        )
    
    # Print results
    print("S1 Hidden States Shape:", s1_hidden.shape)
    print("S2 Hidden States Shape:", s2_hidden.shape)
    
    # Compute cosine similarity
    def cosine_similarity(a, b):
        a = a / a.norm(dim=-1, keepdim=True)
        b = b / b.norm(dim=-1, keepdim=True)
        return torch.sum(a * b, dim=-1)
    
    sim = cosine_similarity(s1_hidden, s2_hidden)
    print(f"Similarity: {sim.item()}")

if __name__ == "__main__":
    test_cross_encoder_csr()


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


S1 Hidden States Shape: torch.Size([1, 768])
S2 Hidden States Shape: torch.Size([1, 768])
Similarity: 1.000000238418579
