In [1]:
import os
import numpy as np
import torch
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, DataCollatorWithPadding
import pandas as pd
from sklearn.model_selection import train_test_split
import epitran
from functools import lru_cache
from difflib import SequenceMatcher

In [2]:
is_phonetic = False
epi = epitran.Epitran("eng-Latn")

@lru_cache(maxsize=None)
def xsampa_list(word: str) -> list:
    return epi.xsampa_list(word)


In [3]:
pd_dataset = pd.read_csv('/home/toure215/BERT_phonetic/DATASETS/verses/super_verses.csv')
pd_dataset.head()

Unnamed: 0,Verse
0,ah why this boding start this sudden pain
1,that wings my pulse and shoots from vein to vein
2,what mean regardless of yon midnight bell
3,these earthborn visions saddening oer my cell
4,what strange disorder prompts these thoughts t...


In [4]:
pd_dataset = pd_dataset.rename({"Verse": "input_text"}, axis='columns')
pd_dataset.head()

Unnamed: 0,input_text
0,ah why this boding start this sudden pain
1,that wings my pulse and shoots from vein to vein
2,what mean regardless of yon midnight bell
3,these earthborn visions saddening oer my cell
4,what strange disorder prompts these thoughts t...


In [5]:
def get_last_word(verse: str) -> str:
    return verse.split()[-1]

def mask_last_word(verse: str) -> str:
    words = verse.split()
    words[-1] = '[MASK]'
    return ' '.join(words)

In [6]:
pd_dataset['target_word'] = pd_dataset['input_text'].apply(get_last_word)
# pd_dataset['input_text'] = pd_dataset['input_text'].apply(mask_last_word)

In [7]:
pd_dataset.head()

Unnamed: 0,input_text,target_word
0,ah why this boding start this sudden pain,pain
1,that wings my pulse and shoots from vein to vein,vein
2,what mean regardless of yon midnight bell,bell
3,these earthborn visions saddening oer my cell,cell
4,what strange disorder prompts these thoughts t...,glow


In [8]:
train, test = train_test_split(pd_dataset, test_size=0.1, random_state=42, shuffle=True)
train, val = train_test_split(train, test_size=0.1, random_state=42, shuffle=True)

In [9]:
hf_train = Dataset.from_pandas(train)
hf_val = Dataset.from_pandas(val)
hf_test = Dataset.from_pandas(test)

In [10]:
hf_dataset = DatasetDict({"train": hf_train, "validation": hf_val, "test": hf_test}).remove_columns(['__index_level_0__'])
hf_dataset

DatasetDict({
    train: Dataset({
        features: ['input_text', 'target_word'],
        num_rows: 440950
    })
    validation: Dataset({
        features: ['input_text', 'target_word'],
        num_rows: 48995
    })
    test: Dataset({
        features: ['input_text', 'target_word'],
        num_rows: 54439
    })
})

In [11]:
model_path = ['bert-base-uncased','psktoure/BERT_BPE_phonetic_wikitext-103-raw-v1','psktoure/BERT_WordLevel_phonetic_wikitext-103-raw-v1']

if is_phonetic:
    model = AutoModelForMaskedLM.from_pretrained(model_path[1])
    tokenizer = AutoTokenizer.from_pretrained(model_path[1])
else:
    model = AutoModelForMaskedLM.from_pretrained(model_path[0])
    tokenizer = AutoTokenizer.from_pretrained(model_path[0])

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

In [12]:
def translate_sentence(sentence: str) -> str:
    words = sentence.split()
    for i in range(len(words)):
        words[i] = ''.join(xsampa_list(words[i]))
    return ' '.join(words)

def translate_function(examples):
    examples['input_text'] = [translate_sentence(sentence) for sentence in examples['input_text']]
    examples['target_word'] = [''.join(xsampa_list(word)) for word in examples['target_word']]
    return examples
    

def tokenize_function(examples):
    inputs = tokenizer(examples['input_text'], padding='max_length', truncation=True, max_length=50)
    targets = tokenizer(examples['target_word'], padding='max_length', truncation=True, max_length=5)
    inputs['labels'] = targets['input_ids']
    return inputs

In [13]:

if is_phonetic:
    hf_dataset = hf_dataset.map(translate_function, batched=True, num_proc=15)

In [14]:
# tokenized_dataset = hf_dataset.map(tokenize_function, batched=True, num_proc=15, remove_columns=['input_text', 'target_word'])

In [15]:
# tokenized_dataset

In [None]:
import torch   
from transformers import PreTrainedTokenizerBase

