In [1]:
import pandas as pd
import torch
import transformers

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertModel, DistilBertTokenizer

In [2]:
from torch import cuda

device = "cuda" if cuda.is_available() else "cpu"

In [3]:
MAX_LEN = 512
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 2
EPOCHS = 10
LEARNING_RATE = 1e-05

In [4]:
audio_features = [
    "acousticness",
    "danceability",
    "energy",
    "instrumentalness",
    "key",
    "liveness",
    "loudness",
    "mode",
    "speechiness",
    "tempo",
    "valence",
]

In [5]:
class DistilBERTClass(torch.nn.Module):
    def __init__(self):
        super(DistilBERTClass, self).__init__()
        self.l1 = DistilBertModel.from_pretrained("distilbert-base-cased")
        self.pre_classifier = torch.nn.Linear(768 + len(audio_features), 768)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(768, 5)

    def forward(self, input_ids, attention_mask, token_type_ids, audio):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = torch.cat((pooler, audio), 1)
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

    def freeze_bert_encoder(self):
        for param in self.l1.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.l1.parameters():
            param.requires_grad = True

In [6]:
output_model_file = "model/distilbert_uncleaned_lyrics_audio.bin"
output_vocab_file = "model/distilbert_uncleaned_lyrics_audio_vocab.bin"

model = torch.load(output_model_file)
tokenizer = DistilBertTokenizer.from_pretrained(output_vocab_file)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'DistilBertTokenizer'.


In [7]:
def predict_genre(model, tokenizer, lyrics, audio_features):
    model.eval()
    inputs = tokenizer.encode_plus(
        lyrics,
        None,
        add_special_tokens=True,
        max_length=MAX_LEN,
        padding="max_length",
        return_token_type_ids=True,
        truncation=True,
    )

    ids = inputs["input_ids"]
    mask = inputs["attention_mask"]
    token_type_ids = inputs["token_type_ids"]

    with torch.no_grad():
        outputs = model(
            torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device),
            torch.tensor(mask, dtype=torch.long).unsqueeze(0).to(device),
            torch.tensor(token_type_ids, dtype=torch.long).unsqueeze(0).to(device),
            torch.tensor(audio_features, dtype=torch.float).unsqueeze(0).to(device),
        )
        big_val, big_idx = torch.max(outputs.data, dim=1)

    return big_idx.item()

In [8]:
genre_mapping = {0: "edm", 1: "r&b", 2: "rap", 3: "rock", 4: "pop"}

In [9]:
lyrics = """NA In the shuffling madness Of the locomotive"""
audio_features = {
    "acousticness": 0.417,
    "danceability": 0.680,
    "energy": 0.530,
    "instrumentalness": 0.0110,
    "key": 11,
    "liveness": 0.0559,
    "loudness": -13.105,
    "mode": 0,
    "speechiness": 0.0889,
    "tempo": 124.551,
    "valence": 0.352,
}

genre = predict_genre(model, tokenizer, lyrics, list(audio_features.values()))
print("Predicted Genre:", genre_mapping[genre])

Predicted Genre: rock
