In [1]:
import os

from typing import Any, Mapping, List, Tuple, Dict

import numpy as np
import pandas as pd

from loguru import logger
from sklearn.model_selection import train_test_split
from collections import OrderedDict
from pathlib import Path
from tqdm.auto import tqdm

import torch
from torch import nn
from torch.nn.init import constant_, kaiming_normal_
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

from catalyst import dl, metrics
from catalyst.engines import Engine, CPUEngine, GPUEngine, DataParallelEngine

In [2]:
def get_available_engine() -> "Engine":
    if not torch.cuda.is_available():
        return CPUEngine()
    return GPUEngine() if torch.cuda.device_count() == 1 else DataParallelEngine()

In [3]:
class MultiVaeModel(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int, dropout: float = 0.2) -> None:
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.dropout = dropout

        self.encoder_dims = [self.input_dim, self.hidden_dim, self.latent_dim * 2]
        self.decoder_dims = [self.latent_dim, self.hidden_dim, self.input_dim]

        self.encoder = self._build_layers(self.encoder_dims)
        self.decoder = self._build_layers(self.decoder_dims)

        self.apply(self._init_layer)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        h = nn.functional.dropout(x, self.dropout, training=self.training)
        h = self.encoder(h)

        mu = h[:, :self.latent_dim]
        log_var = h[:, self.latent_dim:]

        z = self._reparameterize(mu, log_var)
        z = self.decoder(z)
        return z, mu, log_var

    def reset(self) -> None:
        self.apply(self._init_layer)

    @staticmethod
    def _build_layers(dims: List[int]) -> nn.Sequential:
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            if i + 1 < len(dims) - 1:
                layers.append(nn.BatchNorm1d(dims[i + 1]))
                layers.append(nn.ReLU())
        return nn.Sequential(*layers)

    @staticmethod
    def _init_layer(layer: nn.Module) -> None:
        if isinstance(layer, nn.Linear):
            kaiming_normal_(layer.weight.data)
            if layer.bias is not None:
                constant_(layer.bias.data, 0)

    def _reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

In [4]:
class MultiVaeDataset(Dataset):
    _embedding_ids: List[int]
    _embedding_id_to_embedding: Dict[int, np.ndarray[np.float64]]

    def __init__(self, embedding_ids: List[int], embedding_id_to_embedding: Dict[int, np.ndarray[np.float64]]) -> None:
        super(Dataset).__init__()
        self._embedding_ids = embedding_ids
        self._embedding_id_to_embedding = embedding_id_to_embedding

    def __len__(self) -> int:
        return len(self._embedding_ids)

    def __getitem__(self, index: int) -> Mapping[str, Any]:
        embedding_id = self._embedding_ids[index]
        embedding = self._embedding_id_to_embedding[embedding_id]
        return {"embedding_id": embedding_id, "embedding": embedding}

In [5]:
class MultiVaeRunner(dl.Runner):
    _loader_additive_metrics: Dict[str, metrics.AdditiveMetric]

    def __init__(self):
        super(MultiVaeRunner, self).__init__()

    @property
    def logger(self) -> Any:
        pass

    def on_loader_start(self, runner: dl.Runner) -> None:
        super().on_loader_start(runner)
        self._loader_additive_metrics = {
            metric_name: metrics.AdditiveMetric(compute_on_call=False)
            for metric_name in ["loss_ae", "loss_kld", "loss"]}

    def handle_batch(self, batch: Mapping[str, Any]) -> None:
        x = batch["embedding"]
        z, mu, log_var = self.model(x)

        anneal = min(self.hparams["anneal_cap"], self.batch_step / self.hparams["anneal_total_steps"])

        loss_ae = self._compute_loss_ae(x, z)
        loss_kld = self._compute_loss_kld(mu, log_var)
        loss = loss_ae + anneal * loss_kld

        self.batch_metrics = {"loss_ae": loss_ae, "loss_kld": loss_kld, "loss": loss}
        for metric_name, metric in self.batch_metrics.items():
            self._loader_additive_metrics[metric_name].update(metric.item(), self.batch_size)

    def on_loader_end(self, runner: dl.Runner) -> None:
        for metric_name, metric in self._loader_additive_metrics.items():
            self.loader_metrics[metric_name] = metric.compute()[0]
        super().on_loader_end(runner)

    def predict_batch(self, batch: Mapping[str, Any], **kwargs) -> Mapping[str, Any]:
        x = batch["embedding"]
        z, mu, log_var = self.model(x)
        return {**batch, "output": mu}

    @staticmethod
    def _compute_loss_ae(x: torch.Tensor, z: torch.Tensor) -> float:
        return -torch.mean(torch.sum(nn.functional.log_softmax(z, dim=1) * x, dim=1))

    @staticmethod
    def _compute_loss_kld(mu: torch.Tensor, log_var: torch.Tensor) -> float:
        return -0.5 * torch.mean(torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1))

