## Multioutput regression 

- Captures correlations between related targets , the model learns joint patterns
- More efficient, train once instead of multiple times
- Can improve performance, related outputs help regularize each other
- Shares feature importance across targets

In [58]:
import numpy as np
import pandas as pd
import librosa
from pathlib import Path
from tqdm import tqdm
import pickle
import hashlib

In [59]:
# Configuration - SET THIS ACCORDING TO YOUR ENVIRONMENT
RUNNING_ON_WSL = False  

if RUNNING_ON_WSL:
    # Running on Linux/WSL
    PROJECT_ROOT = Path("/home/rime/music-recom")
else:
    # Running on Windows accessing WSL files
    PROJECT_ROOT = Path(r"\\wsl.localhost\Ubuntu-22.04\home\rime\music-recom")

DATA_DIR = PROJECT_ROOT / "data"
PROCESSED_DIR = DATA_DIR / "processed"
CACHE_DIR = PROCESSED_DIR / "cache"


In [60]:
def fix_audio_path(path_str, running_on_wsl=RUNNING_ON_WSL):
    """Convert stored path to the correct format for current environment"""
    # Extract just the filename
    filename = Path(path_str).name
    
    # Build correct path
    correct_path = DATA_DIR / "processed" / 'audio' / filename
    
    return correct_path

Cache files

In [61]:
# Create directories if they don't exist
CACHE_DIR.mkdir(parents=True, exist_ok=True)

def get_cache_key(audio_path, n_mfcc, sr):
    """Generate unique cache key for audio file"""
    file_stat = Path(audio_path).stat()
    cache_string = f"{audio_path}_{n_mfcc}_{sr}_{file_stat.st_size}_{file_stat.st_mtime}"
    return hashlib.md5(cache_string.encode()).hexdigest()

def load_from_cache(cache_key):
    """Load features from cache if available"""
    cache_file = CACHE_DIR / f"{cache_key}.pkl"
    if cache_file.exists():
        with open(cache_file, 'rb') as f:
            return pickle.load(f)
    return None

def save_to_cache(cache_key, features):
    """Save features to cache"""
    cache_file = CACHE_DIR / f"{cache_key}.pkl"
    with open(cache_file, 'wb') as f:
        pickle.dump(features, f)

In [62]:
def extract_mfcc_features(audio_path, n_mfcc=13, sr=22050, use_cache=True):
    try:
        # Check cache first
        if use_cache:
            cache_key = get_cache_key(audio_path, n_mfcc, sr)
            cached_features = load_from_cache(cache_key)
            if cached_features is not None:
                return cached_features, True
        
        # Load audio
        signal, sr = librosa.load(audio_path, sr=sr)
        
        # Extract MFCCs
        mfccs = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=n_mfcc)
        
        # Extract deltas
        delta_mfccs = librosa.feature.delta(mfccs)
        delta2_mfccs = librosa.feature.delta(mfccs, order=2)
        
        # Compute statistics
        features = {}
        
        # MFCC statistics
        features['mfcc_mean'] = np.mean(mfccs, axis=1)
        features['mfcc_std'] = np.std(mfccs, axis=1)
        features['mfcc_q25'] = np.percentile(mfccs, 25, axis=1)
        features['mfcc_q50'] = np.percentile(mfccs, 50, axis=1)  # median
        features['mfcc_q75'] = np.percentile(mfccs, 75, axis=1)
        
        # Delta statistics
        features['delta_mean'] = np.mean(delta_mfccs, axis=1)
        features['delta_std'] = np.std(delta_mfccs, axis=1)
        features['delta_q25'] = np.percentile(delta_mfccs, 25, axis=1)
        features['delta_q50'] = np.percentile(delta_mfccs, 50, axis=1)
        features['delta_q75'] = np.percentile(delta_mfccs, 75, axis=1)
        
        # Delta2 statistics
        features['delta2_mean'] = np.mean(delta2_mfccs, axis=1)
        features['delta2_std'] = np.std(delta2_mfccs, axis=1)
        features['delta2_q25'] = np.percentile(delta2_mfccs, 25, axis=1)
        features['delta2_q50'] = np.percentile(delta2_mfccs, 50, axis=1)
        features['delta2_q75'] = np.percentile(delta2_mfccs, 75, axis=1)
        
        # Flatten all features into a single vector
        feature_vector = np.concatenate([
            features['mfcc_mean'], features['mfcc_std'],
            features['mfcc_q25'], features['mfcc_q50'], features['mfcc_q75'],
            features['delta_mean'], features['delta_std'],
            features['delta_q25'], features['delta_q50'], features['delta_q75'],
            features['delta2_mean'], features['delta2_std'],
            features['delta2_q25'], features['delta2_q50'], features['delta2_q75']
        ])
        
        # Save to cache
        if use_cache:
            save_to_cache(cache_key, feature_vector)
        
        return feature_vector, True
        
    except Exception as e:
        print(f"Error processing {audio_path}: {str(e)}")
        return None, False

In [63]:
def save_checkpoint(ids, vectors, n_mfcc, checkpoint_path, existing_df=None):
    """Save progress checkpoint"""
    stats = ['mean', 'std', 'q25', 'q50', 'q75']
    feature_names = []
    for prefix in ['mfcc', 'delta', 'delta2']:
        for stat in stats:
            for i in range(n_mfcc):
                feature_names.append(f"{prefix}_{stat}_{i}")
    
    new_df = pd.DataFrame(vectors, columns=feature_names)
    new_df.insert(0, 'track_id', ids)
    
    if existing_df is not None:
        checkpoint_df = pd.concat([existing_df, new_df], ignore_index=True)
    else:
        checkpoint_df = new_df
    
    checkpoint_df.to_csv(checkpoint_path, index=False)
    print(f"\nCheckpoint saved: {len(checkpoint_df)} songs processed")

In [64]:
def extract_features_for_dataset(df, n_mfcc=13, sr=22050, use_cache=True, 
                                checkpoint_interval=100, checkpoint_path=None):
    if checkpoint_path is None:
        checkpoint_path = PROCESSED_DIR / "mfcc_checkpoint.csv"
    
    # Try to load existing checkpoint
    if checkpoint_path.exists():
        print(f"Found checkpoint at {checkpoint_path}")
        existing_df = pd.read_csv(checkpoint_path)
        processed_ids = set(existing_df['track_id'].values)
        print(f"Resuming from checkpoint with {len(processed_ids)} already processed songs")
    else:
        existing_df = None
        processed_ids = set()
    
    successful_ids = []
    feature_vectors = []
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Extracting MFCC features"):
        track_id = row['track_id']
        audio_path_from_csv = row['audio_path']
        
        # Skip if already processed
        if track_id in processed_ids:
            continue
        
        # Convert to correct format for current environment
        audio_path = fix_audio_path(audio_path_from_csv)
        
        # Extract features
        features, success = extract_mfcc_features(audio_path, n_mfcc=n_mfcc, 
                                                  sr=sr, use_cache=use_cache)
        
        if success:
            successful_ids.append(track_id)
            feature_vectors.append(features)
            
            # Save checkpoint periodically
            if len(successful_ids) % checkpoint_interval == 0:
                save_checkpoint(successful_ids, feature_vectors, n_mfcc, 
                              checkpoint_path, existing_df)
    
    # Generate column names
    feature_names = []
    stats = ['mean', 'std', 'q25', 'q50', 'q75']
    for prefix in ['mfcc', 'delta', 'delta2']:
        for stat in stats:
            for i in range(n_mfcc):
                feature_names.append(f"{prefix}_{stat}_{i}")
    
    # Create new features DataFrame
    new_features_df = pd.DataFrame(feature_vectors, columns=feature_names)
    new_features_df.insert(0, 'track_id', successful_ids)
    
    # Combine with existing if checkpoint exists
    if existing_df is not None:
        features_df = pd.concat([existing_df, new_features_df], ignore_index=True)
    else:
        features_df = new_features_df
    
    print(f"\n✓ Successfully extracted features for {len(features_df)}/{len(df)} songs")
    print(f"✓ Feature vector dimension: {len(feature_names)} ({n_mfcc} MFCCs × 5 statistics × 3 types)")
    
    # Remove checkpoint file after successful completion
    if checkpoint_path.exists():
        checkpoint_path.unlink()
        print(f"✓ Removed checkpoint file")
    
    return features_df

