In [1]:
import os
import numpy as np
import torch
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, DataCollatorWithPadding
import pandas as pd
from sklearn.model_selection import train_test_split
import epitran
from functools import lru_cache
from difflib import SequenceMatcher

In [2]:
is_phonetic = True
epi = epitran.Epitran("eng-Latn")

@lru_cache(maxsize=None)
def xsampa_list(word: str) -> list:
    return epi.xsampa_list(word)


In [3]:
pd_dataset = pd.read_csv('/home/toure215/BERT_phonetic/DATASETS/verses/super_verses.csv')
pd_dataset.head()

Unnamed: 0,Verse
0,ah why this boding start this sudden pain
1,that wings my pulse and shoots from vein to vein
2,what mean regardless of yon midnight bell
3,these earthborn visions saddening oer my cell
4,what strange disorder prompts these thoughts t...


In [4]:
pd_dataset = pd_dataset.rename({"Verse": "input_text"}, axis='columns')
pd_dataset.head()

Unnamed: 0,input_text
0,ah why this boding start this sudden pain
1,that wings my pulse and shoots from vein to vein
2,what mean regardless of yon midnight bell
3,these earthborn visions saddening oer my cell
4,what strange disorder prompts these thoughts t...


In [5]:
def get_last_word(verse: str) -> str:
    return verse.split()[-1]

def mask_last_word(verse: str) -> str:
    words = verse.split()
    words[-1] = '[MASK]'
    return ' '.join(words)

In [6]:
pd_dataset['target_word'] = pd_dataset['input_text'].apply(get_last_word)
# pd_dataset['input_text'] = pd_dataset['input_text'].apply(mask_last_word)

In [7]:
pd_dataset.head()

Unnamed: 0,input_text,target_word
0,ah why this boding start this sudden pain,pain
1,that wings my pulse and shoots from vein to vein,vein
2,what mean regardless of yon midnight bell,bell
3,these earthborn visions saddening oer my cell,cell
4,what strange disorder prompts these thoughts t...,glow


In [8]:
train, test = train_test_split(pd_dataset, test_size=0.1, random_state=42, shuffle=True)
train, val = train_test_split(train, test_size=0.1, random_state=42, shuffle=True)

In [9]:
hf_train = Dataset.from_pandas(train)
hf_val = Dataset.from_pandas(val)
hf_test = Dataset.from_pandas(test)

In [10]:
hf_dataset = DatasetDict({"train": hf_train, "validation": hf_val, "test": hf_test}).remove_columns(['__index_level_0__'])
hf_dataset

DatasetDict({
    train: Dataset({
        features: ['input_text', 'target_word'],
        num_rows: 440950
    })
    validation: Dataset({
        features: ['input_text', 'target_word'],
        num_rows: 48995
    })
    test: Dataset({
        features: ['input_text', 'target_word'],
        num_rows: 54439
    })
})

In [11]:
model_path = [
    "bert-base-uncased",
    "psktoure/BERT_BPE_phonetic_wikitext-103-raw-v1",
    "psktoure/BERT_WordLevel_phonetic_wikitext-103-raw-v1",
]

if is_phonetic:
    model = AutoModelForMaskedLM.from_pretrained(model_path[-1])
    tokenizer = AutoTokenizer.from_pretrained(model_path[-1])
else:
    model = AutoModelForMaskedLM.from_pretrained(model_path[0])
    tokenizer = AutoTokenizer.from_pretrained(model_path[0])

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [12]:
def translate_sentence(sentence: str) -> str:
    words = sentence.split()
    for i in range(len(words)):
        words[i] = ' '.join(xsampa_list(words[i]))
    return ' '.join(words)

def translate_function(examples):
    examples['input_text'] = [translate_sentence(sentence) for sentence in examples['input_text']]
    examples['target_word'] = [' '.join(xsampa_list(word)) for word in examples['target_word']]
    return examples
    

def tokenize_function(examples):
    inputs = tokenizer(examples['input_text'], padding='max_length', truncation=True, max_length=50)
    targets = tokenizer(examples['target_word'], padding='max_length', truncation=True, max_length=5)
    inputs['labels'] = targets['input_ids']
    return inputs

In [13]:

if is_phonetic:
    hf_dataset = hf_dataset.map(translate_function, batched=True, num_proc=15)

Map (num_proc=15):   0%|          | 0/440950 [00:00<?, ? examples/s]

