In [2]:
import os
import numpy as np
import librosa
import faiss
import logging
import torch
from transformers import ClapProcessor, ClapModel
import crepe
from scipy.signal import butter, filtfilt
from scipy.spatial.distance import cosine
from fastdtw import fastdtw
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# CLAP model setup
embedding_dim = 512
index = faiss.IndexFlatL2(embedding_dim)
processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")
model = ClapModel.from_pretrained("laion/clap-htsat-unfused")

# Database storage
database_embeddings = []
database_spectral = []
database_chroma = []
database_onsets = []
database_tempos = []
database_melody_sequences = []
database_tracks = []

# Paths and constants
AUDIO_FOLDER_PATH = "/home/ubuntu/mahesh_YUE/Finetuned_songs/ambient"  # <-- Your input directory
EMBEDDINGS_DIR = "/home/ubuntu/mahesh_YUE/music_similarity_embeddings/genrewise1/lofi"
ALLOWED_FORMATS = ('.mp3', '.wav', '.aiff')

# Low-pass filter for melody SWB filtering
def lowpass_filter(signal, sr, cutoff=2500):
    nyquist = 0.5 * sr
    norm_cutoff = cutoff / nyquist
    b, a = butter(4, norm_cutoff, btype='low', analog=False)
    return filtfilt(b, a, signal)

# Feature extraction functions
def extract_clap_embedding(audio, sr=48000):
    try:
        inputs = processor(audios=[audio], sampling_rate=sr, return_tensors="pt", padding=True)
        with torch.no_grad():
            embedding = model.get_audio_features(**inputs)
        return np.array(embedding.squeeze().numpy(), dtype=np.float32)
    except Exception as e:
        logger.error(f"Error extracting CLAP embedding: {e}")
        return None

def extract_melody_sequence(audio, sr=16000):
    try:
        mono = librosa.to_mono(audio) if audio.ndim > 1 else audio
        filtered = lowpass_filter(mono, sr)
        time, frequency, confidence, _ = crepe.predict(filtered, sr, model_capacity='tiny', step_size=20)
        high_conf_idx = confidence > 0.7
        if not np.any(high_conf_idx):
            return None
        return frequency[high_conf_idx].astype(np.float32)
    except Exception as e:
        logger.error(f"Error extracting melody: {e}")
        return None

def extract_spectral_features(audio, sr):
    try:
        mfccs = librosa.feature.mfcc(y=audio, sr=sr, n_mfcc=13)
        return np.mean(mfccs, axis=1).astype(np.float32)
    except Exception as e:
        logger.error(f"Error extracting spectral features: {e}")
        return None

def extract_chroma_features(audio, sr):
    try:
        chroma = librosa.feature.chroma_stft(y=audio, sr=sr)
        return np.mean(chroma, axis=1).astype(np.float32)
    except Exception as e:
        logger.error(f"Error extracting chroma features: {e}")
        return None

def extract_onset_features(audio, sr):
    try:
        onset_env = librosa.onset.onset_strength(y=audio, sr=sr)
        tempo, beats = librosa.beat.beat_track(onset_envelope=onset_env, sr=sr)
        onset_frames = librosa.frames_to_samples(beats, hop_length=512)
        tempo_scalar = float(tempo) if np.isscalar(tempo) else float(tempo[0])
        return onset_env[:len(onset_frames)], tempo_scalar
    except Exception as e:
        logger.error(f"Error extracting onset features: {e}")
        return None, None

def process_audio_file(audio_path):
    try:
        y, sr = librosa.load(audio_path, sr=None, duration=30.0)
        return audio_path, (
            extract_clap_embedding(y),
            extract_spectral_features(y, sr),
            extract_chroma_features(y, sr),
            extract_onset_features(y, sr)[0],
            extract_onset_features(y, sr)[1],
            extract_melody_sequence(y)
        )
    except Exception as e:
        logger.error(f"Error processing {audio_path}: {e}")
        return audio_path, (None, None, None, None, None, None)

# Distance and similarity functions
def scalar_distance(x, y):
    return abs(x - y)

def dtw_similarity(seq1, seq2, sigma):
    if seq1 is None or seq2 is None or len(seq1) < 2 or len(seq2) < 2:
        return 0.0
    try:
        seq1_ds = seq1[::2]
        seq2_ds = seq2[::2]
        dtw_distance, _ = fastdtw(seq1_ds, seq2_ds, dist=scalar_distance)
        return np.exp(-dtw_distance / sigma)
    except Exception as e:
        logger.error(f"Error in DTW computation: {e}")
        return 0.0