In [65]:
# Main execution
if __name__ == "__main__":
    # Load audio files list
    audio_list_path = PROCESSED_DIR / "audio_files_list.csv"
    
    print(f"Loading data from: {audio_list_path}")
    print(f"Environment: {'WSL/Linux' if RUNNING_ON_WSL else 'Windows accessing WSL'}")
    
    id_music_list_df = pd.read_csv(audio_list_path)
    
    print(f"\nLoaded {len(id_music_list_df)} audio files")
    print(f"Sample path: {id_music_list_df['audio_path'].iloc[0]}")
    
    # Test path conversion
    test_path = fix_audio_path(id_music_list_df['audio_path'].iloc[0])
    print(f"Converted to: {test_path}")
    print(f"Path exists: {test_path.exists()}")
    
    if not test_path.exists():
        print("\nWARNING: Test path doesn't exist!")
        print("Please check:")
        print("  1. Is RUNNING_ON_WSL set correctly?")
        print("  2. Are the files accessible at the WSL location?")
        response = input("\nContinue anyway? (y/n): ")
        if response.lower() != 'y':
            print("Exiting...")
            exit()
    
    # Extract features
    mfcc_features_df = extract_features_for_dataset(
        id_music_list_df, 
        n_mfcc=13, 
        sr=22050,
        use_cache=True,
        checkpoint_interval=100
    )
    
    # Save results
    output_path = PROCESSED_DIR / "mfcc_features_with_quartiles.csv"
    mfcc_features_df.to_csv(output_path, index=False)
    print(f"\n✓ Saved features to {output_path}")
    
    print(f"\nFeature extraction summary:")
    print(f"  Total features per song: {len(mfcc_features_df.columns) - 1}")
    print(f"  Successfully processed: {len(mfcc_features_df)} songs")
    print(f"  Cache directory: {CACHE_DIR}")

Loading data from: \\wsl.localhost\Ubuntu-22.04\home\rime\music-recom\data\processed\audio_files_list.csv
Environment: Windows accessing WSL

Loaded 1590 audio files
Sample path: \home\rime\music-recom\data\audio\109497.mp3
Converted to: \\wsl.localhost\Ubuntu-22.04\home\rime\music-recom\data\processed\audio\109497.mp3
Path exists: True


Extracting MFCC features:   7%|▋         | 110/1590 [00:01<00:28, 52.15it/s]


Checkpoint saved: 100 songs processed


Extracting MFCC features:  13%|█▎        | 206/1590 [00:03<00:28, 48.75it/s]


Checkpoint saved: 200 songs processed


Extracting MFCC features:  19%|█▉        | 309/1590 [00:05<00:26, 48.16it/s]


Checkpoint saved: 300 songs processed


Extracting MFCC features:  26%|██▌       | 412/1590 [00:07<00:24, 47.37it/s]


Checkpoint saved: 400 songs processed


Extracting MFCC features:  32%|███▏      | 509/1590 [00:09<00:24, 44.25it/s]


Checkpoint saved: 500 songs processed


Extracting MFCC features:  38%|███▊      | 607/1590 [00:11<00:22, 43.27it/s]


Checkpoint saved: 600 songs processed


Extracting MFCC features:  45%|████▍     | 711/1590 [00:13<00:20, 41.97it/s]


Checkpoint saved: 700 songs processed


Extracting MFCC features:  51%|█████     | 807/1590 [00:15<00:20, 37.94it/s]


Checkpoint saved: 800 songs processed


Extracting MFCC features:  57%|█████▋    | 909/1590 [00:17<00:18, 36.02it/s]


Checkpoint saved: 900 songs processed


Extracting MFCC features:  64%|██████▎   | 1010/1590 [00:19<00:16, 35.17it/s]


Checkpoint saved: 1000 songs processed


Extracting MFCC features:  70%|██████▉   | 1107/1590 [00:21<00:14, 34.25it/s]


Checkpoint saved: 1100 songs processed


Extracting MFCC features:  76%|███████▌  | 1206/1590 [00:23<00:11, 33.79it/s]


Checkpoint saved: 1200 songs processed


Extracting MFCC features:  82%|████████▏ | 1311/1590 [00:25<00:08, 33.36it/s]


Checkpoint saved: 1300 songs processed


Extracting MFCC features:  88%|████████▊ | 1407/1590 [00:27<00:05, 30.86it/s]


Checkpoint saved: 1400 songs processed


Extracting MFCC features:  95%|█████████▍| 1510/1590 [00:29<00:02, 31.94it/s]


Checkpoint saved: 1500 songs processed


Extracting MFCC features: 100%|██████████| 1590/1590 [00:31<00:00, 50.70it/s]



✓ Successfully extracted features for 1590/1590 songs
✓ Feature vector dimension: 195 (13 MFCCs × 5 statistics × 3 types)
✓ Removed checkpoint file

✓ Saved features to \\wsl.localhost\Ubuntu-22.04\home\rime\music-recom\data\processed\mfcc_features_with_quartiles.csv

Feature extraction summary:
  Total features per song: 195
  Successfully processed: 1590 songs
  Cache directory: \\wsl.localhost\Ubuntu-22.04\home\rime\music-recom\data\processed\cache


### Grouping similar features and predicting them via a Random Forest multioutput regressor 

In [66]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
import numpy as np

In [67]:
# Rythm/ Movement targets
rythm_features = ['tempo', 'danceability', 'energy']

# Audio characteristics
audio_features = ['acousticness', 'liveness', 'loudness']

# Vocal/ content features
vocal_features = ['speechiness', 'valence', 'intrementalness']

In [68]:
# Get the current working directory 
root = Path.cwd().parent

In [69]:
pd.set_option('display.max_columns', None)

In [70]:
# Build the full path
MFCC_Q_df = pd.read_csv(root / 'data/processed/mfcc_features_with_quartiles.csv')

MFCC_Q_df.head()

