In [29]:
import pandas as pd
from sklearn.model_selection import train_test_split

In [30]:
df_specs = pd.read_csv('vanilla_audio/data.csv')
df_specs = df_specs.rename(columns={'input': 'audio', 'target': 'sentence'})
df_specs['sentence'] = df_specs['sentence'].apply(lambda x: x.capitalize())

df_specs = df_specs.sample(frac=1, random_state=42).reset_index(drop=True)
df_specs

Unnamed: 0,audio,sentence
0,/home/g7/Desktop/names-dataset/vanilla_audio/2...,I tell eeli his conduct is often condemnable a...
1,/home/g7/Desktop/names-dataset/vanilla_audio/1...,Perhaps he is ill aylen said you know he is ne...
2,/home/g7/Desktop/names-dataset/vanilla_audio/4...,Dora nodded so he said then she hastened to ad...
3,/home/g7/Desktop/names-dataset/vanilla_audio/5...,After much perplexed consultation gottfried ha...
4,/home/g7/Desktop/names-dataset/vanilla_audio/4...,From jesper of wild roses clover and honeysuck...
...,...,...
3965,/home/g7/Desktop/names-dataset/vanilla_audio/1...,And not receiving the assistance he expected f...
3966,/home/g7/Desktop/names-dataset/vanilla_audio/7...,A household book once on a time i discovered m...
3967,/home/g7/Desktop/names-dataset/vanilla_audio/7...,Frank had an idea they would be visited by a b...
3968,/home/g7/Desktop/names-dataset/vanilla_audio/6...,For the fourth piece charisma harding slightly...


In [31]:
df_train, df_test = train_test_split(df_specs, test_size=0.2, random_state=42)

In [32]:
df_test

Unnamed: 0,audio,sentence
2078,/home/g7/Desktop/names-dataset/vanilla_audio/3...,Glad the journey had been interrupted valérie ...
1971,/home/g7/Desktop/names-dataset/vanilla_audio/3...,Geir did not sing as usual while preparing for...
2686,/home/g7/Desktop/names-dataset/vanilla_audio/4...,And it is just as well to postpone any engagem...
211,/home/g7/Desktop/names-dataset/vanilla_audio/5...,And i am just wondering if she has changed her...
3532,/home/g7/Desktop/names-dataset/vanilla_audio/3...,Therefore the work was pushed on briskly gray ...
...,...,...
2254,/home/g7/Desktop/names-dataset/vanilla_audio/2...,How ingenious conseil said to reduce dividing ...
602,/home/g7/Desktop/names-dataset/vanilla_audio/4...,Charlie sloane's name was written up with em w...
442,/home/g7/Desktop/names-dataset/vanilla_audio/8...,Here she blushed and lord conniston gusta walk...
3807,/home/g7/Desktop/names-dataset/vanilla_audio/4...,The day passed pleasantly enough however in a ...


In [33]:
df_train

Unnamed: 0,audio,sentence
2092,/home/g7/Desktop/names-dataset/vanilla_audio/6...,They did not think either of the danger which ...
2189,/home/g7/Desktop/names-dataset/vanilla_audio/1...,Simon screecher was more than willing and they...
736,/home/g7/Desktop/names-dataset/vanilla_audio/4...,Still retains in its broad hospitable lines so...
3300,/home/g7/Desktop/names-dataset/vanilla_audio/4...,On which account said the major grasping the l...
3215,/home/g7/Desktop/names-dataset/vanilla_audio/4...,You don't think it was intentional surely i sa...
...,...,...
1130,/home/g7/Desktop/names-dataset/vanilla_audio/4...,There were always people going in and out of t...
1294,/home/g7/Desktop/names-dataset/vanilla_audio/8...,Dad will want to know how we are doing as he w...
860,/home/g7/Desktop/names-dataset/vanilla_audio/2...,Only the english seem to have a fortissimo tas...
3507,/home/g7/Desktop/names-dataset/vanilla_audio/4...,And since that though it's quite a little whil...


In [34]:
from datasets import Dataset, DatasetDict, Audio

In [35]:
dataset = DatasetDict()
dataset["train"] = Dataset.from_pandas(df_train).cast_column("audio", Audio())
dataset["test"] = Dataset.from_pandas(df_test).cast_column("audio", Audio())

In [36]:
dataset

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence', '__index_level_0__'],
        num_rows: 3176
    })
    test: Dataset({
        features: ['audio', 'sentence', '__index_level_0__'],
        num_rows: 794
    })
})

In [37]:
dataset = dataset.remove_columns(["__index_level_0__"])
dataset

DatasetDict({
    train: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 3176
    })
    test: Dataset({
        features: ['audio', 'sentence'],
        num_rows: 794
    })
})

In [38]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [39]:
from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")


In [40]:
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="English", task="transcribe")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [41]:
input_str = dataset["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"Input:                 {input_str}")
print(f"Decoded w/ special:    {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal:             {input_str == decoded_str}")


Input:                 They did not think either of the danger which threatened them should the convicts return or of the precautions to be taken for the future but on this day while pencroft watched by the sick bed justus harding and the reporter consulted as to what it would be best to do
Decoded w/ special:    <|startoftranscript|><|en|><|transcribe|><|notimestamps|>They did not think either of the danger which threatened them should the convicts return or of the precautions to be taken for the future but on this day while pencroft watched by the sick bed justus harding and the reporter consulted as to what it would be best to do<|endoftext|>
Decoded w/out special: They did not think either of the danger which threatened them should the convicts return or of the precautions to be taken for the future but on this day while pencroft watched by the sick bed justus harding and the reporter consulted as to what it would be best to do
Are equal:             True


In [42]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="English", task="transcribe")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [43]:
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

