In [None]:
!pip install -q bitsandbytes accelerate transformers datasets peft sentencepiece huggingface_hub streamlit pyngrok

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m28.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m101.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m107.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
!pip install -q bert-score

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import re
from collections import Counter
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
    EarlyStoppingCallback
)
from bert_score import BERTScorer

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
class BiLSTM_GRU_Classifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes, padding_idx=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.bilstm    = nn.LSTM(embed_dim, hidden_dim,
                                 bidirectional=True, batch_first=True)
        self.gru       = nn.GRU(2*hidden_dim, hidden_dim, batch_first=True)
        self.classifier= nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        emb, _ = self.bilstm(self.embedding(x))
        out, _ = self.gru(emb)
        feat    = out.mean(dim=1)
        return self.classifier(feat)

In [None]:
df = pd.read_csv("/content/drive/MyDrive/NLP/Datasets/lyrics_song_info.csv")
df = df.dropna(subset=["lyrics", "tags"])
df["primary_genre"] = df["tags"].str.split(";").str[0]

TOP_N = 100
top_genres = df["primary_genre"].value_counts().nlargest(TOP_N).index
df = df[df["primary_genre"].isin(top_genres)].reset_index(drop=True)

def tokenize(text):
    return re.findall(r"\w+'?\w*|[.,!?;]", text.lower())

counter = Counter()
for lyric in df["lyrics"]:
    counter.update(tokenize(lyric))
vocab = {tok: i+2 for i, (tok, _) in enumerate(counter.most_common(20_000))}
vocab["<pad>"] = 0
vocab["<unk>"] = 1

In [None]:
labels = sorted(df["primary_genre"].unique())
label2id = {lab: i for i, lab in enumerate(labels)}
id2label = {i: lab for i, lab in enumerate(labels)}

print(f"Working with {len(labels)} music genres: {', '.join(labels)}")

Working with 100 music genres: Acoustic, Adult Alternative, Alternative, Alternative Metal, Alternative Pop, Alternative R&B, Alternative Rock, Art, Art Pop, Art Rock, Atlanta, Australia, Ballad, Baroque Pop, Beef, België/Belgique, Boom Bap, Boy Band, British Rock, Canada, Chicago Drill, Christian, Christian Rap, Cloud Rap, Conscious Hip-Hop, Country, Cover, DMV, Dance, Dance-Pop, Dancehall, Demo, Dirty South, Drill, East Coast, Eighties, Electro-Pop, Electronic, Electronic Rock, Emo, Emo Rap, Experimental, France, Freestyle, French Rap, Funk, G-Funk, Gangsta Rap, Glam Rock, Hard Rock, Hardcore Hip-Hop, Heavy Metal, History, Horrorcore, Indie, Indie Pop, Indie Rap, Indie Rock, Industrial Rock, Interlude, Ireland, Memes, Motown, Musicals, Neo Soul, Non-Music, Piano, Polska, Polska Muzyka, Polski Rap, Pop, Pop-Punk, Pop-Rock, Post-Hardcore, Power Pop, Producer, Progressive Rock, Psychedelic, Punk Rock, R&B, Rap, Rap Rock, Remix, Rock, Screen, Singer-Songwriter, Skit, Soul, Soul Pop, Soun

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bilstm_gru_model = BiLSTM_GRU_Classifier(
    vocab_size=len(vocab),
    embed_dim=128,
    hidden_dim=256,
    num_classes=len(labels)
).to(device)

In [None]:
bilstm_gru_model.load_state_dict(torch.load("/content/drive/MyDrive/NLP/bilstm_gru_classifier_100.pt",
                                           map_location=device))
bilstm_gru_model.eval()

BiLSTM_GRU_Classifier(
  (embedding): Embedding(20002, 128, padding_idx=0)
  (bilstm): LSTM(128, 256, batch_first=True, bidirectional=True)
  (gru): GRU(512, 256, batch_first=True)
  (classifier): Linear(in_features=256, out_features=100, bias=True)
)

In [None]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
model_name = "FacebookAI/roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(
    df, test_size=0.1, stratify=df["primary_genre"], random_state=42
)

print(f"Train set: {len(train_df)} samples")
print(f"Validation set: {len(val_df)} samples")

Train set: 32055 samples
Validation set: 3562 samples