Unnamed: 0,track_id,mfcc_mean_0,mfcc_mean_1,mfcc_mean_2,mfcc_mean_3,mfcc_mean_4,mfcc_mean_5,mfcc_mean_6,mfcc_mean_7,mfcc_mean_8,mfcc_mean_9,mfcc_mean_10,mfcc_mean_11,mfcc_mean_12,mfcc_std_0,mfcc_std_1,mfcc_std_2,mfcc_std_3,mfcc_std_4,mfcc_std_5,mfcc_std_6,mfcc_std_7,mfcc_std_8,mfcc_std_9,mfcc_std_10,mfcc_std_11,mfcc_std_12,mfcc_q25_0,mfcc_q25_1,mfcc_q25_2,mfcc_q25_3,mfcc_q25_4,mfcc_q25_5,mfcc_q25_6,mfcc_q25_7,mfcc_q25_8,mfcc_q25_9,mfcc_q25_10,mfcc_q25_11,mfcc_q25_12,mfcc_q50_0,mfcc_q50_1,mfcc_q50_2,mfcc_q50_3,mfcc_q50_4,mfcc_q50_5,mfcc_q50_6,mfcc_q50_7,mfcc_q50_8,mfcc_q50_9,mfcc_q50_10,mfcc_q50_11,mfcc_q50_12,mfcc_q75_0,mfcc_q75_1,mfcc_q75_2,mfcc_q75_3,mfcc_q75_4,mfcc_q75_5,mfcc_q75_6,mfcc_q75_7,mfcc_q75_8,mfcc_q75_9,mfcc_q75_10,mfcc_q75_11,mfcc_q75_12,delta_mean_0,delta_mean_1,delta_mean_2,delta_mean_3,delta_mean_4,delta_mean_5,delta_mean_6,delta_mean_7,delta_mean_8,delta_mean_9,delta_mean_10,delta_mean_11,delta_mean_12,delta_std_0,delta_std_1,delta_std_2,delta_std_3,delta_std_4,delta_std_5,delta_std_6,delta_std_7,delta_std_8,delta_std_9,delta_std_10,delta_std_11,delta_std_12,delta_q25_0,delta_q25_1,delta_q25_2,delta_q25_3,delta_q25_4,delta_q25_5,delta_q25_6,delta_q25_7,delta_q25_8,delta_q25_9,delta_q25_10,delta_q25_11,delta_q25_12,delta_q50_0,delta_q50_1,delta_q50_2,delta_q50_3,delta_q50_4,delta_q50_5,delta_q50_6,delta_q50_7,delta_q50_8,delta_q50_9,delta_q50_10,delta_q50_11,delta_q50_12,delta_q75_0,delta_q75_1,delta_q75_2,delta_q75_3,delta_q75_4,delta_q75_5,delta_q75_6,delta_q75_7,delta_q75_8,delta_q75_9,delta_q75_10,delta_q75_11,delta_q75_12,delta2_mean_0,delta2_mean_1,delta2_mean_2,delta2_mean_3,delta2_mean_4,delta2_mean_5,delta2_mean_6,delta2_mean_7,delta2_mean_8,delta2_mean_9,delta2_mean_10,delta2_mean_11,delta2_mean_12,delta2_std_0,delta2_std_1,delta2_std_2,delta2_std_3,delta2_std_4,delta2_std_5,delta2_std_6,delta2_std_7,delta2_std_8,delta2_std_9,delta2_std_10,delta2_std_11,delta2_std_12,delta2_q25_0,delta2_q25_1,delta2_q25_2,delta2_q25_3,delta2_q25_4,delta2_q25_5,delta2_q25_6,delta2_q25_7,delta2_q25_8,delta2_q25_9,delta2_q25_10,delta2_q25_11,delta2_q25_12,delta2_q50_0,delta2_q50_1,delta2_q50_2,delta2_q50_3,delta2_q50_4,delta2_q50_5,delta2_q50_6,delta2_q50_7,delta2_q50_8,delta2_q50_9,delta2_q50_10,delta2_q50_11,delta2_q50_12,delta2_q75_0,delta2_q75_1,delta2_q75_2,delta2_q75_3,delta2_q75_4,delta2_q75_5,delta2_q75_6,delta2_q75_7,delta2_q75_8,delta2_q75_9,delta2_q75_10,delta2_q75_11,delta2_q75_12
0,109497,-14.46628,99.482414,-13.127948,43.313545,9.995256,5.39969,-5.793285,13.883069,-1.838129,2.187864,-2.674747,3.841628,-9.301169,26.336906,13.115074,12.082673,10.713166,7.633686,5.559512,5.984026,5.907621,5.963255,5.159916,5.725162,4.790164,5.414805,-25.974308,89.79773,-21.092014,38.467827,6.111401,1.58603,-9.500146,9.957172,-5.797652,-1.314428,-6.898641,0.665738,-12.874005,-12.028065,98.77183,-13.556653,44.643456,11.083105,5.215743,-5.622651,14.334005,-1.961504,2.256998,-2.872006,3.641615,-9.346977,0.913519,107.98017,-5.87316,50.60029,15.318432,9.029509,-1.505398,18.285854,1.977354,5.757357,1.137958,7.034313,-5.877173,0.220803,0.009053,-0.002882,-0.003425,0.00075,0.0046,-0.003279,0.011742,0.006375,0.00165,0.000165,-0.000724,0.005746,4.41648,2.52771,1.879202,1.614143,1.372204,0.91686,1.072609,1.153425,1.141011,0.906527,1.142041,0.910949,1.01229,-2.169072,-1.664055,-1.188237,-1.053944,-0.914239,-0.635814,-0.725844,-0.817876,-0.767006,-0.613866,-0.804201,-0.617282,-0.637065,0.089583,0.031641,0.0656,0.032003,0.03831,-0.001774,-0.00971,0.056556,0.040189,-0.014482,-0.020248,0.00035,0.008368,2.173461,1.705773,1.284936,1.063335,0.875187,0.629731,0.694629,0.863577,0.787185,0.644893,0.81583,0.614444,0.648924,-0.135776,-0.035974,0.004075,-0.013851,-0.004475,-0.001072,-0.0025,-0.009171,-0.006136,-0.009151,-0.006148,-0.006779,-0.001375,2.662576,1.463119,1.183564,1.022441,0.859411,0.726815,0.79601,0.860587,0.880833,0.73046,0.814752,0.704768,0.741921,-1.265242,-0.991705,-0.796757,-0.670792,-0.626229,-0.487701,-0.56262,-0.634363,-0.638328,-0.543769,-0.581081,-0.50053,-0.513199,0.06991,-0.016047,-0.014487,0.00524,-0.036464,-0.029755,-0.022141,-0.052846,-0.019029,0.005924,-0.022269,0.003925,-0.008469,1.341433,0.936322,0.792232,0.692057,0.569096,0.473818,0.533445,0.627216,0.612832,0.523362,0.575562,0.491335,0.518873
1,53666,-52.18589,99.23931,-19.81856,34.8412,-4.830251,7.452331,0.821418,-0.551589,0.541429,8.759988,-2.642872,6.46883,-2.554433,17.970594,16.446632,10.71229,8.822542,7.811877,5.619895,6.609158,6.775237,6.076496,6.250626,5.90914,5.741867,5.448841,-63.08838,89.49299,-26.571468,28.994284,-10.797971,3.577901,-3.682878,-5.356423,-3.77431,4.250537,-6.775855,2.378149,-6.145599,-52.409813,99.797485,-19.99194,34.918488,-4.951907,7.278353,0.506644,-0.81201,0.210771,8.576931,-3.146664,6.51309,-2.431867,-39.256344,109.91645,-12.887173,40.430374,1.22944,11.149651,5.539262,4.122548,4.650206,13.182243,1.037986,10.359025,1.06506,0.012402,-0.007382,-0.005155,0.002076,-0.001866,-0.001652,-0.000474,-0.000553,-0.000199,0.011042,0.004957,-0.008891,-0.007915,3.199941,2.723585,2.00516,1.703189,1.259007,1.013505,1.200884,1.042044,1.084845,1.036201,1.091861,0.978752,1.000591,-1.999134,-1.362825,-1.205278,-1.154972,-0.826864,-0.658937,-0.860859,-0.717076,-0.754294,-0.687359,-0.685047,-0.675195,-0.684837,-0.191827,0.232023,-0.003556,0.06164,0.060407,0.014538,0.011139,-0.037082,-0.008744,0.009474,-0.014274,-0.050996,-0.054915,2.035141,1.856711,1.09918,1.128089,0.83752,0.652004,0.874717,0.679773,0.709825,0.683374,0.693426,0.599407,0.676133,-0.011197,-0.015572,-0.007065,-0.002806,-0.001591,-0.006339,-0.008179,-0.007428,-0.002442,-0.002701,-0.006067,-0.005588,-0.005737,2.597663,1.668799,1.340824,1.059314,0.872083,0.693278,0.744236,0.724987,0.736523,0.742568,0.732206,0.654933,0.686628,-1.837787,-1.172514,-0.897175,-0.681935,-0.573201,-0.437319,-0.512194,-0.506406,-0.50773,-0.475746,-0.441582,-0.4135,-0.418753,-0.039706,-0.089775,-0.027551,-0.03317,0.044339,-0.008325,-0.015951,-0.010395,0.014646,0.01691,-0.01641,-0.018389,-0.000896,1.814744,1.032142,0.8526,0.704708,0.576565,0.439989,0.485595,0.474887,0.520407,0.478401,0.471515,0.439307,0.418716
2,55400,-37.400524,124.763824,-36.81307,21.995646,2.225886,21.032818,3.546124,10.978546,-5.663468,1.670306,-1.866043,4.414843,0.263866,25.074936,15.589418,11.059312,9.996412,8.603876,6.485661,7.064106,4.8184,4.59468,4.897986,4.508664,5.303911,5.510609,-55.823593,115.9258,-44.989243,14.944798,-4.011454,16.510422,-1.536884,7.46964,-8.968188,-1.748559,-4.928318,0.673367,-3.620507,-40.126724,128.31274,-36.82352,22.862446,2.857833,21.027534,3.393919,11.081544,-5.586058,1.522695,-1.931237,4.392392,0.225774,-20.72766,135.4742,-29.081856,29.360538,8.64887,25.388668,8.517001,14.297654,-2.515616,5.062921,1.076639,8.111399,3.940616,0.058065,-0.01056,0.011036,0.009491,-0.004898,0.002371,0.01321,0.016682,0.005274,0.004574,0.003347,0.005777,0.008513,4.374987,2.896857,1.517322,1.221072,1.538684,0.871887,1.157159,0.796285,0.784595,0.881469,0.744571,0.868556,0.73955,-3.177753,-1.378299,-1.040217,-0.672707,-0.8364,-0.634032,-0.674083,-0.525178,-0.499175,-0.505695,-0.505284,-0.499896,-0.459783,-0.753589,0.231625,0.094282,0.086951,0.137383,-0.045202,0.156396,0.002043,0.023508,0.025493,0.010294,0.098622,0.010783,2.283091,1.80902,1.129577,0.835514,1.050787,0.593606,0.833639,0.567711,0.542888,0.622507,0.522446,0.593174,0.487758,-0.057128,0.014514,-0.004043,0.006608,0.01231,-0.003013,0.000503,-0.001101,0.002705,-0.000753,-0.00103,0.000794,-0.001403,3.734957,1.728411,1.449256,1.055884,0.866639,0.662977,0.788613,0.562866,0.613871,0.658512,0.639172,0.61351,0.579289,-2.413319,-1.008883,-1.082545,-0.767986,-0.61327,-0.481714,-0.560757,-0.381659,-0.449736,-0.454691,-0.460656,-0.429691,-0.402365,-0.364135,-0.153551,-0.015322,-0.052667,0.023132,0.010028,-0.007917,-0.005446,0.029381,0.013197,0.028722,0.032815,-0.009185,2.70037,0.727889,1.150816,0.740516,0.621775,0.454694,0.579547,0.385677,0.433737,0.492632,0.459254,0.428915,0.39296
3,10589,-380.00287,181.90225,-65.369286,10.870298,16.465267,-47.996254,-3.810154,-8.47399,-42.44927,-8.006559,-6.592747,-21.92328,-0.598965,65.70725,25.217339,45.143143,18.352695,13.137667,17.60754,8.177697,8.135131,13.13912,8.508199,9.304217,8.381541,8.219736,-441.36676,159.42926,-104.954346,-4.825752,7.888224,-61.16475,-9.917351,-14.74465,-53.827713,-13.605358,-13.424345,-27.304317,-5.747596,-371.26947,180.5116,-66.08597,10.847923,16.253841,-47.161537,-4.374521,-8.86621,-41.02721,-8.333242,-8.221781,-21.264357,-0.674093,-321.4391,205.4904,-22.16711,25.337595,24.259344,-32.145584,2.749909,-2.479402,-31.048077,-1.929297,-0.330194,-16.018957,4.672289,0.189026,0.083782,-0.052729,-0.005151,-0.000277,-0.026846,-0.010993,-0.0137,-0.027458,0.001743,0.007929,-0.008412,-0.008311,7.055398,3.073463,4.952223,2.348252,1.871211,1.966379,1.128496,1.124765,1.474231,1.269974,1.269533,1.053716,1.284437,-3.956972,-1.703191,-1.546891,-1.382452,-0.804074,-0.806837,-0.617009,-0.698046,-0.628123,-0.674188,-0.673521,-0.584642,-0.764081,-0.703907,-0.201381,0.534755,0.111816,0.090634,0.188336,-0.009244,-0.033566,0.066667,0.006995,-0.0324,0.008554,-0.040826,2.512165,1.473111,2.843905,1.570254,0.985183,1.065468,0.607956,0.704702,0.786409,0.661649,0.627425,0.634066,0.647253,-0.054324,-0.050451,0.013551,-0.010189,-0.012286,0.013427,-0.00487,-0.000656,0.008195,-2.5e-05,0.00531,0.010322,-0.001069,3.407578,1.517492,2.385228,1.016233,0.909341,0.9996,0.677607,0.646453,0.7891,0.651975,0.664161,0.649972,0.697728,-1.321799,-0.650242,-0.895908,-0.536443,-0.451264,-0.472133,-0.382774,-0.351289,-0.395578,-0.361905,-0.381472,-0.367546,-0.367758,0.045249,0.015976,0.011183,0.015646,-0.002959,-0.01119,-0.010822,0.003489,0.001338,-0.001073,0.031543,-0.014531,-0.012989,1.264395,0.62027,1.032508,0.585428,0.417241,0.496032,0.366111,0.358032,0.410059,0.387136,0.406306,0.371518,0.374446
4,55923,-18.598597,71.36076,-16.790304,51.837322,-2.433789,12.741719,6.08001,5.622876,0.253179,3.234265,-1.363661,1.0434,-7.887192,21.66691,18.164318,15.200961,13.062137,8.76108,9.015322,8.081576,7.085149,7.131298,6.38119,5.982719,5.748267,5.225229,-29.887802,57.067654,-27.065834,41.674706,-8.032942,6.919233,0.672681,0.647905,-4.756876,-1.157414,-5.327281,-2.891277,-11.493351,-17.792692,70.70049,-16.452028,50.660637,-2.545805,11.25412,6.214553,5.592544,0.633819,2.894148,-1.651002,0.56816,-8.043936,-4.971285,85.113754,-4.907349,60.607765,3.752512,17.062355,11.718766,10.113415,5.475327,7.40704,2.175229,4.356773,-4.360371,0.066447,0.025715,-0.015879,0.012799,-0.001647,-0.011966,-0.006075,0.003678,0.003172,0.002557,-0.000214,-0.003642,-0.020094,4.150834,2.444548,2.64983,2.237288,1.532949,1.640232,1.431198,1.290325,1.357411,1.095927,1.101842,1.017437,0.991913,-2.922354,-1.461451,-1.728062,-1.366675,-1.017409,-1.050978,-1.019049,-0.838868,-0.928503,-0.767651,-0.706379,-0.588422,-0.68716,-0.421276,0.161188,-0.014692,0.047534,-0.00967,-0.053324,0.00571,0.005706,-0.004065,0.024265,-0.035329,-0.025648,-0.061232,2.535571,1.687318,1.642538,1.517204,0.973415,0.993121,0.990486,0.841111,0.893974,0.747224,0.756076,0.592153,0.584655,-0.040584,0.003627,0.011505,9.5e-05,0.00042,-0.011483,-0.005313,0.010726,0.000119,-0.002157,-5.5e-05,0.005269,-0.004375,3.154008,1.603595,1.603042,1.200993,0.932761,0.907508,0.857162,0.766474,0.761395,0.690974,0.697007,0.658661,0.626155,-1.841837,-1.103803,-1.001154,-0.805113,-0.591796,-0.563174,-0.56959,-0.527515,-0.496164,-0.486195,-0.415531,-0.449461,-0.398625,0.071399,-0.047183,0.057941,-0.081383,-0.018571,0.005258,-0.029624,0.022076,-0.035216,0.006162,-0.017363,-0.007889,-0.006645,1.994204,1.118296,1.091109,0.765444,0.579457,0.59363,0.547581,0.539624,0.505126,0.500238,0.486141,0.450131,0.416221


