In [4]:
# %pip install evaluate
# %pip install datasets
# %pip install accelerate -U
# %pip install transformers[torch]

In [1]:
# train a machine learning model to predict the genre of a song based on its lyrics with pytorch
import torch
import numpy as np
import evaluate
from datasets import load_dataset
from datasets import load_metric
from datasets import Dataset
from datasets import DatasetDict
import datasets
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

  from .autonotebook import tqdm as notebook_tqdm


### Setting a device depending on whats available
* ´cuda´ for GPU
* ´cpu´ for CPU
* ´mps´ for Apple silicon

In [2]:
# set random seeds to make sure results are reproducible
SEED = 42
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# set device to cuda or mps if available

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(device)

cuda


### Split data in train, test and validation

In [3]:
# load the dataset
raw_datasets = load_dataset("csv", data_files="data.csv", split="train")

# change column name lyrics to sentence and playlist_genre to label
raw_datasets = raw_datasets.rename_column("lyrics", "sentence")
raw_datasets = raw_datasets.rename_column("playlist_genre", "label")

# Split the dataset into train, validation, and test sets
train_testvalid = raw_datasets.train_test_split(test_size=0.2, seed=SEED)
test_valid = train_testvalid["test"].train_test_split(test_size=0.5, seed=SEED)

# Assign the resulting datasets to variables
train_dataset = train_testvalid["train"]
valid_dataset = test_valid["train"]
test_dataset = test_valid["test"]

# Now you have train_dataset, valid_dataset, and test_dataset
raw_datasets = {
    "train": train_dataset,
    "validation": valid_dataset,
    "test": test_dataset,
}
raw_datasets = datasets.DatasetDict(raw_datasets)
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['label', 'sentence'],
        num_rows: 155966
    })
    validation: Dataset({
        features: ['label', 'sentence'],
        num_rows: 19496
    })
    test: Dataset({
        features: ['label', 'sentence'],
        num_rows: 19496
    })
})

### Tokenize the data

In [4]:
model_name = "distilbert-base-uncased"

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=6)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)


# Function to tokenize a single row
def tokenize_function(row):
    try:
        # Ensure that 'lyrics' is always a string. Replace non-strings with a placeholder.
        lyrics = [
            str(lyric) if isinstance(lyric, str) else "" for lyric in row["sentence"]
        ]
        return tokenizer(lyrics, truncation=True, padding="max_length", max_length=512)
    except Exception as e:
        print(f"Error tokenizing row: {row}")
        print(f"Exception: {e}")
        return None


# Apply tokenization to each subset
tokenized_train = raw_datasets["train"].map(tokenize_function, batched=True)
tokenized_validation = raw_datasets["validation"].map(tokenize_function, batched=True)
tokenized_test = raw_datasets["test"].map(tokenize_function, batched=True)

# Combine back into a DatasetDict
tokenized_datasets = DatasetDict(
    {
        "train": tokenized_train,
        "validation": tokenized_validation,
        "test": tokenized_test,
    }
)

# Check the result
tokenized_datasets

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DatasetDict({
    train: Dataset({
        features: ['label', 'sentence', 'input_ids', 'attention_mask'],
        num_rows: 155966
    })
    validation: Dataset({
        features: ['label', 'sentence', 'input_ids', 'attention_mask'],
        num_rows: 19496
    })
    test: Dataset({
        features: ['label', 'sentence', 'input_ids', 'attention_mask'],
        num_rows: 19496
    })
})

In [5]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

### Implement evaluation metrics

In [6]:
# Load the metric function
accuracy = load_metric("accuracy")
precision = load_metric("precision")
recall = load_metric("recall")
f1 = load_metric("f1")


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    # Calculate accuracy
    acc = accuracy_score(labels, preds)

    # Calculate precision, recall, and f1
    prec, rec, f1_score, _ = precision_recall_fscore_support(
        labels, preds, average="weighted"
    )

    return {
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "f1": f1_score,
    }

  accuracy = load_metric("accuracy")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


### Start training the model

In [7]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    num_train_epochs=12,
    weight_decay=0.01,
)

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
# train from latest checkpoint
trainer.train(resume_from_checkpoint=True)

  0%|          | 0/116976 [00:00<?, ?it/s]You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
 78%|███████▊  | 91500/116976 [01:50<13:13, 32.11it/s]   