In [None]:
def get_bilstm_predictions(model, texts):
    model.eval()

    predictions = []
    confidences = []

    for text in tqdm(texts, desc="Getting BiLSTM-GRU predictions"):
        tokens = tokenize(text)[:512]
        ids = [vocab.get(t, vocab["<unk>"]) for t in tokens]

        if len(ids) == 0:
            ids = [vocab["<unk>"]]

        input_tensor = torch.tensor([ids], dtype=torch.long).to(device)

        with torch.no_grad():
            logits = model(input_tensor)
            probs = torch.softmax(logits, dim=1)
            pred_idx = torch.argmax(probs, dim=1).item()
            confidence = probs[0][pred_idx].item()

        predictions.append(pred_idx)
        confidences.append(confidence)

    return predictions, confidences


In [None]:
def prepare_dataset(df, include_predictions=True):
    if include_predictions:
        bilstm_preds, confidences = get_bilstm_predictions(bilstm_gru_model, df["lyrics"].tolist())

        dataset_dict = {
            "lyrics": df["lyrics"].tolist(),
            "original_genre": df["primary_genre"].tolist(),
            "original_label_id": [label2id[genre] for genre in df["primary_genre"]],
            "bilstm_prediction": [id2label[pred] for pred in bilstm_preds],
            "bilstm_pred_id": bilstm_preds,
            "confidence": confidences
        }
    else:
        dataset_dict = {
            "lyrics": df["lyrics"].tolist(),
            "original_genre": df["primary_genre"].tolist(),
            "original_label_id": [label2id[genre] for genre in df["primary_genre"]],
        }

    return Dataset.from_dict(dataset_dict)

In [None]:
train_dataset = prepare_dataset(train_df, include_predictions=True)
val_dataset = prepare_dataset(val_df, include_predictions=True)

Getting BiLSTM-GRU predictions: 100%|██████████| 32055/32055 [07:25<00:00, 71.88it/s]
Getting BiLSTM-GRU predictions: 100%|██████████| 3562/3562 [00:49<00:00, 71.86it/s]


In [None]:
def preprocess_function(examples):
    tokenized_inputs = tokenizer(
        examples["lyrics"],
        truncation=True,
        max_length=512,
        padding=False
    )

    tokenized_inputs["labels"] = examples["original_label_id"]

    return tokenized_inputs

In [None]:
tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Map:   0%|          | 0/32055 [00:00<?, ? examples/s]

Map:   0%|          | 0/3562 [00:00<?, ? examples/s]

In [None]:
import transformers
print(transformers.__version__)


4.51.3


In [None]:
import inspect
from transformers import TrainingArguments
print(inspect.signature(TrainingArguments))




In [None]:
training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/NLP/RoBerta/roberta-lstm/results",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    metric_for_best_model="accuracy",
    warmup_steps=500,
    logging_dir="/content/drive/MyDrive/NLP/RoBerta/roberta-lstm/logs",
    logging_steps=100,
)


In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    accuracy = accuracy_score(labels, predictions)
    report = classification_report(labels, predictions, target_names=list(label2id.keys()), output_dict=True)

    results = {
        "accuracy": accuracy,
    }

    for genre, genre_id in label2id.items():
        if genre in report:
            results[f"{genre}_precision"] = report[genre]["precision"]
            results[f"{genre}_recall"] = report[genre]["recall"]
            results[f"{genre}_f1"] = report[genre]["f1-score"]

    results["macro_precision"] = report["macro avg"]["precision"]
    results["macro_recall"] = report["macro avg"]["recall"]
    results["macro_f1"] = report["macro avg"]["f1-score"]
    results["weighted_f1"] = report["weighted avg"]["f1-score"]

    return results

early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=3)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback],
)