# Save features
def save_features():
    try:
        os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
        if not os.access(EMBEDDINGS_DIR, os.W_OK):
            logger.error(f"No write permission for {EMBEDDINGS_DIR}")
            raise PermissionError(f"No write permission for {EMBEDDINGS_DIR}")

        if not database_tracks:
            logger.warning("No valid tracks to save")
            return False

        np.save(os.path.join(EMBEDDINGS_DIR, "embeddings.npy"), np.array(database_embeddings, dtype=np.float32))
        np.save(os.path.join(EMBEDDINGS_DIR, "spectral.npy"), np.array(database_spectral, dtype=np.float32))
        np.save(os.path.join(EMBEDDINGS_DIR, "chroma.npy"), np.array(database_chroma, dtype=np.float32))
        np.save(os.path.join(EMBEDDINGS_DIR, "onsets.npy"), np.array(database_onsets, dtype=object))
        np.save(os.path.join(EMBEDDINGS_DIR, "tempos.npy"), np.array(database_tempos, dtype=np.float32))
        np.save(os.path.join(EMBEDDINGS_DIR, "melody_sequences.npy"), np.array(database_melody_sequences, dtype=object))
        np.save(os.path.join(EMBEDDINGS_DIR, "tracks.npy"), np.array(database_tracks, dtype=str))
        np.save(os.path.join(EMBEDDINGS_DIR, "max_distance_onsets.npy"), np.array([MAX_DISTANCE_ONSETS], dtype=np.float32))
        logger.info(f"Successfully saved features to {EMBEDDINGS_DIR} with {len(database_tracks)} tracks")
        return True
    except Exception as e:
        logger.error(f"Failed to save features: {e}")
        return False

# Load features
def load_features():
    global database_embeddings, database_spectral, database_chroma, database_onsets, database_tempos, database_melody_sequences, database_tracks, MAX_DISTANCE_ONSETS
    required_files = ["embeddings.npy", "spectral.npy", "chroma.npy", "onsets.npy", "tempos.npy", "melody_sequences.npy", "tracks.npy", "max_distance_onsets.npy"]

    if os.path.exists(EMBEDDINGS_DIR) and all(os.path.exists(os.path.join(EMBEDDINGS_DIR, f)) for f in required_files):
        try:
            database_embeddings = np.load(os.path.join(EMBEDDINGS_DIR, "embeddings.npy")).tolist()
            database_spectral = np.load(os.path.join(EMBEDDINGS_DIR, "spectral.npy")).tolist()
            database_chroma = np.load(os.path.join(EMBEDDINGS_DIR, "chroma.npy")).tolist()
            database_onsets = np.load(os.path.join(EMBEDDINGS_DIR, "onsets.npy"), allow_pickle=True).tolist()
            database_tempos = np.load(os.path.join(EMBEDDINGS_DIR, "tempos.npy")).tolist()
            database_melody_sequences = np.load(os.path.join(EMBEDDINGS_DIR, "melody_sequences.npy"), allow_pickle=True).tolist()
            database_tracks = np.load(os.path.join(EMBEDDINGS_DIR, "tracks.npy")).tolist()
            MAX_DISTANCE_ONSETS = np.load(os.path.join(EMBEDDINGS_DIR, "max_distance_onsets.npy"))[0]

            index.reset()
            index.add(np.array(database_embeddings, dtype=np.float32))
            logger.info(f"Successfully loaded features from {EMBEDDINGS_DIR} with {len(database_tracks)} tracks")
            logger.info(f"Loaded MAX_DISTANCE_ONSETS: {MAX_DISTANCE_ONSETS}")
            return True
        except Exception as e:
            logger.error(f"Error loading features: {e}")
            return False
    else:
        missing_files = [f for f in required_files if not os.path.exists(os.path.join(EMBEDDINGS_DIR, f))]
        logger.warning(f"Feature files missing in {EMBEDDINGS_DIR}: {missing_files}")
        return False

