In [1]:
import torch
import spacy
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModel, AutoTokenizer, AutoConfig, PreTrainedModel

In [2]:
GENRE_MAPPING = {"pop": 0, "rap": 1, "rock": 2, "r&b": 3, "edm": 4}
AUDIO_FEATURES = [
    "acousticness",
    "danceability",
    "energy",
    "instrumentalness",
    "key",
    "liveness",
    "loudness",
    "mode",
    "speechiness",
    "tempo",
    "valence",
]

In [3]:
class LyricsAudioModelInference:
    def __init__(self, model_name, num_labels=5):
        self.model = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.num_labels = num_labels
        self.classifier = nn.Linear(
            self.model.config.hidden_size + len(AUDIO_FEATURES), num_labels
        )

        self.nlp = spacy.load("en_core_web_sm")

    def predict_genre(self, lyrics: str, audio_features: dict) -> str:
        # Tokenize the lyrics using the same tokenizer as before
        with torch.no_grad():
            lyrics = self._preprocess_lyrics(lyrics)

            inputs = self.tokenizer.encode_plus(
                lyrics,
                None,
                return_tensors="pt",
                padding="max_length",
                max_length=128,
                truncation=True,
            )
            inputs_id = inputs["input_ids"]
            attention_mask = inputs["attention_mask"]
            token_type_ids = inputs["token_type_ids"]

            audio_features = [audio_features[feature] for feature in AUDIO_FEATURES]

            outputs = self.model(
                input_ids=inputs_id,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
            )

            lyrics_embedding = outputs.last_hidden_state.mean(dim=1)
            audio_features = torch.tensor(audio_features).float().unsqueeze(0)
            combined_features = torch.cat([lyrics_embedding, audio_features], dim=1)
            logits = self.classifier(combined_features)
            probabilities = F.softmax(logits, dim=1)
            predicted_genre = torch.argmax(probabilities, dim=1).item()
            for genre, index in GENRE_MAPPING.items():
                if index == predicted_genre:
                    return genre

        return "Unknown"

    def _preprocess_lyrics(self, text):
        doc = self.nlp(text)
        processed_text = " ".join(
            [
                token.lemma_.lower().strip()
                for token in doc
                if not token.is_stop and token.lemma_.isalpha() and not token.is_punct
            ]
        )
        return processed_text

In [4]:
REPO_NAME = "PunGrumpy/music-genre-classification"
model_inference = LyricsAudioModelInference(REPO_NAME)

Some weights of BertModel were not initialized from the model checkpoint at PunGrumpy/music-genre-classification and are newly initialized: ['embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.dense.bias

In [5]:
lyrics = """It might seem crazy what I am 'bout to say
Sunshine, she's here, you can take a break
I'm a hot air balloon that could go to space
With the air, like I don't care, baby by the way"""

audio_features = {
    "acousticness": 0.15,
    "danceability": 0.85,
    "energy": 0.88,
    "instrumentalness": 0.03,
    "key": 7,
    "liveness": 0.12,
    "loudness": -6.3,
    "mode": 1,
    "speechiness": 0.05,
    "tempo": 160.0,
    "valence": 0.96,
}

In [6]:
predicted_genre = model_inference.predict_genre(lyrics, audio_features)
print(f"Predicted Genre: {predicted_genre}")

Predicted Genre: r&b