### Config

In [6]:
BASE_DIR = "/Users/artemvopilov/Programming/yandex_cup_2023"

In [37]:
NORMED_EMBEDDINGS_DIR = f"{BASE_DIR}/normed_embeddings"
VAE_EMBEDDINGS_DIR = f"{BASE_DIR}/vae_embeddings"

In [13]:
INPUT_DIM = 768
HIDDEN_DIM = 256
LATENT_DIM = 64

TRAIN_BATCH_SIZE = 256
VALID_BATCH_SIZE = 2048
INFERENCE_BATCH_SIZE = 8192

LR_SCHEDULER_STEP = 5

EPOCHS = 10

ANNEAL_CAP = 0.5
ANNEAL_TOTAL_STEPS = 20000

### Read embeddings

In [24]:
track_id_to_embeddings = {}
for fn in tqdm(os.listdir(NORMED_EMBEDDINGS_DIR)):
    fp = f"{NORMED_EMBEDDINGS_DIR}/{fn}"

    track_id = fn.split('.')[0]
    embeddings = np.load(fp)
    track_id_to_embeddings[track_id] = embeddings

  0%|          | 0/76714 [00:00<?, ?it/s]

### Prepare data

In [26]:
embedding_id_to_track_id_pos = {}
embedding_ids = []
embedding_id_to_embedding = {}
for ti, embeds in tqdm(track_id_to_embeddings.items()):
    for ei, embed in enumerate(embeds):
        embedding_id = len(embedding_id_to_track_id_pos)

        embedding_id_to_track_id_pos[embedding_id] = (ti, ei)
        embedding_ids.append(embedding_id)
        embedding_id_to_embedding[embedding_id] = embed

  0%|          | 0/76714 [00:00<?, ?it/s]

In [27]:
len(embedding_id_to_track_id_pos), len(embedding_ids), len(embedding_id_to_embedding)

(4452609, 4452609, 4452609)

### Train

In [12]:
train_embedding_ids, valid_embedding_ids = train_test_split(embedding_ids[:1000000], test_size=0.2)
logger.info(f"Divided df into train {len(train_embedding_ids)} and validation {len(valid_embedding_ids)}")

train_dataset = MultiVaeDataset(train_embedding_ids, embedding_id_to_embedding)
valid_dataset = MultiVaeDataset(valid_embedding_ids, embedding_id_to_embedding)
logger.info("Datasets created")

train_loader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE)
valid_loader = DataLoader(valid_dataset, batch_size=VALID_BATCH_SIZE)
loaders = OrderedDict([('train', train_loader), ('valid', valid_loader)])
logger.info("Loaders created")

model = MultiVaeModel(input_dim=INPUT_DIM, hidden_dim=HIDDEN_DIM, latent_dim=LATENT_DIM, dropout=0.1)
logger.info("Model initialized with config")
optimizer = Adam(params=model.parameters())
logger.info("Optimizer initialized")
lr_scheduler = StepLR(optimizer=optimizer, step_size=LR_SCHEDULER_STEP)
logger.info("Scheduler initialized")

num_epochs = EPOCHS

hparams = {"anneal_cap": ANNEAL_CAP, "anneal_total_steps": ANNEAL_TOTAL_STEPS}

engine = get_available_engine()
logger.info(f"Using engine {engine}")

callbacks = [
    dl.BackwardCallback(metric_key="loss"),
    dl.OptimizerCallback("loss", accumulation_steps=1),
    dl.SchedulerCallback(),
    dl.EarlyStoppingCallback(patience=2, loader_key="valid", metric_key="loss", minimize=True)]
logger.info(f"Callbacks created: {callbacks}")

runner = MultiVaeRunner()
runner.train(
    loaders=loaders,
    model=model,
    optimizer=optimizer,
    scheduler=lr_scheduler,
    num_epochs=num_epochs,
    hparams=hparams,
    engine=engine,
    verbose=True,
    timeit=False,
    callbacks=callbacks)

