In [None]:
import os
import torch
import torchaudio
import pandas as pd
import torch.nn.functional as F
import torchaudio.transforms as T
from transformers import ClapProcessor, ClapModel

In [None]:
# # CLAP-Compatible Text Descriptions for Each Tag
# text_descriptions = {
#     "Soothing": "A gentle, calming song that helps soothe and regulate emotional states.",
#     "Stimulating": "A vibrant and upbeat track designed to increase alertness and sensory engagement.",
#     "Grounding": "A slow, repetitive melody that provides structure and emotional grounding.",
#     "Playful": "An engaging and whimsical tune that encourages joy and playful interaction.",
#     "Focusing": "A structured and predictable melody that helps children sustain attention and reduce distractions.",
#     "Transitional": "A motivating and rhythmically clear piece used to support transitions between activities.",
#     "Interactive": "A musical piece designed to invite turn-taking, imitation, and social play.",
#     "Motivating": "A cheerful and energizing melody that inspires movement and participation.",
#     "Anxiety-Reducing": "A soft and flowing piece that helps reduce sensory overload and anxiety.",
#     "Task-Oriented": "A functional and directive tune used to support task initiation or completion.",
#     "Self-Expressive": "A flexible, emotionally rich song that allows for personal vocal or instrumental expression.",
#     "Sensory-Calming": "A low-frequency, mellow soundscape that promotes sensory regulation and relaxation.",
#     "Attention-Shifting": "A contrasting and dynamic melody designed to redirect or refocus attention in therapeutic settings.",
#     "Rhythmic Synchronizing": "A rhythmically steady piece that supports motor coordination and entrainment.",
#     "Confidence-Building": "A triumphant and empowering melody that encourages self-esteem and initiative."
# }

# text_descriptions = {
#     "Soothing": "A gentle melodic layer with soft pad harmonies and flowing water sounds, promoting calm and relaxation.",
#     "Stimulating": "An energetic musical track featuring rhythmic percussion, sharp bass, and dynamic tempo changes to engage movement.",
#     "Grounding": "A steady rhythmic base and repetitive chord structure that helps stabilize emotional and sensory responses.",
#     "Playful": "A whimsical tune with plucked strings, bell-like melodies, and light percussive textures that encourage joyful interaction.",
#     "Focusing": "A repetitive harp or flute melody with consistent timing and predictable structure, aiding concentration and focus.",
#     "Transitional": "A musically clear phrase with firm bass drum and synthesized rhythm, guiding activity shifts or transitions.",
#     "Interactive": "A multi-layered song with call-and-response melodies and spacious textures for turn-taking and social engagement.",
#     "Motivating": "An upbeat electric guitar and lead synth combination with medium tempo percussion to inspire active participation.",
#     "Anxiety-Reducing": "A warm, slow-paced track using soft pads, descending melodic lines, and ambient textures to lower anxiety.",
#     "Task-Oriented": "A track with consistent rhythmic pulses and directive harmonic cues that support initiation of structured activities.",
#     "Self-Expressive": "An emotionally flexible piece featuring layered melodies and rich harmonies to support vocal or instrumental expression.",
#     "Sensory-Calming": "A soft ambient mix of flowing water, sustained strings, and low-register harmony to reduce sensory overload.",
#     "Attention-Shifting": "A contrasting track with register changes and sudden timbral shifts used to redirect attention gently.",
#     "Rhythmic Synchronizing": "A repetitive percussive beat and walking bass line designed to promote motor coordination and entrainment.",
#     "Confidence-Building": "A strong, uplifting melody with plucked bass and vibrant harmony that supports self-esteem and initiation."
# }

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Functional Emotion Tags
emotional_tags = [
    "Stimulating", # High Arousal and High Valence
    "Playful", # High Arousal and High Valence 
    "Soothing", # Low Arousal and High Valence
    "Sensory-Calming", # Low Arousal and High Valence
    "Grounding", # Low Arousal and Low Valence
    "Focusing", # Low Arousal and Low Valence
    "Transitional", # High Arousal and Low Valence
    "Anxiety-Reduction" # High Arousal and Low Valence

]

