In [12]:
# figure out what's going on with ColBERT scores

from __future__ import annotations

import numpy as np
import torch
import torch.nn.functional as F
from pylate.utils.tensor import convert_to_tensor


def colbert_scores(
    queries_embeddings: list | np.ndarray | torch.Tensor,
    documents_embeddings: list | np.ndarray | torch.Tensor,
    queries_mask: torch.Tensor = None,
    documents_mask: torch.Tensor = None,
) -> torch.Tensor:

    queries_embeddings = convert_to_tensor(queries_embeddings)
    documents_embeddings = convert_to_tensor(documents_embeddings)

    scores = torch.einsum(
        "ash,bth->abst",
        queries_embeddings,
        documents_embeddings,
    )

    if queries_mask is not None:
        queries_mask = convert_to_tensor(queries_mask)
        scores = scores * queries_mask.unsqueeze(1).unsqueeze(3)

    if documents_mask is not None:
        documents_mask = convert_to_tensor(documents_mask)
        scores = scores * documents_mask.unsqueeze(0).unsqueeze(2)
    scores = scores.max(axis=-1).values.max(axis=-1)
    return scores

In [20]:
a = torch.tensor([[[1, 2], [-2, 1], [-2, 1]]]).float()
norma = F.normalize(a, dim=-1)
print(a.shape)
sim = norma @ norma.transpose(-1, -2)
print(sim.shape)
als = sim[:, ~torch.eye(sim.shape[1], dtype=bool)].mean(dim=1)
print(als)


torch.Size([1, 3, 2])
torch.Size([1, 3, 3])
tensor([0.3333])


In [14]:
als

tensor([0.9853])

In [10]:
als.shape

torch.Size([1])

In [2]:
queries_embeddings = torch.tensor([
    [[1.], [0.], [0.], [0.], [1]],
    [[0.], [2.], [0.], [0.], [1]],
    [[0.], [0.], [3.], [0.], [1]],
])

documents_embeddings = torch.tensor([   
    [[50.], [5.], [1.]],
    [[0.], [100.], [10.]],
    [[1.], [0.], [500.]],
])

# documents_mask = torch.tensor([ 
#     [1., 1., 1.],
#     [1., 0., 1.],
#     [1., 1., 1.],
# ])

# query_mask = torch.tensor([
#     [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 0., 1.]
# ])

scores = colbert_scores(
    queries_embeddings=queries_embeddings,
    documents_embeddings=documents_embeddings,
    # queries_mask=query_mask,
    # documents_mask=documents_mask,
)

scores

torch.return_types.max(
values=tensor([[  50.,  100.,  500.],
        [ 100.,  200., 1000.],
        [ 150.,  300., 1500.]]),
indices=tensor([[0, 0, 0],
        [1, 1, 1],
        [2, 2, 2]]))