In [1]:
import os
import numpy as np
import torch
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, DataCollatorWithPadding
import pandas as pd
from sklearn.model_selection import train_test_split
import epitran
from functools import lru_cache
from difflib import SequenceMatcher

In [21]:
is_phonetic = False
epi = epitran.Epitran("eng-Latn")

@lru_cache(maxsize=None)
def xsampa_list(word: str) -> list:
    return epi.xsampa_list(word)

In [22]:
pd_dataset = pd.read_csv('/home/toure215/BERT_phonetic/DATASETS/verses/verse_dataset.csv')
pd_dataset.head()

Unnamed: 0,id,sentence1,sentence2,label
0,0,ah why this boding start this sudden pain,that wings my pulse and shoots from vein to vein,1
1,1,ah why this boding start this sudden pain,those parts of thee that the worlds eye doth view,0
2,2,what mean regardless of yon midnight bell,these earthborn visions saddening o'er my cell,1
3,3,what mean regardless of yon midnight bell,to save their matrons from the brutal rape,0
4,4,what strange disorder prompts these thoughts t...,these sighs to murmur and these tears to flow,1


In [23]:
pd_dataset = pd_dataset.loc[pd_dataset["label"] == 1].drop(columns=["label", "id"])
pd_dataset.head()

Unnamed: 0,sentence1,sentence2
0,ah why this boding start this sudden pain,that wings my pulse and shoots from vein to vein
2,what mean regardless of yon midnight bell,these earthborn visions saddening o'er my cell
4,what strange disorder prompts these thoughts t...,these sighs to murmur and these tears to flow
6,'tis she 'tis eloisa's form restor'd,strike the soft sweet harmonic chord
8,she comes in all her killing charms confest,glares thro' the gloom and pours upon my breast


In [24]:
def get_last_word(text):
    return text.split()[-1]

pd_dataset["label"] = pd_dataset["sentence1"].apply(get_last_word)
pd_dataset.head()

Unnamed: 0,sentence1,sentence2,label
0,ah why this boding start this sudden pain,that wings my pulse and shoots from vein to vein,pain
2,what mean regardless of yon midnight bell,these earthborn visions saddening o'er my cell,bell
4,what strange disorder prompts these thoughts t...,these sighs to murmur and these tears to flow,glow
6,'tis she 'tis eloisa's form restor'd,strike the soft sweet harmonic chord,restor'd
8,she comes in all her killing charms confest,glares thro' the gloom and pours upon my breast,confest


In [25]:
train, test = train_test_split(pd_dataset, test_size=0.1, random_state=42, shuffle=True)
train, val = train_test_split(train, test_size=0.1, random_state=42, shuffle=True)

In [26]:
train_dataset = Dataset.from_pandas(train)
val_dataset = Dataset.from_pandas(val)
test_dataset = Dataset.from_pandas(test)

In [27]:
dataset = DatasetDict({
    "train": train_dataset,
    "validation": val_dataset,
    "test": test_dataset
})
dataset = dataset.remove_columns(column_names=['__index_level_0__'])
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 80595
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 8955
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 9951
    })
})

In [9]:
model_path = ['bert-base-uncased','psktoure/BERT_BPE_phonetic_wikitext-103-raw-v1','psktoure/BERT_WordLevel_phonetic_wikitext-103-raw-v1']

if is_phonetic:
    model = AutoModelForMaskedLM.from_pretrained(model_path[1])
    tokenizer = AutoTokenizer.from_pretrained(model_path[1])
else:
    model = AutoModelForMaskedLM.from_pretrained(model_path[0])
    tokenizer = AutoTokenizer.from_pretrained(model_path[0])

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [10]:
def translate_sentence(sentence: str) -> str:
    words = sentence.split()
    for i in range(len(words)):
        words[i] = ''.join(xsampa_list(words[i]))
    return ' '.join(words)

def translate_function(examples):
    examples['sentence1'] = [translate_sentence(sentence) for sentence in examples['sentence1']]
    examples['sentence2'] = [translate_sentence(sentence) for sentence in examples['sentence2']]
    examples['label'] = [''.join(xsampa_list(word)) for word in examples['label']]
    return examples

In [11]:
if is_phonetic:
    dataset = dataset.map(translate_function, batched=True, num_proc=15)

Map (num_proc=15):   0%|          | 0/80595 [00:00<?, ? examples/s]

Map (num_proc=15):   0%|          | 0/8955 [00:00<?, ? examples/s]

Map (num_proc=15):   0%|          | 0/9951 [00:00<?, ? examples/s]

In [28]:
dataset.save_to_disk('/home/toure215/BERT_phonetic/DATASETS/verses/rhyming_verses')

Saving the dataset (0/1 shards):   0%|          | 0/80595 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/8955 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/9951 [00:00<?, ? examples/s]

In [12]:
import torch   
from transformers import PreTrainedTokenizerBase

