In [14]:
import os

import time
from typing import Dict, TypeVar, Set, Optional, Any

import numpy as np
import pandas as pd
from implicit.als import AlternatingLeastSquares
from loguru import logger
from scipy.sparse import csr_matrix
from tqdm.auto import tqdm

In [15]:
BASE_DIR = "/Users/artemvopilov/Programming/yandex_cup_2023"

In [16]:
DATA_DIR = f"{BASE_DIR}/data"

TRAIN_DF_PATH = f"{DATA_DIR}/train.csv"
TEST_DF_PATH = f"{DATA_DIR}/test.csv"

NORMED_EMBEDDINGS_DIR = f"{BASE_DIR}/normed_embeddings"
PCA_EMBEDDINGS_DIR = f"{BASE_DIR}/pca_embeddings"
VAE_EMBEDDINGS_DIR = f"{BASE_DIR}/vae_embeddings"
NORMED_LSTM_EMBEDDINGS_DIR = f"{BASE_DIR}/normed_lstm_embeddings"

### Read data

In [17]:
train_df = pd.read_csv(TRAIN_DF_PATH)
test_df = pd.read_csv(TEST_DF_PATH)

In [None]:
track_id_to_embeddings = {}
for fn in tqdm(os.listdir(VAE_EMBEDDINGS_DIR)):
    fp = f"{VAE_EMBEDDINGS_DIR}/{fn}"

    track_id = fn.split('.')[0]
    embeddings = np.load(fp).astype(np.float32)
    track_id_to_embeddings[track_id] = embeddings

  0%|          | 0/76714 [00:00<?, ?it/s]

### Als trainer

