In [None]:
from collections import defaultdict

import numpy as np
import json
import pickle

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [None]:
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

embeddings_input_path = '../data/Beauty/content_embeddings.pkl'
semantic_index_output_path = '../data/Beauty/index_rqvae.json'

In [None]:
class EmbeddingsDatasets:
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data['item_id'])

    def __getitem__(self, idx):
        return {
            'item_id': self.data['item_id'][idx],
            'embedding': self.data['embedding'][idx]
        }

In [None]:
def batch_processor(samples):
    item_ids = torch.LongTensor(list(map(lambda x: x['item_id'], samples)))
    embeddings = torch.FloatTensor(list(map(lambda x: x['embedding'], samples))).to(DEVICE)
    return {
        'item_id': item_ids,
        'embedding': embeddings
    }


In [None]:
class RQVAE(nn.Module):
    def __init__(
            self,
            input_dim,
            hidden_dim,
            beta,
            codebook_sizes
    ):
        super().__init__()
        self.register_buffer('beta', torch.tensor(beta))
        self.mse_loss = torch.nn.MSELoss()

        self.encoder = self.make_encoding_tower(input_dim, hidden_dim)
        self.decoder = self.make_encoding_tower(hidden_dim, input_dim)

        self.codebook_sizes = codebook_sizes
        self.codebooks = torch.nn.ParameterList()
        for codebook_size in codebook_sizes:
            cb = torch.FloatTensor(codebook_size, hidden_dim)
            with torch.no_grad():
                torch.nn.init.trunc_normal_(cb, std=0.02, a=-2 * 0.02, b=2 * 0.02)
            self.codebooks.append(cb)

    @staticmethod
    def make_encoding_tower(d1, d2, bias=False):
        return torch.nn.Sequential(
            torch.nn.Linear(d1, d1),
            torch.nn.ReLU(),
            torch.nn.Linear(d1, d2),
            torch.nn.ReLU(),
            torch.nn.Linear(d2, d2, bias=bias)
        )

    @staticmethod
    def get_codebook_indices(remainder, codebook):
        dist = torch.cdist(remainder, codebook)
        return dist.argmin(dim=-1)

    def forward(self, inputs):
        latent_vector = self.encoder(inputs['embedding'])

        latent_restored = 0
        rqvae_loss = 0
        clusters = []
        remainder = latent_vector
        for codebook in self.codebooks:
            codebook_indices = self.get_codebook_indices(remainder, codebook)
            clusters.append(codebook_indices)
            quantized = codebook[codebook_indices]
            codebook_vectors = remainder + (quantized - remainder).detach()
            
            rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach())
            rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach())

            latent_restored += codebook_vectors
            remainder = remainder - codebook_vectors

        embeddings_restored = self.decoder(latent_restored)
        recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding'])
        loss = (recon_loss + rqvae_loss).mean()

        clusters_counts = []
        for codebook_size, cluster in zip(self.codebook_sizes, clusters):
            clusters_counts.append(torch.bincount(cluster, minlength=codebook_size))

        return {
            'loss': loss,
            'clusters_counts': clusters_counts,
            'clusters': torch.stack(clusters).T,
            'embedding_hat': embeddings_restored,
            'metrics': dict(
                loss=loss.detach(),
                recon_loss=recon_loss.mean().item(),
                rqvae_loss=rqvae_loss.mean().item()
            ),
        }


In [None]:
class InitCodebooks:
    def __init__(self, model: RQVAE, dataset):
        self.model = model
        self.dataset_iter = iter(map(lambda x: x['embedding'], dataset))
    
    @torch.no_grad()
    def __call__(self):
        for i in range(len(self.model.codebooks)):
            X = next(self.dataset_iter)  # TODO: Try to use the same batch
            idx = torch.randperm(X.shape[0], device=X.device)[:len(self.model.codebooks[i])]
            remainder = self.model.encoder(X[idx])
            for j in range(i):
                codebook_indices = self.model.get_codebook_indices(remainder, self.model.codebooks[j])
                codebook_vectors = self.model.codebooks[j][codebook_indices]
                remainder = remainder - codebook_vectors
            
            self.model.codebooks[i].data = remainder.detach()


