## Imports

In [None]:
from typing import Literal
from datasets import Dataset, DatasetDict, load_dataset
import numpy as np
import pandas as pd
from pathlib import Path


np.random.seed(1 0) # Make sure we always sample the same folks
root = Path("")
dataset_dir = Path("Dataset")
dataset_dir = root / "Dataset"
dataset_dir.mkdir(exist_ok=True)

  from .autonotebook import tqdm as notebook_tqdm


## Download class

In [2]:

class YambdaDataset:
    INTERACTIONS = frozenset([
        "likes", "listens", "multi_event", "dislikes", "unlikes", "undislikes"
    ])

    def __init__(
        self,
        dataset_type: Literal["flat", "sequential"] = "flat",
        dataset_size: Literal["50m", "500m", "5b"] = "50m"
    ):
        assert dataset_type in {"flat", "sequential"}
        assert dataset_size in {"50m", "500m", "5b"}
        self.dataset_type = dataset_type
        self.dataset_size = dataset_size

    def interaction(self, event_type: Literal[
        "likes", "listens", "multi_event", "dislikes", "unlikes", "undislikes"
    ]) -> Dataset:
        assert event_type in YambdaDataset.INTERACTIONS
        return self._download(f"{self.dataset_type}/{self.dataset_size}", event_type)

    def audio_embeddings(self) -> Dataset:
        return self._download("", "embeddings")

    def album_item_mapping(self) -> Dataset:
        return self._download("", "album_item_mapping")

    def artist_item_mapping(self) -> Dataset:
        return self._download("", "artist_item_mapping")


    @staticmethod
    def _download(data_dir: str, file: str) -> Dataset:
        data = load_dataset("yandex/yambda", data_dir=data_dir, data_files=f"{file}.parquet")
        # Returns DatasetDict; extracting the only split
        assert isinstance(data, DatasetDict)
        return data["train"]
    
dataset = YambdaDataset('flat', '50m')

## Download and write locally to CSV's

In [3]:
def locally_save_df(dataset_type):
    if not (dataset_dir / f"{dataset_type}.csv").exists():
        df = dataset.interaction(f"{dataset_type}").to_pandas()
        df.to_csv(dataset_dir / f"{dataset_type}.csv", index=False)
        del df

In [4]:
# Write files locally
dataset_types = ["likes", "listens", "dislikes", "unlikes", "undislikes"]

for dt in dataset_types:
    locally_save_df(dt)

if not (dataset_dir / "embeddings.csv").exists():
    embeddings = dataset.audio_embeddings().to_pandas()
    embeddings.to_csv(dataset_dir / "embeddings.csv", index=False)
    del embeddings

## Create our tiny dataset

In [5]:
if not (dataset_dir / "listens_subset.csv").exists():
    listens = pd.read_csv(dataset_dir / "listens.csv")
    users = np.random.choice(listens['uid'].unique(), size=2).tolist()

    listens_subset = listens.loc[listens['uid'].isin(users)]
    listens_subset.to_csv(dataset_dir / "listens_subset.csv", index=False)
else:
    listens_subset = pd.read_csv(dataset_dir / "listens_subset.csv")
listens_subset

Unnamed: 0,uid,timestamp,item_id,is_organic,played_ratio_pct,track_length_seconds
6454893,140100,5193935,6133189,1,0,250
6454894,140100,5193935,3246648,1,1,205
6454895,140100,5193935,5357236,1,1,255
6454896,140100,5193935,5046316,1,0,235
6454897,140100,5193935,3710074,1,0,220
...,...,...,...,...,...,...
36860005,788300,25950960,5369887,1,30,135
36860006,788300,25950960,999732,1,1,200
36860007,788300,25951020,9029013,1,26,220
36860008,788300,25951215,5734320,1,100,180


In [6]:
if not (dataset_dir / "embeddings_subset.csv").exists():
    embedded = pd.read_csv(dataset_dir/ "embeddings.csv")
    embedded_subset = embedded[embedded["item_id"].isin(listens_subset["item_id"])]
    embedded_subset.to_csv(dataset_dir/ "embeddings_subset.csv")
    del embedded
else:
    embedded_subset = pd.read_csv(dataset_dir / "embeddings_subset.csv")

In [7]:
embedded_subset

Unnamed: 0,item_id,embed,normalized_embed
9066,10994,[-2.30585122 1.15795255 0.36453152 -1.117058...,[-0.11932126 0.05992076 0.01886347 -0.057804...
10974,13304,[ 2.49134803 -1.52789295 4.20174217 0.314149...,[ 0.08874871 -0.05442778 0.14967769 0.011190...
13119,15905,[-3.65013242 0.83235425 2.50837851 -0.704075...,[-0.14780727 0.03370508 0.10157346 -0.028510...
17605,21385,[ 2.72873878e+00 -3.73627281e+00 -1.67920911e+...,[ 6.49961847e-02 -8.89947689e-02 -3.99973007e-...
19897,24177,[-0.46834391 -2.22887564 0.65408915 0.703571...,[-0.01265283 -0.06021555 0.01767094 0.019007...
...,...,...,...
7692468,9355074,[ 2.59113169 -0.71516579 3.07850266 2.180783...,[ 0.07803622 -0.0215384 0.0927142 0.065677...
7708927,9375106,[ 2.82731962 1.36490476 -0.06637979 3.004100...,[ 0.08144361 0.03931737 -0.00191213 0.086535...
7712136,9378983,[-4.06189728 -0.95195591 -1.51909649 -3.914383...,[-0.13074855 -0.03064254 -0.04889825 -0.126000...
7719554,9387974,[ 1.96107447e+00 1.21009099e+00 -2.31464684e-...,[ 6.39159156e-02 3.94396414e-02 -7.54396501e-...