Map (num_proc=15):   0%|          | 0/48995 [00:00<?, ? examples/s]

Map (num_proc=15):   0%|          | 0/54439 [00:00<?, ? examples/s]

In [14]:
# tokenized_dataset = hf_dataset.map(tokenize_function, batched=True, num_proc=15, remove_columns=['input_text', 'target_word'])

In [15]:
# tokenized_dataset

In [16]:
import torch   
from transformers import PreTrainedTokenizerBase

class CustomDataCollator:
    def __init__(
        self, tokenizer: PreTrainedTokenizerBase, padding=True, max_length=128
    ):
        self.tokenizer = tokenizer
        self.mask_token_id = tokenizer.mask_token_id
        self.padding = padding
        self.max_length = max_length

    def __call__(self, examples):

        input_text = [example["input_text"] for example in examples]
        targets = [example["target_word"] for example in examples]

        encoded_targets = self.tokenizer(
            targets,
            add_special_tokens=False,
        )["input_ids"]

        batch = self.tokenizer(
            input_text,
            padding=self.padding,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = batch["input_ids"]
        labels = input_ids.clone()

        for i, idx in enumerate(input_ids):
            sep_token_indices = torch.where(idx == self.tokenizer.sep_token_id)[0]
            start = sep_token_indices[0] - len(encoded_targets[i])
            end = sep_token_indices[0]
            input_ids[i, start:end] = self.mask_token_id
            labels[i, :start] = -100
            labels[i, end:] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": batch["attention_mask"],
            "labels": labels,
        }


In [17]:
# data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
data_collator = CustomDataCollator(tokenizer)

In [18]:
sample = hf_dataset['train'][:2]
sample_list = [{key: sample[key][i] for key in sample} for i in range(len(sample['input_text']))]
print("mask_token_id: ", tokenizer.mask_token_id, "\nmask_token:", tokenizer.mask_token)
print(sample_list)
c = data_collator(sample_list)
for key in c:
    print(key ,":", c[key])
print(tokenizer.decode(0))

mask_token_id:  4 
mask_token: [MASK]
[{'input_text': 'w a j d w e j v D @ f l e j m I N s O r\\ d { n d s E n d o w s E n d', 'target_word': 's E n d'}, {'input_text': 'D @ m A r\\ b @ l l i p s O r\\ S r\\ I N k s O r\\ b r\\= n z', 'target_word': 'b r\\= n z'}]
input_ids : tensor([[ 1, 16, 30, 14, 12, 16, 31, 14, 26, 23,  5, 29, 13, 31, 14, 19,  9, 36,
         10, 32,  6, 11, 12, 17,  7, 12, 10, 20,  7, 12, 33, 16,  4,  4,  4,  4,
          2],
        [ 1, 23,  5, 19, 25,  6, 11, 28,  5, 13, 13, 18, 24, 10, 32,  6, 11, 37,
          6, 11,  9, 36, 15, 10, 32,  6, 11,  4,  4,  4,  4,  4,  2,  3,  3,  3,
          3]])
attention_mask : tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])
labels : tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
  

In [19]:
def rhyme_score(word1: str, word2: str) -> int:
    if not is_phonetic:    
        end1 = xsampa_list(word1)
        end2 = xsampa_list(word2)
    else:
        end1 = word1
        end2 = word2
    length = min(len(end1), len(end2), 3)
    end1 = end1[-length:]
    end2 = end2[-length:]
    return SequenceMatcher(None, end1, end2).ratio()

