In [1]:
# figure out what's going on with ColBERT scores

from __future__ import annotations

import numpy as np
import torch
import torch.nn.functional as F
from pylate.utils.tensor import convert_to_tensor


def colbert_scores(
    queries_embeddings: list | np.ndarray | torch.Tensor,
    documents_embeddings: list | np.ndarray | torch.Tensor,
    queries_mask: torch.Tensor = None,
    documents_mask: torch.Tensor = None,
) -> torch.Tensor:

    queries_embeddings = convert_to_tensor(queries_embeddings)
    documents_embeddings = convert_to_tensor(documents_embeddings)

    scores = torch.einsum(
        "ash,bth->abst",
        queries_embeddings,
        documents_embeddings,
    )

    if queries_mask is not None:
        queries_mask = convert_to_tensor(queries_mask)
        scores = scores * queries_mask.unsqueeze(1).unsqueeze(3)

    if documents_mask is not None:
        documents_mask = convert_to_tensor(documents_mask)
        scores = scores * documents_mask.unsqueeze(0).unsqueeze(2)
    scores = scores.max(axis=-1).values.max(axis=-1)
    return scores

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch

def compute_avg_pairwise_sim(embeddings):
    B, T, D = embeddings.shape
    if T == 1:
        return torch.tensor([0.0])
    
    # Identify non-zero vectors (any non-zero element in last dim)
    non_zero_mask = (embeddings.abs().sum(dim=-1) > 0)  # (B, T)
    
    # Compute similarity matrix
    sim = embeddings @ embeddings.transpose(-1, -2)  # (B, T, T)
    
    # Create mask for valid pairs (both vectors non-zero, excluding diagonal)
    valid_pairs = non_zero_mask.unsqueeze(-1) & non_zero_mask.unsqueeze(-2)  # (B, T, T)
    valid_pairs = valid_pairs & ~torch.eye(T, dtype=bool, device=embeddings.device)
    
    # Count valid pairs per batch
    num_valid = valid_pairs.sum(dim=(1, 2))  # (B,)
    
    # Sum similarities for valid pairs
    masked_sim = torch.where(valid_pairs, sim, torch.zeros_like(sim))
    sum_sim = masked_sim.sum(dim=(1, 2))  # (B,)
    
    # Compute average, handling case where no valid pairs exist
    avg_sim = torch.where(num_valid > 0, sum_sim / num_valid, torch.zeros_like(sum_sim))
    
    return avg_sim


def test_compute_avg_pairwise_sim():
    print("Test 1: No zero vectors - should match original behavior")
    embeddings = torch.randn(2, 4, 8)
    result = compute_avg_pairwise_sim(embeddings)
    print(f"Shape: {result.shape}, Values: {result}")
    
    # Manual verification for first batch
    sim = embeddings[0] @ embeddings[0].T
    manual = sim[~torch.eye(4, dtype=bool)].mean()
    print(f"Manual calculation for batch 0: {manual}")
    print(f"Match: {torch.allclose(result[0], manual)}\n")
    
    
    print("Test 2: Some zero vectors")
    embeddings = torch.tensor([
        [[1.0, 0.0], [0.0, 1.0], [0.0, 0.0], [1.0, 1.0]],  # One zero vector
        [[1.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0]]   # Two zero vectors
    ])
    result = compute_avg_pairwise_sim(embeddings)
    print(f"Result: {result}")
    
    # Manual verification for batch 0: vectors 0,1,3 are non-zero
    # Pairs: (0,1)=0, (0,3)=1, (1,3)=1 -> avg = 2/3
    v = embeddings[0][[0,1,3]]
    sim = v @ v.T
    manual = sim[~torch.eye(3, dtype=bool)].mean()
    print(f"Expected for batch 0: {manual}")
    print(f"Match: {torch.allclose(result[0], manual)}\n")
    
    
    print("Test 3: All zero vectors")
    embeddings = torch.zeros(2, 3, 5)
    result = compute_avg_pairwise_sim(embeddings)
    print(f"Result (should be zeros): {result}")
    print(f"All zeros: {torch.all(result == 0)}\n")
    
    
    print("Test 4: Only one non-zero vector per batch")
    embeddings = torch.tensor([
        [[1.0, 2.0], [0.0, 0.0], [0.0, 0.0]],
        [[0.0, 0.0], [3.0, 4.0], [0.0, 0.0]]
    ])
    result = compute_avg_pairwise_sim(embeddings)
    print(f"Result (should be zeros - no pairs): {result}")
    print(f"All zeros: {torch.all(result == 0)}\n")
    
    
    print("Test 5: T=1 edge case")
    embeddings = torch.randn(2, 1, 5)
    result = compute_avg_pairwise_sim(embeddings)
    print(f"Result for T=1: {result}\n")
    
    
    print("Test 6: Differentiability test")
    embeddings = torch.tensor([
        [[1.0, 2.0], [3.0, 4.0], [0.0, 0.0]]
    ], requires_grad=True)
    result = compute_avg_pairwise_sim(embeddings)
    loss = result.sum()
    loss.backward()
    print(f"Result: {result}")
    print(f"Gradient shape: {embeddings.grad.shape}")
    print(f"Gradient (should be non-zero for non-zero vectors):\n{embeddings.grad}")
    print(f"Zero vector gradient (should be zero): {embeddings.grad[0, 2]}")
    print(f"Is differentiable: {embeddings.grad is not None}\n")
    
    
    print("Test 7: Mixed scenario")
    embeddings = torch.tensor([
        [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0]],
    ])
    result = compute_avg_pairwise_sim(embeddings)
    # Non-zero vectors: v0=[1,0], v1=[0,1], v2=[1,1]
    # Similarities: v0·v1=0, v0·v2=1, v1·v2=1
    # Average: (0 + 1 + 1) / 3 = 0.667
    print(f"Result: {result}")
    print(f"Expected: ~0.667")
    print(f"Match: {torch.allclose(result, torch.tensor([2/3]))}\n")