In [71]:
matched_df = pd.read_csv(root / 'data/processed/matched_metadata.csv')

continuous_targets = [ 'danceability', 'energy',  
           'loudness', 'speechiness', 'acousticness', 
           'instrumentalness', 'liveness', 'valence', 'tempo']

matched_df_target = matched_df[['track_id'] + continuous_targets]
matched_df_target.head()

Unnamed: 0,track_id,danceability,energy,loudness,speechiness,acousticness,instrumentalness,liveness,valence,tempo
0,10,0.606,0.916,-8.162,0.0371,0.14,0.356,0.132,0.889,111.563
1,237,0.28,0.64,-7.799,0.123,0.349,0.675,0.136,0.0537,140.368
2,238,0.192,0.411,-9.445,0.0655,0.539,0.709,0.0909,0.139,56.929
3,459,0.584,0.918,-9.883,0.0345,0.0254,0.77,0.348,0.114,108.305
4,459,0.415,0.646,-12.022,0.0399,0.0189,0.948,0.0965,0.123,93.887


In [72]:
c_targets_df = matched_df_target[["track_id"]+ continuous_targets]
merged_df = MFCC_Q_df.merge(c_targets_df, on = 'track_id')

merged_df.head()

Unnamed: 0,track_id,mfcc_mean_0,mfcc_mean_1,mfcc_mean_2,mfcc_mean_3,mfcc_mean_4,mfcc_mean_5,mfcc_mean_6,mfcc_mean_7,mfcc_mean_8,mfcc_mean_9,mfcc_mean_10,mfcc_mean_11,mfcc_mean_12,mfcc_std_0,mfcc_std_1,mfcc_std_2,mfcc_std_3,mfcc_std_4,mfcc_std_5,mfcc_std_6,mfcc_std_7,mfcc_std_8,mfcc_std_9,mfcc_std_10,mfcc_std_11,mfcc_std_12,mfcc_q25_0,mfcc_q25_1,mfcc_q25_2,mfcc_q25_3,mfcc_q25_4,mfcc_q25_5,mfcc_q25_6,mfcc_q25_7,mfcc_q25_8,mfcc_q25_9,mfcc_q25_10,mfcc_q25_11,mfcc_q25_12,mfcc_q50_0,mfcc_q50_1,mfcc_q50_2,mfcc_q50_3,mfcc_q50_4,mfcc_q50_5,mfcc_q50_6,mfcc_q50_7,mfcc_q50_8,mfcc_q50_9,mfcc_q50_10,mfcc_q50_11,mfcc_q50_12,mfcc_q75_0,mfcc_q75_1,mfcc_q75_2,mfcc_q75_3,mfcc_q75_4,mfcc_q75_5,mfcc_q75_6,mfcc_q75_7,mfcc_q75_8,mfcc_q75_9,mfcc_q75_10,mfcc_q75_11,mfcc_q75_12,delta_mean_0,delta_mean_1,delta_mean_2,delta_mean_3,delta_mean_4,delta_mean_5,delta_mean_6,delta_mean_7,delta_mean_8,delta_mean_9,delta_mean_10,delta_mean_11,delta_mean_12,delta_std_0,delta_std_1,delta_std_2,delta_std_3,delta_std_4,delta_std_5,delta_std_6,delta_std_7,delta_std_8,delta_std_9,delta_std_10,delta_std_11,delta_std_12,delta_q25_0,delta_q25_1,delta_q25_2,delta_q25_3,delta_q25_4,delta_q25_5,delta_q25_6,delta_q25_7,delta_q25_8,delta_q25_9,delta_q25_10,delta_q25_11,delta_q25_12,delta_q50_0,delta_q50_1,delta_q50_2,delta_q50_3,delta_q50_4,delta_q50_5,delta_q50_6,delta_q50_7,delta_q50_8,delta_q50_9,delta_q50_10,delta_q50_11,delta_q50_12,delta_q75_0,delta_q75_1,delta_q75_2,delta_q75_3,delta_q75_4,delta_q75_5,delta_q75_6,delta_q75_7,delta_q75_8,delta_q75_9,delta_q75_10,delta_q75_11,delta_q75_12,delta2_mean_0,delta2_mean_1,delta2_mean_2,delta2_mean_3,delta2_mean_4,delta2_mean_5,delta2_mean_6,delta2_mean_7,delta2_mean_8,delta2_mean_9,delta2_mean_10,delta2_mean_11,delta2_mean_12,delta2_std_0,delta2_std_1,delta2_std_2,delta2_std_3,delta2_std_4,delta2_std_5,delta2_std_6,delta2_std_7,delta2_std_8,delta2_std_9,delta2_std_10,delta2_std_11,delta2_std_12,delta2_q25_0,delta2_q25_1,delta2_q25_2,delta2_q25_3,delta2_q25_4,delta2_q25_5,delta2_q25_6,delta2_q25_7,delta2_q25_8,delta2_q25_9,delta2_q25_10,delta2_q25_11,delta2_q25_12,delta2_q50_0,delta2_q50_1,delta2_q50_2,delta2_q50_3,delta2_q50_4,delta2_q50_5,delta2_q50_6,delta2_q50_7,delta2_q50_8,delta2_q50_9,delta2_q50_10,delta2_q50_11,delta2_q50_12,delta2_q75_0,delta2_q75_1,delta2_q75_2,delta2_q75_3,delta2_q75_4,delta2_q75_5,delta2_q75_6,delta2_q75_7,delta2_q75_8,delta2_q75_9,delta2_q75_10,delta2_q75_11,delta2_q75_12,danceability,energy,loudness,speechiness,acousticness,instrumentalness,liveness,valence,tempo
0,109497,-14.46628,99.482414,-13.127948,43.313545,9.995256,5.39969,-5.793285,13.883069,-1.838129,2.187864,-2.674747,3.841628,-9.301169,26.336906,13.115074,12.082673,10.713166,7.633686,5.559512,5.984026,5.907621,5.963255,5.159916,5.725162,4.790164,5.414805,-25.974308,89.79773,-21.092014,38.467827,6.111401,1.58603,-9.500146,9.957172,-5.797652,-1.314428,-6.898641,0.665738,-12.874005,-12.028065,98.77183,-13.556653,44.643456,11.083105,5.215743,-5.622651,14.334005,-1.961504,2.256998,-2.872006,3.641615,-9.346977,0.913519,107.98017,-5.87316,50.60029,15.318432,9.029509,-1.505398,18.285854,1.977354,5.757357,1.137958,7.034313,-5.877173,0.220803,0.009053,-0.002882,-0.003425,0.00075,0.0046,-0.003279,0.011742,0.006375,0.00165,0.000165,-0.000724,0.005746,4.41648,2.52771,1.879202,1.614143,1.372204,0.91686,1.072609,1.153425,1.141011,0.906527,1.142041,0.910949,1.01229,-2.169072,-1.664055,-1.188237,-1.053944,-0.914239,-0.635814,-0.725844,-0.817876,-0.767006,-0.613866,-0.804201,-0.617282,-0.637065,0.089583,0.031641,0.0656,0.032003,0.03831,-0.001774,-0.00971,0.056556,0.040189,-0.014482,-0.020248,0.00035,0.008368,2.173461,1.705773,1.284936,1.063335,0.875187,0.629731,0.694629,0.863577,0.787185,0.644893,0.81583,0.614444,0.648924,-0.135776,-0.035974,0.004075,-0.013851,-0.004475,-0.001072,-0.0025,-0.009171,-0.006136,-0.009151,-0.006148,-0.006779,-0.001375,2.662576,1.463119,1.183564,1.022441,0.859411,0.726815,0.79601,0.860587,0.880833,0.73046,0.814752,0.704768,0.741921,-1.265242,-0.991705,-0.796757,-0.670792,-0.626229,-0.487701,-0.56262,-0.634363,-0.638328,-0.543769,-0.581081,-0.50053,-0.513199,0.06991,-0.016047,-0.014487,0.00524,-0.036464,-0.029755,-0.022141,-0.052846,-0.019029,0.005924,-0.022269,0.003925,-0.008469,1.341433,0.936322,0.792232,0.692057,0.569096,0.473818,0.533445,0.627216,0.612832,0.523362,0.575562,0.491335,0.518873,0.224,0.897,-8.232,0.0874,0.00014,0.363,0.113,0.189,105.241
1,53666,-52.18589,99.23931,-19.81856,34.8412,-4.830251,7.452331,0.821418,-0.551589,0.541429,8.759988,-2.642872,6.46883,-2.554433,17.970594,16.446632,10.71229,8.822542,7.811877,5.619895,6.609158,6.775237,6.076496,6.250626,5.90914,5.741867,5.448841,-63.08838,89.49299,-26.571468,28.994284,-10.797971,3.577901,-3.682878,-5.356423,-3.77431,4.250537,-6.775855,2.378149,-6.145599,-52.409813,99.797485,-19.99194,34.918488,-4.951907,7.278353,0.506644,-0.81201,0.210771,8.576931,-3.146664,6.51309,-2.431867,-39.256344,109.91645,-12.887173,40.430374,1.22944,11.149651,5.539262,4.122548,4.650206,13.182243,1.037986,10.359025,1.06506,0.012402,-0.007382,-0.005155,0.002076,-0.001866,-0.001652,-0.000474,-0.000553,-0.000199,0.011042,0.004957,-0.008891,-0.007915,3.199941,2.723585,2.00516,1.703189,1.259007,1.013505,1.200884,1.042044,1.084845,1.036201,1.091861,0.978752,1.000591,-1.999134,-1.362825,-1.205278,-1.154972,-0.826864,-0.658937,-0.860859,-0.717076,-0.754294,-0.687359,-0.685047,-0.675195,-0.684837,-0.191827,0.232023,-0.003556,0.06164,0.060407,0.014538,0.011139,-0.037082,-0.008744,0.009474,-0.014274,-0.050996,-0.054915,2.035141,1.856711,1.09918,1.128089,0.83752,0.652004,0.874717,0.679773,0.709825,0.683374,0.693426,0.599407,0.676133,-0.011197,-0.015572,-0.007065,-0.002806,-0.001591,-0.006339,-0.008179,-0.007428,-0.002442,-0.002701,-0.006067,-0.005588,-0.005737,2.597663,1.668799,1.340824,1.059314,0.872083,0.693278,0.744236,0.724987,0.736523,0.742568,0.732206,0.654933,0.686628,-1.837787,-1.172514,-0.897175,-0.681935,-0.573201,-0.437319,-0.512194,-0.506406,-0.50773,-0.475746,-0.441582,-0.4135,-0.418753,-0.039706,-0.089775,-0.027551,-0.03317,0.044339,-0.008325,-0.015951,-0.010395,0.014646,0.01691,-0.01641,-0.018389,-0.000896,1.814744,1.032142,0.8526,0.704708,0.576565,0.439989,0.485595,0.474887,0.520407,0.478401,0.471515,0.439307,0.418716,0.459,0.874,-5.437,0.0447,1.9e-05,0.925,0.318,0.5,139.8
2,55400,-37.400524,124.763824,-36.81307,21.995646,2.225886,21.032818,3.546124,10.978546,-5.663468,1.670306,-1.866043,4.414843,0.263866,25.074936,15.589418,11.059312,9.996412,8.603876,6.485661,7.064106,4.8184,4.59468,4.897986,4.508664,5.303911,5.510609,-55.823593,115.9258,-44.989243,14.944798,-4.011454,16.510422,-1.536884,7.46964,-8.968188,-1.748559,-4.928318,0.673367,-3.620507,-40.126724,128.31274,-36.82352,22.862446,2.857833,21.027534,3.393919,11.081544,-5.586058,1.522695,-1.931237,4.392392,0.225774,-20.72766,135.4742,-29.081856,29.360538,8.64887,25.388668,8.517001,14.297654,-2.515616,5.062921,1.076639,8.111399,3.940616,0.058065,-0.01056,0.011036,0.009491,-0.004898,0.002371,0.01321,0.016682,0.005274,0.004574,0.003347,0.005777,0.008513,4.374987,2.896857,1.517322,1.221072,1.538684,0.871887,1.157159,0.796285,0.784595,0.881469,0.744571,0.868556,0.73955,-3.177753,-1.378299,-1.040217,-0.672707,-0.8364,-0.634032,-0.674083,-0.525178,-0.499175,-0.505695,-0.505284,-0.499896,-0.459783,-0.753589,0.231625,0.094282,0.086951,0.137383,-0.045202,0.156396,0.002043,0.023508,0.025493,0.010294,0.098622,0.010783,2.283091,1.80902,1.129577,0.835514,1.050787,0.593606,0.833639,0.567711,0.542888,0.622507,0.522446,0.593174,0.487758,-0.057128,0.014514,-0.004043,0.006608,0.01231,-0.003013,0.000503,-0.001101,0.002705,-0.000753,-0.00103,0.000794,-0.001403,3.734957,1.728411,1.449256,1.055884,0.866639,0.662977,0.788613,0.562866,0.613871,0.658512,0.639172,0.61351,0.579289,-2.413319,-1.008883,-1.082545,-0.767986,-0.61327,-0.481714,-0.560757,-0.381659,-0.449736,-0.454691,-0.460656,-0.429691,-0.402365,-0.364135,-0.153551,-0.015322,-0.052667,0.023132,0.010028,-0.007917,-0.005446,0.029381,0.013197,0.028722,0.032815,-0.009185,2.70037,0.727889,1.150816,0.740516,0.621775,0.454694,0.579547,0.385677,0.433737,0.492632,0.459254,0.428915,0.39296,0.36,0.979,-4.001,0.0893,6.5e-05,0.00159,0.807,0.106,62.513
3,10589,-380.00287,181.90225,-65.369286,10.870298,16.465267,-47.996254,-3.810154,-8.47399,-42.44927,-8.006559,-6.592747,-21.92328,-0.598965,65.70725,25.217339,45.143143,18.352695,13.137667,17.60754,8.177697,8.135131,13.13912,8.508199,9.304217,8.381541,8.219736,-441.36676,159.42926,-104.954346,-4.825752,7.888224,-61.16475,-9.917351,-14.74465,-53.827713,-13.605358,-13.424345,-27.304317,-5.747596,-371.26947,180.5116,-66.08597,10.847923,16.253841,-47.161537,-4.374521,-8.86621,-41.02721,-8.333242,-8.221781,-21.264357,-0.674093,-321.4391,205.4904,-22.16711,25.337595,24.259344,-32.145584,2.749909,-2.479402,-31.048077,-1.929297,-0.330194,-16.018957,4.672289,0.189026,0.083782,-0.052729,-0.005151,-0.000277,-0.026846,-0.010993,-0.0137,-0.027458,0.001743,0.007929,-0.008412,-0.008311,7.055398,3.073463,4.952223,2.348252,1.871211,1.966379,1.128496,1.124765,1.474231,1.269974,1.269533,1.053716,1.284437,-3.956972,-1.703191,-1.546891,-1.382452,-0.804074,-0.806837,-0.617009,-0.698046,-0.628123,-0.674188,-0.673521,-0.584642,-0.764081,-0.703907,-0.201381,0.534755,0.111816,0.090634,0.188336,-0.009244,-0.033566,0.066667,0.006995,-0.0324,0.008554,-0.040826,2.512165,1.473111,2.843905,1.570254,0.985183,1.065468,0.607956,0.704702,0.786409,0.661649,0.627425,0.634066,0.647253,-0.054324,-0.050451,0.013551,-0.010189,-0.012286,0.013427,-0.00487,-0.000656,0.008195,-2.5e-05,0.00531,0.010322,-0.001069,3.407578,1.517492,2.385228,1.016233,0.909341,0.9996,0.677607,0.646453,0.7891,0.651975,0.664161,0.649972,0.697728,-1.321799,-0.650242,-0.895908,-0.536443,-0.451264,-0.472133,-0.382774,-0.351289,-0.395578,-0.361905,-0.381472,-0.367546,-0.367758,0.045249,0.015976,0.011183,0.015646,-0.002959,-0.01119,-0.010822,0.003489,0.001338,-0.001073,0.031543,-0.014531,-0.012989,1.264395,0.62027,1.032508,0.585428,0.417241,0.496032,0.366111,0.358032,0.410059,0.387136,0.406306,0.371518,0.374446,0.413,0.338,-7.432,0.0317,0.767,8e-06,0.113,0.247,94.139
4,55923,-18.598597,71.36076,-16.790304,51.837322,-2.433789,12.741719,6.08001,5.622876,0.253179,3.234265,-1.363661,1.0434,-7.887192,21.66691,18.164318,15.200961,13.062137,8.76108,9.015322,8.081576,7.085149,7.131298,6.38119,5.982719,5.748267,5.225229,-29.887802,57.067654,-27.065834,41.674706,-8.032942,6.919233,0.672681,0.647905,-4.756876,-1.157414,-5.327281,-2.891277,-11.493351,-17.792692,70.70049,-16.452028,50.660637,-2.545805,11.25412,6.214553,5.592544,0.633819,2.894148,-1.651002,0.56816,-8.043936,-4.971285,85.113754,-4.907349,60.607765,3.752512,17.062355,11.718766,10.113415,5.475327,7.40704,2.175229,4.356773,-4.360371,0.066447,0.025715,-0.015879,0.012799,-0.001647,-0.011966,-0.006075,0.003678,0.003172,0.002557,-0.000214,-0.003642,-0.020094,4.150834,2.444548,2.64983,2.237288,1.532949,1.640232,1.431198,1.290325,1.357411,1.095927,1.101842,1.017437,0.991913,-2.922354,-1.461451,-1.728062,-1.366675,-1.017409,-1.050978,-1.019049,-0.838868,-0.928503,-0.767651,-0.706379,-0.588422,-0.68716,-0.421276,0.161188,-0.014692,0.047534,-0.00967,-0.053324,0.00571,0.005706,-0.004065,0.024265,-0.035329,-0.025648,-0.061232,2.535571,1.687318,1.642538,1.517204,0.973415,0.993121,0.990486,0.841111,0.893974,0.747224,0.756076,0.592153,0.584655,-0.040584,0.003627,0.011505,9.5e-05,0.00042,-0.011483,-0.005313,0.010726,0.000119,-0.002157,-5.5e-05,0.005269,-0.004375,3.154008,1.603595,1.603042,1.200993,0.932761,0.907508,0.857162,0.766474,0.761395,0.690974,0.697007,0.658661,0.626155,-1.841837,-1.103803,-1.001154,-0.805113,-0.591796,-0.563174,-0.56959,-0.527515,-0.496164,-0.486195,-0.415531,-0.449461,-0.398625,0.071399,-0.047183,0.057941,-0.081383,-0.018571,0.005258,-0.029624,0.022076,-0.035216,0.006162,-0.017363,-0.007889,-0.006645,1.994204,1.118296,1.091109,0.765444,0.579457,0.59363,0.547581,0.539624,0.505126,0.500238,0.486141,0.450131,0.416221,0.304,0.943,-5.427,0.09,0.121,0.0,0.34,0.566,79.396