print("Starting model training...")
trainer.train()

  trainer = Trainer(


Starting model training...


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mparames3[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,Acoustic Precision,Acoustic Recall,Acoustic F1,Adult alternative Precision,Adult alternative Recall,Adult alternative F1,Alternative Precision,Alternative Recall,Alternative F1,Alternative metal Precision,Alternative metal Recall,Alternative metal F1,Alternative pop Precision,Alternative pop Recall,Alternative pop F1,Alternative r&b Precision,Alternative r&b Recall,Alternative r&b F1,Alternative rock Precision,Alternative rock Recall,Alternative rock F1,Art Precision,Art Recall,Art F1,Art pop Precision,Art pop Recall,Art pop F1,Art rock Precision,Art rock Recall,Art rock F1,Atlanta Precision,Atlanta Recall,Atlanta F1,Australia Precision,Australia Recall,Australia F1,Ballad Precision,Ballad Recall,Ballad F1,Baroque pop Precision,Baroque pop Recall,Baroque pop F1,Beef Precision,Beef Recall,Beef F1,België/belgique Precision,België/belgique Recall,België/belgique F1,Boom bap Precision,Boom bap Recall,Boom bap F1,Boy band Precision,Boy band Recall,Boy band F1,British rock Precision,British rock Recall,British rock F1,Canada Precision,Canada Recall,Canada F1,Chicago drill Precision,Chicago drill Recall,Chicago drill F1,Christian Precision,Christian Recall,Christian F1,Christian rap Precision,Christian rap Recall,Christian rap F1,Cloud rap Precision,Cloud rap Recall,Cloud rap F1,Conscious hip-hop Precision,Conscious hip-hop Recall,Conscious hip-hop F1,Country Precision,Country Recall,Country F1,Cover Precision,Cover Recall,Cover F1,Dmv Precision,Dmv Recall,Dmv F1,Dance Precision,Dance Recall,Dance F1,Dance-pop Precision,Dance-pop Recall,Dance-pop F1,Dancehall Precision,Dancehall Recall,Dancehall F1,Demo Precision,Demo Recall,Demo F1,Dirty south Precision,Dirty south Recall,Dirty south F1,Drill Precision,Drill Recall,Drill F1,East coast Precision,East coast Recall,East coast F1,Eighties Precision,Eighties Recall,Eighties F1,Electro-pop Precision,Electro-pop Recall,Electro-pop F1,Electronic Precision,Electronic Recall,Electronic F1,Electronic rock Precision,Electronic rock Recall,Electronic rock F1,Emo Precision,Emo Recall,Emo F1,Emo rap Precision,Emo rap Recall,Emo rap F1,Experimental Precision,Experimental Recall,Experimental F1,France Precision,France Recall,France F1,Freestyle Precision,Freestyle Recall,Freestyle F1,French rap Precision,French rap Recall,French rap F1,Funk Precision,Funk Recall,Funk F1,G-funk Precision,G-funk Recall,G-funk F1,Gangsta rap Precision,Gangsta rap Recall,Gangsta rap F1,Glam rock Precision,Glam rock Recall,Glam rock F1,Hard rock Precision,Hard rock Recall,Hard rock F1,Hardcore hip-hop Precision,Hardcore hip-hop Recall,Hardcore hip-hop F1,Heavy metal Precision,Heavy metal Recall,Heavy metal F1,History Precision,History Recall,History F1,Horrorcore Precision,Horrorcore Recall,Horrorcore F1,Indie Precision,Indie Recall,Indie F1,Indie pop Precision,Indie pop Recall,Indie pop F1,Indie rap Precision,Indie rap Recall,Indie rap F1,Indie rock Precision,Indie rock Recall,Indie rock F1,Industrial rock Precision,Industrial rock Recall,Industrial rock F1,Interlude Precision,Interlude Recall,Interlude F1,Ireland Precision,Ireland Recall,Ireland F1,Memes Precision,Memes Recall,Memes F1,Motown Precision,Motown Recall,Motown F1,Musicals Precision,Musicals Recall,Musicals F1,Neo soul Precision,Neo soul Recall,Neo soul F1,Non-music Precision,Non-music Recall,Non-music F1,Piano Precision,Piano Recall,Piano F1,Polska Precision,Polska Recall,Polska F1,Polska muzyka Precision,Polska muzyka Recall,Polska muzyka F1,Polski rap Precision,Polski rap Recall,Polski rap F1,Pop Precision,Pop Recall,Pop F1,Pop-punk Precision,Pop-punk Recall,Pop-punk F1,Pop-rock Precision,Pop-rock Recall,Pop-rock F1,Post-hardcore Precision,Post-hardcore Recall,Post-hardcore F1,Power pop Precision,Power pop Recall,Power pop F1,Producer Precision,Producer Recall,Producer F1,Progressive rock Precision,Progressive rock Recall,Progressive rock F1,Psychedelic Precision,Psychedelic Recall,Psychedelic F1,Punk rock Precision,Punk rock Recall,Punk rock F1,R&b Precision,R&b Recall,R&b F1,Rap Precision,Rap Recall,Rap F1,Rap rock Precision,Rap rock Recall,Rap rock F1,Remix Precision,Remix Recall,Remix F1,Rock Precision,Rock Recall,Rock F1,Screen Precision,Screen Recall,Screen F1,Singer-songwriter Precision,Singer-songwriter Recall,Singer-songwriter F1,Skit Precision,Skit Recall,Skit F1,Soul Precision,Soul Recall,Soul F1,Soul pop Precision,Soul pop Recall,Soul pop F1,Soundtrack Precision,Soundtrack Recall,Soundtrack F1,Spoken word Precision,Spoken word Recall,Spoken word F1,Synth-pop Precision,Synth-pop Recall,Synth-pop F1,Trap Precision,Trap Recall,Trap F1,Uk Precision,Uk Recall,Uk F1,Uk rap Precision,Uk rap Recall,Uk rap F1,Underground hip-hop Precision,Underground hip-hop Recall,Underground hip-hop F1,Unreleased Precision,Unreleased Recall,Unreleased F1,West coast Precision,West coast Recall,West coast F1,Youtube Precision,Youtube Recall,Youtube F1,Русский рэп (russian rap) Precision,Русский рэп (russian rap) Recall,Русский рэп (russian rap) F1,Macro Precision,Macro Recall,Macro F1,Weighted F1
1,2.0966,1.983006,0.541269,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.383721,0.458333,0.417722,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.6,0.857143,0.705882,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.520833,0.943396,0.671141,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.55102,0.675,0.606742,0.0,0.0,0.0,0.0,0.0,0.0,0.461538,1.0,0.631579,0.0,0.0,0.0,0.191919,0.322034,0.240506,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.295515,0.506787,0.373333,0.662593,0.908982,0.766473,0.0,0.0,0.0,0.0,0.0,0.0,0.216814,0.662162,0.326667,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.308271,0.226519,0.261146,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5625,0.141732,0.226415,0.0,0.0,0.0,0.666667,1.0,0.8,0.054214,0.077021,0.060276,0.452759
2,1.8696,1.85806,0.54941,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.444444,0.571429,0.5,0.0,0.0,0.0,0.0,0.0,0.0,0.589286,0.458333,0.515625,0.714286,0.416667,0.526316,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.6,0.857143,0.705882,0.0,0.0,0.0,0.5,0.254902,0.337662,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.535714,0.211268,0.30303,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.461538,0.285714,0.352941,0.0,0.0,0.0,0.550562,0.924528,0.690141,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.444444,0.666667,0.533333,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.553191,0.65,0.597701,0.0,0.0,0.0,0.0,0.0,0.0,0.461538,1.0,0.631579,0.0,0.0,0.0,0.189189,0.355932,0.247059,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.34626,0.565611,0.429553,0.694712,0.865269,0.770667,0.0,0.0,0.0,0.0,0.0,0.0,0.177632,0.72973,0.285714,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.323529,0.18232,0.233216,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.442029,0.480315,0.460377,0.0,0.0,0.0,0.705882,1.0,0.827586,0.087342,0.104758,0.089484,0.483458
3,1.7851,1.791081,0.555867,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.117647,0.095238,0.105263,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.507692,0.458333,0.481752,0.625,0.416667,0.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.636364,1.0,0.777778,0.0,0.0,0.0,0.545455,0.352941,0.428571,0.0,0.0,0.0,0.4,0.307692,0.347826,0.0,0.0,0.0,1.0,0.088235,0.162162,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.318182,0.388889,0.35,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.535714,0.211268,0.30303,0.0,0.0,0.0,0.0,0.0,0.0,0.1,0.034483,0.051282,0.0,0.0,0.0,0.0,0.0,0.0,0.291667,0.333333,0.311111,0.0,0.0,0.0,0.561798,0.943396,0.704225,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.454545,0.833333,0.588235,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.333333,0.166667,0.222222,0.0,0.0,0.0,0.625,0.625,0.625,0.0,0.0,0.0,0.0,0.0,0.0,0.461538,1.0,0.631579,0.0,0.0,0.0,0.194139,0.449153,0.2711,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.311573,0.475113,0.376344,0.710991,0.867665,0.781553,0.0,0.0,0.0,0.0,0.0,0.0,0.247312,0.621622,0.353846,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.312883,0.281768,0.296512,0.5,0.027778,0.052632,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.509091,0.440945,0.472574,0.0,0.0,0.0,0.705882,1.0,0.827586,0.110058,0.114195,0.100222,0.497703
4,1.4932,1.81272,0.552779,0.0,0.0,0.0,0.0,0.0,0.0,0.066667,0.037037,0.047619,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.454545,0.714286,0.555556,0.0,0.0,0.0,0.0,0.0,0.0,0.446809,0.583333,0.506024,0.5,0.416667,0.454545,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.583333,1.0,0.736842,0.0,0.0,0.0,0.465116,0.392157,0.425532,0.0,0.0,0.0,0.4,0.307692,0.347826,0.0,0.0,0.0,0.75,0.088235,0.157895,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.222222,0.333333,0.266667,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.222222,0.190476,0.205128,0.333333,0.267606,0.296875,0.0,0.0,0.0,0.0,0.0,0.0,0.25,0.172414,0.204082,0.0,0.0,0.0,0.0,0.0,0.0,0.307692,0.571429,0.4,0.0,0.0,0.0,0.568182,0.943396,0.70922,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.333333,0.5,0.555556,0.833333,0.666667,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.142857,0.25,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.6,0.5,0.545455,0.0,0.0,0.0,0.564103,0.55,0.556962,0.0,0.0,0.0,0.0,0.0,0.0,0.461538,1.0,0.631579,0.0,0.0,0.0,0.210256,0.347458,0.261981,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.34507,0.443439,0.388119,0.750136,0.828743,0.787482,0.0,0.0,0.0,0.0,0.0,0.0,0.254335,0.594595,0.356275,0.0,0.0,0.0,0.150685,0.34375,0.209524,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.311966,0.403315,0.351807,0.09375,0.083333,0.088235,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5,0.503937,0.501961,0.0,0.0,0.0,0.705882,1.0,0.827586,0.130734,0.139262,0.122374,0.511399
5,1.3906,1.833924,0.55306,0.0,0.0,0.0,0.0,0.0,0.0,0.071429,0.037037,0.04878,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5,0.571429,0.533333,0.0,0.0,0.0,0.0,0.0,0.0,0.473118,0.611111,0.533333,0.625,0.416667,0.5,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.7,1.0,0.823529,0.0,0.0,0.0,0.529412,0.352941,0.423529,0.0,0.0,0.0,0.333333,0.307692,0.32,0.0,0.0,0.0,0.333333,0.088235,0.139535,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.32,0.444444,0.372093,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.222222,0.285714,0.25,0.357143,0.28169,0.314961,0.0,0.0,0.0,0.0,0.0,0.0,0.227273,0.172414,0.196078,0.0,0.0,0.0,0.0,0.0,0.0,0.266667,0.571429,0.363636,0.0,0.0,0.0,0.581395,0.943396,0.719424,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.333333,0.5,0.555556,0.833333,0.666667,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.142857,0.25,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5,0.333333,0.4,0.0,0.0,0.0,0.55814,0.6,0.578313,0.0,0.0,0.0,0.0,0.0,0.0,0.44,0.916667,0.594595,0.0,0.0,0.0,0.205607,0.372881,0.26506,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.336449,0.488688,0.398524,0.753719,0.819162,0.785079,0.0,0.0,0.0,0.0,0.0,0.0,0.238806,0.648649,0.349091,0.0,0.0,0.0,0.166667,0.25,0.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.331776,0.392265,0.359494,0.142857,0.083333,0.105263,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.474453,0.511811,0.492424,0.0,0.0,0.0,0.705882,1.0,0.827586,0.129502,0.138105,0.123103,0.512443


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize

TrainOutput(global_step=10020, training_loss=1.7786762446938398, metrics={'train_runtime': 3330.8747, 'train_samples_per_second': 48.118, 'train_steps_per_second': 3.008, 'total_flos': 4.22072299938816e+16, 'train_loss': 1.7786762446938398, 'epoch': 5.0})

In [None]:
print("\nPerforming evaluation...")
eval_results = trainer.evaluate()

print("\nEvaluation Results:")
for metric, value in eval_results.items():
    print(f"{metric}: {value:.4f}")



Performing evaluation...



Evaluation Results:
eval_loss: 1.7911
eval_accuracy: 0.5559
eval_Acoustic_precision: 0.0000
eval_Acoustic_recall: 0.0000
eval_Acoustic_f1: 0.0000
eval_Adult Alternative_precision: 0.0000
eval_Adult Alternative_recall: 0.0000
eval_Adult Alternative_f1: 0.0000
eval_Alternative_precision: 0.0000
eval_Alternative_recall: 0.0000
eval_Alternative_f1: 0.0000
eval_Alternative Metal_precision: 0.0000
eval_Alternative Metal_recall: 0.0000
eval_Alternative Metal_f1: 0.0000
eval_Alternative Pop_precision: 0.0000
eval_Alternative Pop_recall: 0.0000
eval_Alternative Pop_f1: 0.0000
eval_Alternative R&B_precision: 0.0000
eval_Alternative R&B_recall: 0.0000
eval_Alternative R&B_f1: 0.0000
eval_Alternative Rock_precision: 0.1176
eval_Alternative Rock_recall: 0.0952
eval_Alternative Rock_f1: 0.1053
eval_Art_precision: 0.0000
eval_Art_recall: 0.0000
eval_Art_f1: 0.0000
eval_Art Pop_precision: 0.0000
eval_Art Pop_recall: 0.0000
eval_Art Pop_f1: 0.0000
eval_Art Rock_precision: 0.0000
eval_Art Rock_recall: 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
predictions = trainer.predict(tokenized_val_dataset)
logits = predictions.predictions
preds = np.argmax(predictions.predictions, axis=1)
labels = predictions.label_ids

probs = torch.softmax(torch.tensor(logits), dim=1)
confidences = probs[range(len(probs)), preds].numpy()

true_genres = [id2label[i] for i in labels]
pred_genres = [id2label[i] for i in preds]

print("\nDetailed Classification Report:")
print(classification_report(labels, preds, target_names=list(label2id.keys())))

conf_matrix = confusion_matrix(labels, preds)
print("\nConfusion Matrix:")
print(conf_matrix)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Detailed Classification Report:
                           precision    recall  f1-score   support

                 Acoustic       0.00      0.00      0.00         4
        Adult Alternative       0.00      0.00      0.00         5
              Alternative       0.00      0.00      0.00        27
        Alternative Metal       0.00      0.00      0.00         5
          Alternative Pop       0.00      0.00      0.00        15
          Alternative R&B       0.00      0.00      0.00        13
         Alternative Rock       0.12      0.10      0.11        21
                      Art       0.00      0.00      0.00         7
                  Art Pop       0.00      0.00      0.00         4
                 Art Rock       0.00      0.00      0.00         9
                  Atlanta       0.51      0.46      0.48        72
                Australia       0.62      0.42      0.50        12
                   Ballad       0.00      0.00      0.00         8
              Baroque Pop   

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
def calculate_bertscore(references, predictions):
    scorer = BERTScorer(lang="en", rescale_with_baseline=True)
    P, R, F1 = scorer.score(predictions, references)
    return torch.mean(P).item(), torch.mean(R).item(), torch.mean(F1).item()

In [None]:
pred_genres = [id2label[pred] for pred in preds]
true_genres = [id2label[label] for label in labels]

try:
    precision, recall, f1 = calculate_bertscore(true_genres, pred_genres)
    print(f"\nBERTScore:")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1: {f1:.4f}")
except Exception as e:
    print(f"Error calculating BERTScore: {e}")


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



BERTScore:
Precision: 0.7653
Recall: 0.7191
F1: 0.7363


In [None]:
model_save_path = "/content/drive/MyDrive/NLP/RoBerta/roberta_lyrics_classifier2"
trainer.save_model(model_save_path)
tokenizer.save_pretrained(model_save_path)
print(f"\nModel saved to {model_save_path}")


Model saved to /content/drive/MyDrive/NLP/RoBerta/roberta_lyrics_classifier2
