In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
from tqdm import tqdm
import os
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class EnhancedSongDataset(Dataset):
    """Custom Dataset for songs"""

    def __init__(self, features, lyrics_embeddings):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.lyrics_embeddings = torch.tensor(lyrics_embeddings, dtype=torch.float32)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return (self.features[idx], self.lyrics_embeddings[idx]), self.features[idx]


class SimplerRecommenderNet(nn.Module):
    """Simplified Neural Network for Song Recommendation"""

    def __init__(self, input_dim, lyrics_dim=768):
        super().__init__()
        combined_dim = input_dim + lyrics_dim

        self.network = nn.Sequential(
            nn.Linear(combined_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
        )

    def forward(self, features, lyrics_embedding):
        combined = torch.cat([features, lyrics_embedding], dim=1)
        return self.network(combined)


class EarlyStopping:
    """Early stopping to prevent overfitting"""

    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0


class LyricProcessor:
    """Process lyrics using BERT"""

    def __init__(self, cache_file="lyrics_embeddings_cache.pkl"):
        self.cache_file = cache_file
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"LyricProcessor using device: {self.device}")

        # Initialize BERT
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.bert = BertModel.from_pretrained("bert-base-uncased").to(self.device)

    def process_lyrics_batch(self, lyrics_list):
        """Process lyrics in batches"""
        embeddings = []
        batch_size = 32

        # Ensure all lyrics are strings and clean them
        cleaned_lyrics = []
        for lyric in lyrics_list:
            if isinstance(lyric, (float, int)):
                lyric = str(lyric)
            if not lyric or lyric.isspace():
                lyric = "no lyrics available"
            cleaned_lyrics.append(lyric)

        for i in tqdm(range(0, len(cleaned_lyrics), batch_size), desc="Processing lyrics"):
            batch = cleaned_lyrics[i : i + batch_size]

            # Ensure batch is a list of strings
            batch = [str(text) for text in batch]

            try:
                inputs = self.tokenizer(
                    batch,
                    padding=True,
                    truncation=True,
                    max_length=128,
                    return_tensors="pt",
                )

                # Move inputs to device
                inputs = {k: v.to(self.device) for k, v in inputs.items()}

                with torch.no_grad():
                    outputs = self.bert(**inputs)
                    batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
                    embeddings.extend(batch_embeddings)

            except Exception as e:
                print(f"Error processing batch: {e}")
                # Add zero embeddings for failed batch
                zero_embeddings = np.zeros((len(batch), 768))
                embeddings.extend(zero_embeddings)

        return np.array(embeddings)


    def get_cached_embeddings(self, lyrics_list):
        if os.path.exists(self.cache_file):
            print("Loading cached embeddings...")
            with open(self.cache_file, "rb") as f:
                cached_embeddings = pickle.load(f)
                if len(cached_embeddings) == len(lyrics_list):
                    return cached_embeddings
                print("Cache size mismatch. Recomputing embeddings...")

        print("Computing BERT embeddings...")
        embeddings = self.process_lyrics_batch(lyrics_list)

        print("Saving embeddings to cache...")
        with open(self.cache_file, "wb") as f:
            pickle.dump(embeddings, f)

        return embeddings