{'loss': 0.084, 'learning_rate': 4.355765285186706e-06, 'epoch': 9.39}


 79%|███████▊  | 92000/116976 [03:38<1:29:20,  4.66it/s]

{'loss': 0.0879, 'learning_rate': 4.270277663794283e-06, 'epoch': 9.44}


 79%|███████▉  | 92500/116976 [05:26<1:27:22,  4.67it/s]

{'loss': 0.0871, 'learning_rate': 4.184790042401861e-06, 'epoch': 9.49}


 80%|███████▉  | 93000/116976 [07:14<1:25:32,  4.67it/s]

{'loss': 0.0824, 'learning_rate': 4.0993024210094386e-06, 'epoch': 9.54}


 80%|███████▉  | 93500/116976 [09:02<1:24:00,  4.66it/s]

{'loss': 0.0902, 'learning_rate': 4.013814799617016e-06, 'epoch': 9.59}


 80%|████████  | 94000/116976 [10:50<1:22:02,  4.67it/s]

{'loss': 0.0785, 'learning_rate': 3.928327178224594e-06, 'epoch': 9.64}


 81%|████████  | 94500/116976 [12:38<1:20:19,  4.66it/s]

{'loss': 0.0842, 'learning_rate': 3.842839556832171e-06, 'epoch': 9.69}


 81%|████████  | 95000/116976 [14:26<1:18:32,  4.66it/s]

{'loss': 0.0844, 'learning_rate': 3.7573519354397486e-06, 'epoch': 9.75}


 82%|████████▏ | 95500/116976 [16:14<1:16:43,  4.67it/s]

{'loss': 0.0913, 'learning_rate': 3.6718643140473266e-06, 'epoch': 9.8}


 82%|████████▏ | 96000/116976 [18:02<1:15:05,  4.66it/s]

{'loss': 0.0919, 'learning_rate': 3.5863766926549037e-06, 'epoch': 9.85}


 82%|████████▏ | 96500/116976 [19:50<1:13:08,  4.67it/s]

{'loss': 0.0796, 'learning_rate': 3.5008890712624816e-06, 'epoch': 9.9}


 83%|████████▎ | 97000/116976 [21:39<1:11:55,  4.63it/s]

{'loss': 0.0914, 'learning_rate': 3.4154014498700587e-06, 'epoch': 9.95}


                                                        
 83%|████████▎ | 97480/116976 [24:55<1:07:23,  4.82it/s]

{'eval_loss': 1.7385746240615845, 'eval_accuracy': 0.7715941731637259, 'eval_precision': 0.7721286606428058, 'eval_recall': 0.7715941731637259, 'eval_f1': 0.7716238520543374, 'eval_runtime': 91.5375, 'eval_samples_per_second': 212.984, 'eval_steps_per_second': 26.623, 'epoch': 10.0}


 83%|████████▎ | 97500/116976 [24:59<1:20:18,  4.04it/s]  

{'loss': 0.0785, 'learning_rate': 3.3299138284776366e-06, 'epoch': 10.0}


 84%|████████▍ | 98000/116976 [26:47<1:08:14,  4.63it/s]

{'loss': 0.0682, 'learning_rate': 3.2444262070852146e-06, 'epoch': 10.05}


 84%|████████▍ | 98500/116976 [28:36<1:06:15,  4.65it/s]

{'loss': 0.0556, 'learning_rate': 3.1589385856927917e-06, 'epoch': 10.1}


 85%|████████▍ | 99000/116976 [30:25<1:04:13,  4.66it/s]

{'loss': 0.0635, 'learning_rate': 3.0734509643003696e-06, 'epoch': 10.16}


 85%|████████▌ | 99500/116976 [32:16<1:04:38,  4.51it/s]

{'loss': 0.0673, 'learning_rate': 2.9879633429079476e-06, 'epoch': 10.21}


 85%|████████▌ | 100000/116976 [34:08<1:02:38,  4.52it/s]

{'loss': 0.0761, 'learning_rate': 2.9024757215155247e-06, 'epoch': 10.26}


 86%|████████▌ | 100500/116976 [35:58<56:37,  4.85it/s]  

{'loss': 0.0689, 'learning_rate': 2.8169881001231026e-06, 'epoch': 10.31}


 86%|████████▋ | 101000/116976 [37:43<55:00,  4.84it/s]  

{'loss': 0.0593, 'learning_rate': 2.7315004787306797e-06, 'epoch': 10.36}


 87%|████████▋ | 101500/116976 [39:27<53:51,  4.79it/s]  

