In [1]:
import os
import torch
import numpy as np
import epitran
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    AutoModelForMaskedLM,
)
import kagglehub
from functools import lru_cache

In [2]:
# path = kagglehub.dataset_download("jamiewelsh2/rap-lyrics")

# print("Path to dataset files:", path)

In [3]:
# pd_data_frame = pd.read_csv(
#     "/home/toure215/Documents/BERT_phonetic/DATASETS/rap/updated_rappers.csv"
# )
# pd_data_frame.head(20)

In [4]:
# print(len(pd_data_frame))
# print(pd_data_frame["song"].nunique())

In [5]:
# ds = load_dataset("Cropinky/rap_lyrics_english")
# ds

In [6]:
# print(ds["train"][:39])
# ds["train"] = ds["train"][40:]

In [7]:
# ds["train"] = Dataset.from_dict(ds["train"])

In [8]:
# ds = ds.filter(
#     lambda x: "[" not in x["text"] or "]" not in x["text"], num_proc=15
# ).filter(lambda x: x["text"] != "", num_proc=15)
# ds

In [9]:
# ds["train"][:10]

In [10]:
# pd_ds = ds["train"].to_pandas()
# pd_ds.head(10)

In [11]:
epi = epitran.Epitran("eng-Latn")


@lru_cache(maxsize=None)
def xsampa_list(word):
    return epi.xsampa_list(word)


def is_rhyming(word1, word2):
    sound1 = xsampa_list(word1)
    sound2 = xsampa_list(word2)
    if len(sound1) < 2 or len(sound2) < 2:
        return False
    return sound1[-2:] == sound2[-2:]


# Pre-compute phonetic endings for all verses
def get_last_phonetic(word):
    phonemes = xsampa_list(word)
    return phonemes[-2:] if len(phonemes) >= 2 else phonemes

In [12]:
# pd_ds["last_word"] = pd_ds["text"].apply(lambda x: x.split()[-1])
# pd_ds["phonetic_ending"] = pd_ds["last_word"].apply(get_last_phonetic)
# pd_ds.head()

In [13]:
# # Convert phonetic_ending lists to tuples for hashing
# pd_ds["phonetic_ending"] = pd_ds["phonetic_ending"].apply(tuple)

# # Group verses by their phonetic endings for quick access to rhyming pairs
# rhyme_groups = (
#     pd_ds.groupby("phonetic_ending").apply(lambda x: x.index.tolist()).to_dict()
# )

In [14]:
# phonetic_endings = list(rhyme_groups.keys())

In [15]:
# # Build the dataset
# import random

# rap_ds = pd.DataFrame(columns=["id", "sentence1", "sentence2", "label"])

# for i in range(0, len(pd_ds), 2):
#     last = len(rap_ds)
#     word1 = pd_ds.iloc[i]["last_word"]
#     phonetic1 = pd_ds.iloc[i]["phonetic_ending"]

#     # Find a rhyming pair
#     rhyming_indices = rhyme_groups.get(phonetic1, [])
#     rhyming_idx = i  # Default to self if no other rhyme is found
#     for idx in rhyming_indices:
#         if idx != i:
#             rhyming_idx = idx
#             break

#     rap_ds.loc[last] = [
#         last,
#         pd_ds.iloc[i]["text"],
#         pd_ds.iloc[rhyming_idx]["text"],
#         1,  # Label for rhyming
#     ]

#     # Find a non-rhyming pair by selecting from different phonetic endings
#     non_rhyme_phonetic = phonetic1
#     while non_rhyme_phonetic == phonetic1:
#         non_rhyme_phonetic = random.choice(phonetic_endings)
#     non_rhyme_idx = np.random.choice(rhyme_groups[non_rhyme_phonetic])

#     rap_ds.loc[last + 1] = [
#         last + 1,
#         pd_ds.iloc[i]["text"],
#         pd_ds.iloc[non_rhyme_idx]["text"],
#         0,  # Label for non-rhyming
#     ]

# print("Final row count in rap_ds:", len(rap_ds))