In [8]:
class AlsTrainer:
    _params: Dict[str, Any]
    _init_user_embeddings: Optional[Dict]
    _init_item_embeddings: Optional[Dict]
    _als: AlternatingLeastSquares
    _buyeruid_to_index: Dict[int, int]
    _item_id_to_index: Dict[str, int]

    def __init__(self, params: Dict[str, Any], init_user_embeds: Optional[Dict], init_item_embeds: Optional[Dict]) -> None:
        self._params = params
        self._init_user_embeddings = init_user_embeds
        self._init_item_embeddings = init_item_embeds
        self.reset()

    @property
    def params(self) -> Dict[str, Any]:
        return self._params

    @property
    def init_user_embeddings(self) -> Optional[Dict]:
        return self._init_user_embeddings

    @property
    def init_item_embeddings(self) -> Optional[Dict]:
        return self._init_item_embeddings

    def reset(self) -> None:
        self._reset()
        if self.is_user_step():
            self._set_item_embeddings()
        elif self.is_item_step():
            self._set_user_embeddings()

    def train(self, dataset: pd.DataFrame) -> None:
        if self.is_user_step():
            self.user_step(dataset)
        elif self.is_item_step():
            self.item_step(dataset)
        else:
            self.full_train(dataset)

    def full_train(self, dataset: pd.DataFrame) -> None:
        start_time = time.time()
        train_buyeruid_to_index = self._index(set(dataset['buyeruid']), self._buyeruid_to_index)
        train_item_id_to_index = self._index(set(dataset['item_id']), self._item_id_to_index)
        self._buyeruid_to_index = {**self._buyeruid_to_index, **train_buyeruid_to_index}
        self._item_id_to_index = {**self._item_id_to_index, **train_item_id_to_index}
        user_item_matrix = self._build_interactions_matrix(dataset, self._buyeruid_to_index, self._item_id_to_index)
        self._als.fit(user_item_matrix)
        logger.info(f'Training took {(time.time() - start_time) / 60} minutes, '
                    f'users: {len(self.get_buyeruids())}, items: {len(self.get_item_ids())}')

    def user_step(self, dataset: pd.DataFrame) -> None:
        start_time = time.time()
        step_buyeruid_to_index = self._index(set(dataset['buyeruid']), self._buyeruid_to_index)
        self._buyeruid_to_index = {**self._buyeruid_to_index, **step_buyeruid_to_index}

        if self._als.user_factors is None:
            self._als.user_factors = np.zeros((len(step_buyeruid_to_index), self._als.factors), dtype=self._als.dtype)
        step_user_item_matrix = self._build_interactions_matrix(dataset, self._buyeruid_to_index, self._item_id_to_index)
        step_user_indices = list(step_buyeruid_to_index.values())
        self._als.partial_fit_users(step_user_indices, step_user_item_matrix[step_user_indices])
        self._als._XtX = None
        logger.info(f'User step took {(time.time() - start_time) / 60} minutes, '
                    f'users: {len(self.get_buyeruids())}, items: {len(self.get_item_ids())}')

    def item_step(self, dataset: pd.DataFrame) -> None:
        start_time = time.time()
        step_item_id_to_index = self._index(set(dataset['item_id']), self._item_id_to_index)
        self._item_id_to_index = {**self._item_id_to_index, **step_item_id_to_index}

    def user_inference(self, dataset: pd.DataFrame) -> Dict[int, np.ndarray]:
        start_time = time.time()
        buyeruid_to_index = self._index(set(dataset['buyeruid']))
        user_item_matrix = self._build_interactions_matrix(dataset, buyeruid_to_index, self._item_id_to_index)
        buyeruids = list(buyeruid_to_index.keys())
        user_indices = [buyeruid_to_index[buyeruid] for buyeruid in buyeruids]
        user_embeddings = self._als.recalculate_user(user_indices, user_item_matrix[user_indices])
        buyeruid_to_embedding = dict(zip(buyeruids, user_embeddings.tolist()))
        logger.info(f'Inference took {(time.time() - start_time) / 60} minutes, '
                    f'users: {len(buyeruid_to_embedding)}')
        return buyeruid_to_embedding

    def is_user_step(self) -> bool:
        return self._init_item_embeddings is not None

    def is_item_step(self) -> bool:
        return self._init_user_embeddings is not None

    def get_buyeruids(self) -> Set[int]:
        return set(self._buyeruid_to_index.keys())

    def get_item_ids(self) -> Set[str]:
        return set(self._item_id_to_index.keys())

    def get_user_embeddings(self, buyeruids: Set[int] = None) -> Dict[int, np.ndarray]:
        buyeruids = buyeruids & self.get_buyeruids() if buyeruids is not None else self.get_buyeruids()
        return {buyeruid: self.get_user_embedding(buyeruid) for buyeruid in buyeruids}

    def get_user_embedding(self, buyeruid: int) -> np.ndarray:
        return self._als.user_factors[self._buyeruid_to_index[buyeruid]]

    def get_item_embeddings(self, item_ids: Set[str] = None) -> Dict[str, np.ndarray]:
        item_ids = item_ids & self.get_item_ids() if item_ids is not None else self.get_item_ids()
        return {item_id: self.get_item_embedding(item_id) for item_id in item_ids}

    def get_item_embedding(self, item_id: str) -> np.ndarray:
        return self._als.item_factors[self._item_id_to_index[item_id]]

    def _reset(self) -> None:
        self._als = AlternatingLeastSquares(**self._params)
        self._buyeruid_to_index = {}
        self._item_id_to_index = {}
        logger.info('Als reset')

    def _set_user_embeddings(self) -> None:
        self._buyeruid_to_index = self._index(set(self._init_user_embeddings.keys()))
        user_embeddings = self._build_embeddings_matrix(self._buyeruid_to_index, self._init_user_embeddings)
        self._als.user_factors = user_embeddings
        logger.info('User embeddings set')

    def _set_item_embeddings(self) -> None:
        self._item_id_to_index = self._index(set(self._init_item_embeddings.keys()))
        item_embeddings = self._build_embeddings_matrix(self._item_id_to_index, self._init_item_embeddings)
        self._als.item_factors = item_embeddings
        logger.info('Item embeddings set')

    def _build_embeddings_matrix(
            self,
            id_to_index: Dict[T_id_type, int],
            id_to_embeddings: Dict
    ) -> np.ndarray:
        embeddings = np.zeros((len(id_to_index), self._als.factors), dtype=self._als.dtype)
        for id_, index in tqdm(id_to_index.items()):
            embedding = np.array(id_to_embeddings[id_], dtype=self._als.dtype)
            embeddings[index, :] = embedding
        return embeddings

    @staticmethod
    def _build_interactions_matrix(
            dataset: pd.DataFrame,
            buyeruid_to_index: Dict[int, int],
            item_id_to_index: Dict[str, int]
    ) -> csr_matrix:
        targets = dataset['target']
        users_indices = dataset['buyeruid'].map(buyeruid_to_index).tolist()
        item_indices = dataset['item_id'].map(item_id_to_index).tolist()
        return csr_matrix((targets, (users_indices, item_indices)))

    @staticmethod
    def _index(ids: Set[T_id_type], id_to_index: Dict[T_id_type, int] = None) -> Dict[T_id_type, int]:
        if id_to_index is None:
            id_to_index = {}
        new_id_to_index = {}
        for id_ in ids:
            index = id_to_index.get(id_, len(id_to_index) + len(new_id_to_index))
            new_id_to_index[id_] = index
        return new_id_to_index