## Imports

In [1]:
from typing import Literal
from datasets import Dataset, DatasetDict, load_dataset
import numpy as np
import pandas as pd
from pathlib import Path
import ast


np.random.seed(10) # Make sure we always sample the same folks
root = Path("")
dataset_dir = Path("Dataset")
dataset_dir.mkdir(exist_ok=True)

  from .autonotebook import tqdm as notebook_tqdm


## Download class

In [2]:
class YambdaDataset:
    INTERACTIONS = frozenset([
        "likes", "listens", "multi_event", "dislikes", "unlikes", "undislikes"
    ])

    def __init__(
        self,
        dataset_type: Literal["flat", "sequential"] = "flat",
        dataset_size: Literal["50m", "500m", "5b"] = "50m"
    ):
        assert dataset_type in {"flat", "sequential"}
        assert dataset_size in {"50m", "500m", "5b"}
        self.dataset_type = dataset_type
        self.dataset_size = dataset_size

    def interaction(self, event_type: Literal[
        "likes", "listens", "multi_event", "dislikes", "unlikes", "undislikes"
    ]) -> Dataset:
        assert event_type in YambdaDataset.INTERACTIONS
        return self._download(f"{self.dataset_type}/{self.dataset_size}", event_type)

    def audio_embeddings(self) -> Dataset:
        return self._download("", "embeddings")

    def album_item_mapping(self) -> Dataset:
        return self._download("", "album_item_mapping")

    def artist_item_mapping(self) -> Dataset:
        return self._download("", "artist_item_mapping")


    @staticmethod
    def _download(data_dir: str, file: str) -> Dataset:
        data = load_dataset("yandex/yambda", data_dir=data_dir, data_files=f"{file}.parquet")
        # Returns DatasetDict; extracting the only split
        assert isinstance(data, DatasetDict)
        return data["train"]
    
dataset = YambdaDataset('flat', '50m')

## Download and write locally to CSV's

In [3]:
def locally_save_df(dataset_type):
    if not (dataset_dir / f"{dataset_type}.csv").exists():
        df = dataset.interaction(f"{dataset_type}").to_pandas()
        df.to_csv(dataset_dir / f"{dataset_type}.csv", index=False)
        del df

In [4]:
# Write files locally
dataset_types = ["likes", "listens", "dislikes", "unlikes", "undislikes"]

for dt in dataset_types:
    locally_save_df(dt)

if not (dataset_dir / "embeddings.csv").exists():
    embeddings = dataset.audio_embeddings().to_pandas()
    embeddings.to_csv(dataset_dir / "embeddings.csv", index=False)
    del embeddings

Generating train split: 881456 examples [00:00, 23232294.56 examples/s]


## Create our tiny test dataset

In [5]:
if not (dataset_dir / "listens_subset.csv").exists():
    listens = pd.read_csv(dataset_dir / "listens.csv")
    users = np.random.choice(listens['uid'].unique(), size=2).tolist()

    listens_subset = listens.loc[listens['uid'].isin(users)]
    listens_subset.to_csv(dataset_dir / "listens_subset.csv", index=False)
else:
    listens_subset = pd.read_csv(dataset_dir / "listens_subset.csv")
listens_subset[listens_subset['is_organic'] == 1]

Unnamed: 0,uid,timestamp,item_id,is_organic,played_ratio_pct,track_length_seconds
0,140100,5193935,6133189,1,0,250
1,140100,5193935,3246648,1,1,205
2,140100,5193935,5357236,1,1,255
3,140100,5193935,5046316,1,0,235
4,140100,5193935,3710074,1,0,220
...,...,...,...,...,...,...
11429,788300,25950960,5369887,1,30,135
11430,788300,25950960,999732,1,1,200
11431,788300,25951020,9029013,1,26,220
11432,788300,25951215,5734320,1,100,180


In [6]:
def parse_embedding(s):
    # Remove brackets
    s = s.strip().strip('[]')
    # Split by whitespace
    return np.fromstring(s, sep=' ')




if not (dataset_dir / "embeddings_subset.csv").exists():
    embedded = pd.read_csv(dataset_dir/ "embeddings.csv")
    embedded_subset = embedded[embedded["item_id"].isin(listens_subset["item_id"])]
    embedded_subset.to_csv(dataset_dir/ "embeddings_subset.csv")
    del embedded
else:
    embedded_subset = pd.read_csv(dataset_dir / "embeddings_subset.csv")
    
    # Convert the embeddings from string column back to np.array
    embedded_subset["embed"] = embedded_subset["embed"].apply(parse_embedding)
    embedded_subset["normalized_embed"] = embedded_subset["normalized_embed"].apply(parse_embedding)

In [7]:
embedded_subset

Unnamed: 0.1,Unnamed: 0,item_id,embed,normalized_embed
0,9066,10994,"[-2.30585122, 1.15795255, 0.36453152, -1.11705...","[-0.11932126, 0.05992076, 0.01886347, -0.05780..."
1,10974,13304,"[2.49134803, -1.52789295, 4.20174217, 0.314149...","[0.08874871, -0.05442778, 0.14967769, 0.011190..."
2,13119,15905,"[-3.65013242, 0.83235425, 2.50837851, -0.70407...","[-0.14780727, 0.03370508, 0.10157346, -0.02851..."
3,17605,21385,"[2.72873878, -3.73627281, -1.67920911, 0.26456...","[0.0649961847, -0.0889947689, -0.0399973007, 0..."
4,19897,24177,"[-0.46834391, -2.22887564, 0.65408915, 0.70357...","[-0.01265283, -0.06021555, 0.01767094, 0.01900..."
...,...,...,...,...
2523,7692468,9355074,"[2.59113169, -0.71516579, 3.07850266, 2.180783...","[0.07803622, -0.0215384, 0.0927142, 0.06567789..."
2524,7708927,9375106,"[2.82731962, 1.36490476, -0.06637979, 3.004100...","[0.08144361, 0.03931737, -0.00191213, 0.086535..."
2525,7712136,9378983,"[-4.06189728, -0.95195591, -1.51909649, -3.914...","[-0.13074855, -0.03064254, -0.04889825, -0.126..."
2526,7719554,9387974,"[1.96107447, 1.21009099, -0.231464684, -1.0612...","[0.0639159156, 0.0394396414, -0.00754396501, -..."