class CustomDataCollator:
    def __init__(self, tokenizer: PreTrainedTokenizerBase, padding=True, max_length=50):
        self.tokenizer = tokenizer
        self.mask_token_id = tokenizer.mask_token_id
        self.padding = padding
        self.max_length = max_length

    def __call__(self, examples):
       
        input_texts = [example['input_text'] for example in examples]
        target_words = [example['target_word'] for example in examples]

        batch = self.tokenizer(
            input_texts,
            padding=self.padding,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = batch["input_ids"]
        labels = input_ids.clone()

        for i, idx in enumerate(input_ids):
            sep_token_indices = torch.where(idx == self.tokenizer.sep_token_id)[0]
            if len(sep_token_indices) == 0:
                raise ValueError(f"[SEP] token not found in input_ids: {idx}")
            
            masked_token_idx = sep_token_indices[0] - 1
            input_ids[i, masked_token_idx] = self.mask_token_id 
            labels[i, :masked_token_idx] = -100 
            labels[i, masked_token_idx + 1:] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": batch["attention_mask"],
            "labels": labels,
        }


In [17]:
# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
data_collator = CustomDataCollator(tokenizer)

In [18]:
sample = hf_dataset['train'][:2]
sample_list = [{key: sample[key][i] for key in sample} for i in range(len(sample['input_text']))]
print("mask_token_id: ", tokenizer.mask_token_id, "\nmask_token:", tokenizer.mask_token)
print(sample_list)
c = data_collator(sample_list)
for key in c:
    print(key ,":", c[key])
print(tokenizer.decode(12342))

mask_token_id:  103 
mask_token: [MASK]
[{'input_text': 'wide wave the flaming sword and send o send', 'target_word': 'send'}, {'input_text': 'the marble leaps or shrinks or burns', 'target_word': 'burns'}]
input_ids : tensor([[  101,  2898,  4400,  1996, 19091,  4690,  1998,  4604,  1051,   103,
           102],
        [  101,  1996,  7720, 29195,  2030, 22802,  2015,  2030,   103,   102,
             0]])
attention_mask : tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]])
labels : tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, 4604, -100],
        [-100, -100, -100, -100, -100, -100, -100, -100, 7641, -100, -100]])
shine


In [19]:
def rhyme_score(word1: str, word2: str) -> int:
    if not is_phonetic:    
        end1 = xsampa_list(word1)
        end2 = xsampa_list(word2)
    else:
        end1 = word1
        end2 = word2
    length = min(len(end1), len(end2), 3)
    end1 = end1[-length:]
    end2 = end2[-length:]
    return SequenceMatcher(None, end1, end2).ratio()