class FastSongRecommender:
    def __init__(self, songs_df):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Recommender using device: {self.device}")

        self.songs_df = songs_df
        self.preprocessor = None
        self.model = None
        self.lyric_processor = LyricProcessor()

    def preprocess_data(self):
        print("Starting data preprocessing...")

        numerical_features = [
            "len",
            "danceability",
            "loudness",
            "acousticness",
            "instrumentalness",
            "valence",
            "energy",
            "age",
            "dating",
            "violence",
            "world/life",
            "night/time",
            "shake the audience",
            "family/gospel",
            "romantic",
            "communication",
            "obscene",
            "music",
            "movement/places",
            "light/visual perceptions",
            "family/spiritual",
            "like/girls",
            "sadness",
            "feelings",
        ]

        categorical_features = ["genre", "topic"]

        # Clean and prepare lyrics
        lyrics_list = self.songs_df["lyrics"].fillna("").astype(str).tolist()

        # Convert lyrics to list of strings and clean them
        cleaned_lyrics = []
        for lyric in lyrics_list:
            # Clean and validate lyrics
            if isinstance(lyric, (float, int)):
                lyric = str(lyric)
            if not lyric or lyric.isspace():
                lyric = "no lyrics available"
            cleaned_lyrics.append(lyric)

        # Process lyrics
        lyrics_embeddings = self.lyric_processor.get_cached_embeddings(self.songs_df["lyrics"].fillna("").astype(str).tolist())

        # Create and fit preprocessor
        preprocessor = ColumnTransformer(
            transformers=[
                ("num", StandardScaler(), numerical_features),
                (
                    "cat",
                    OneHotEncoder(handle_unknown="ignore", sparse_output=False),
                    categorical_features,
                ),
            ],
            remainder="drop",
        )

        # Process features
        X = preprocessor.fit_transform(self.songs_df)
        self.preprocessor = preprocessor

        # Ensure matching lengths
        min_length = min(len(X), len(lyrics_embeddings))
        X = X[:min_length]
        lyrics_embeddings = lyrics_embeddings[:min_length]

        print(
            f"Preprocessed data shapes - Features: {X.shape}, "
            f"Lyrics embeddings: {lyrics_embeddings.shape}"
        )

        return X, lyrics_embeddings
    def save_model(self, model_path="model.pth", preprocessor_path="preprocessor.pkl"):
        """Save the trained model and preprocessor"""
        # Save model state
        if self.model is not None:
            torch.save(
                {
                    "model_state_dict": self.model.state_dict(),
                    "input_dim": self.model.network[0].in_features
                    - 768,  # Save input dimension for reconstruction
                },
                model_path,
            )
            print(f"Model saved to {model_path}")
        else:
            print("No model to save!")

        # Save preprocessor
        if self.preprocessor is not None:
            with open(preprocessor_path, "wb") as f:
                pickle.dump(self.preprocessor, f)
            print(f"Preprocessor saved to {preprocessor_path}")
        else:
            print("No preprocessor to save!")


    def load_model(self, model_path="model.pth", preprocessor_path="preprocessor.pkl"):
        """Load the trained model and preprocessor"""
        if os.path.exists(model_path) and os.path.exists(preprocessor_path):
            # Load preprocessor
            with open(preprocessor_path, "rb") as f:
                self.preprocessor = pickle.load(f)

            # Load model
            checkpoint = torch.load(model_path, map_location=self.device)
            input_dim = checkpoint["input_dim"]
            self.model = SimplerRecommenderNet(input_dim=input_dim).to(self.device)
            self.model.load_state_dict(checkpoint["model_state_dict"])
            self.model.eval()

            print("Model and preprocessor loaded successfully!")
            return True
        else:
            print("Model or preprocessor files not found!")
            return False



    def train(self, test_size=0.2, random_state=42, num_epochs=30):
        print("Starting training process...")

        # Preprocess data
        X, lyrics_embeddings = self.preprocess_data()

        # Split data
        X_train, X_test, lyrics_train, lyrics_test = train_test_split(
            X, lyrics_embeddings, test_size=test_size, random_state=random_state
        )

        # Create datasets
        train_dataset = EnhancedSongDataset(X_train, lyrics_train)
        test_dataset = EnhancedSongDataset(X_test, lyrics_test)

        # Create dataloaders
        train_loader = DataLoader(
            train_dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=4
        )

        test_loader = DataLoader(
            test_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=4
        )

        # Initialize model
        self.model = SimplerRecommenderNet(input_dim=X.shape[1]).to(self.device)

        # Setup training
        criterion = nn.MSELoss()
        optimizer = optim.AdamW(self.model.parameters(), lr=0.001, weight_decay=0.01)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.5, patience=3
        )
        early_stopping = EarlyStopping(patience=5)

        # Training loop
        history = {"train_loss": [], "val_loss": []}

        for epoch in range(num_epochs):
            # Training phase
            self.model.train()
            train_loss = 0.0

            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

            for (features, lyrics_embed), targets in progress_bar:
                features = features.to(self.device)
                lyrics_embed = lyrics_embed.to(self.device)
                targets = targets.to(self.device)

                optimizer.zero_grad()
                outputs = self.model(features, lyrics_embed)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                progress_bar.set_postfix({"train_loss": f"{loss.item():.4f}"})

            # Validation phase
            self.model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for (features, lyrics_embed), targets in test_loader:
                    features = features.to(self.device)
                    lyrics_embed = lyrics_embed.to(self.device)
                    targets = targets.to(self.device)

                    outputs = self.model(features, lyrics_embed)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item()

            # Calculate average losses
            train_loss /= len(train_loader)
            val_loss /= len(test_loader)

            # Update learning rate
            scheduler.step(val_loss)

            # Check early stopping
            early_stopping(val_loss)

            # Save history
            history["train_loss"].append(train_loss)
            history["val_loss"].append(val_loss)

            print(
                f"Epoch [{epoch+1}/{num_epochs}], "
                f"Train Loss: {train_loss:.4f}, "
                f"Val Loss: {val_loss:.4f}"
            )

            if early_stopping.early_stop:
                print("Early stopping triggered")
                break

        return history

    def get_song_index(self, song_name=None, artist_name=None):
        if song_name and artist_name:
            matches = self.songs_df[
                (self.songs_df["track_name"].str.contains(song_name, case=False))
                & (self.songs_df["artist_name"].str.contains(artist_name, case=False))
            ]
        elif song_name:
            matches = self.songs_df[
                self.songs_df["track_name"].str.contains(song_name, case=False)
            ]
        else:
            raise ValueError("Please provide at least a song name")

        if len(matches) == 0:
            raise ValueError("No matching songs found")

        print("\nMatching Songs:")
        print(matches[["artist_name", "track_name", "genre"]])

        return matches.index[0]

    def recommend_similar_songs(self, song_index, top_k=5):
        self.model.eval()
        X, lyrics_embeddings = self.preprocess_data()

        with torch.no_grad():
            features = torch.tensor(X, dtype=torch.float32).to(self.device)
            lyrics = torch.tensor(lyrics_embeddings, dtype=torch.float32).to(
                self.device
            )
            embeddings = self.model(features, lyrics).cpu().numpy()

        reference_embedding = embeddings[song_index]
        similarities = np.dot(embeddings, reference_embedding) / (
            np.linalg.norm(embeddings, axis=1) * np.linalg.norm(reference_embedding)
        )

        similar_indices = similarities.argsort()[::-1][1 : top_k + 1]
        recommendations = self.songs_df.iloc[similar_indices].copy()
        recommendations["similarity_score"] = similarities[similar_indices]

        return recommendations[
            ["artist_name", "track_name", "genre", "similarity_score"]
        ]

