In [1]:
from collections import defaultdict
from typing import List, Tuple, Dict
import torch

In [2]:
class CollisionSolver:
    def __init__(self, residual_length, semantic_id_length, device: torch.device = torch.device('cpu')):
        """
        :param residual_length: Длина остатка для каждого semantic_id
        :param semantic_id_length: Длина semantic_id (без токена решающего коллизии)
        :param device: Устройство
        """
        self._semantic_id_dict = defaultdict(list)
        self.residual_length = residual_length
        self.semantic_id_length = semantic_id_length
        self.device = device

    def _to_device(self, tensor: torch.Tensor) -> torch.Tensor:
        """
        Перенос тензора на устройство
        """
        if tensor.device != self.device:
            tensor = tensor.to(self.device)
        return tensor

    def add_item(self, semantic_id: List[int] | torch.Tensor, residual: torch.Tensor) -> None:
        """
        Добавляет новый элемент в словарь хранящий semantic_ids с остатками

        :param semantic_id: Semantic id (без токена решающего коллизии)
        :param residual: Тензор с остатком для данного semantic_id
        """
        if isinstance(semantic_id, torch.Tensor):
            semantic_id = semantic_id.tolist()

        assert isinstance(residual, torch.Tensor)
        assert residual.shape == (self.residual_length,)
        assert len(semantic_id) == self.semantic_id_length

        residual = self._to_device(residual)
        key = tuple(semantic_id)
        self._semantic_id_dict[key].append((len(self._semantic_id_dict[key]), residual))


    def create_query_candidates_dict(self, semantic_ids: torch.Tensor | List[List[int]], residuals: torch.Tensor | List[List[int]]) -> None:
        """
        Создает словарь, который содержит сгруппирированные по semantic id элементы, к ним добавлены токены решающие коллизии (добавляются по порядку начиная с нуля)

        :param semantic_ids: Тензор или список всех semantic_id, полученных из rq-vae (без токенов решающих коллизии)
        :param residuals: Тензор или список остатков для каждого semantic_id
        """
        residuals_count = residuals.shape[0] if isinstance(residuals, torch.Tensor) else len(residuals)
        semantic_ids_count = semantic_ids.shape[0] if isinstance(semantic_ids, torch.Tensor) else len(semantic_ids)
        assert(residuals_count == semantic_ids_count)

        if isinstance(residuals, list):
            residuals = torch.tensor(residuals, device=self.device)
        residuals = self._to_device(residuals)

        for semantic_id, residual in zip(semantic_ids, residuals):
            self.add_item(semantic_id, residual)

    def get_candidates_tensor(self, query_prefixes: List[List[int]]) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :param query_prefixes: [num_prefixes, prefix_len] список из semantic id (без токенов решающих коллизии)

        :return: Кортеж из двух тензоров:
        - candidates_tensor (размерность: [num_prefixes, max_collisions, residual_dim]): тензор, содержащий остатки кандидатов для каждого префикса
          `max_collisions` — максимальное количество кандидатов для каждого префикса
        - mask (размерность: [num_prefixes, max_collisions]): Маска для candidates_tensor

        Примечание:
            Предполагаем что все префиксы из `query_prefixes` уже есть в словаре semantic ids
            Если префикс не найден, будет выброшено исключение
        """
        assert isinstance(query_prefixes, list)
        assert(self.residual_length == len(self._semantic_id_dict[tuple(query_prefixes[0])][0][1]))
        assert(len(query_prefixes[0]) == self.semantic_id_length)

        max_collision_len = max(len(x) for x in self._semantic_id_dict.values())
        candidates_tensor = torch.zeros(len(query_prefixes), max_collision_len, self.residual_length, dtype=torch.float32, device=self.device)
        mask = torch.zeros(len(query_prefixes), max_collision_len, dtype=torch.bool, device=self.device)

        for i, semantic_id in enumerate(query_prefixes):
            key = tuple(semantic_id)
            assert key in self._semantic_id_dict.keys(), f"Не найдено обьектов с semantic id {key}" # нужно что-то с этим делать
            for j, residual in self._semantic_id_dict[key]: #сохранение порядка
                candidates_tensor[i, j] = residual
                mask[i, j] = True
        return candidates_tensor, mask

    def get_semantic_ids(self, query_prefixes: torch.Tensor, query_residuals: torch.Tensor) -> torch.Tensor:
        """
        :param query_prefixes: [num_prefixes, prefix_len] список из semantic id (без токенов решающих коллизии)

        :return: semantic_ids: [num_prefixes, prefix_len + 1] список из semantic id с токенами решающие коллизии
        """
        assert isinstance(query_prefixes, torch.Tensor)
        assert isinstance(query_residuals, torch.Tensor)
        assert(query_prefixes.shape[0] == query_residuals.shape[0])
        assert(query_prefixes.shape[1] == self.semantic_id_length)
        assert(query_residuals.shape[1] == self.residual_length)

        query_prefixes = self._to_device(query_prefixes)
        query_residuals = self._to_device(query_residuals)

        candidates_tensor, mask = self.get_candidates_tensor(query_prefixes.tolist())

        masked_dot_products = torch.einsum('ijk,ik->ij', candidates_tensor, query_residuals).masked_fill(~mask, float('-inf'))
        max_indices = torch.argmax(masked_dot_products, dim=1)
        best_semantic_ids = torch.concat((query_prefixes, max_indices.unsqueeze(1)), dim=1)
        return best_semantic_ids

# Пример использования

In [7]:
residual_length = 12
semantic_ids_length = 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

semantic_ids = torch.tensor([
    [1, 2, 3, 0],
    [1, 2, 3, 1],
    [1, 2, 3, 2],
    [1, 2, 3, 3],
    [1, 2, 4, 0],
    [1, 2, 4, 1],
    [1, 2, 4, 2],
    [5, 2, 3, 0],
    [5, 2, 3, 1],
    [5, 2, 3, 2],
    [5, 2, 3, 3],
    [5, 2, 3, 4],
    [5, 2, 3, 5],
    [5, 2, 3, 6],
    [2, 8, 7, 6],
], device=torch.device('cpu'))

residuals = torch.rand(semantic_ids.shape[0], residual_length)

query_prefixes = torch.tensor([
    [1, 2, 3],
    [1, 2, 4],
    [5, 2, 3],
    [5, 2, 3]
], device=device)  # [num_prefixes, prefix_len]

query_residuals = torch.rand(query_prefixes.shape[0], residual_length, device=torch.device('cpu'))  # [num_prefixes, emb_dim]

In [8]:
solver = CollisionSolver(residual_length, semantic_ids_length)

solver.create_query_candidates_dict(semantic_ids[:, :-1], residuals)

solver.get_semantic_ids(query_prefixes, query_residuals)

tensor([[1, 2, 3, 0],
        [1, 2, 4, 0],
        [5, 2, 3, 0],
        [5, 2, 3, 3]])

In [9]:
solver._semantic_id_dict

defaultdict(list,
            {(1,
              2,
              3): [(0,
               tensor([0.7220, 0.9496, 0.4006, 0.8832, 0.6087, 0.4947, 0.1341, 0.9645, 0.7408,
                       0.5972, 0.3433, 0.8700])), (1,
               tensor([0.4457, 0.4410, 0.1333, 0.4391, 0.4153, 0.1703, 0.3044, 0.0940, 0.2773,
                       0.5258, 0.5838, 0.0273])), (2,
               tensor([0.6268, 0.1060, 0.0841, 0.0750, 0.4090, 0.2886, 0.4343, 0.1945, 0.0429,
                       0.8477, 0.1418, 0.6465])), (3,
               tensor([0.0077, 0.8171, 0.1344, 0.2223, 0.9616, 0.2790, 0.3448, 0.1485, 0.7148,
                       0.5900, 0.0154, 0.4752]))],
             (1,
              2,
              4): [(0,
               tensor([0.3480, 0.3537, 0.3771, 0.1443, 0.6877, 0.4845, 0.8278, 0.9831, 0.4941,
                       0.0682, 0.0900, 0.2485])), (1,
               tensor([0.4048, 0.3308, 0.2278, 0.4890, 0.4899, 0.9994, 0.1511, 0.9374, 0.8730,
                       0.7538, 

# Альтернативное решение только через torch

In [10]:
semantic_ids = semantic_ids.to(device)
residuals = residuals.to(device)
query_prefixes = query_prefixes.to(device)
query_residuals = query_residuals.to(device)

batch_size, max_length = semantic_ids.shape
num_prefixes, prefix_len = query_prefixes.shape

#привожу к одной размерности чтобы найти совпадения по префиксам
semantic_ids_exp = semantic_ids[:, :prefix_len].unsqueeze(0).expand(num_prefixes, batch_size, prefix_len) # [num_prefixes, batch_size, prefix_len]
prefixes_exp = query_prefixes.unsqueeze(1).expand(num_prefixes, batch_size, prefix_len) #torch.tile
is_prefix_match = (semantic_ids_exp == prefixes_exp).all(dim=2)  # [num_prefixes, batch_size]

# Шаг 2: Маскирование residuals для каждого префикса
residuals_exp = residuals.unsqueeze(0).expand(num_prefixes, batch_size, -1)  # [num_prefixes, batch_size, emb_dim]
masked_residuals = residuals_exp * is_prefix_match.unsqueeze(2).float()  # Зануляем строки, не соответствующие префиксам
dot_products = torch.einsum('ijk,ik->ij', masked_residuals, query_residuals)
max_indices = torch.argmax(dot_products, dim=1)  # [num_prefixes] #

best_semantic_ids = semantic_ids[max_indices]  # [num_prefixes, max_length]
best_residuals = residuals[max_indices]  # [num_prefixes, emb_dim]

for i, prefix in enumerate(query_prefixes):
    print(f"Префикс: {prefix.tolist()}")
    print(f"Лучший semantic_id: {best_semantic_ids[i].tolist()}")
    print(f"Соответствующий residual: {best_residuals[i]}")

Префикс: [1, 2, 3]
Лучший semantic_id: [1, 2, 3, 0]
Соответствующий residual: tensor([0.7220, 0.9496, 0.4006, 0.8832, 0.6087, 0.4947, 0.1341, 0.9645, 0.7408,
        0.5972, 0.3433, 0.8700])
Префикс: [1, 2, 4]
Лучший semantic_id: [1, 2, 4, 0]
Соответствующий residual: tensor([0.3480, 0.3537, 0.3771, 0.1443, 0.6877, 0.4845, 0.8278, 0.9831, 0.4941,
        0.0682, 0.0900, 0.2485])
Префикс: [5, 2, 3]
Лучший semantic_id: [5, 2, 3, 0]
Соответствующий residual: tensor([0.9524, 0.6976, 0.7598, 0.9994, 0.2881, 0.9854, 0.2537, 0.6400, 0.5632,
        0.5768, 0.2833, 0.0570])
Префикс: [5, 2, 3]
Лучший semantic_id: [5, 2, 3, 3]
Соответствующий residual: tensor([0.5732, 0.6359, 0.1402, 0.0661, 0.6557, 0.5067, 0.7383, 0.7173, 0.3075,
        0.3920, 0.7497, 0.9602])
