In [17]:
!pip install ipywidgets

Collecting ipywidgets
  Downloading ipywidgets-8.1.8-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Downloading widgetsnbextension-4.0.15-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.16-py3-none-any.whl.metadata (20 kB)
Downloading ipywidgets-8.1.8-py3-none-any.whl (139 kB)
Downloading jupyterlab_widgets-3.0.16-py3-none-any.whl (914 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m914.9/914.9 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading widgetsnbextension-4.0.15-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
[?25hInstalling collected packages: widgetsnbextension, jupyterlab_widgets, ipywidgets
Successfully installed ipywidgets-8.1.8 jupyterlab_widgets-3.0.16 widgetsnbextension-4.0.15

[1m[[0m[3

In [40]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
Collecting aiohttp (from torch_geometric)
  Downloading aiohttp-3.13.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (8.1 kB)
Collecting pyparsing (from torch_geometric)
  Downloading pyparsing-3.2.5-py3-none-any.whl.metadata (5.0 kB)
Collecting xxhash (from torch_geometric)
  Downloading xxhash-3.6.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Collecting aiohappyeyeballs>=2.5.0 (from aiohttp->torch_geometric)
  Downloading aiohappyeyeballs-2.6.1-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.4.0 (from aiohttp->torch_geometric)
  Downloading aiosignal-1.4.0-py3-none-any.whl.metadata (3.7 kB)
Collecting frozenlist>=1.1.1 (from aiohttp->torch_geometric)
  Downloading frozenlist-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl.metadata (20

In [1]:
import numpy as np
import networkx as nx
from tqdm import tqdm
import os
import json
import argparse
import random
import time
from collections import defaultdict
from typing import Dict, List, Tuple
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv

# PyG hetero stuff
from torch_geometric.data import HeteroData
from torch_geometric.nn import HGTConv, Linear

try:
    import faiss
except Exception:
    faiss = None



    # Constants
AUDIO_FEATURE_KEYS = [
    'danceability','energy','valence','tempo','loudness',
    'speechiness','instrumentalness','liveness','acousticness'
]

SEED = 42

In [2]:
# Cell 2: MPD iterator & small sample fallback

def iter_mpd_playlists(mpd_dir: str, n=None):
    """Yield playlists from each .json file in mpd_dir (sorted)."""
    files = [os.path.join(mpd_dir, f) for f in os.listdir(mpd_dir) if f.endswith('.json')]
    if n is None:
        files = sorted(files)
    else:
        files = sorted(files)[:n]

    print(f"There is {len(files)} files")
    for fp in files:
        with open(fp, 'r', encoding='utf-8') as fh:
            try:
                data = json.load(fh)
            except Exception as e:
                print(f'Warning: failed to load {fp}: {e}')
                continue
            for pl in data.get('playlists', []):
                yield pl



In [3]:
import http.client

conn = http.client.HTTPSConnection("api.reccobeats.com")
payload = ''
headers = {
  'Accept': 'application/json'
}


In [15]:
def augment_tracks(playlists: List[Dict], sp=None, cache_path: str = None, batch=50):
    """Return dict track_uri -> feature dict for AUDIO_FEATURE_KEYS. Use cache if available."""
    unique = {}
    for pl in playlists:
        for t in pl['tracks']:
            unique[t['track_uri']] = t
    tids = list(unique.keys())
    features = {}
    cache = {}
    if cache_path and os.path.exists(cache_path):
        try:
            with open(cache_path, 'r', encoding='utf-8') as fh:
                cache = json.load(fh)
        except Exception:
            cache = {}

    def af_to_vec(af):
        if af is None:
            return {k: 0.0 for k in AUDIO_FEATURE_KEYS}
        return {k: float(af.get(k, 0.0)) for k in AUDIO_FEATURE_KEYS}


    if sp is None:
        for tid in tqdm(tids, desc='augment (synth)'):
            if tid in cache:
                features[tid] = cache[tid]; continue
            vec = {k: float(np.random.rand()) for k in AUDIO_FEATURE_KEYS}
            vec['tempo'] = 60.0 + np.random.rand() * 120.0
            vec['loudness'] = -60.0 + np.random.rand() * 60.0
            features[tid] = vec
            cache[tid] = vec
    else:
        tid2sp = {tid: tid.split(':')[-1] for tid in tids}
        batch_ids, batch_tids = [], []
        for tid, spid in tqdm(tid2sp.items(), desc='augment (spotify)'):
            if tid in cache:
                features[tid] = cache[tid]; continue
            batch_ids.append(spid); batch_tids.append(tid)
            if len(batch_ids) >= batch:
                try:
                    conn.request("GET", f"/v1/track/{tid}/audio-features", payload, headers)
                    res = conn.getresponse()
                    data = res.read()
                    afs = data.decode("utf-8")
                    print(afs)
                except Exception as e:
                    print('Spotify API error:', e); afs = [None]*len(batch_ids)
                for bt, af in zip(batch_tids, afs):
                    vec = af_to_vec(af)
                    features[bt] = vec; cache[bt] = vec
                batch_ids, batch_tids = [], []
        if batch_ids:
            try:

                afs = sp.audio_features(batch_ids)
            except Exception as e:
                print('Spotify API error:', e); afs = [None]*len(batch_ids)
            for bt, af in zip(batch_tids, afs):
                vec = af_to_vec(af)
                features[bt] = vec; cache[bt] = vec

    if cache_path:
        try:
            with open(cache_path, 'w', encoding='utf-8') as fh:
                json.dump(cache, fh)
        except Exception as e:
            print('Warning: failed to write cache:', e)
    return features


In [16]:
# Cell 4: build full heterogeneous KG (NetworkX + PyG HeteroData)

def build_full_kg(playlists: List[Dict], track_features: Dict[str, Dict]=None):
    nxg = nx.MultiDiGraph()
    node_ids = {nt: {} for nt in ['playlist','track','artist','album']}
    counters = {nt: 0 for nt in node_ids}

    # Build NX graph and maps
    for pl in playlists:
        pid_raw = pl.get('pid') or pl.get('uri') or f"pl:{time.time()}:{random.randint(0,1e6)}"
        pid = f'playlist:{pid_raw}'
        if pid not in node_ids['playlist']:
            node_ids['playlist'][pid] = counters['playlist']; counters['playlist'] += 1
            nxg.add_node(pid, type='playlist')
        for t in pl['tracks']:
            tid = t.get('track_uri')
            print(tid)
            art = t.get('artist_uri', 'artist:unknown')
            alb = t.get('album_uri', 'album:unknown')
            if tid not in node_ids['track']:
                node_ids['track'][tid] = counters['track']; counters['track'] += 1
                nxg.add_node(tid, type='track')
            if art not in node_ids['artist']:
                node_ids['artist'][art] = counters['artist']; counters['artist'] += 1
                nxg.add_node(art, type='artist')
            if alb not in node_ids['album']:
                node_ids['album'][alb] = counters['album']; counters['album'] += 1
                nxg.add_node(alb, type='album')
            nxg.add_edge(pid, tid, relation='contains')
            nxg.add_edge(tid, art, relation='by')
            nxg.add_edge(tid, alb, relation='on')

    # # Build HeteroData
    data = HeteroData()
    for ntype in ['playlist','track','artist','album']:
        ncount = counters[ntype]
        if ncount == 0:
            data[ntype].x = torch.zeros((0, len(AUDIO_FEATURE_KEYS)))
            continue
        if ntype == 'track':
            track_list = [None] * ncount
            feat_list = [None] * ncount
            for tid, idx in node_ids['track'].items():
                track_list[idx] = tid
                tf = track_features.get(tid, {k: 0.0 for k in AUDIO_FEATURE_KEYS})
                feat_list[idx] = np.array([tf.get(k, 0.0) for k in AUDIO_FEATURE_KEYS], dtype=float)
            feats_np = np.vstack(feat_list)
            feats_np = StandardScaler().fit_transform(feats_np)
            data['track'].x = torch.tensor(feats_np, dtype=torch.float)
            data['track'].tid_list = track_list
        else:
            inverse = [None] * counters[ntype]
            for k, v in node_ids[ntype].items():
                inverse[v] = k
            feat_list = []
            for node in inverse:
                deg = nxg.degree(node)
                feat_list.append([float(deg)])
            feats_np = np.vstack(feat_list) if len(feat_list)>0 else np.zeros((0,1))
            if feats_np.shape[0] > 0:
                feats_np = StandardScaler().fit_transform(feats_np)
                data[ntype].x = torch.tensor(feats_np, dtype=torch.float)
            else:
                data[ntype].x = torch.zeros((0,1), dtype=torch.float)
            data[ntype].id_list = inverse
    #
    # # edges: playlist->track, track->artist, track->album with reverse relations
    # def edges_from_nx(source_type, target_type, relation):
    #     src_idxs = []
    #     dst_idxs = []
    #     for u, v, d in nxg.edges(data=True):
    #         if d.get('relation') != relation:
    #             continue
    #         if nxg.nodes[u].get('type') != source_type or nxg.nodes[v].get('type') != target_type:
    #             continue
    #         src_idxs.append(node_ids[source_type][u])
    #         dst_idxs.append(node_ids[target_type][v])
    #     if len(src_idxs) == 0:
    #         return torch.empty((2,0), dtype=torch.long)
    #     return torch.tensor([src_idxs, dst_idxs], dtype=torch.long)
    #
    # data['playlist', 'contains', 'track'].edge_index = edges_from_nx('playlist','track','contains')
    # if data['playlist', 'contains', 'track'].edge_index.numel() > 0:
    #     data['track', 'rev_contains', 'playlist'].edge_index = data['playlist', 'contains', 'track'].edge_index.flip(0)
    # else:
    #     data['track', 'rev_contains', 'playlist'].edge_index = torch.empty((2,0), dtype=torch.long)
    #
    # data['track', 'by', 'artist'].edge_index = edges_from_nx('track','artist','by')
    # if data['track','by','artist'].edge_index.numel() > 0:
    #     data['artist', 'rev_by', 'track'].edge_index = data['track','by','artist'].edge_index.flip(0)
    # else:
    #     data['artist', 'rev_by', 'track'].edge_index = torch.empty((2,0), dtype=torch.long)
    #
    # data['track', 'on', 'album'].edge_index = edges_from_nx('track','album','on')
    # if data['track','on','album'].edge_index.numel() > 0:
    #     data['album', 'rev_on', 'track'].edge_index = data['track','on','album'].edge_index.flip(0)
    # else:
    #     data['album', 'rev_on', 'track'].edge_index = torch.empty((2,0), dtype=torch.long)
    #
    return nxg, node_ids


In [17]:
mpd_dir = "./archive/data"  # e.g. "/home/user/datasets/spotify_mpd"
use_spotify = True
spotify_client_id = None
spotify_client_secret = None
cache_path = 'track_features_cache.json'
sample_rate = 0.05
limit = 2000
device = 'cpu'
epochs = 8

playlists = []
for pl in iter_mpd_playlists(mpd_dir, 10):
    playlists.append(pl)
    if limit and len(playlists) >= limit:
        break

print(f'Loaded {len(playlists)} playlists')

There is 10 files
Loaded 2000 playlists


In [18]:

if not os.path.exists(cache_path):
    with open(cache_path, 'w') as file:
        file.write("This is a new file created because it didn't exist.\n")
    print(f"File '{cache_path}' created successfully.")
else:
    print(f"File '{cache_path}' already exists. No new file was created.")
track_features = augment_tracks(playlists, cache_path=cache_path)

try:
    os.remove(cache_path)
    print(f"File '{cache_path}' deleted successfully.")
except FileNotFoundError:
    print(f"Error: File '{cache_path}' not found.")
except Exception as e:
    print(f"An error occurred: {e}")
# nxg, node_id_maps = build_full_kg(playlists)
# print('HeteroData node types:', hetero_data.node_types, 'edge types:', hetero_data.edge_types)
# print(nxg)


File 'track_features_cache.json' created successfully.


augment (synth): 100%|██████████| 57884/57884 [00:00<00:00, 73359.89it/s]


File 'track_features_cache.json' deleted successfully.


In [13]:
print(track_features)

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)