class FixDeadCentroids:
    def __init__(self, model: RQVAE, dataset):
        self.model = model
        self.dataset = list(map(lambda x: x['embedding'], dataset))

    @torch.no_grad()
    def fix_dead_codebooks(self):
        num_fixed = []
        for codebook_idx, codebook in enumerate(self.model.codebooks):
            centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=DEVICE)
            random_batch_idx = torch.randint(len(self.dataset), (1,)).item()

            random_batch = None
            for i, batch in enumerate(self.dataset):
                if i == random_batch_idx:
                    random_batch = batch
                
                remainder = self.model.encoder(batch)
                for l in range(codebook_idx):
                    ind = self.model.get_codebook_indices(remainder, self.model.codebooks[l])
                    remainder = remainder - self.model.codebooks[l][ind]
                
                indices = self.model.get_codebook_indices(remainder, codebook)
                centroid_counts.scatter_add_(0, indices, torch.ones_like(indices))

            dead_mask = (centroid_counts == 0)
            num_dead = int(dead_mask.sum().item())
            num_fixed.append(num_dead)
            if num_dead == 0:
                continue

            remainder = self.model.encoder(random_batch)
            for l in range(codebook_idx):
                ind = self.model.get_codebook_indices(remainder, self.model.codebooks[l])
                remainder = remainder - self.model.codebooks[l][ind]
            remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead]
            codebook[dead_mask] = remainder

        return num_fixed

    def __call__(self):
        return {
            f'num_dead/{i}': num_fixed
            for i, num_fixed in enumerate(self.fix_dead_codebooks())
        }

In [None]:
torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


num_epochs = 10
batch_size = 256
input_dim = 4096
hidden_dim = 32
codebook_size = 256
num_codebooks = 3
beta = 0.25
lr = 3e-4


if __name__ == '__main__':
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    with open(embeddings_input_path, 'rb') as f:
        data = pickle.load(f)

    dataset = EmbeddingsDatasets(data)
    train_dataloader = DataLoader(
        dataset=dataset, 
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        collate_fn=batch_processor
    )
    valid_dataloader = DataLoader(
        dataset=dataset, 
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        collate_fn=batch_processor
    )
    index_dataloader = DataLoader(
        dataset=dataset, 
        batch_size=len(dataset),
        shuffle=False,
        drop_last=False,
        collate_fn=batch_processor
    )

    NUM_STEPS = len(train_dataloader)
    LOG_EVERY_NUM_STEPS = max(NUM_STEPS // 10, 1)
    print(NUM_STEPS, LOG_EVERY_NUM_STEPS)

    model = RQVAE(
        input_dim=input_dim,
        hidden_dim=hidden_dim,
        beta=beta,
        codebook_sizes=[codebook_size] * num_codebooks
    ).to(DEVICE)
    
    # Create callbacks
    init_codebooks = InitCodebooks(model, valid_dataloader)
    fix_dead_centroids = FixDeadCentroids(model, index_dataloader)

    # Create optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, fused=True)

    # Initialize codebooks with kmeans
    init_codebooks()

    current_step = 0
    for epoch in range(num_epochs):
        print(f'Start epoch #{epoch}')
        for batch in train_dataloader:
            model.train()
            metrics = {}
            
            model_output = model(batch)
            metrics.update(model_output['metrics'])

            loss = model_output['loss']

            optimizer.zero_grad()
            loss.backward()

            nn.utils.clip_grad_norm_(model.encoder.parameters(), 1.0)
            nn.utils.clip_grad_norm_(model.codebooks.parameters(), 1.0)
            nn.utils.clip_grad_norm_(model.decoder.parameters(), 1.0)

            optimizer.step()

            model.eval()
            dead_centroids_nums = fix_dead_centroids()
            metrics.update(dead_centroids_nums)

            current_step += 1

            if current_step % LOG_EVERY_NUM_STEPS == 0:
                print(metrics)



In [None]:
# Create semantics mapping
inter = {}
sem_2_ids = defaultdict(list)
model.eval()
with torch.inference_mode():
    for batch in valid_dataloader:
        model_output = model(batch)
        for item_id, semantic_ids in zip(batch['item_id'].tolist(), model_output['clusters'].cpu().tolist()):
            inter[item_id] = semantic_ids
            sem_2_ids[tuple(semantic_ids)].append(item_id)

# Solve collistions
for semantics, item_ids in sem_2_ids.items():
    assert len(item_ids) <= 256, 'This check is only needed to check collitions'
    collision_solvers = np.random.permutation(256)[:len(item_ids)].tolist()
    for item_id, collision_solver in zip(item_ids, collision_solvers):
        inter[item_id].append(collision_solver)

# Save semantics
with open(semantic_index_output_path, 'w') as f:
    json.dump(inter, f)