def compute_metrics(pred):
    preds = torch.argmax(pred.predictions, dim=-1)
    batch_size = 32  
    total_rhyme_score = 0
    count = 0

    for i in range(0, len(preds), batch_size):
        batch_preds = preds[i:i + batch_size]
        batch_labels = pred.label_ids[i:i + batch_size]

        # Decode predictions and labels in batches
        decoded_preds = tokenizer.batch_decode(batch_preds, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(batch_labels, skip_special_tokens=True)

        # Compute rhyme scores for the batch
        batch_rhyme_scores = [
            rhyme_score(pred_word, target_word)
            for pred_word, target_word in zip(decoded_preds, decoded_labels)
        ]
        
        total_rhyme_score += sum(batch_rhyme_scores)
        count += len(batch_rhyme_scores)

    # Compute the mean rhyme score
    mean_rhyme_score = total_rhyme_score / count if count > 0 else 0
    return {"rhyme_score": mean_rhyme_score}



In [20]:
training_args = TrainingArguments(
    output_dir='/tmp/verses',
    num_train_epochs=3,
    learning_rate=1e-4,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=256,
    eval_strategy='epoch',
    logging_strategy='no',
    save_strategy='no',
    report_to='none',
    remove_unused_columns=False,
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=hf_dataset['train'],
    eval_dataset=hf_dataset['validation'],
    tokenizer=tokenizer,
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [21]:
trainer.train()

  0%|          | 0/5169 [00:00<?, ?it/s]

  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 2.8288426399230957, 'eval_runtime': 10.4753, 'eval_samples_per_second': 4677.186, 'eval_steps_per_second': 18.329, 'epoch': 1.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 2.7825145721435547, 'eval_runtime': 10.7593, 'eval_samples_per_second': 4553.723, 'eval_steps_per_second': 17.845, 'epoch': 2.0}


  0%|          | 0/192 [00:00<?, ?it/s]

{'eval_loss': 2.764601230621338, 'eval_runtime': 10.6198, 'eval_samples_per_second': 4613.566, 'eval_steps_per_second': 18.079, 'epoch': 3.0}
{'train_runtime': 698.4524, 'train_samples_per_second': 1893.973, 'train_steps_per_second': 7.401, 'train_loss': 2.813008024097988, 'epoch': 3.0}


TrainOutput(global_step=5169, training_loss=2.813008024097988, metrics={'train_runtime': 698.4524, 'train_samples_per_second': 1893.973, 'train_steps_per_second': 7.401, 'total_flos': 4.457153434383468e+16, 'train_loss': 2.813008024097988, 'epoch': 3.0})

In [22]:
def evaluate_rhyme(model, dataset, tokenizer):
    model = model.to("cuda")
    model.eval()
    rhyme_scores = []
    batch_size = 256

    for i in range(0, len(dataset), batch_size):
        # if i > 16:
        #     break
        print(f"Processing example {i}/{len(dataset)} ...", end="\r")
        batch = dataset[i : i + batch_size]
        batch_sequence = [{key: batch[key][j] for key in batch} for j in range(len(batch["input_text"]))]
        inputs = data_collator(batch_sequence)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

        for j in range(len(batch["input_text"])):
            masked_token_index = torch.where(inputs["input_ids"][j] == tokenizer.mask_token_id)[0]
            predicted_index = logits[j, masked_token_index].argmax(-1)
            predicted_word = tokenizer.decode(predicted_index)
            target = tokenizer.decode(inputs["labels"][j, masked_token_index])
            if i < 16 and j < 8:
                print('predicted_word:', predicted_word, '-- target_word:', target)
            rhyme_scores.append(rhyme_score(predicted_word, target))

    return {"score": np.mean(rhyme_scores)}


In [23]:
print(tokenizer.mask_token_id)
print(tokenizer.mask_token)
print(hf_dataset['test'][0])
print(tokenizer(hf_dataset['test'][0]['input_text']))
print(tokenizer.decode(103))
print(tokenizer.decode(102))

4
[MASK]
{'input_text': 'I z h i h u z v I z @ dZ I n D @ l e j z i m I s t', 'target_word': 'm I s t'}
{'input_ids': [1, 9, 21, 35, 18, 35, 34, 21, 26, 9, 21, 5, 39, 9, 7, 23, 5, 13, 31, 14, 21, 18, 19, 9, 10, 8, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
b_
s_


In [24]:
evaluate_rhyme(model, hf_dataset['test'], tokenizer)


predicted_word: s e j d -- target_word: m I s t
predicted_word: d I s a j n d -- target_word: w I T s t U d
predicted_word: b O n -- target_word: w O l
predicted_word: h r r \ z -- target_word: h { n d z
predicted_word: I r \ r r r r @ r e j S @ n -- target_word: h I r \ @ f r \= d S a j r \
predicted_word: s e j z -- target_word: v e j s
predicted_word: f r r r z -- target_word: o w v r \=
predicted_word: s e j d -- target_word: w a j f
Processing example 54272/54439 ...

{'score': np.float64(0.48488522321619915)}

In [25]:
# del model

In [26]:
# import gc
# import ctypes
# import torch
# gc.collect()
# libc = ctypes.CDLL("libc.so.6") # clearing cache
# libc.malloc_trim(0)
# torch.cuda.empty_cache()