In [None]:
%conda install -y psycopg2

Channels:
 - defaults
Platform: linux-64
Collecting package metadata (repodata.json): ...working... done
Solving environment: ...working... done

# All requested packages already installed.


Note: you may need to restart the kernel to use updated packages.


In [34]:
import pickle
import joblib
import logging
import os
from pathlib import Path
from tqdm import tqdm
from dotenv import find_dotenv, load_dotenv
from sklearn.feature_extraction.text import TfidfVectorizer
from sqlalchemy import create_engine, Column, Integer, String, ARRAY
from sqlalchemy.orm import declarative_base, sessionmaker
from pgvector.sqlalchemy import Vector

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s"
)

In [35]:
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s"
)

MIN_PLAYLIST_LENGTH = 5
INPUT_PATH = "../scripts/data/02_processed"
VECTOR_DIM = 1000

PG_HOST = os.getenv("POSTGRES_HOST", "localhost")
PG_PORT = os.getenv("POSTGRES_PORT", 5432)
PG_USER = os.getenv("POSTGRES_USER", "postgres")
PG_PASSWORD = os.getenv("POSTGRES_PASSWORD", "123456")
PG_DB = os.getenv("POSTGRES_DB", "testdb")


In [36]:
Base = declarative_base()

In [29]:
def filter_valid_tracks(playlists, valid_tracks):
    logging.info(f"Filtering playlists using {len(valid_tracks)} valid tracks...")
    filtered = []
    for pl in tqdm(playlists, total=len(playlists), desc="Filtering playlists"):
        filtered_tracks = [t for t in pl['tracks'] if t in valid_tracks]
        filtered.append({'name': pl['name'], 'tracks': filtered_tracks})
    logging.info(f"Filtered down to {len(filtered)} playlists.")
    return filtered

In [30]:
logging.info("Loading playlist data...")
with open(Path(INPUT_PATH) / "filtered_playlists_clustering.pkl", "rb") as f:
    playlists = pickle.load(f)

with open(Path(INPUT_PATH) / "valid_tracks_clustering.pkl", "rb") as f:
    valid_tracks_dict = pickle.load(f)

logging.info(f"Loaded {len(playlists)} playlists.")

filtered = filter_valid_tracks(playlists, valid_tracks_dict)
filtered = [p for p in filtered if len(p['tracks']) >= MIN_PLAYLIST_LENGTH]
logging.info(f"Retained {len(filtered)} playlists with at least {MIN_PLAYLIST_LENGTH} tracks.")


2025-08-01 13:57:18,879 [INFO] Loading playlist data...
2025-08-01 13:57:21,758 [INFO] Loaded 774682 playlists.
2025-08-01 13:57:21,759 [INFO] Filtering playlists using 661 valid tracks...
Filtering playlists: 100%|██████████| 774682/774682 [00:01<00:00, 390696.61it/s]
2025-08-01 13:57:23,743 [INFO] Filtered down to 774682 playlists.
2025-08-01 13:57:23,816 [INFO] Retained 513108 playlists with at least 5 tracks.


In [31]:
names = [p['name'] for p in filtered]
tracks = [p['tracks'] for p in filtered]

In [37]:
logging.info("Fitting TF-IDF vectorizer on playlist names...")
vectorizer = TfidfVectorizer(
    max_features=VECTOR_DIM,
    stop_words='english',
    lowercase=True,
    token_pattern=r'\b\w+\b'
)

2025-08-01 13:58:57,827 [INFO] Fitting TF-IDF vectorizer on playlist names...


In [38]:
name_vectors = vectorizer.fit_transform(names)
logging.info("Vectorization complete.")

joblib.dump(vectorizer, "../scripts/data/03_artifacts/vectorizer.pkl")
logging.info("Saved vectorizer to 'vectorizer.pkl'.")

2025-08-01 13:59:14,743 [INFO] Vectorization complete.
2025-08-01 13:59:14,752 [INFO] Saved vectorizer to 'vectorizer.pkl'.


In [39]:
class Playlist(Base):
    __tablename__ = "playlists"
    id = Column(Integer, primary_key=True)
    name = Column(String, nullable=False)
    tracks = Column(ARRAY(String), nullable=False)
    embedding = Column(Vector(1024))