# Database initialization
def initialize_database():
    global MAX_DISTANCE_ONSETS
    # Check if features are already saved
    if load_features():
        return

    # If features don't exist, extract and save them
    audio_files = [os.path.join(AUDIO_FOLDER_PATH, f) for f in os.listdir(AUDIO_FOLDER_PATH) if f.endswith(ALLOWED_FORMATS)]
    with ThreadPoolExecutor() as executor:
        results = list(tqdm(executor.map(process_audio_file, audio_files), total=len(audio_files), desc="Processing audio files"))

    for audio_path, result in results:
        if all(v is not None for v in result):
            database_embeddings.append(result[0])
            database_spectral.append(result[1])
            database_chroma.append(result[2])
            database_onsets.append(result[3])
            database_tempos.append(result[4])
            database_melody_sequences.append(result[5])
            database_tracks.append(os.path.basename(audio_path))

    onset_dtw_distances = []
    for i in range(len(database_onsets)):
        for j in range(i + 1, len(database_onsets)):
            if database_onsets[i] is not None and database_onsets[j] is not None:
                dist, _ = fastdtw(database_onsets[i][::2], database_onsets[j][::2], dist=scalar_distance)
                onset_dtw_distances.append(dist)
    MAX_DISTANCE_ONSETS = max(onset_dtw_distances) if onset_dtw_distances else 1000
    logger.info(f"Dynamic MAX_DISTANCE_ONSETS: {MAX_DISTANCE_ONSETS}")

    if database_embeddings:
        index.add(np.array(database_embeddings, dtype=np.float32))
        logger.info(f"Database initialized with {len(database_embeddings)} tracks")
        save_features()
    else:
        logger.warning("Database is empty!")

# Query comparison
def compare_query_to_database(query_audio_path):
    y, sr = librosa.load(query_audio_path, sr=None, duration=30.0)
    query_clap = extract_clap_embedding(y)
    query_spectral = extract_spectral_features(y, sr)
    query_chroma = extract_chroma_features(y, sr)
    query_onsets, query_tempo = extract_onset_features(y, sr)
    query_melody_seq = extract_melody_sequence(y)

    # Compute similarities
    H_list, M_list, B_list, C_list, S_list, T_list = [], [], [], [], [], []
    for i in range(len(database_embeddings)):
        H = np.dot(query_clap, database_embeddings[i]) / (np.linalg.norm(query_clap) * np.linalg.norm(database_embeddings[i])) if query_clap is not None else 0.0
        H = max(0, H)
        S = 1 - cosine(query_spectral, database_spectral[i]) if query_spectral is not None else 0.0
        S = max(0, S)
        C = 1 - cosine(query_chroma, database_chroma[i]) if query_chroma is not None else 0.0
        C = max(0, C)
        B = dtw_similarity(query_onsets, database_onsets[i], sigma=MAX_DISTANCE_ONSETS / 10)
        M = dtw_similarity(query_melody_seq, database_melody_sequences[i], sigma=50)
        T = np.exp(-abs(query_tempo - database_tempos[i]) / 10) if query_tempo is not None else 0.0

        H_list.append(H)
        S_list.append(S)
        C_list.append(C)
        B_list.append(B)
        M_list.append(M)
        T_list.append(T)

    # Normalize similarities to [0,1]
    def normalize(lst):
        lst = np.array(lst)
        min_val, max_val = lst.min(), lst.max()
        return (lst - min_val) / (max_val - min_val + 1e-8) if max_val > min_val else lst / (lst + 1e-8)

    H_list = normalize(H_list)
    M_list = normalize(M_list)
    B_list = normalize(B_list)
    C_list = normalize(C_list)
    S_list = normalize(S_list)
    T_list = normalize(T_list)

    # Compute variances for dynamic weights
    variances = {
        'H': np.var(H_list),
        'M': np.var(M_list),
        'B': np.var(B_list),
        'C': np.var(C_list),
        'S': np.var(S_list),
        'T': np.var(T_list)
    }

    # Dynamic weights based solely on variances
    adjusted_weights = {f: variances[f] for f in variances}
    total_weight = sum(adjusted_weights.values()) + 1e-8
    adjusted_weights = {f: w / total_weight for f, w in adjusted_weights.items()}
    logger.info(f"Dynamic weights: {adjusted_weights}")
    logger.info(f"Variances: {variances}")

    # Compute final scores using weighted sum
    similarities = []
    for i in range(len(database_tracks)):
        score = (
            adjusted_weights['H'] * H_list[i] +
            adjusted_weights['M'] * M_list[i] +
            adjusted_weights['B'] * B_list[i] +
            adjusted_weights['C'] * C_list[i] +
            adjusted_weights['S'] * S_list[i] +
            adjusted_weights['T'] * T_list[i]
        )
        similarities.append({'track': database_tracks[i], 'similarity_score': score})

    similarities.sort(key=lambda x: x['similarity_score'], reverse=True)
    top_20 = similarities[:20]
    logger.info(f"Top 20 tracks: {[t['track'] for t in top_20]}")
    return top_20