{'loss': 0.0671, 'learning_rate': 2.6460128573382576e-06, 'epoch': 10.41}


 87%|████████▋ | 102000/116976 [41:12<52:15,  4.78it/s]  

{'loss': 0.0772, 'learning_rate': 2.5605252359458356e-06, 'epoch': 10.46}


 88%|████████▊ | 102500/116976 [42:56<49:43,  4.85it/s]  

{'loss': 0.0607, 'learning_rate': 2.4750376145534127e-06, 'epoch': 10.51}


 88%|████████▊ | 103000/116976 [44:40<48:11,  4.83it/s]  

{'loss': 0.0695, 'learning_rate': 2.3895499931609906e-06, 'epoch': 10.57}


 88%|████████▊ | 103500/116976 [46:24<46:22,  4.84it/s]  

{'loss': 0.0606, 'learning_rate': 2.304062371768568e-06, 'epoch': 10.62}


 89%|████████▉ | 104000/116976 [48:08<44:44,  4.83it/s]  

{'loss': 0.063, 'learning_rate': 2.2185747503761456e-06, 'epoch': 10.67}


 89%|████████▉ | 104500/116976 [49:52<42:52,  4.85it/s]  

{'loss': 0.062, 'learning_rate': 2.133087128983723e-06, 'epoch': 10.72}


 90%|████████▉ | 105000/116976 [51:36<41:09,  4.85it/s]  

{'loss': 0.0721, 'learning_rate': 2.047599507591301e-06, 'epoch': 10.77}


 90%|█████████ | 105500/116976 [53:20<39:27,  4.85it/s]  

{'loss': 0.0622, 'learning_rate': 1.9621118861988786e-06, 'epoch': 10.82}


 91%|█████████ | 106000/116976 [55:04<37:51,  4.83it/s]  

{'loss': 0.0733, 'learning_rate': 1.8766242648064561e-06, 'epoch': 10.87}


 91%|█████████ | 106500/116976 [56:48<36:03,  4.84it/s]  

{'loss': 0.0658, 'learning_rate': 1.7911366434140337e-06, 'epoch': 10.93}


 91%|█████████▏| 107000/116976 [58:32<34:15,  4.85it/s]  

{'loss': 0.077, 'learning_rate': 1.7056490220216116e-06, 'epoch': 10.98}


                                                         
 92%|█████████▏| 107228/116976 [1:00:48<32:09,  5.05it/s]

{'eval_loss': 1.79966139793396, 'eval_accuracy': 0.7745691423881822, 'eval_precision': 0.7745954563806603, 'eval_recall': 0.7745691423881822, 'eval_f1': 0.7745381877578411, 'eval_runtime': 88.2394, 'eval_samples_per_second': 220.944, 'eval_steps_per_second': 27.618, 'epoch': 11.0}


 92%|█████████▏| 107500/116976 [1:01:46<33:55,  4.65it/s]   

{'loss': 0.0599, 'learning_rate': 1.6201614006291891e-06, 'epoch': 11.03}


 92%|█████████▏| 108000/116976 [1:03:35<32:13,  4.64it/s]  

{'loss': 0.0497, 'learning_rate': 1.5346737792367666e-06, 'epoch': 11.08}


 93%|█████████▎| 108500/116976 [1:05:23<30:32,  4.63it/s]

{'loss': 0.051, 'learning_rate': 1.4491861578443442e-06, 'epoch': 11.13}


 93%|█████████▎| 109000/116976 [1:07:12<28:34,  4.65it/s]

{'loss': 0.059, 'learning_rate': 1.3636985364519217e-06, 'epoch': 11.18}


 94%|█████████▎| 109500/116976 [1:09:00<26:45,  4.66it/s]

{'loss': 0.0514, 'learning_rate': 1.2782109150594996e-06, 'epoch': 11.23}


 94%|█████████▍| 110000/116976 [1:10:49<24:49,  4.68it/s]

{'loss': 0.0611, 'learning_rate': 1.1927232936670771e-06, 'epoch': 11.28}


 94%|█████████▍| 110500/116976 [1:12:37<23:07,  4.67it/s]

{'loss': 0.0503, 'learning_rate': 1.1072356722746546e-06, 'epoch': 11.34}


 95%|█████████▍| 111000/116976 [1:14:25<22:05,  4.51it/s]