In [16]:
# from sklearn.model_selection import train_test_split

# train, test = train_test_split(rap_ds, test_size=0.1, random_state=42)
# train, val = train_test_split(train, test_size=0.1, random_state=42)

# train = Dataset.from_pandas(train)
# val = Dataset.from_pandas(val)
# test = Dataset.from_pandas(test)

# rap_ds_hf = DatasetDict({"train": train, "validation": val, "test": test})
# rap_ds_hf.save_to_disk("/home/toure215/Documents/BERT_phonetic/DATASETS/rap/rap_ds_hf")

In [17]:
rap_ds_hf = load_from_disk(
    "/home/toure215/BERT_phonetic/DATASETS/rap/rap_ds_hf"
)

rap_ds_rhyme = rap_ds_hf.filter(lambda x: x["label"] == 1, num_proc=os.cpu_count() - 1)

In [18]:
rap_ds_rhyme = rap_ds_rhyme.remove_columns(["__index_level_0__", "id"])


def add_rhyme_label(example):
    label = example["sentence1"].split()[-1]
    return {
        "sentence1": example["sentence1"],
        "sentence2": example["sentence2"],
        "label": label,
    }


rap_ds_rhyme = rap_ds_rhyme.map(add_rhyme_label, num_proc=os.cpu_count() - 1)
rap_ds_rhyme

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 411009
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 45717
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 50881
    })
})

In [19]:
is_phonetic = True

model_path = [
    "bert-base-uncased",
    "psktoure/BERT_BPE_phonetic_wikitext-103-raw-v1",
    "psktoure/BERT_WordPiece_wikitext-103-raw-v1",
]

if is_phonetic:
    model = AutoModelForMaskedLM.from_pretrained(model_path[1])
    tokenizer = AutoTokenizer.from_pretrained(model_path[1])
else:
    model = AutoModelForMaskedLM.from_pretrained(model_path[0])
    tokenizer = AutoTokenizer.from_pretrained(model_path[0])

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [20]:
def translate_sentence(sentence: str) -> str:
    words = sentence.split()
    for i in range(len(words)):
        words[i] = "".join(xsampa_list(words[i]))
    return " ".join(words)


def translate_function(examples):
    examples["sentence1"] = [
        translate_sentence(sentence) for sentence in examples["sentence1"]
    ]
    examples["sentence2"] = [
        translate_sentence(sentence) for sentence in examples["sentence2"]
    ]
    examples["label"] = ["".join(xsampa_list(word)) for word in examples["label"]]
    return examples


if is_phonetic:
    rap_ds_rhyme = rap_ds_rhyme.map(
        translate_function, num_proc=os.cpu_count() - 1, batched=True
    )

Map (num_proc=15):   0%|          | 0/411009 [00:00<?, ? examples/s]

Map (num_proc=15):   0%|          | 0/45717 [00:00<?, ? examples/s]

Map (num_proc=15):   0%|          | 0/50881 [00:00<?, ? examples/s]

In [21]:
import torch
from transformers import PreTrainedTokenizerBase


class CustomDataCollator:
    def __init__(self, tokenizer: PreTrainedTokenizerBase, padding=True, max_length=64):
        self.tokenizer = tokenizer
        self.mask_token_id = tokenizer.mask_token_id
        self.padding = padding
        self.max_length = max_length

    def __call__(self, examples):

        sentence1 = [example["sentence1"] for example in examples]
        sentence2 = [example["sentence2"] for example in examples]
        targets = [example["label"] for example in examples]

        encoded_targets = self.tokenizer(targets, add_special_tokens=False)

        batch = self.tokenizer(
            sentence1,
            sentence2,
            padding=self.padding,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = batch["input_ids"]
        labels = input_ids.clone()

        for i, idx in enumerate(input_ids):
            sep_token_indices = torch.where(idx == self.tokenizer.sep_token_id)[0]
            start = sep_token_indices[0] - len(encoded_targets[i])
            end = sep_token_indices[0]
            input_ids[i, start:end] = self.mask_token_id
            labels[i, :start] = -100
            labels[i, end:] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": batch["attention_mask"],
            "labels": labels,
        }