# Main execution
def main():
    initialize_database()
    if not database_embeddings:
        logger.warning("No valid audio data available.")
        return
    query_audio = "/home/ubuntu/mahesh_YUE/Finetuned_songs/lofi/aunt - Ocean Ride.mp3"  # <-- Your query path (replace with a dataset song)
    top_similar = compare_query_to_database(query_audio)
    print("\nTop similar tracks:")
    query_filename = os.path.basename(query_audio)
    for t in top_similar:
        print(f"Track: {t['track']}, Score: {t['similarity_score']:.3f}")
        if t['track'] == query_filename:
            logger.info(f"Query audio {query_filename} found with score {t['similarity_score']:.3f}")

if __name__ == "__main__":
    main()

2025-08-21 11:53:48,193 - INFO - Successfully loaded features from /home/ubuntu/mahesh_YUE/music_similarity_embeddings/genrewise1/lofi with 177 tracks
2025-08-21 11:53:48,194 - INFO - Loaded MAX_DISTANCE_ONSETS: 71.14161682128906
I0000 00:00:1755777230.930685 3412980 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1030 MB memory:  -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:07:00.0, compute capability: 8.6
I0000 00:00:1755777232.113785 3413172 service.cc:148] XLA service 0x7c67ac003fd0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1755777232.113868 3413172 service.cc:156]   StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6
2025-08-21 11:53:52.139208: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1755777232.193696 3413172 cuda_dnn.cc:529] Loaded cu

[1m 60/130[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m0s[0m 3ms/step    

I0000 00:00:1755777232.867670 3413172 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 8ms/step


2025-08-21 11:53:55,017 - INFO - Dynamic weights: {'H': 0.11507052389472819, 'M': 0.029641170207862784, 'B': 0.2067701882614766, 'C': 0.08402225247675282, 'S': 0.1200334440474605, 'T': 0.44446236834870795}
2025-08-21 11:53:55,018 - INFO - Variances: {'H': 0.021808937957094796, 'M': 0.005617793507487299, 'B': 0.03918847376845869, 'C': 0.01592446118484047, 'S': 0.02274954389190373, 'T': 0.08423749095335847}
2025-08-21 11:53:55,019 - INFO - Top 20 tracks: ['aunt - Ocean Ride.mp3', 'Casiio - Central Park.mp3', 'Cmd q - Morse.mp3', 'aroramo - machine.mp3', 'Vincent Rayn - Birds Eye.mp3', 'parrow - Pink.mp3', 'Justnormal - Fika.mp3', 'galaxx - Lazy Streets.mp3', 'Domo 759 - Meantime.mp3', 'Dryden - Summer Memories.mp3', 'Jost Esser - lemon soda.mp3', 'Erwin Do - Fragments.mp3', 'Tojié Cai - Niwa.mp3', 'Summer Clarke - devotion.mp3', 'MODALiST - Jouissance.mp3', 'Lownas - Spirit.mp3', 'Phresh Milk - no way but up.mp3', 'rocomoco - Aurola - Inf & Rainn Remix.mp3', 'morningtime - headway.mp3', 


Top similar tracks:
Track: aunt - Ocean Ride.mp3, Score: 1.000
Track: Casiio - Central Park.mp3, Score: 0.834
Track: Cmd q - Morse.mp3, Score: 0.825
Track: aroramo - machine.mp3, Score: 0.816
Track: Vincent Rayn - Birds Eye.mp3, Score: 0.803
Track: parrow - Pink.mp3, Score: 0.796
Track: Justnormal - Fika.mp3, Score: 0.780
Track: galaxx - Lazy Streets.mp3, Score: 0.769
Track: Domo 759 - Meantime.mp3, Score: 0.748
Track: Dryden - Summer Memories.mp3, Score: 0.746
Track: Jost Esser - lemon soda.mp3, Score: 0.731
Track: Erwin Do - Fragments.mp3, Score: 0.723
Track: Tojié Cai - Niwa.mp3, Score: 0.666
Track: Summer Clarke - devotion.mp3, Score: 0.666
Track: MODALiST - Jouissance.mp3, Score: 0.664
Track: Lownas - Spirit.mp3, Score: 0.653
Track: Phresh Milk - no way but up.mp3, Score: 0.649
Track: rocomoco - Aurola - Inf & Rainn Remix.mp3, Score: 0.620
Track: morningtime - headway.mp3, Score: 0.619
Track: C4C - Call a Friend.mp3, Score: 0.615