In [44]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]
    
    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch


In [45]:
dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"], num_proc=1)

Map:   0%|          | 0/3176 [00:00<?, ? examples/s]

Map:   0%|          | 0/794 [00:00<?, ? examples/s]

In [46]:
dataset

DatasetDict({
    train: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 3176
    })
    test: Dataset({
        features: ['input_features', 'labels'],
        num_rows: 794
    })
})

In [47]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")


In [48]:
model.generation_config.language = "english"
model.generation_config.task = "transcribe"

model.generation_config.forced_decoder_ids = None

In [49]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [50]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)


In [3]:
# import evaluate

# metric = evaluate.load("wer")

from jiwer import wer

print(wer("hello", "hillo") )


1.0


In [52]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
nlp = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer)

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [53]:
# def compute_metrics(pred):
#     pred_ids = pred.predictions
#     label_ids = pred.label_ids

#     # replace -100 with the pad_token_id
#     label_ids[label_ids == -100] = tokenizer.pad_token_id

#     # we do not want to group tokens when computing the metrics
#     pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
#     label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

#     wer = 100 * metric.compute(predictions=pred_str, references=label_str)

#     return {"wer": wer}



def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # Handle padding masking
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # Decode predictions and labels from their token IDs
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    pred_entities = []
    for result in pred_str:
        result = nlp(result)
        doc_entities = []
        for entity in result:
            # Normalize B-PER and I-PER to PER and include only if PER
            if entity['entity'] == 'B-PER' or entity['entity'] == 'I-PER':
                doc_entities.append((entity['word'], 'PER'))
        # if doc_entities:  # Only append if there are PER entities in the document
        pred_entities.append(len(doc_entities))
        # print("pred_entity: ", pred_entities)

    true_entities = []
    for result in label_str:
        result = nlp(result)
        doc_entities = []
        for entity in result:
            # Normalize B-PER and I-PER to PER and include only if PER
            if entity['entity'] == 'B-PER' or entity['entity'] == 'I-PER':
                doc_entities.append((entity['word'], 'PER'))
        # if doc_entities:  # Only append if there are PER entities in the document
        true_entities.append(len(doc_entities))
        # print("true_entity: ", true_entities)

    ratio = 0
    count = 0
    for i in range(len(pred_entities)):
        if true_entities[i]:
            ratio += 100 * (pred_entities[i] / true_entities[i])
            count += 1
    ratio /= count

    # Calculate WER if needed
    print("computing wer metric")
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"ner percent": ratio, "wer": wer}



In [57]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./names-whisper-en-spectrogram-vanilla",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=5000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
    push_to_hub_model_id="names-whisper-en-spectrogram-vanilla",
)




In [58]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)


max_steps is given, it will override any value given in num_train_epochs


In [59]:
trainer.train()

  0%|          | 0/5000 [00:00<?, ?it/s]