class CustomDataCollator:
    def __init__(self, tokenizer: PreTrainedTokenizerBase, padding=True, max_length=50):
        self.tokenizer = tokenizer
        self.mask_token_id = tokenizer.mask_token_id
        self.padding = padding
        self.max_length = max_length

    def __call__(self, examples):
       
        sentence1 = [example["sentence1"] for example in examples]
        sentence2 = [example["sentence2"] for example in examples]
        targets = [example["label"] for example in examples]

        encoded_targets = self.tokenizer(
            targets,
            add_special_tokens=False,
        )["input_ids"]
            
        batch = self.tokenizer(
            sentence1,
            sentence2,
            padding=self.padding,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = batch["input_ids"]
        labels = input_ids.clone()

        for i, idx in enumerate(input_ids):
            sep_token_indices = torch.where(idx == self.tokenizer.sep_token_id)[0]
            start = sep_token_indices[0] - len(encoded_targets[i])
            end = sep_token_indices[0]
            input_ids[i, start:end] = self.mask_token_id 
            labels[i, :start] = -100 
            labels[i, end:] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": batch["attention_mask"],
            "labels": labels,
        }


In [13]:
sample = dataset['train'][:1]
for key, value in sample.items():
    print(f'{key}: {value}')

sentence1: ['Its nejtIv fIr\\sn@s stIl D@ fejs r\\Itejnd']
sentence2: ['slip kAnS@ns slip itS Af@l TOt bi dr\\awnd']
label: ['r\\Itejnd']


In [14]:
data_collator = CustomDataCollator(tokenizer)
sample_list = [{key: sample[key][i] for key in sample} for i in range(len(sample['sentence1']))]
c = data_collator(sample_list)
for key in c:
    print(key ,":", c[key])
print(tokenizer.encode("retain'd", add_special_tokens=False))

input_ids : tensor([[   1,  170, 1749,  985,   18, 1416,    6,   35,  733,    8,    6, 1113,
            4,    4,    4,    2, 3720, 2666,    6,   87, 3720,  330, 1309,    6,
           29, 1142,  109,   96,   18,  165,    2]])
attention_mask : tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1]])
labels : tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
           34,   18, 2777, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100]])
[34, 22, 36, 19, 142, 0, 21]


In [15]:
training_args = TrainingArguments(
    output_dir="/tmp/fine_tuned_bert",
    eval_strategy="epoch",
    save_strategy="no",
    learning_rate=5e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_strategy='no',
    remove_unused_columns=False,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    data_collator=data_collator,
)

In [16]:
trainer.train()

  0%|          | 0/3780 [00:00<?, ?it/s]

  0%|          | 0/140 [00:00<?, ?it/s]

{'eval_loss': 1.4906039237976074, 'eval_runtime': 7.5937, 'eval_samples_per_second': 1179.273, 'eval_steps_per_second': 18.436, 'epoch': 1.0}


  0%|          | 0/140 [00:00<?, ?it/s]

{'eval_loss': 1.3292596340179443, 'eval_runtime': 7.6256, 'eval_samples_per_second': 1174.327, 'eval_steps_per_second': 18.359, 'epoch': 2.0}


  0%|          | 0/140 [00:00<?, ?it/s]

{'eval_loss': 1.2923369407653809, 'eval_runtime': 7.5686, 'eval_samples_per_second': 1183.183, 'eval_steps_per_second': 18.498, 'epoch': 3.0}
{'train_runtime': 611.9658, 'train_samples_per_second': 395.096, 'train_steps_per_second': 6.177, 'train_loss': 1.387801494295635, 'epoch': 3.0}


TrainOutput(global_step=3780, training_loss=1.387801494295635, metrics={'train_runtime': 611.9658, 'train_samples_per_second': 395.096, 'train_steps_per_second': 6.177, 'total_flos': 5220059314830300.0, 'train_loss': 1.387801494295635, 'epoch': 3.0})

In [17]:
def rhyme_score(word1: str, word2: str) -> int:
    if not is_phonetic:    
        end1 = xsampa_list(word1)
        end2 = xsampa_list(word2)
    else:
        end1 = word1
        end2 = word2
    length = min(len(end1), len(end2), 3)
    end1 = end1[-length:]
    end2 = end2[-length:]
    return SequenceMatcher(None, end1, end2).ratio()

In [18]:
def evaluate_rhyme(model, dataset, tokenizer):
    model = model.to("cuda")
    model.eval()
    rhyme_scores = []
    batch_size = 256

    for i in range(0, len(dataset), batch_size):
        print(f"Processing example {i}/{len(dataset)} ...", end="\r")
        batch = dataset[i : i + batch_size]
        batch_sequence = [{key: batch[key][j] for key in batch} for j in range(len(batch["sentence1"]))]
        inputs = data_collator(batch_sequence)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

        for j in range(len(batch["sentence1"])):
            masked_token_index = torch.where(inputs["input_ids"][j] == tokenizer.mask_token_id)[0]
            predicted_index = logits[j, masked_token_index].argmax(-1)
            predicted_word = tokenizer.decode(predicted_index)
            target = tokenizer.decode(inputs["labels"][j, masked_token_index])
            if i < 16 and j < 8:
                print('predicted_word:', predicted_word, '-- target_word:', target)
            rhyme_scores.append(rhyme_score(predicted_word, target))

    return {"score": np.mean(rhyme_scores)}

In [19]:
evaluate_rhyme(model, dataset['test'], tokenizer)

predicted_word: Ol -- target_word: Ol
predicted_word: hIr \ -- target_word: hIr \
predicted_word: mjuz -- target_word: nuz
predicted_word: pejnz -- target_word: pejnz
predicted_word: pow @ tr \ i -- target_word: tSIm @ str \ i
predicted_word: tSejndZ -- target_word: tSejndZ
predicted_word: juT -- target_word: juT
predicted_word: vd -- target_word: bIlivd
Processing example 9728/9951 ...

{'score': np.float64(0.9285163970120255)}