test_compute_avg_pairwise_sim()

Test 1: No zero vectors - should match original behavior
Shape: torch.Size([2]), Values: tensor([-0.6790, -0.1605])
Manual calculation for batch 0: -0.6789754033088684
Match: True

Test 2: Some zero vectors
Result: tensor([0.6667, 0.0000])
Expected for batch 0: 0.6666666865348816
Match: True

Test 3: All zero vectors
Result (should be zeros): tensor([0., 0.])
All zeros: True

Test 4: Only one non-zero vector per batch
Result (should be zeros - no pairs): tensor([0., 0.])
All zeros: True

Test 5: T=1 edge case
Result for T=1: tensor([0.])

Test 6: Differentiability test
Result: tensor([11.], grad_fn=<WhereBackward0>)
Gradient shape: torch.Size([1, 3, 2])
Gradient (should be non-zero for non-zero vectors):
tensor([[[3., 4.],
         [1., 2.],
         [0., 0.]]])
Zero vector gradient (should be zero): tensor([0., 0.])
Is differentiable: True

Test 7: Mixed scenario
Result: tensor([0.6667])
Expected: ~0.667
Match: True



In [20]:
a = torch.tensor([[[1, 2], [-2, 1], [-2, 1]]]).float()
norma = F.normalize(a, dim=-1)
print(a.shape)
sim = norma @ norma.transpose(-1, -2)
print(sim.shape)
als = sim[:, ~torch.eye(sim.shape[1], dtype=bool)].mean(dim=1)
print(als)


torch.Size([1, 3, 2])
torch.Size([1, 3, 3])
tensor([0.3333])


In [14]:
als

tensor([0.9853])

In [10]:
als.shape

torch.Size([1])

In [2]:
queries_embeddings = torch.tensor([
    [[1.], [0.], [0.], [0.], [1]],
    [[0.], [2.], [0.], [0.], [1]],
    [[0.], [0.], [3.], [0.], [1]],
])

documents_embeddings = torch.tensor([   
    [[50.], [5.], [1.]],
    [[0.], [100.], [10.]],
    [[1.], [0.], [500.]],
])

# documents_mask = torch.tensor([ 
#     [1., 1., 1.],
#     [1., 0., 1.],
#     [1., 1., 1.],
# ])

# query_mask = torch.tensor([
#     [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 0., 1.]
# ])

scores = colbert_scores(
    queries_embeddings=queries_embeddings,
    documents_embeddings=documents_embeddings,
    # queries_mask=query_mask,
    # documents_mask=documents_mask,
)

scores

torch.return_types.max(
values=tensor([[  50.,  100.,  500.],
        [ 100.,  200., 1000.],
        [ 150.,  300., 1500.]]),
indices=tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2]]))