In [22]:
class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

In [23]:
data_collator = CustomDataCollator(tokenizer)

training_args = TrainingArguments(
    output_dir="/tmp/fine_tuned_bert",
    eval_strategy="epoch",
    save_strategy="no",
    learning_rate=5e-5,
    per_device_train_batch_size=264,
    per_device_eval_batch_size=264,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_strategy="no",
    remove_unused_columns=False,
    fp16=True,
)

# Initialize the Trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=rap_ds_rhyme["train"],
    eval_dataset=rap_ds_rhyme["validation"],
    data_collator=data_collator,
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [24]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,No log,1.263351
2,No log,1.085101
3,No log,1.027311


TrainOutput(global_step=4671, training_loss=1.2224574251364804, metrics={'train_runtime': 752.7146, 'train_samples_per_second': 1638.107, 'train_steps_per_second': 6.206, 'total_flos': 3.80848301034003e+16, 'train_loss': 1.2224574251364804, 'epoch': 3.0})

In [25]:
def evaluate_rhyme_indices(model, dataset, tokenizer, k=5):

    model = model.to("cuda")
    model.eval()
    res = []
    batch_size = 256

    for i in range(0, len(dataset), batch_size):
        print(f"Processing batch {i}/{len(dataset)}...", end="\r")
        batch = dataset[i : i + batch_size]
        batch_sequence = [
            {key: batch[key][j] for key in batch}
            for j in range(len(batch["sentence1"]))
        ]
        inputs = data_collator(batch_sequence)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            labels = inputs["labels"]

        count = 0

        for j in range(len(batch["sentence1"])):
            # Identify the position of the masked token
            masked_token_index = (
                inputs["input_ids"][j] == tokenizer.mask_token_id
            ).nonzero(as_tuple=True)[0]

            targets = labels[j, masked_token_index]
            top_k_indices = logits[j, masked_token_index].topk(k).indices.squeeze(0)
            if i < 16 and j < 8:
                print("targets:", targets, "-- top_k_indices:", top_k_indices)

            ok = True
            for idx, target in enumerate(targets):
                if target not in top_k_indices[idx]:
                    ok = False
            if ok:
                count += 1

        res.append(count / len(batch["sentence1"]))

    return {"score": np.mean(res)}

In [26]:
evaluate_rhyme_indices(model, rap_ds_rhyme["test"], tokenizer, k=5)

targets: tensor([63], device='cuda:0') -- top_k_indices: tensor([  63, 7489,  451, 1465, 1315], device='cuda:0')
targets: tensor([191], device='cuda:0') -- top_k_indices: tensor([10266,   191,  6197,  7365,  1123], device='cuda:0')
targets: tensor([821], device='cuda:0') -- top_k_indices: tensor([ 275, 5786,  499,  821, 1343], device='cuda:0')
targets: tensor([1738,   18], device='cuda:0') -- top_k_indices: tensor([[241,  66, 145, 233,  61],
        [ 18,  40,  42,   6,  62]], device='cuda:0')
targets: tensor([2886], device='cuda:0') -- top_k_indices: tensor([ 2886, 28578,   353,  6662,  8991], device='cuda:0')
targets: tensor([181], device='cuda:0') -- top_k_indices: tensor([  181, 25987,  4511, 16038,  6242], device='cuda:0')
targets: tensor([  28,    6, 4397,    6,   31], device='cuda:0') -- top_k_indices: tensor([[  28,   34, 1279,   56,  580],
        [   6,   18,   42,   41,   59],
        [4397,  808,  839,  383,  402],
        [   6,   59,  129, 6214, 3911],
        [  31,   81

{'score': np.float64(0.7898269277084906)}

In [27]:
{'score': np.float64(0.7740060405655218)}
{'score': np.float64(0.7898269277084906)}

{'score': np.float64(0.7740060405655218)}