In [2]:
%pip install -q pymongo

Note: you may need to restart the kernel to use updated packages.


In [4]:
import pickle
from pathlib import Path
from pymongo import MongoClient, UpdateOne
from tqdm import tqdm


In [5]:
input_path = "processed/01_filtered/"

with open(Path(input_path) / "filtered_playlists.pkl", "rb") as f:
    playlists = pickle.load(f)  # list of dicts {'name':..., 'tracks':[track_uris]}

with open(Path(input_path) / "valid_tracks.pkl", "rb") as f:
    valid_tracks_dict = pickle.load(f)

In [6]:
# Filter playlists to keep only valid tracks
def filter_valid_tracks(playlists, valid_tracks):
    filtered = []
    for pl in playlists:
        filtered_tracks = [t for t in pl['tracks'] if t in valid_tracks]
        filtered.append({'name': pl['name'], 'tracks': filtered_tracks})
    return filtered

filtered_playlists = filter_valid_tracks(playlists, valid_tracks_dict)


In [None]:
from pymongo import MongoClient
from pymongo.errors import BulkWriteError

def load_data_to_mongo(valid_tracks_dict, filtered_playlists, mongo_uri="your_mongo_uri", db_name="spotify", batch_size=1000):
    client = MongoClient(mongo_uri)
    db = client[db_name]

    # Drop existing collections
    db.tracks.drop()
    db.playlists.drop()

    # Insert tracks
    track_docs = [{'_id': uri, **meta} for uri, meta in valid_tracks_dict.items()]
    if track_docs:
        db.tracks.insert_many(track_docs)

    def batched(iterable, n):
        for i in range(0, len(iterable), n):
            yield iterable[i:i + n]

    playlist_docs = [{'name': pl['name'], 'tracks': pl['tracks']} for pl in filtered_playlists]

    for batch in batched(playlist_docs, batch_size):
        try:
            db.playlists.insert_many(batch, ordered=False)
        except BulkWriteError as e:
            print(f"⚠️ Bulk write error in playlist batch: {e.details}")


    db.playlists.create_index("tracks")

    print("✅ Data loaded into MongoDB.")
    return client


In [28]:
client = load_data_to_mongo(valid_tracks_dict, filtered_playlists)

✅ Data loaded into MongoDB.


In [29]:
from collections import Counter
import random

def filter_valid_tracks_mongo(client, playlists):
    valid_track_ids = set(t['_id'] for t in client.spotify.tracks.find({}, {'_id': 1}))
    filtered = []
    for pl in playlists:
        filtered_tracks = [t for t in pl['tracks'] if t in valid_track_ids]
        filtered.append({'name': pl['name'], 'tracks': filtered_tracks})
    return filtered

def get_all_playlists(client):
    return list(client.spotify.playlists.find({}, {'name': 1, 'tracks': 1}))

def build_inverted_index(playlists):
    track_to_playlists = {}
    for idx, pl in enumerate(playlists):
        for track in pl['tracks']:
            if track not in track_to_playlists:
                track_to_playlists[track] = set()
            track_to_playlists[track].add(idx)
    return track_to_playlists

def find_similar_playlists(query_tracks, track_to_playlists, max_neighbors=500):
    candidate_counts = Counter()
    for track in query_tracks:
        candidate_counts.update(track_to_playlists.get(track, set()))
    most_common = candidate_counts.most_common(max_neighbors)
    return [pid for pid, _ in most_common]

def recommend_tracks_mongo(query_uris, client, top_k=100, max_neighbors=500):
    playlists = get_all_playlists(client)
    filtered_playlists = filter_valid_tracks_mongo(client, playlists)
    
    # Train/test split
    random.seed(42)
    idxs = list(range(len(filtered_playlists)))
    random.shuffle(idxs)
    split_idx = int(0.8 * len(idxs))
    train_playlists = [filtered_playlists[i] for i in idxs[:split_idx]]

    # Build inverted index
    track_to_playlists = build_inverted_index(train_playlists)

    # Find similar playlists
    similar_pids = find_similar_playlists(query_uris, track_to_playlists, max_neighbors=max_neighbors)

    # Aggregate track frequencies
    track_counter = Counter()
    for pid in similar_pids:
        track_counter.update(train_playlists[pid]['tracks'])

    # Filter out already seen tracks
    recommendations = [t for t, _ in track_counter.most_common() if t not in query_uris]
    return recommendations[:top_k]


In [None]:
def recommend_tracks_mongo_light(query_uris, client, top_k=100, max_neighbors=500):
    db = client.spotify
    candidate_playlists = list(db.playlists.find(
        {"tracks": {"$in": query_uris}},
        {"tracks": 1}
    ).limit(max_neighbors))

    track_counter = Counter()
    for pl in candidate_playlists:
        track_counter.update(pl['tracks'])

    # Filter out query tracks from recommendations
    recommendations = [t for t, _ in track_counter.most_common() if t not in query_uris]

    return recommendations[:top_k]

In [36]:
import time


query_uris = [
    "spotify:track:0DdpxWfVvUGgkJv5536tiF",
    "spotify:track:3ZFTkvIE7kyPt6Nu3PEa7V"
]

start_time = time.time()
recommendations = recommend_tracks_mongo(query_uris, client, top_k=10, max_neighbors=500)
end_time = time.time()

print(f"Recommendation took {end_time - start_time:.4f} seconds\n")

start_time = time.time()
recommendations = recommend_tracks_mongo_light(query_uris, client, top_k=10, max_neighbors=500)
end_time = time.time()

print(f"Recommendation light took {end_time - start_time:.4f} seconds\n")


print("INPUT")
for qu in query_uris:
    track = client.spotify.tracks.find_one({'_id': qu})
    print(f"{track['track_name']} by {track['artist_name']}")



print("OUTPUT")
# Display result
for uri in recommendations:
    track = client.spotify.tracks.find_one({'_id': uri})
    print(f"{track['track_name']} by {track['artist_name']}")

Recommendation took 6.6083 seconds

Recommendation light took 0.0385 seconds

INPUT
Sexy Can I feat. Yung Berg by Ray J
Hips Don't Lie by Shakira
OUTPUT
Buy U a Drank (Shawty Snappin') by T-Pain
Ignition - Remix by R. Kelly
Kiss Me Thru The Phone by Soulja Boy
It Wasn't Me by Shaggy
Yeah! by Usher
BedRock by Young Money
Gold Digger by Kanye West
Promiscuous by Nelly Furtado
Smack That - Dirty by Akon
Suga Suga by Baby Bash