In [3]:
try:
    # Load dataset
    songs_df = pd.read_csv("cleaned_lyrics_data.csv")
    print(f"Loaded dataset with {len(songs_df)} songs")

    # Validate required columns
    required_columns = ["lyrics", "artist_name", "track_name", "genre", "topic"] + [
        "len",
        "danceability",
        "loudness",
        "acousticness",
        "instrumentalness",
        "valence",
        "energy",
        "age",
        "dating",
        "violence",
        "world/life",
        "night/time",
        "shake the audience",
        "family/gospel",
        "romantic",
        "communication",
        "obscene",
        "music",
        "movement/places",
        "light/visual perceptions",
        "family/spiritual",
        "like/girls",
        "sadness",
        "feelings",
    ]

    missing_columns = [col for col in required_columns if col not in songs_df.columns]
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")

    # Fill NaN values
    songs_df = songs_df.fillna({"lyrics": "", "genre": "unknown", "topic": "unknown"})

    # Initialize and train recommender
    recommender = FastSongRecommender(songs_df)
    history = recommender.train(num_epochs=30)
    recommender.save_model()

    # Get recommendations
    while True:
        song_name = input("\nEnter song name (or 'quit' to exit): ")
        if song_name.lower() == "quit":
            break

        try:
            reference_song_index = recommender.get_song_index(song_name=song_name)
            recommendations = recommender.recommend_similar_songs(reference_song_index)
            print("\nRecommended Songs:")
            print(recommendations)
        except ValueError as e:
            print(f"Error: {e}")
            continue

except Exception as e:
    print(f"An error occurred: {str(e)}")
    raise


Loaded dataset with 27498 songs
Recommender using device: cuda
LyricProcessor using device: cuda
Starting training process...
Starting data preprocessing...
Loading cached embeddings...
Cache size mismatch. Recomputing embeddings...
Computing BERT embeddings...


Processing lyrics: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 860/860 [05:37<00:00,  2.55it/s]


Saving embeddings to cache...
Preprocessed data shapes - Features: (27498, 39), Lyrics embeddings: (27498, 768)


Epoch 1/30: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 116.78it/s, train_loss=0.0638]


Epoch [1/30], Train Loss: 0.1478, Val Loss: 0.0267


Epoch 2/30: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 298.56it/s, train_loss=0.0494]


Epoch [2/30], Train Loss: 0.0562, Val Loss: 0.0216


Epoch 3/30: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 138.91it/s, train_loss=0.0466]


Epoch [3/30], Train Loss: 0.0474, Val Loss: 0.0169


Epoch 4/30: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 173.61it/s, train_loss=0.0344]


Epoch [4/30], Train Loss: 0.0424, Val Loss: 0.0158


Epoch 5/30: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 272.06it/s, train_loss=0.0362]


Epoch [5/30], Train Loss: 0.0389, Val Loss: 0.0143


Epoch 6/30: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 122.06it/s, train_loss=0.0343]