{'loss': 0.0487, 'learning_rate': 1.0217480508822324e-06, 'epoch': 11.39}


 95%|█████████▌| 111500/116976 [1:16:16<20:10,  4.52it/s]

{'loss': 0.0636, 'learning_rate': 9.362604294898099e-07, 'epoch': 11.44}


 96%|█████████▌| 112000/116976 [1:18:07<18:19,  4.52it/s]

{'loss': 0.0568, 'learning_rate': 8.507728080973876e-07, 'epoch': 11.49}


 96%|█████████▌| 112500/116976 [1:19:57<16:26,  4.54it/s]

{'loss': 0.0621, 'learning_rate': 7.652851867049651e-07, 'epoch': 11.54}


 97%|█████████▋| 113000/116976 [1:21:47<14:22,  4.61it/s]

{'loss': 0.0475, 'learning_rate': 6.797975653125429e-07, 'epoch': 11.59}


 97%|█████████▋| 113500/116976 [1:23:37<13:06,  4.42it/s]

{'loss': 0.0564, 'learning_rate': 5.943099439201204e-07, 'epoch': 11.64}


 97%|█████████▋| 114000/116976 [1:25:29<11:15,  4.41it/s]

{'loss': 0.0491, 'learning_rate': 5.08822322527698e-07, 'epoch': 11.69}


 98%|█████████▊| 114500/116976 [1:27:21<09:05,  4.54it/s]

{'loss': 0.0579, 'learning_rate': 4.2333470113527563e-07, 'epoch': 11.75}


 98%|█████████▊| 115000/116976 [1:29:10<06:54,  4.76it/s]

{'loss': 0.0577, 'learning_rate': 3.3784707974285326e-07, 'epoch': 11.8}


 99%|█████████▊| 115500/116976 [1:30:56<05:14,  4.69it/s]

{'loss': 0.0634, 'learning_rate': 2.523594583504309e-07, 'epoch': 11.85}


 99%|█████████▉| 116000/116976 [1:32:42<03:21,  4.84it/s]

{'loss': 0.0505, 'learning_rate': 1.668718369580085e-07, 'epoch': 11.9}


100%|█████████▉| 116500/116976 [1:34:28<01:39,  4.80it/s]

{'loss': 0.0536, 'learning_rate': 8.13842155655861e-08, 'epoch': 11.95}


                                                         
100%|██████████| 116976/116976 [1:37:41<00:00, 19.96it/s]

{'eval_loss': 1.8341795206069946, 'eval_accuracy': 0.7761592121460813, 'eval_precision': 0.7763562434187938, 'eval_recall': 0.7761592121460813, 'eval_f1': 0.7761422177522612, 'eval_runtime': 92.0434, 'eval_samples_per_second': 211.813, 'eval_steps_per_second': 26.477, 'epoch': 12.0}
{'train_runtime': 5861.1948, 'train_samples_per_second': 319.319, 'train_steps_per_second': 19.958, 'train_loss': 0.01489559734361701, 'epoch': 12.0}





TrainOutput(global_step=116976, training_loss=0.01489559734361701, metrics={'train_runtime': 5861.1948, 'train_samples_per_second': 319.319, 'train_steps_per_second': 19.958, 'train_loss': 0.01489559734361701, 'epoch': 12.0})

In [None]:
# save model
# trainer.save_model("genre_model")

### Evaluate the model

In [8]:
model = AutoModelForSequenceClassification.from_pretrained("./genre_model")

# trainer from checkpoint
trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [10]:
trainer.evaluate()

100%|██████████| 2437/2437 [01:29<00:00, 27.10it/s]


{'eval_loss': 1.8341795206069946,
 'eval_accuracy': 0.7761592121460813,
 'eval_precision': 0.7763562434187938,
 'eval_recall': 0.7761592121460813,
 'eval_f1': 0.7761422177522612,
 'eval_runtime': 90.2115,
 'eval_samples_per_second': 216.114,
 'eval_steps_per_second': 27.014}

In [11]:
trainer.evaluate(eval_dataset=tokenized_datasets['test'])

100%|██████████| 2437/2437 [01:32<00:00, 26.35it/s]


{'eval_loss': 1.8546456098556519,
 'eval_accuracy': 0.7769798933114485,
 'eval_precision': 0.7770985621765307,
 'eval_recall': 0.7769798933114485,
 'eval_f1': 0.7769661126113228,
 'eval_runtime': 92.5364,
 'eval_samples_per_second': 210.685,
 'eval_steps_per_second': 26.336}