In [73]:
# Multi-output Random Forest (built-in support)
rf_multi = RandomForestRegressor(n_estimators=200, max_depth=15, random_state=42)

In [74]:
# features
X = merged_df.drop(['track_id'] + continuous_targets, axis=1)

# target groups
target_groups = {
    'rhythm': ['tempo', 'danceability', 'energy'],
    'audio': ['acousticness', 'liveness', 'loudness'],
    'content': ['speechiness', 'instrumentalness', 'valence']
}

# Store results
results = {}

# Train and evaluate each group
for group_name, target_cols in target_groups.items():
    print(f"Training {group_name.upper()} Features")
    
    # targets for this group
    y = merged_df[target_cols]
    
    # Split into train/test
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    
    # Train model
    rf_multi = RandomForestRegressor(n_estimators=200, max_depth=15, random_state=42)
    rf_multi.fit(X_train, y_train)
    
    # Predict
    predictions = rf_multi.predict(X_test)
    
    # Evaluate each target
    group_results = {}
    for i, col in enumerate(target_cols):
        r2 = r2_score(y_test.iloc[:, i], predictions[:, i])
        mae = mean_absolute_error(y_test.iloc[:, i], predictions[:, i])
        rmse = np.sqrt(mean_squared_error(y_test.iloc[:, i], predictions[:, i]))
        
        group_results[col] = {'r2': r2, 'mae': mae, 'rmse': rmse}
        print(f"{col:20s} R²: {r2:.3f} | MAE: {mae:.3f} | RMSE: {rmse:.3f}")
    
    # Store results and model
    results[group_name] = {
        'model': rf_multi,
        'metrics': group_results,
        'targets': target_cols
    }