Epoch [6/30], Train Loss: 0.0361, Val Loss: 0.0129


Epoch 7/30: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 281.29it/s, train_loss=0.0309]


Epoch [7/30], Train Loss: 0.0339, Val Loss: 0.0134


Epoch 8/30: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 189.01it/s, train_loss=0.0364]


Epoch [8/30], Train Loss: 0.0327, Val Loss: 0.0129


Epoch 9/30: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 153.86it/s, train_loss=0.0312]


Epoch [9/30], Train Loss: 0.0313, Val Loss: 0.0123


Epoch 10/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 268.70it/s, train_loss=0.0280]


Epoch [10/30], Train Loss: 0.0303, Val Loss: 0.0122


Epoch 11/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 126.67it/s, train_loss=0.0289]


Epoch [11/30], Train Loss: 0.0294, Val Loss: 0.0119


Epoch 12/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 237.97it/s, train_loss=0.0256]


Epoch [12/30], Train Loss: 0.0282, Val Loss: 0.0116


Epoch 13/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 270.31it/s, train_loss=0.0286]


Epoch [13/30], Train Loss: 0.0276, Val Loss: 0.0105


Epoch 14/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 121.42it/s, train_loss=0.0275]


Epoch [14/30], Train Loss: 0.0269, Val Loss: 0.0122


Epoch 15/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 277.38it/s, train_loss=0.0240]


Epoch [15/30], Train Loss: 0.0265, Val Loss: 0.0107


Epoch 16/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 145.24it/s, train_loss=0.0241]


Epoch [16/30], Train Loss: 0.0259, Val Loss: 0.0100


Epoch 17/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 192.57it/s, train_loss=0.0229]


Epoch [17/30], Train Loss: 0.0256, Val Loss: 0.0102


Epoch 18/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 274.87it/s, train_loss=0.0274]


Epoch [18/30], Train Loss: 0.0252, Val Loss: 0.0109


Epoch 19/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 124.64it/s, train_loss=0.0212]


Epoch [19/30], Train Loss: 0.0245, Val Loss: 0.0099


Epoch 20/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 244.64it/s, train_loss=0.0238]


Epoch [20/30], Train Loss: 0.0241, Val Loss: 0.0093


Epoch 21/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 199.79it/s, train_loss=0.0246]


Epoch [21/30], Train Loss: 0.0237, Val Loss: 0.0105


Epoch 22/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 143.04it/s, train_loss=0.0239]


Epoch [22/30], Train Loss: 0.0234, Val Loss: 0.0097


Epoch 23/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 263.51it/s, train_loss=0.0232]


Epoch [23/30], Train Loss: 0.0229, Val Loss: 0.0091


Epoch 24/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 122.59it/s, train_loss=0.0226]


Epoch [24/30], Train Loss: 0.0227, Val Loss: 0.0099


Epoch 25/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 266.32it/s, train_loss=0.0238]


Epoch [25/30], Train Loss: 0.0225, Val Loss: 0.0096


Epoch 26/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 164.12it/s, train_loss=0.0204]


Epoch [26/30], Train Loss: 0.0222, Val Loss: 0.0092


Epoch 27/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 177.69it/s, train_loss=0.0217]


Epoch [27/30], Train Loss: 0.0215, Val Loss: 0.0092


Epoch 28/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 276.20it/s, train_loss=0.0184]


Epoch [28/30], Train Loss: 0.0190, Val Loss: 0.0069


Epoch 29/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:02<00:00, 124.54it/s, train_loss=0.0146]


Epoch [29/30], Train Loss: 0.0183, Val Loss: 0.0085


Epoch 30/30: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 344/344 [00:01<00:00, 274.95it/s, train_loss=0.0162]


Epoch [30/30], Train Loss: 0.0179, Val Loss: 0.0075
Model saved to model.pth
Preprocessor saved to preprocessor.pkl

Enter song name (or 'quit' to exit): country boyz

Matching Songs:
       artist_name    track_name    genre
27491  nappy roots  country boyz  hip hop
Starting data preprocessing...
Loading cached embeddings...
Preprocessed data shapes - Features: (27498, 39), Lyrics embeddings: (27498, 768)

Recommended Songs:
        artist_name                              track_name    genre  \
27493       mack 10                         10 million ways  hip hop   
22188     mr. vegas                            party tun up   reggae   
27400   nappy roots                                  sholiz  hip hop   
4614   citizen king  better days (and the bottom drops out)      pop   
27386   nappy roots                            kentucky mud  hip hop   

       similarity_score  
27493          0.977242  
22188          0.955152  
27400          0.951693  
4614           0.950117  
27386  