`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...


{'loss': 0.7751, 'grad_norm': 8.88510513305664, 'learning_rate': 4.6000000000000004e-07, 'epoch': 0.13}
{'loss': 0.6491, 'grad_norm': 6.019869804382324, 'learning_rate': 9.600000000000001e-07, 'epoch': 0.25}
{'loss': 0.5225, 'grad_norm': 4.296834945678711, 'learning_rate': 1.46e-06, 'epoch': 0.38}
{'loss': 0.3885, 'grad_norm': 3.8177144527435303, 'learning_rate': 1.9600000000000003e-06, 'epoch': 0.5}
{'loss': 0.258, 'grad_norm': 3.545436143875122, 'learning_rate': 2.46e-06, 'epoch': 0.63}
{'loss': 0.1984, 'grad_norm': 3.859245538711548, 'learning_rate': 2.96e-06, 'epoch': 0.75}
{'loss': 0.1589, 'grad_norm': 2.6619577407836914, 'learning_rate': 3.46e-06, 'epoch': 0.88}
{'loss': 0.1546, 'grad_norm': 2.9576919078826904, 'learning_rate': 3.96e-06, 'epoch': 1.01}
{'loss': 0.103, 'grad_norm': 2.112499713897705, 'learning_rate': 4.4600000000000005e-06, 'epoch': 1.13}
{'loss': 0.0939, 'grad_norm': 3.0689077377319336, 'learning_rate': 4.960000000000001e-06, 'epoch': 1.26}
{'loss': 0.092, 'grad_

  0%|          | 0/100 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


computing wer metric
{'eval_loss': 0.035529669374227524, 'eval_ner percent': 98.59038142620231, 'eval_wer': 0.9813998209608434, 'eval_runtime': 172.482, 'eval_samples_per_second': 4.603, 'eval_steps_per_second': 0.58, 'epoch': 5.03}




{'loss': 0.0013, 'grad_norm': 0.049123238772153854, 'learning_rate': 8.83777777777778e-06, 'epoch': 5.15}
{'loss': 0.0013, 'grad_norm': 0.054848361760377884, 'learning_rate': 8.782222222222223e-06, 'epoch': 5.28}
{'loss': 0.0016, 'grad_norm': 0.04139916971325874, 'learning_rate': 8.726666666666667e-06, 'epoch': 5.4}
{'loss': 0.0015, 'grad_norm': 0.32663199305534363, 'learning_rate': 8.671111111111113e-06, 'epoch': 5.53}
{'loss': 0.0018, 'grad_norm': 0.048426542431116104, 'learning_rate': 8.615555555555555e-06, 'epoch': 5.65}
{'loss': 0.0013, 'grad_norm': 0.1585712730884552, 'learning_rate': 8.560000000000001e-06, 'epoch': 5.78}
{'loss': 0.0013, 'grad_norm': 0.11793075501918793, 'learning_rate': 8.504444444444445e-06, 'epoch': 5.9}
{'loss': 0.0012, 'grad_norm': 0.031709056347608566, 'learning_rate': 8.448888888888889e-06, 'epoch': 6.03}
{'loss': 0.0015, 'grad_norm': 0.03784472495317459, 'learning_rate': 8.393333333333335e-06, 'epoch': 6.16}
{'loss': 0.001, 'grad_norm': 0.034115739166736

  0%|          | 0/100 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


computing wer metric
{'eval_loss': 0.03694279119372368, 'eval_ner percent': 97.67827529021558, 'eval_wer': 0.9847153608965219, 'eval_runtime': 172.575, 'eval_samples_per_second': 4.601, 'eval_steps_per_second': 0.579, 'epoch': 10.05}




{'loss': 0.0005, 'grad_norm': 0.01860830746591091, 'learning_rate': 6.615555555555556e-06, 'epoch': 10.18}
{'loss': 0.0007, 'grad_norm': 0.013279497623443604, 'learning_rate': 6.560000000000001e-06, 'epoch': 10.3}
{'loss': 0.0006, 'grad_norm': 0.01789858192205429, 'learning_rate': 6.504444444444446e-06, 'epoch': 10.43}
{'loss': 0.0005, 'grad_norm': 0.015045026317238808, 'learning_rate': 6.448888888888889e-06, 'epoch': 10.55}
{'loss': 0.0005, 'grad_norm': 0.022168666124343872, 'learning_rate': 6.393333333333334e-06, 'epoch': 10.68}
{'loss': 0.0005, 'grad_norm': 0.016081007197499275, 'learning_rate': 6.3377777777777786e-06, 'epoch': 10.8}
{'loss': 0.0004, 'grad_norm': 0.0125707583501935, 'learning_rate': 6.282222222222223e-06, 'epoch': 10.93}
{'loss': 0.0004, 'grad_norm': 0.02541368082165718, 'learning_rate': 6.2266666666666675e-06, 'epoch': 11.06}
{'loss': 0.0004, 'grad_norm': 0.015265116468071938, 'learning_rate': 6.171111111111112e-06, 'epoch': 11.18}
{'loss': 0.0004, 'grad_norm': 0.0

  0%|          | 0/100 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


computing wer metric
{'eval_loss': 0.03863412141799927, 'eval_ner percent': 97.67827529021558, 'eval_wer': 0.9946619807035576, 'eval_runtime': 168.5908, 'eval_samples_per_second': 4.71, 'eval_steps_per_second': 0.593, 'epoch': 15.08}




{'loss': 0.0003, 'grad_norm': 0.008313920348882675, 'learning_rate': 4.393333333333334e-06, 'epoch': 15.2}
{'loss': 0.0002, 'grad_norm': 0.008601675741374493, 'learning_rate': 4.337777777777778e-06, 'epoch': 15.33}
{'loss': 0.0002, 'grad_norm': 0.008553043007850647, 'learning_rate': 4.282222222222222e-06, 'epoch': 15.45}
{'loss': 0.0002, 'grad_norm': 0.008230718784034252, 'learning_rate': 4.226666666666667e-06, 'epoch': 15.58}
{'loss': 0.0003, 'grad_norm': 0.00891111046075821, 'learning_rate': 4.171111111111111e-06, 'epoch': 15.7}
{'loss': 0.0003, 'grad_norm': 0.008070187643170357, 'learning_rate': 4.115555555555556e-06, 'epoch': 15.83}
{'loss': 0.0002, 'grad_norm': 0.007887458428740501, 'learning_rate': 4.060000000000001e-06, 'epoch': 15.95}
{'loss': 0.0002, 'grad_norm': 0.009624909609556198, 'learning_rate': 4.004444444444445e-06, 'epoch': 16.08}
{'loss': 0.0002, 'grad_norm': 0.007502677850425243, 'learning_rate': 3.948888888888889e-06, 'epoch': 16.21}
{'loss': 0.0002, 'grad_norm': 0

  0%|          | 0/100 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


computing wer metric
{'eval_loss': 0.039706721901893616, 'eval_ner percent': 97.92703150912105, 'eval_wer': 1.0079241404462715, 'eval_runtime': 169.5154, 'eval_samples_per_second': 4.684, 'eval_steps_per_second': 0.59, 'epoch': 20.1}




{'loss': 0.0002, 'grad_norm': 0.006983945611864328, 'learning_rate': 2.1711111111111113e-06, 'epoch': 20.23}
{'loss': 0.0002, 'grad_norm': 0.00742976414039731, 'learning_rate': 2.1155555555555557e-06, 'epoch': 20.35}
{'loss': 0.0002, 'grad_norm': 0.006333294790238142, 'learning_rate': 2.06e-06, 'epoch': 20.48}
{'loss': 0.0002, 'grad_norm': 0.005302124656736851, 'learning_rate': 2.0044444444444446e-06, 'epoch': 20.6}
{'loss': 0.0002, 'grad_norm': 0.005807811859995127, 'learning_rate': 1.948888888888889e-06, 'epoch': 20.73}
{'loss': 0.0002, 'grad_norm': 0.0060289171524345875, 'learning_rate': 1.8933333333333333e-06, 'epoch': 20.85}
{'loss': 0.0002, 'grad_norm': 0.0048184082843363285, 'learning_rate': 1.837777777777778e-06, 'epoch': 20.98}
{'loss': 0.0002, 'grad_norm': 0.006392026785761118, 'learning_rate': 1.7822222222222225e-06, 'epoch': 21.11}
{'loss': 0.0002, 'grad_norm': 0.0052548483945429325, 'learning_rate': 1.7266666666666667e-06, 'epoch': 21.23}
{'loss': 0.0002, 'grad_norm': 0.00

  0%|          | 0/100 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


computing wer metric
{'eval_loss': 0.040159180760383606, 'eval_ner percent': 97.92703150912105, 'eval_wer': 1.0079241404462715, 'eval_runtime': 169.077, 'eval_samples_per_second': 4.696, 'eval_steps_per_second': 0.591, 'epoch': 25.13}


There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


{'train_runtime': 7492.6385, 'train_samples_per_second': 10.677, 'train_steps_per_second': 0.667, 'train_loss': 0.020407422031089664, 'epoch': 25.13}


TrainOutput(global_step=5000, training_loss=0.020407422031089664, metrics={'train_runtime': 7492.6385, 'train_samples_per_second': 10.677, 'train_steps_per_second': 0.667, 'total_flos': 2.3029114945536e+19, 'train_loss': 0.020407422031089664, 'epoch': 25.12562814070352})

In [60]:
kwargs = {
    "dataset_tags": "Libri",
    "dataset": "LibriSpeech",  # a 'pretty' name for the training dataset
    "dataset_args": "config: en, split: test",
    "language": "en",
    "model_name": "names Whisper small",  # a 'pretty' name for your model
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
}


In [61]:
trainer.save_model() 
trainer.push_to_hub()
tokenizer.push_to_hub("names-whisper-en-spectrogram-vanilla") 


Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618

README.md:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/shahd237/names-whisper-en-spectrogram-vanilla/commit/e4cf18e828e978df64dc7307237790a5a8685b25', commit_message='Upload tokenizer', commit_description='', oid='e4cf18e828e978df64dc7307237790a5a8685b25', pr_url=None, pr_revision=None, pr_num=None)

# Testing

In [62]:
df_test

Unnamed: 0,audio,sentence
2078,/home/g7/Desktop/names-dataset/vanilla_audio/3...,Glad the journey had been interrupted valérie ...
1971,/home/g7/Desktop/names-dataset/vanilla_audio/3...,Geir did not sing as usual while preparing for...
2686,/home/g7/Desktop/names-dataset/vanilla_audio/4...,And it is just as well to postpone any engagem...
211,/home/g7/Desktop/names-dataset/vanilla_audio/5...,And i am just wondering if she has changed her...
3532,/home/g7/Desktop/names-dataset/vanilla_audio/3...,Therefore the work was pushed on briskly gray ...
...,...,...
2254,/home/g7/Desktop/names-dataset/vanilla_audio/2...,How ingenious conseil said to reduce dividing ...
602,/home/g7/Desktop/names-dataset/vanilla_audio/4...,Charlie sloane's name was written up with em w...
442,/home/g7/Desktop/names-dataset/vanilla_audio/8...,Here she blushed and lord conniston gusta walk...
3807,/home/g7/Desktop/names-dataset/vanilla_audio/4...,The day passed pleasantly enough however in a ...


In [83]:
# from torch.utils.data import DataLoader, Dataset
# import torch

# # Define a simple dataset
# class MyDataset(Dataset):
#     def __init__(self):
#         self.data = torch.arange(0, 100) 
#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, index):
#         return self.data[index]

# # Create an instance of the dataset
# my_dataset = MyDataset()

In [2]:
print(df_test.head())

NameError: name 'df_test' is not defined

In [103]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_features = [item['input_features'] for item in batch]
    labels = [item['labels'] for item in batch]

    # Pad sequences
    input_features_padded = pad_sequence(input_features, batch_first=True, padding_value=0)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)

    return {"input_features": input_features_padded, "labels": labels_padded}

In [105]:
from torch.utils.data import DataLoader, Dataset
import torch
from datasets import load_metric
import torchaudio
from torch.nn.utils.rnn import pad_sequence

# Define the metric
metric = load_metric("wer")

class TestDataset(Dataset):
    def __init__(self, df, processor):
        self.df = df
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        audio_path = item['audio']
        sentence = item['sentence']

        # Load and preprocess the audio
        waveform, sample_rate = torchaudio.load(audio_path)
        input_features = self.processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt").input_features[0]

        # Tokenize the sentence
        labels = self.processor.tokenizer(sentence, return_tensors="pt").input_ids[0]

        return {"input_features": input_features, "labels": labels}

def collate_fn(batch):
    input_features = [item['input_features'] for item in batch]
    labels = [item['labels'] for item in batch]

    # Pad sequences
    input_features_padded = pad_sequence(input_features, batch_first=True, padding_value=0)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-100)

    return {"input_features": input_features_padded, "labels": labels_padded}

def evaluate(model, processor, dataloader):
    model.eval()
    wer_accumulator = []

    with torch.no_grad():
        for batch in dataloader:
            input_features = batch["input_features"].to(model.device)
            labels = batch["labels"].to(model.device)

            outputs = model.generate(input_features, attention_mask=(input_features > 0))
            pred_str = processor.batch_decode(outputs, skip_special_tokens=True)
            label_str = processor.batch_decode(labels, skip_special_tokens=True)

            wer_accumulator.append(100* metric.compute(predictions=pred_str, references=label_str))

    avg_wer = sum(wer_accumulator) / len(wer_accumulator)
    return avg_wer

# Assuming df_test is already defined and contains the test data
test_dataset = TestDataset(df_test, processor)
test_dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)

# Calculate WER
wer = evaluate(model, processor, test_dataloader)
print(f"Word Error Rate: {wer:.2f}%")


# Calculate WER if needed
# print("computing wer metric")
# wer = 100 * metric.compute(predictions=pred_str, references=label_str)

# return {"ner percent": ratio, "wer": wer}

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Word Error Rate: 2.83%


# Testing on random Librispeech

In [11]:
# Replace 'your-username/the-name-you-picked' with the appropriate model identifier
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
finetuned_model = WhisperForConditionalGeneration.from_pretrained("shahd237/names-whisper-en-spectrogram-vanilla")
finetuned_processor = WhisperProcessor.from_pretrained("shahd237/names-whisper-en-spectrogram-vanilla")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
import librosa
import numpy as np
import whisper
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict, Audio

In [7]:
#load data from random_1000_from_librispeech.csv
df_specs = pd.read_csv('random_1000_from_librispeech.csv')
df_specs

Unnamed: 0.1,Unnamed: 0,audio,sentence
0,0,/home/g7/Desktop/LibriSpeech/train-clean-100/5...,Olenin was as happy as a boy of twelve tie it ...
1,1,/home/g7/Desktop/LibriSpeech/train-clean-100/3...,I could just make out that he had a book as we...
2,2,/home/g7/Desktop/LibriSpeech/train-clean-100/2...,And what a centrepiece it was it required the ...
3,3,/home/g7/Desktop/LibriSpeech/train-clean-100/6...,The iron was rusty the leather torn the wood w...
4,4,/home/g7/Desktop/LibriSpeech/train-clean-100/7...,Will satisfy my everlasting hatred my courage ...
...,...,...,...
995,995,/home/g7/Desktop/LibriSpeech/train-clean-100/7...,The piano bard the piano rhapsodist the piano ...
996,996,/home/g7/Desktop/LibriSpeech/train-clean-100/3...,Then another and a different horror fell to my...
997,997,/home/g7/Desktop/LibriSpeech/train-clean-100/5...,But the cold drove us out and making a large f...
998,998,/home/g7/Desktop/LibriSpeech/train-clean-100/1...,Only i beg it shall not be before midnight


In [8]:
# df_specs['sentence'] = df_specs['sentence'].apply(lambda x: x.capitalize())

df_specs = df_specs.sample(frac=1, random_state=42).reset_index(drop=True)
df_specs

Unnamed: 0.1,Unnamed: 0,audio,sentence
0,521,/home/g7/Desktop/LibriSpeech/train-clean-100/1...,Love and water brought back all her strength s...
1,737,/home/g7/Desktop/LibriSpeech/train-clean-100/1...,That the captain determined to run into wigwam...
2,740,/home/g7/Desktop/LibriSpeech/train-clean-100/4...,Growled sam kitteridge bitterly resenting the ...
3,660,/home/g7/Desktop/LibriSpeech/train-clean-100/7...,Sir sidney colvin regrets that the love letter...
4,411,/home/g7/Desktop/LibriSpeech/train-clean-100/4...,Leading their young full fledged and about as ...
...,...,...,...
995,106,/home/g7/Desktop/LibriSpeech/train-clean-100/4...,And unconscious of the danger stood her ground...
996,270,/home/g7/Desktop/LibriSpeech/train-clean-100/3...,And every knight shall have a squire and two y...
997,860,/home/g7/Desktop/LibriSpeech/train-clean-100/7...,His appearance was welcomed by a joyful shout ...
998,435,/home/g7/Desktop/LibriSpeech/train-clean-100/2...,Well isabel you must be aware that it is an aw...


In [9]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
ner_tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER-uncased")
ner_model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER-uncased")
nlp = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer)

Some weights of the model checkpoint at dslim/bert-base-NER-uncased were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
# true_entities = []
# n = len(df_specs)
# for i in range (n):
#     ground_truth = df_specs['sentence'][i]
#     print("true result: ", ground_truth)
#     result = nlp(ground_truth)
#     doc_entities = []
#     for entity in result:
#         # Normalize B-PER and I-PER to PER and include only if PER
#         if entity['entity'] == 'B-PER' or entity['entity'] == 'I-PER':
#             doc_entities.append((entity['word'], 'PER'))
#     # if doc_entities:  # Only append if there are PER entities in the document
#     true_entities.append(len(doc_entities))
#     print("len true_entity: ", len(true_entities))

In [7]:
# print("true_entities: ", true_entities)

In [13]:
import torch
from transformers import pipeline
from jiwer import wer

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

pipe = pipeline(
    "automatic-speech-recognition",
    model=finetuned_model,
    tokenizer=finetuned_processor.tokenizer,
    feature_extractor=finetuned_processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)

n = len(df_specs)
wer_accumulator = []
cer_accumulator = []
pred_entities = []
predictions = []
sentences = []

for i in range (n):
  sample = df_specs['audio'][i]
  prediction = pipe(sample)["text"]
  print(i, sep=' ')
  #list of predictions
  predictions.append(prediction)
  #list of true sentences
  sentences.append(df_specs['sentence'][i])
  #calculate WER
  wer_ans = wer(prediction.lower(),df_specs['sentence'][i].lower())
  wer_accumulator.append(wer_ans)

# Calculate WER
avg_wer = 100 * (sum(wer_accumulator) / len(wer_accumulator))

print(f"Average Word Error Rate: {avg_wer:.2f}%")

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [25]:
# Calculate NER
#align first 
import jiwer

predictions = [prediction.lower() for prediction in predictions]
sentences = [sentence.lower() for sentence in sentences]
out = jiwer.process_words(
    predictions,
    sentences,
)
visualization = jiwer.visualize_alignment(out)
ref_list = []
hyp_list = []

print("out:", out)
print("visualize: ",jiwer.visualize_alignment(out))
lines = visualization.split('\n')
for line in lines:
    if line.startswith("REF:"):
        ref_list.append(line.strip().replace("REF: ", ""))
    elif line.startswith("HYP:"):
        hyp_list.append(line.strip().replace("HYP: ", ""))

# Now you have REF and HYP lines in two separate lists
print("References:", ref_list)
print("Hypotheses:", hyp_list)


visualize:  sentence 4
REF: sir sydney colvin regrets that the love letters of keats to fanny were ever published it would be as reasonable in my opinion
HYP: sir sidney colvin regrets that the love letters of keats to fanny were ever published it would be as reasonable in my opinion
              S                                                                                                                   

sentence 8
REF: taking one absence with the other he had been away from her chiefly in paris pursuing his own course and his own pleasure how fared it with lady isabel just as it must be expected to  fay her and does  fay her
HYP: taking one absence with the other he had been away from her chiefly in paris pursuing his own course and his own pleasure how fared it with lady isabel just as it must be expected to fare *** and does fare ***
                                                                                                                                              

In [26]:
# calculate NER for each true sentence
true_entities = []
for i in range(len(ref_list)):
    result = nlp(ref_list[i])
    doc_entities = []
    for entity in result:
        # Normalize B-PER and I-PER to PER and include only if PER
        if entity['entity'] == 'B-PER' or entity['entity'] == 'I-PER':
            doc_entities.append((entity['word'], 'PER'))
    # if doc_entities:  # Only append if there are PER entities in the document
    true_entities.append(doc_entities)
print("len true_entity: ", len(true_entities))
print("true_entities: ", true_entities)

len true_entity:  547
true_entities:  [[('sydney', 'PER'), ('col', 'PER'), ('##vin', 'PER'), ('ke', 'PER'), ('##ats', 'PER'), ('fanny', 'PER')], [('isabel', 'PER')], [], [], [], [], [], [], [('ala', 'PER'), ('##po', 'PER'), ('salt', 'PER'), ('##are', 'PER'), ('##lo', 'PER')], [], [], [('cr', 'PER'), ('##eman', 'PER'), ('##te', 'PER')], [], [('to', 'PER'), ('crystal', 'PER')], [('diana', 'PER')], [], [], [], [('pr', 'PER'), ('##udence', 'PER'), ('mar', 'PER'), ('##gar', 'PER'), ('##it', 'PER'), ('delicious', 'PER'), ('pr', 'PER'), ('##udence', 'PER')], [], [('carly', 'PER'), ('##le', 'PER'), ('barbara', 'PER'), ('ma', 'PER'), ('##mma', 'PER')], [('sc', 'PER'), ('##ag', 'PER'), ('##zi', 'PER')], [('eric', 'PER'), ('gunn', 'PER'), ('##b', 'PER'), ('##jo', 'PER'), ('##rn', 'PER')], [('henry', 'PER')], [], [('margaret', 'PER')], [], [], [], [], [], [], [('am', 'PER'), ('##nity', 'PER')], [('hai', 'PER'), ('##oc', 'PER'), ('##ent', 'PER'), ('co', 'PER'), ('##rone', 'PER'), ('##l', 'PER'), ('

In [27]:
#merge names to one word if they are in the same entity depending on the hashtags #
true_entities_merged = []
for i in range(len(true_entities)):
    doc_entities = true_entities[i]
    doc_entities_merged = []
    for j in range(len(doc_entities)):
        if doc_entities[j][1] == 'PER':
            doc_entities_merged.append((doc_entities[j][0], 'PER'))
            while(j+1 < len(doc_entities) and doc_entities[j+1][1] == 'PER' and doc_entities[j+1][0][0] =='#'):
                #remove the hashtag and merge the names
                doc_entities_merged[-1] = (doc_entities_merged[-1][0] + doc_entities[j+1][0][2:], 'PER')
                doc_entities[j+1] = (doc_entities[j+1][0], '*')
                j += 1
            # doc_entities_merged.append(doc_entities[j])
    true_entities_merged.append(doc_entities_merged)

print("true_entities_merged: ", true_entities_merged)           

true_entities_merged:  [[('sydney', 'PER'), ('colvin', 'PER'), ('keats', 'PER'), ('fanny', 'PER')], [('isabel', 'PER')], [], [], [], [], [], [], [('alapo', 'PER'), ('saltarelo', 'PER')], [], [], [('cremante', 'PER')], [], [('to', 'PER'), ('crystal', 'PER')], [('diana', 'PER')], [], [], [], [('prudence', 'PER'), ('margarit', 'PER'), ('delicious', 'PER'), ('prudence', 'PER')], [], [('carlyle', 'PER'), ('barbara', 'PER'), ('mamma', 'PER')], [('scagzi', 'PER')], [('eric', 'PER'), ('gunnbjorn', 'PER')], [('henry', 'PER')], [], [('margaret', 'PER')], [], [], [], [], [], [], [('amnity', 'PER')], [('haiocent', 'PER'), ('coronel', 'PER'), ('udo', 'PER')], [('d', 'PER'), ("'", 'PER'), ('artagnan', 'PER')], [], [('jane', 'PER')], [('nielsen', 'PER'), ('billy', 'PER'), ('alice', 'PER'), ('greggory', 'PER')], [('myriam', 'PER'), ('fang', 'PER'), ('nora', 'PER')], [], [('verry', 'PER'), ('cassie', 'PER'), ('grandfather', 'PER')], [('elizabeth', 'PER')], [('headie', 'PER')], [], [('agatha', 'PER'), (

In [28]:
#add the other words add (add entity * to it) to the true entities merged list to have the same length as the predictions
final_true_entities_merged = [] #list of lists of tuples (word, entity) for each sentence for all the words in the sentence
for i in range(len(true_entities_merged)): #for each sentence 
    doc_entities = true_entities_merged[i] 
    ref_sentence = ref_list[i].split()
    k = 0

    final_true_entities = []
    for j in range(len(ref_sentence)):
        if (len(doc_entities) == k or doc_entities[k][0].lower() not in ref_sentence[j].lower()): #if the word is not in the entities list add it with entity *
            # print(ref_sentence[j], doc_entities[k][0])
            final_true_entities.append((ref_sentence[j], '*'))
        else:
            #if the word is a substring of the entity word, add the original word to the list iinstead of the entity word (WITH THE ENTITY)
            final_true_entities.append((ref_sentence[j], doc_entities[k][1]))
            k += 1
    final_true_entities_merged.append(final_true_entities)
print("true_entities_merged: ", final_true_entities_merged)
    



In [29]:
for i in range(len(final_true_entities_merged)):
    ref_sentence = ref_list[i].split()
    doc_entities = final_true_entities_merged[i]

    for j in range(len(ref_sentence)):
        if (len(ref_sentence[j]) != len(doc_entities[j][0])):
            print("ref_sentence: ", ref_sentence[j])
            print("doc_entities: ", doc_entities[j][0])

In [30]:
# calculate NER for each prediction
pred_entities = []
for i in range(len(hyp_list)):
    result = nlp(hyp_list[i])
    doc_entities = []
    for entity in result:
        # Normalize B-PER and I-PER to PER and include only if PER
        if entity['entity'] == 'B-PER' or entity['entity'] == 'I-PER':
            doc_entities.append((entity['word'], 'PER'))
        else:
            doc_entities.append((entity['word'], '*'))
    # if doc_entities:  # Only append if there are PER entities in the document
    pred_entities.append(doc_entities)
print("len pred_entities: ", len(pred_entities))
print("pred_entities: ", pred_entities)

len pred_entities:  547
pred_entities:  [[('sidney', 'PER'), ('col', 'PER'), ('##vin', 'PER'), ('ke', 'PER'), ('##ats', 'PER'), ('fanny', 'PER')], [('paris', '*'), ('isabel', 'PER')], [('mari', '*')], [], [], [('ry', 'PER'), ('##nch', 'PER')], [], [], [('trojan', '*'), ('fe', '*'), ('##e', '*'), ('rome', '*'), ('lap', '*'), ('##o', '*'), ('salt', '*'), ('##ere', '*'), ('##llo', '*')], [], [], [('miss', 'PER'), ('gram', 'PER'), ('##mont', 'PER')], [], [('##r', '*'), ('old', 'PER'), ('toll', 'PER'), ('cr', 'PER'), ('##iste', 'PER'), ('##l', 'PER')], [('diana', 'PER')], [], [], [], [('pr', 'PER'), ('##udence', 'PER'), ('marguerite', 'PER'), ('delicious', 'PER'), ('pr', 'PER'), ('##udence', 'PER')], [], [('carly', 'PER'), ('##le', 'PER'), ('barbara', 'PER'), ('ma', 'PER'), ('##mma', 'PER')], [('ska', 'PER'), ('##ggs', 'PER'), ('##y', 'PER')], [('eric', 'PER'), ('gunn', 'PER'), ('##bio', 'PER'), ('##rn', 'PER'), ('iceland', '*')], [('henry', 'PER')], [], [('margaret', 'PER')], [], [('great'

In [31]:
#merge names to one word if they are in the same entity depending on the hashtags #
pred_entities_merged = []
for i in range(len(pred_entities)):
    doc_entities = pred_entities[i]
    doc_entities_merged = []
    for j in range(len(doc_entities)):
        if doc_entities[j][1] == 'PER':
            doc_entities_merged.append((doc_entities[j][0], 'PER'))
            while(j+1 < len(doc_entities) and doc_entities[j+1][1] == 'PER' and doc_entities[j+1][0][0] =='#'):
                #remove the hashtag and merge the names
                doc_entities_merged[-1] = (doc_entities_merged[-1][0] + doc_entities[j+1][0][2:], 'PER')
                doc_entities[j+1] = (doc_entities[j+1][0], '*')
                j += 1
            # doc_entities_merged.append(doc_entities[j])
    pred_entities_merged.append(doc_entities_merged)

print("pred_entities_merged: ", pred_entities_merged)

pred_entities_merged:  [[('sidney', 'PER'), ('colvin', 'PER'), ('keats', 'PER'), ('fanny', 'PER')], [('isabel', 'PER')], [], [], [], [('rynch', 'PER')], [], [], [], [], [], [('miss', 'PER'), ('grammont', 'PER')], [], [('old', 'PER'), ('toll', 'PER'), ('cristel', 'PER')], [('diana', 'PER')], [], [], [], [('prudence', 'PER'), ('marguerite', 'PER'), ('delicious', 'PER'), ('prudence', 'PER')], [], [('carlyle', 'PER'), ('barbara', 'PER'), ('mamma', 'PER')], [('skaggsy', 'PER')], [('eric', 'PER'), ('gunnbiorn', 'PER')], [('henry', 'PER')], [], [('margaret', 'PER')], [], [], [], [], [], [], [], [('hyacinth', 'PER'), ('coronel', 'PER'), ('udo', 'PER')], [('d', 'PER'), ("'", 'PER'), ('artagnan', 'PER')], [], [('jane', 'PER')], [('neilson', 'PER'), ('billy', 'PER'), ('alice', 'PER'), ('greggory', 'PER')], [('miriam', 'PER'), ('nora', 'PER')], [], [('verry', 'PER'), ('cassy', 'PER'), ('grandfather', 'PER')], [('elizabeth', 'PER')], [('hetty', 'PER')], [], [('agatha', 'PER')], [], [], [], [], [('m

In [32]:
#add the other words add (add entity * to it) to the true entities merged list to have the same length as the predictions
final_pred_entities_merged = [] #list of lists of tuples (word, entity) for each sentence for all the words in the sentence
for i in range(len(pred_entities_merged)): #for each sentence 
    doc_entities = pred_entities_merged[i] 
    hyp_sentence = hyp_list[i].split()
    k = 0

    final_pred_entities = []
    for j in range(len(hyp_sentence)):
        if (len(doc_entities) == k or doc_entities[k][0].lower() not in hyp_sentence[j].lower()): #if the word is not in the entities list add it with entity *
            # print(ref_sentence[j], doc_entities[k][0])
            final_pred_entities.append((hyp_sentence[j], '*'))
        else:
            #if the word is a substring of the entity word, add the original word to the list iinstead of the entity word (WITH THE ENTITY)
            final_pred_entities.append((hyp_sentence[j], doc_entities[k][1]))
            k += 1
    final_pred_entities_merged.append(final_pred_entities)
print("final_pred_entities_merged: ", final_pred_entities_merged)



In [33]:
# Function to extract only the tags and replace '*' with 'O'
def extract_tags(nested_list):
    tags = []
    for sublist in nested_list:
        # Collect tags from each tuple in the sublist, replace '*' with 'O'
        sublist_tags = ['O' if tag == '*' else 'I-' + tag for word, tag in sublist]
        if sublist_tags:  # Only add non-empty lists
            tags.append(sublist_tags)
    return tags


In [34]:
true_extracted_tags = extract_tags(final_true_entities_merged)
predicted_extracted_tags = extract_tags(final_pred_entities_merged)

In [35]:
#calculate precision, recall, f1 score
from seqeval.metrics import precision_score, recall_score, f1_score

precision = precision_score(true_extracted_tags, predicted_extracted_tags)
recall = recall_score(true_extracted_tags, predicted_extracted_tags)
f1 = f1_score(true_extracted_tags, predicted_extracted_tags)

print("Precision: ", precision)
print("Recall: ", recall)
print("F1: ", f1)


Precision:  0.8338028169014085
Recall:  0.8245125348189415
F1:  0.8291316526610644