# Print summary
print("SUMMARY - Average R² by Group")
for group_name, data in results.items():
    avg_r2 = np.mean([metrics['r2'] for metrics in data['metrics'].values()])
    print(f"{group_name.upper():15s} Average R²: {avg_r2:.3f}")

Training RHYTHM Features


KeyboardInterrupt: 

In [75]:
from sklearn.model_selection import GridSearchCV

# features
X = merged_df.drop(['track_id'] + continuous_targets, axis=1)

# target groups
target_groups = {
    'rhythm': ['tempo', 'danceability', 'energy'],
    'audio': ['acousticness', 'liveness', 'loudness'],
    'content': ['speechiness', 'instrumentalness', 'valence']
}

# Define parameter grid
param_grid = {
    'n_estimators': [100, 200, 300],
    'max_depth': [10, 15, 20, None],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4],
    'max_features': ['sqrt', 'log2', None]
}

# Store results
results = {}

# Train and evaluate each group
for group_name, target_cols in target_groups.items():
    print(f"\n{'='*60}")
    print(f"Training {group_name.upper()} Features with GridSearch")
    print(f"{'='*60}")
    
    # targets for this group
    y = merged_df[target_cols]
    
    # Split into train/test
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    
    # GridSearch with cross-validation
    rf_base = RandomForestRegressor(random_state=42)
    grid_search = GridSearchCV(
        estimator=rf_base,
        param_grid=param_grid,
        cv=3,  # 3-fold cross-validation
        scoring='r2',
        n_jobs=-1,  # Use all CPU cores
        verbose=1
    )
    
    print(f"Running GridSearch (this may take a few minutes)...")
    grid_search.fit(X_train, y_train)
    
    # Best parameters
    print(f"\nBest parameters for {group_name}:")
    print(grid_search.best_params_)
    print(f"Best CV score: {grid_search.best_score_:.3f}")
    
    # Use best model for predictions
    best_model = grid_search.best_estimator_
    predictions = best_model.predict(X_test)
    
    # Evaluate each target
    print(f"\nTest Set Performance:")
    group_results = {}
    for i, col in enumerate(target_cols):
        r2 = r2_score(y_test.iloc[:, i], predictions[:, i])
        mae = mean_absolute_error(y_test.iloc[:, i], predictions[:, i])
        rmse = np.sqrt(mean_squared_error(y_test.iloc[:, i], predictions[:, i]))
        
        group_results[col] = {'r2': r2, 'mae': mae, 'rmse': rmse}
        print(f"{col:20s} R²: {r2:.3f} | MAE: {mae:.3f} | RMSE: {rmse:.3f}")
    
    # Store results and model
    results[group_name] = {
        'model': best_model,
        'best_params': grid_search.best_params_,
        'cv_score': grid_search.best_score_,
        'metrics': group_results,
        'targets': target_cols
    }