[32m2023-11-06 17:31:01.575[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mDivided df into train 800000 and validation 200000[0m
[32m2023-11-06 17:31:01.576[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m6[0m - [1mDatasets created[0m
[32m2023-11-06 17:31:01.579[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m11[0m - [1mLoaders created[0m
[32m2023-11-06 17:31:01.590[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m14[0m - [1mModel initialized with config[0m
[32m2023-11-06 17:31:01.592[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m16[0m - [1mOptimizer initialized[0m
[32m2023-11-06 17:31:01.595[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m18[0m - [1mScheduler initialized[0m
[32m2023-11-06 17:31:01.656[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m25[0m - [1mUsing engine <catalyst.engines.torch.CPUEngine object at 0x13724

1/10 * Epoch (train):   0%|          | 0/3125 [00:00<?, ?it/s]

train (1/10) loss: -50960.788538240995 | loss_ae: -50968.26594083464 | loss_kld: 126.1702626770017 | lr: 0.001 | momentum: 0.9


1/10 * Epoch (valid):   0%|          | 0/98 [00:00<?, ?it/s]

valid (1/10) loss: -143223.86907499997 | loss_ae: -143234.138545 | loss_kld: 64.70226213867187 | lr: 0.001 | momentum: 0.9
* Epoch (1/10) lr: 0.001 | momentum: 0.9


2/10 * Epoch (train):   0%|          | 0/3125 [00:00<?, ?it/s]

train (2/10) loss: -288744.03524749953 | loss_ae: -288757.1866500001 | loss_kld: 55.754120119628936 | lr: 0.001 | momentum: 0.9


2/10 * Epoch (valid):   0%|          | 0/98 [00:00<?, ?it/s]

valid (2/10) loss: -459650.98819000006 | loss_ae: -459668.17545000016 | loss_kld: 53.73457075073241 | lr: 0.001 | momentum: 0.9
* Epoch (2/10) lr: 0.001 | momentum: 0.9


3/10 * Epoch (train):   0%|          | 0/3125 [00:00<?, ?it/s]

train (3/10) loss: -684926.9470799983 | loss_ae: -684951.6161200005 | loss_kld: 61.1147610839844 | lr: 0.001 | momentum: 0.9


3/10 * Epoch (valid):   0%|          | 0/98 [00:00<?, ?it/s]

valid (3/10) loss: -914947.0939399999 | loss_ae: -914989.1469800003 | loss_kld: 87.42091524414064 | lr: 0.001 | momentum: 0.9
* Epoch (3/10) lr: 0.001 | momentum: 0.9


4/10 * Epoch (train):   0%|          | 0/3125 [00:00<?, ?it/s]

train (4/10) loss: -1223736.6033799993 | loss_ae: -1223777.476380002 | loss_kld: 81.86823143066422 | lr: 0.001 | momentum: 0.9


4/10 * Epoch (valid):   0%|          | 0/98 [00:00<?, ?it/s]

valid (4/10) loss: -1549430.0152400003 | loss_ae: -1549470.0472000001 | loss_kld: 80.07871560546876 | lr: 0.001 | momentum: 0.9
* Epoch (4/10) lr: 0.001 | momentum: 0.9


5/10 * Epoch (train):   0%|          | 0/3125 [00:00<?, ?it/s]

train (5/10) loss: -1949670.0936400013 | loss_ae: -1949729.2105200004 | loss_kld: 118.23254907714833 | lr: 0.001 | momentum: 0.9


5/10 * Epoch (valid):   0%|          | 0/98 [00:00<?, ?it/s]

valid (5/10) loss: -2482695.2803999996 | loss_ae: -2482753.2991199996 | loss_kld: 116.01313360595704 | lr: 0.001 | momentum: 0.9
* Epoch (5/10) lr: 0.0001 | momentum: 0.9


6/10 * Epoch (train):   0%|          | 0/3125 [00:00<?, ?it/s]

train (6/10) loss: -2385286.6455200007 | loss_ae: -2385331.21748 | loss_kld: 89.1446086157226 | lr: 0.0001 | momentum: 0.9


6/10 * Epoch (valid):   0%|          | 0/98 [00:00<?, ?it/s]

valid (6/10) loss: -2440712.9044 | loss_ae: -2440750.107440001 | loss_kld: 74.40477052734374 | lr: 0.0001 | momentum: 0.9
* Epoch (6/10) lr: 0.0001 | momentum: 0.9


7/10 * Epoch (train):   0%|          | 0/3125 [00:00<?, ?it/s]

train (7/10) loss: -2465186.4771200023 | loss_ae: -2465221.404239993 | loss_kld: 69.85346609252946 | lr: 0.0001 | momentum: 0.9


7/10 * Epoch (valid):   0%|          | 0/98 [00:00<?, ?it/s]

valid (7/10) loss: -2522021.5968 | loss_ae: -2522054.4688 | loss_kld: 65.73312286376955 | lr: 0.0001 | momentum: 0.9
* Epoch (7/10) lr: 0.0001 | momentum: 0.9


8/10 * Epoch (train):   0%|          | 0/3125 [00:00<?, ?it/s]

train (8/10) loss: -2546054.129360002 | loss_ae: -2546085.914800003 | loss_kld: 63.5701711865236 | lr: 0.0001 | momentum: 0.9


8/10 * Epoch (valid):   0%|          | 0/98 [00:00<?, ?it/s]

valid (8/10) loss: -2602613.8075999995 | loss_ae: -2602644.051600001 | loss_kld: 60.48229815673829 | lr: 0.0001 | momentum: 0.9
* Epoch (8/10) lr: 0.0001 | momentum: 0.9


9/10 * Epoch (train):   0%|          | 0/3125 [00:00<?, ?it/s]

train (9/10) loss: -2628227.0280000055 | loss_ae: -2628256.609840008 | loss_kld: 59.163841883544855 | lr: 0.0001 | momentum: 0.9


9/10 * Epoch (valid):   0%|          | 0/98 [00:00<?, ?it/s]

valid (9/10) loss: -2685532.2866399996 | loss_ae: -2685560.674 | loss_kld: 56.801065078125006 | lr: 0.0001 | momentum: 0.9
* Epoch (9/10) lr: 0.0001 | momentum: 0.9


10/10 * Epoch (train):   0%|          | 0/3125 [00:00<?, ?it/s]

train (10/10) loss: -2711708.5190400044 | loss_ae: -2711736.440799999 | loss_kld: 55.84546118652339 | lr: 0.0001 | momentum: 0.9


10/10 * Epoch (valid):   0%|          | 0/98 [00:00<?, ?it/s]

valid (10/10) loss: -2771299.6408799994 | loss_ae: -2771326.361839999 | loss_kld: 53.44989376220705 | lr: 0.0001 | momentum: 0.9
* Epoch (10/10) lr: 1e-05 | momentum: 0.9


### Inference

In [34]:
model.eval()

inference_dataset = MultiVaeDataset(embedding_ids, embedding_id_to_embedding)
inference_loader = DataLoader(inference_dataset, batch_size=INFERENCE_BATCH_SIZE)

batches_n = np.ceil(len(inference_dataset) / INFERENCE_BATCH_SIZE)
batches_5_perc = np.ceil(batches_n / 20)

logger.info(f"Computing embeddings by {batches_n} batches of size {INFERENCE_BATCH_SIZE}")

track_id_to_vae_embeddings = {ti: [None] * len(embeddings) for ti, embeddings in track_id_to_embeddings.items()}
batch_i = 0
for predictions in runner.predict_loader(loader=inference_loader, model=model, engine=engine):
    batch_embedding_ids = predictions["embedding_id"].detach().cpu().numpy()
    batch_vae_embeddings = predictions["output"].detach().cpu().numpy()

    for ei, vae_embed in zip(batch_embedding_ids, batch_vae_embeddings):
        ti, pos = embedding_id_to_track_id_pos[ei]
        track_id_to_vae_embeddings[ti][pos] = vae_embed

    batch_i += 1
    if batch_i % batches_5_perc == 0:
        logger.info(f'{round(100 * batch_i / batches_n, 2)} % batches processed')
logger.info("Computed embeddings")

[32m2023-11-06 17:51:57.609[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m9[0m - [1mComputing embeddings by 544.0 batches of size 8192[0m
[32m2023-11-06 17:52:03.770[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1m5.15 % batches processed[0m
[32m2023-11-06 17:52:09.397[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1m10.29 % batches processed[0m
[32m2023-11-06 17:52:14.687[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1m15.44 % batches processed[0m
[32m2023-11-06 17:52:21.486[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1m20.59 % batches processed[0m
[32m2023-11-06 17:52:26.967[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1m25.74 % batches processed[0m
[32m2023-11-06 17:52:32.640[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1m30.88 % batches processed[0m
[32m2023-1

In [35]:
len(track_id_to_vae_embeddings)

76714

### Save

In [40]:
os.mkdir(VAE_EMBEDDINGS_DIR)

In [41]:
for ti, embeddings in tqdm(track_id_to_vae_embeddings.items()):
    fn = f"{ti}.npy"
    fp = f"{VAE_EMBEDDINGS_DIR}/{fn}"
    np.save(fp, embeddings)

  0%|          | 0/76714 [00:00<?, ?it/s]