In [12]:
trainer.evaluate(eval_dataset=tokenized_datasets['train'])

100%|██████████| 19496/19496 [12:33<00:00, 25.89it/s]


{'eval_loss': 0.03179062902927399,
 'eval_accuracy': 0.9851570214021005,
 'eval_precision': 0.98521158918168,
 'eval_recall': 0.9851570214021005,
 'eval_f1': 0.9851423452199055,
 'eval_runtime': 753.2059,
 'eval_samples_per_second': 207.07,
 'eval_steps_per_second': 25.884}

### Make predictions on some examples

In [11]:
def predict_genre(sentence):
    inputs = tokenizer(sentence, return_tensors="pt").to(device)
    outputs = model(**inputs)
    # print(outputs.logits)
    genre = outputs.logits.argmax(-1)
    if genre == 0:
        return "pop"
    elif genre == 1:
        return "rap"
    elif genre == 2:
        return "rock"
    elif genre == 3:
        return "r&b"
    elif genre == 4:
        return "latin"
    elif genre == 5:
        return "edm"
    return 0

# Eminem - Rap God | RAP
print("Eminem - Rap God | RAP\n", predict_genre("Look, I was gonna go easy on you not to hurt your feelings.")) # is in training data

# random spanish sentence
print("Tu tienes un gato muy bonito\n", predict_genre("Tu tienes un gato muy bonito"))

# Paul Damixie x SERE - You Got Me Like | POP
print("Paul Damixie x SERE - You Got Me Like | POP\n", predict_genre("I tell myself that I'll be better off without you, but you and I know that's a lie. And I can't get you out of my mind. It's like you've got me hypnotized."))
# is not in training data

# compare Japanese with Latin letters and Japanese characters
print("watashi wa raiku desu\n", predict_genre("watashi wa raiku desu"))
print("私はライクです\n", predict_genre("私はライクです"))

Eminem - Rap God | RAP
 rap
Tu tienes un gato muy bonito
 latin
Paul Damixie x SERE - You Got Me Like | POP
 edm
watashi wa raiku desu
 rock
私はライクです
 rock


In [12]:
def predict_genre(sentence):
    inputs = tokenizer(sentence, return_tensors="pt").to(device)
    outputs = model(**inputs)
    print(outputs.logits)
    genre = outputs.logits.argmax(-1)
    if genre == 0:
        return "pop"
    elif genre == 1:
        return "rap"
    elif genre == 2:
        return "rock"
    elif genre == 3:
        return "r&b"
    elif genre == 4:
        return "latin"
    elif genre == 5:
        return "edm"
    return 0

# Eminem - Rap God | RAP
print("Eminem - Rap God | RAP\n", predict_genre("Look, I was gonna go easy on you not to hurt your feelings.")) # is in training data

# random spanish sentence
print("Tu tienes un gato muy bonito\n", predict_genre("Tu tienes un gato muy bonito"))

# Paul Damixie x SERE - You Got Me Like | POP
print("Paul Damixie x SERE - You Got Me Like | POP\n", predict_genre("I tell myself that I'll be better off without you, but you and I know that's a lie. And I can't get you out of my mind. It's like you've got me hypnotized."))
# is not in training data

# compare Japanese with Latin letters and Japanese characters
print("watashi wa raiku desu\n", predict_genre("watashi wa raiku desu"))
print("私はライクです\n", predict_genre("私はライクです"))

tensor([[-7.2529,  9.0998, -3.0546, -6.2239, -0.2718, -5.0436]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Eminem - Rap God | RAP
 rap
tensor([[-3.6563, -1.4989, -1.4896, -2.7540,  2.9963, -4.0533]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Tu tienes un gato muy bonito
 latin
tensor([[-2.5010, -6.6259, -5.5518, -5.0577, -5.0102, 11.3497]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
Paul Damixie x SERE - You Got Me Like | POP
 edm
tensor([[-0.4940, -1.6997,  5.9809, -7.5397, -1.2829, -4.5777]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
watashi wa raiku desu
 rock
tensor([[-3.7927, -2.7125, 10.7992, -6.3658, -2.3098, -5.7035]],
       device='cuda:0', grad_fn=<AddmmBackward0>)
私はライクです
 rock
