In [None]:
!pip install torch
!pip install sentence-transformers
!pip install transformers
!pip install numpy
!pip install pandas
!pip install scikit-learn



In [None]:
import os
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
import math
import csv
import random

  from tqdm.autonotebook import tqdm, trange


In [None]:
def load_lyrics_dataset(file_path):
    df = pd.read_csv(file_path)
    return df

def retrieve_top_k_songs(query, song_embeddings, song_metadata, bi_encoder, k=10):
    """Retrieve top-k songs based on cosine similarity."""
    query_embedding = bi_encoder.encode_texts([query])
    similarities = util.pytorch_cos_sim(query_embedding, song_embeddings)
    top_k_indices = torch.topk(similarities, k=k).indices[0]
    return [song_metadata[i] for i in top_k_indices.tolist()]

def preprocess_lyrics_dynamic(lyrics, min_segment_size=3, max_segments=10):
    words = lyrics.split()
    total_words = len(words)

    if total_words <= min_segment_size:
        return [lyrics]

    segment_size = max(min_segment_size, math.ceil(total_words / max_segments))
    segments = [" ".join(words[i:i+segment_size]) for i in range(0, total_words, segment_size)]
    return segments

def create_finetuning_dataset(df, num_queries=2, num_negative_pairs=10, min_segment_size=3, max_segments=10, qrels_path="qrels.csv"):
    queries = []
    corpus = []

    with open(qrels_path, mode="w", newline="", encoding="utf-8") as qrels_file:
        qrels_writer = csv.DictWriter(qrels_file, fieldnames=["_query_id", "song_id", "score"])
        qrels_writer.writeheader()

        for idx, row in df.iterrows():
            title = row['track_name']
            lyrics = row['lyrics']
            artist = row['artist_name']

            segments = preprocess_lyrics_dynamic(lyrics, min_segment_size, max_segments)
            corpus.append({"_id": f"{idx+1}", "track_name": title, "lyrics": lyrics, "artist_name": artist})

            selected_queries = random.sample(segments, min(len(segments), num_queries))
            for query in selected_queries:
                query_id = f"q{len(queries)+1}"
                queries.append({"_query_id": query_id, "query": query})

                qrels_writer.writerow({"_query_id": query_id, "song_id": f"{idx+1}", "score": 1})
                negative_song_indices = [i for i in range(len(df)) if i != idx]
                negative_samples = random.sample(negative_song_indices, num_negative_pairs)

                for neg_idx in negative_samples:
                    qrels_writer.writerow({"_query_id": query_id, "song_id": f"{neg_idx+1}", "score": 0})

    return queries, corpus

In [None]:
class BiEncoder:
    def __init__(self, model_name='sentence-transformers/all-mpnet-base-v2'):
        self.model = SentenceTransformer(model_name)

    def encode_texts(self, texts):
        """Used for encoding lyrics into embeddings."""
        return self.model.encode(texts, convert_to_tensor=True, show_progress_bar=True)

class CrossEncoder:
    def __init__(self, model_name='cross-encoder/ms-marco-MiniLM-L-6-v2'):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)

    def rank_candidates(self, query, candidates):
        inputs = [
            self.tokenizer(query, candidate, return_tensors='pt', truncation=True, max_length=512, padding=True)
            for candidate in candidates
        ]
        scores = []
        for input_pair in inputs:
            with torch.no_grad():
                logits = self.model(**input_pair).logits
            scores.append(logits.item())
        ranked_indices = np.argsort(scores)[::-1]
        return ranked_indices

In [None]:
# Main Pipeline
def song_retrieval_pipeline(query, data, bi_encoder_model='sentence-transformers/all-mpnet-base-v2',
                            cross_encoder_model='cross-encoder/ms-marco-MiniLM-L-6-v2', k=10):
    lyrics = data['lyrics'].tolist()
    song_metadata = data[['track_name', 'artist_name', 'lyrics']].to_dict(orient='records')

    # Bi-Encoder
    bi_encoder = BiEncoder(model_name=bi_encoder_model)
    song_embeddings = bi_encoder.encode_texts(lyrics)

    # Initial retrieval
    top_k_songs = retrieve_top_k_songs(query, song_embeddings, song_metadata, bi_encoder, k)
    top_k_lyrics = [song['lyrics'] for song in top_k_songs]

    # Cross-Encoder
    cross_encoder = CrossEncoder(model_name=cross_encoder_model)
    ranked_indices = cross_encoder.rank_candidates(query, top_k_lyrics)

    # Re-ranking
    re_ranked_songs = [top_k_songs[i] for i in ranked_indices]
    return re_ranked_songs


In [29]:
# Main program
dataset_path = "dataset.csv"
corpus_path = "corpus.csv"
queries_path = "queries.csv"
qrels_path = "qrels.csv"
query = "want a little bit heart"

if os.path.exists(dataset_path):
    print("Dataset obtained.")

    data = load_lyrics_dataset(dataset_path)

    if not (os.path.exists(corpus_path) and os.path.exists(queries_path) and os.path.exists(qrels_path)):
        print("Required files not found. Generating fine-tuning datasets...")

        # Generate fine-tuning datasets
        queries, corpus = create_finetuning_dataset(data, num_queries=2, num_negative_pairs=300, qrels_path=qrels_path)

        queries_df = pd.DataFrame(queries)
        corpus_df = pd.DataFrame(corpus)
        queries_df.to_csv(queries_path, index=False)
        corpus_df.to_csv(corpus_path, index=False)
        print("Datasets generated and saved.")
    else:
        print("Datasets already exist. Skipping dataset generation.")

    # Run song retrieval pipeline
    results = song_retrieval_pipeline(query, data)
    print("Top retrieved songs:")
    for idx, song in enumerate(results):
        print(f"{idx + 1}. {song['track_name']} by {song['artist_name']}")
else:
    print("Dataset path does not exist.")

Dataset obtained.
Required files not found. Generating fine-tuning datasets...
Datasets generated and saved.


Batches:   0%|          | 0/887 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Top retrieved songs:
1. your hearts not in it by janie fricke
2. hello my old heart by the oh hellos
3. all of me (loves all of you) by george strait
4. promises, promises by jerry vale
5. two faces have i by lou christie
6. from me to you by the beatles
7. from me to you by del shannon
8. heartfull of soul by the yardbirds
9. somebody by depeche mode
10. love by john lennon