Training RHYTHM Features with GridSearch
Running GridSearch (this may take a few minutes)...
Fitting 3 folds for each of 324 candidates, totalling 972 fits

Best parameters for rhythm:
{'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 300}
Best CV score: 0.296

Test Set Performance:
tempo                R²: 0.092 | MAE: 21.361 | RMSE: 27.541
danceability         R²: 0.415 | MAE: 0.115 | RMSE: 0.138
energy               R²: 0.535 | MAE: 0.151 | RMSE: 0.181

Training AUDIO Features with GridSearch
Running GridSearch (this may take a few minutes)...
Fitting 3 folds for each of 324 candidates, totalling 972 fits

Best parameters for audio:
{'max_depth': 20, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 300}
Best CV score: 0.306

Test Set Performance:
acousticness         R²: 0.462 | MAE: 0.214 | RMSE: 0.253
liveness             R²: 0.039 | MAE: 0.112 | RMSE: 0.148
loudness             R²: 0.5

In [76]:
print("SUMMARY - Average R² by Group")
for group_name, data in results.items():
    avg_r2 = np.mean([metrics['r2'] for metrics in data['metrics'].values()])
    print(f"{group_name.upper():15s} Average R²: {avg_r2:.3f} | CV Score: {data['cv_score']:.3f}")
    print(f"  Best params: {data['best_params']}")

SUMMARY - Average R² by Group
RHYTHM          Average R²: 0.347 | CV Score: 0.296
  Best params: {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 2, 'n_estimators': 300}
AUDIO           Average R²: 0.337 | CV Score: 0.306
  Best params: {'max_depth': 20, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 300}
CONTENT         Average R²: 0.246 | CV Score: 0.201
  Best params: {'max_depth': 20, 'max_features': 'sqrt', 'min_samples_leaf': 2, 'min_samples_split': 2, 'n_estimators': 300}


In [None]:
import joblib
import numpy as np
import os

# Save to parent folder's models directory
models_dir = os.path.join('..', 'models')
os.makedirs(models_dir, exist_ok=True)

# Save your trained models in the parent's models folder
print("Saving models...")
for group_name, data in results.items():
    model_path = os.path.join(models_dir, f'model_{group_name}.pkl')
    joblib.dump(data['model'], model_path)
    print(f"✓ Saved {group_name} model to {model_path}")


# Create a unified predictor class
class UnifiedMusicPredictor:
    def __init__(self, model_paths):
        """Load all group models"""
        self.models = {}
        self.feature_order = ['tempo', 'danceability', 'energy', 
                             'acousticness', 'liveness', 'loudness',
                             'speechiness', 'instrumentalness', 'valence']
        
        for group_name, path in model_paths.items():
            self.models[group_name] = joblib.load(path)
    
    def predict(self, X):
        # Get predictions from each model
        rhythm_pred = self.models['rhythm'].predict(X)  # [tempo, danceability, energy]
        audio_pred = self.models['audio'].predict(X)    # [acousticness, liveness, loudness]
        content_pred = self.models['content'].predict(X) # [speechiness, instrumentalness, valence]
        
        # Combine into single array
        predictions = np.hstack([rhythm_pred, audio_pred, content_pred])
        
        return predictions
    
    def predict_dict(self, X):
        """Return predictions as dictionary for clarity"""
        predictions = self.predict(X)
        
        result = {}
        for i, feature in enumerate(self.feature_order):
            result[feature] = predictions[:, i]
        
        return result

# Initialize the unified predictor with paths to parent's models folder
model_paths = {
    'rhythm': '../models/model_rhythm.pkl',
    'audio': '../models/model_audio.pkl',
    'content': '../models/model_content.pkl'
}

predictor = UnifiedMusicPredictor(model_paths)

# Test it
print("\nTesting Unified Predictor:")
sample_predictions = predictor.predict(X_test)
print(f"Prediction shape: {sample_predictions.shape}")  # Should be (n_samples, 9)
print(f"First sample predictions:\n{sample_predictions[0]}")

# Or get as dictionary
sample_dict = predictor.predict_dict(X_test[:5])
print("\nFirst 5 samples as dictionary:")
for feature, values in sample_dict.items():
    print(f"{feature:20s}: {values[:5]}")

Saving models...
✓ Saved rhythm model to ..\models\model_rhythm.pkl
✓ Saved audio model to ..\models\model_audio.pkl
✓ Saved content model to ..\models\model_content.pkl

Testing Unified Predictor:
Prediction shape: (346, 9)
First sample predictions:
[ 1.18940290e+02  3.61553255e-01  5.16975502e-01  5.47001657e-01
  2.10794069e-01 -1.36840793e+01  6.21974313e-02  5.47757570e-01
  2.12272608e-01]

First 5 samples as dictionary:
tempo               : [118.94029001 129.83866083 113.28319958 120.42865767 126.26880278]
danceability        : [0.36155326 0.42031939 0.44782903 0.37151744 0.4521525 ]
energy              : [0.5169755  0.72352667 0.3895705  0.58921391 0.558611  ]
acousticness        : [0.54700166 0.22107675 0.70529508 0.36972663 0.54253529]
liveness            : [0.21079407 0.2279997  0.16561447 0.2046611  0.18861882]
loudness            : [-13.68407932  -6.31750117 -12.8438341  -10.58172022 -11.22293446]
speechiness         : [0.06219743 0.05797902 0.04918148 0.06400963 0.053766

Cache for faster predictions

In [None]:
from functools import lru_cache
import hashlib
import pickle

class CachedMusicPredictor(UnifiedMusicPredictor):
    def __init__(self, model_paths, cache_size=1000):
        super().__init__(model_paths)
        self.cache = {}
        self.cache_size = cache_size
    
    def _hash_input(self, X):
        """Create hash of input for cache key"""
        return hashlib.md5(pickle.dumps(X.values if hasattr(X, 'values') else X)).hexdigest()
    
    def predict(self, X, use_cache=True):
        """Predict with optional caching"""
        if not use_cache:
            return super().predict(X)
        
        # Check cache
        cache_key = self._hash_input(X)
        if cache_key in self.cache:
            print("✓ Using cached prediction")
            return self.cache[cache_key]
        
        # Compute prediction
        predictions = super().predict(X)
        
        # Store in cache (with size limit)
        if len(self.cache) >= self.cache_size:
            # Remove oldest entry
            self.cache.pop(next(iter(self.cache)))
        
        self.cache[cache_key] = predictions
        return predictions

# Use cached predictor
cached_predictor = CachedMusicPredictor(model_paths, cache_size=1000)

# First call - computes
predictions1 = cached_predictor.predict(X_test)

# Second call - uses cache
predictions2 = cached_predictor.predict(X_test)  

✓ Using cached prediction