text_descriptions = {
    "Stimulating": "A lively, energetic track with strong rhythm and upbeat tempo to boost movement and attention.",
    "Playful": "A whimsical, bouncy track with light percussion and bright tones to encourage joyful exploration.",
    "Soothing": "A calm, gentle track with soft melodies and slow tempo for relaxation and emotional regulation.",
    "Sensory-Calming": "A track with ambient textures and minimal melodic content designed to reduce sensory overload.",
    "Grounding": "A steady, repetitive musical base that promotes sensory stability and emotional anchoring.",
    "Focusing": "A structured, predictable track with consistent rhythm to support sustained attention and task engagement.",
    "Transitional": "A cue-based musical track used to signal transitions between activities or routines.",
    "Anxiety-Reduction": "A soft, harmonically stable track specifically designed to ease agitation or anxious behaviors."
}

clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused").to(device)
clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")

In [None]:
def load_audio(path, sr=48000):
    waveform, original_sr = torchaudio.load(path)
    if original_sr != sr:
        waveform = T.Resample(orig_freq=original_sr, new_freq=sr)(waveform)
    return waveform.mean(0).unsqueeze(0)

def segment_audio(waveform, sr=48000, segment_sec=10, overlap_sec=5):
    segment_len = segment_sec * sr
    step = segment_len - overlap_sec * sr
    return [waveform[:, i:i+segment_len] for i in range(0, waveform.size(1) - segment_len + 1, step)]

def get_clap_audio_embedding(segments):
    with torch.no_grad():
        audio_embeds = []
        for seg in segments:
            inputs = clap_processor(audios=seg.squeeze(0), return_tensors="pt", sampling_rate=48000).to(device)
            embed = clap_model.get_audio_features(**inputs)
            audio_embeds.append(embed)
        return torch.mean(torch.stack(audio_embeds), dim=0)

def get_clap_text_embeddings():
    inputs = clap_processor(text=list(text_descriptions.values()), return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        return clap_model.get_text_features(**inputs)

def cosine_similarity_matrix(audio_embeds, text_embeds):
    return F.cosine_similarity(audio_embeds.unsqueeze(1), text_embeds.unsqueeze(0), dim=2).cpu().numpy()

def run_clap_audio_text_retrieval(input_folder, out_folder):
    os.makedirs(out_folder, exist_ok=True)

    audio_files = []
    audio_embeddings = []

    for root, _, files in os.walk(input_folder):
        for file in files:
            if not file.endswith(".wav"):
                continue
            try:
                path = os.path.join(root, file)
                print(f"Processing: {path}")
                waveform = load_audio(path)
                segments = segment_audio(waveform)
                audio_embed = get_clap_audio_embedding(segments)
                audio_embeddings.append(audio_embed.cpu())
                audio_files.append(os.path.relpath(path, input_folder))
            except Exception as e:
                print(f"Failed: {file} â€” {e}")

    audio_embeddings = torch.stack(audio_embeddings).squeeze(1).to(device)
    text_embeddings = get_clap_text_embeddings()

    sim_matrix = cosine_similarity_matrix(audio_embeddings, text_embeddings)

    df_audio_to_text = pd.DataFrame(sim_matrix, columns=emotional_tags)
    df_audio_to_text.insert(0, "file_name", audio_files)
    df_audio_to_text.to_csv(os.path.join(out_folder, "audio_to_text_similarity_refined.csv"), index=False)

    df_text_to_audio = pd.DataFrame(sim_matrix.T, index=emotional_tags, columns=audio_files).T
    df_text_to_audio.to_csv(os.path.join(out_folder, "text_to_audio_similarity_refined.csv"))

    print("Retrieval results saved.")

In [None]:
# Run
if __name__ == "__main__":
    run_clap_audio_text_retrieval(
        input_folder="<<Path to audio files>>",
        out_folder="<<Path to save the csvs>>"
    )