def compute_metrics(pred):
    preds = torch.argmax(pred.predictions, dim=-1)
    batch_size = 32  
    total_rhyme_score = 0
    count = 0

    for i in range(0, len(preds), batch_size):
        batch_preds = preds[i:i + batch_size]
        batch_labels = pred.label_ids[i:i + batch_size]

        # Decode predictions and labels in batches
        decoded_preds = tokenizer.batch_decode(batch_preds, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(batch_labels, skip_special_tokens=True)

        # Compute rhyme scores for the batch
        batch_rhyme_scores = [
            rhyme_score(pred_word, target_word)
            for pred_word, target_word in zip(decoded_preds, decoded_labels)
        ]
        
        total_rhyme_score += sum(batch_rhyme_scores)
        count += len(batch_rhyme_scores)

    # Compute the mean rhyme score
    mean_rhyme_score = total_rhyme_score / count if count > 0 else 0
    return {"rhyme_score": mean_rhyme_score}



In [20]:
training_args = TrainingArguments(
    output_dir='/tmp/verses',
    num_train_epochs=15,
    learning_rate=1e-4,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    eval_strategy='epoch',
    logging_strategy='no',
    save_strategy='no',
    report_to='none',
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=hf_dataset['train'],
    eval_dataset=hf_dataset['validation'],
    tokenizer=tokenizer,
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [21]:
trainer.train()

  0%|          | 0/25845 [00:00<?, ?it/s]

  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 5.191527366638184, 'eval_runtime': 18.5094, 'eval_samples_per_second': 2647.038, 'eval_steps_per_second': 10.373, 'epoch': 1.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 5.19736909866333, 'eval_runtime': 18.6659, 'eval_samples_per_second': 2624.844, 'eval_steps_per_second': 10.286, 'epoch': 2.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 5.389641761779785, 'eval_runtime': 18.7071, 'eval_samples_per_second': 2619.059, 'eval_steps_per_second': 10.263, 'epoch': 3.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 5.762287616729736, 'eval_runtime': 18.762, 'eval_samples_per_second': 2611.394, 'eval_steps_per_second': 10.233, 'epoch': 4.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 6.129112720489502, 'eval_runtime': 18.7715, 'eval_samples_per_second': 2610.072, 'eval_steps_per_second': 10.228, 'epoch': 5.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 6.395214557647705, 'eval_runtime': 18.8069, 'eval_samples_per_second': 2605.161, 'eval_steps_per_second': 10.209, 'epoch': 6.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 6.68097448348999, 'eval_runtime': 18.7488, 'eval_samples_per_second': 2613.237, 'eval_steps_per_second': 10.241, 'epoch': 7.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 6.88824987411499, 'eval_runtime': 18.8287, 'eval_samples_per_second': 2602.145, 'eval_steps_per_second': 10.197, 'epoch': 8.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 7.077701568603516, 'eval_runtime': 18.8182, 'eval_samples_per_second': 2603.599, 'eval_steps_per_second': 10.203, 'epoch': 9.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 7.221578121185303, 'eval_runtime': 18.8363, 'eval_samples_per_second': 2601.101, 'eval_steps_per_second': 10.193, 'epoch': 10.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 7.334627151489258, 'eval_runtime': 18.8381, 'eval_samples_per_second': 2600.848, 'eval_steps_per_second': 10.192, 'epoch': 11.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 7.419206619262695, 'eval_runtime': 18.8577, 'eval_samples_per_second': 2598.146, 'eval_steps_per_second': 10.182, 'epoch': 12.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 7.486563205718994, 'eval_runtime': 19.0466, 'eval_samples_per_second': 2572.373, 'eval_steps_per_second': 10.081, 'epoch': 13.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 7.515157699584961, 'eval_runtime': 18.8864, 'eval_samples_per_second': 2594.191, 'eval_steps_per_second': 10.166, 'epoch': 14.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 7.534763813018799, 'eval_runtime': 18.8661, 'eval_samples_per_second': 2596.986, 'eval_steps_per_second': 10.177, 'epoch': 15.0}
{'train_runtime': 7267.347, 'train_samples_per_second': 910.133, 'train_steps_per_second': 3.556, 'train_loss': 1.7260270059489262, 'epoch': 15.0}


TrainOutput(global_step=25845, training_loss=1.7260270059489262, metrics={'train_runtime': 7267.347, 'train_samples_per_second': 910.133, 'train_steps_per_second': 3.556, 'total_flos': 7.01194194977256e+16, 'train_loss': 1.7260270059489262, 'epoch': 15.0})

In [22]:
def evaluate_rhyme(model, dataset, tokenizer):
    model = model.to("cuda")
    model.eval()
    rhyme_scores = []
    batch_size = 8

    for i in range(0, len(dataset), batch_size):
        # if i > 16:
        #     break
        print(f"Processing example {i}/{len(dataset)} ...", end="\r")
        batch = dataset[i : i + batch_size]
        batch_sequence = [{key: batch[key][j] for key in batch} for j in range(len(batch["input_text"]))]
        inputs = data_collator(batch_sequence)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

        for j in range(len(batch["input_text"])):
            masked_token_index = torch.where(inputs["input_ids"][j] == tokenizer.mask_token_id)[0].item()
            predicted_index = logits[j, masked_token_index].argmax(-1).item()
            predicted_word = tokenizer.decode(predicted_index)
            target = tokenizer.decode(inputs["labels"][j, masked_token_index])
            if i < 16 and j < 8:
                print('predicted_word:', predicted_word, '-- target_word:', target)
            rhyme_scores.append(rhyme_score(predicted_word, target))

    return {"score": np.mean(rhyme_scores)}


In [23]:
print(tokenizer.mask_token_id)
print(tokenizer.mask_token)
print(hf_dataset['test'][0])
print(tokenizer(hf_dataset['test'][0]['input_text']))
print(tokenizer.decode(103))
print(tokenizer.decode(102))

103
[MASK]
{'input_text': 'is he whose visage in the lazy mist', 'target_word': 'mist'}
{'input_ids': [101, 2003, 2002, 3005, 9425, 3351, 1999, 1996, 13971, 11094, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
[MASK]
[SEP]


In [24]:
evaluate_rhyme(model, hf_dataset['test'], tokenizer)


predicted_word: waves -- target_word: mist
predicted_word: ##od -- target_word: ##od
predicted_word: ridge -- target_word: wall
predicted_word: eyes -- target_word: hands
predicted_word: ##s -- target_word: ##shire
predicted_word: cold -- target_word: vase
predicted_word: to -- target_word: over
predicted_word: curled -- target_word: wife
predicted_word: iii -- target_word: she
predicted_word: wood -- target_word: plant
predicted_word: laid -- target_word: fly
predicted_word: walls -- target_word: mayhem
predicted_word: to -- target_word: down
predicted_word: girl -- target_word: trick
predicted_word: lend -- target_word: there
predicted_word: ##pers -- target_word: ##pers
Processing example 54432/54439 ...

{'score': np.float64(0.31773789623860343)}

In [25]:
# del model

In [26]:
# import gc
# import ctypes
# import torch
# gc.collect()
# libc = ctypes.CDLL("libc.so.6") # clearing cache
# libc.malloc_trim(0)
# torch.cuda.empty_cache()