In [5]:
from datasets import Dataset, Audio
import librosa
import torchaudio
import IPython.display as ipd
from pathlib import Path
import jsonlines


# Dataset To train on 

In [6]:
# Define the path to the directory
data_dir = Path("/home/jupyter/advanced")

# Read data from a jsonl file and reformat it
data = {'key': [], 'audio': [], 'transcript': []}
with jsonlines.open(data_dir / "asr.jsonl") as reader:
    for obj in reader:
        for key, value in obj.items():
            if key == 'audio':
                data[key].append("/home/jupyter/advanced/audio/" + value)
            else:
                data[key].append(value)

# Convert to a Hugging Face dataset
dataset = Dataset.from_dict(data).cast_column("audio", Audio())

# Shuffle the dataset
dataset = dataset.shuffle(seed=42)

# Split the dataset into training, validation, and test sets
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset = dataset.select(range(train_size))
val_dataset = dataset.select(range(train_size, train_size + val_size))
test_dataset = dataset.select(range(train_size + val_size, train_size + val_size + test_size))

def prepare_dataset(batch):
    # load and (possibly) resample audio datato 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    # compute input length of audio sample in seconds
    batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]

    # encode target text to label ids
    batch["labels"] = processor.tokenizer(batch["transcript"]).input_ids
    return batch

# Apply preprocessing
#train_dataset = train_dataset.map(prepare_dataset, remove_columns=train_dataset.column_names)
#val_dataset = val_dataset.map(prepare_dataset, remove_columns=val_dataset.column_names)
#test_dataset = test_dataset.map(prepare_dataset, remove_columns=test_dataset.column_names)


In [21]:
#train_dataset[0]


# Model Training 


In [29]:
from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small.en")
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small.en")

In [30]:
from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small.en")

In [31]:
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [32]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

In [33]:
import evaluate

metric = evaluate.load("wer")

# evaluate with the 'normalised' WER
do_normalize_eval = True

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    # if do_normalize_eval:
    #     pred_str = [normalizer(pred) for pred in pred_str]
    #     label_str = [normalizer(label) for label in label_str]
    #     # filtering step to only evaluate the samples that correspond to non-zero references:
    #     pred_str = [pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0]
    #     label_str = [label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0]
    
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    return {"wer":wer}


In [34]:
from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en")
model.config.forced_decoder_ids = None
model.config.suppress_tokens = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
model.config.use_cache = False

In [35]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small.en",  # your repo name
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=50,
    max_steps=500,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=100,
    eval_steps=100,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)



In [39]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

max_steps is given, it will override any value given in num_train_epochs


In [40]:
trainer.train(resume_from_checkpoint = True)

There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


Step,Training Loss,Validation Loss


There were missing keys in the checkpoint model loaded: ['proj_out.weight'].


TrainOutput(global_step=501, training_loss=3.8343506582646433e-07, metrics={'train_runtime': 40.9022, 'train_samples_per_second': 782.355, 'train_steps_per_second': 12.224, 'total_flos': 9.49792269533184e+18, 'train_loss': 3.8343506582646433e-07, 'epoch': 11.64})

In [None]:
trainer.save_model("./ASR_Model")

In [3]:
from transformers import WhisperForConditionalGeneration, WhisperProcessor

model = WhisperForConditionalGeneration.from_pretrained("./ASR_Model")
processor = WhisperProcessor.from_pretrained("./ASR_Model")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
from transformers import pipeline

device = "cuda:0" 

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    chunk_length_s=30,
    device=device,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
)

audio_file = '/home/jupyter/advanced/audio/audio_1670.wav'
audio_input, sample_rate = librosa.load(audio_file, sr=16000)

transcription = pipe(audio_file)

print(transcription['text'])


Transcription: Control here, surface-to-air missiles, heading zero six zero, engage the silver and orange camouflage helicopter. That's zero six zero, over.


In [9]:
import jiwer

# Function to calculate WER
def calculate_wer(references, hypotheses):
    wer_scores = []
    for ref, hyp in zip(references, hypotheses):
        wer_score = jiwer.wer(ref, hyp)
        wer_scores.append(wer_score)
    return wer_scores

# Calculate WER for each pair

In [11]:
from transformers.pipelines.pt_utils import KeyDataset
references = val_dataset['transcript']
hypothesis = []

for out in pipe(KeyDataset(val_dataset, 'audio')):
    hypothesis.append(out['text'])

print(len(references))
print(len(hypothesis))

wer_scores = calculate_wer(references, hypothesis)

350
350


In [13]:
import pandas as pd 

wer_scores = calculate_wer(references, hypothesis)
data = pd.DataFrame(dict(hypothesis=hypothesis, reference=references))
data


Unnamed: 0,hypothesis,reference
0,"Turret Romeo, heading two six five, engage bro...","Turret Romeo, heading two six five, engage bro..."
1,"Control tower to air defense turrets, target i...","Control tower to air defense turrets, target i..."
2,"Turret Bravo, deploy EMP on the silver, blue, ...","Turret Bravo, deploy EMP on the silver, blue, ..."
3,"Control to air defense turrets, deploy EMP tow...","Control to air defense turrets, deploy EMP tow..."
4,"Turret Alpha, deploy surface-to-air missiles a...","Turret Alpha, deploy surface-to-air missiles a..."
...,...,...
345,"Turret Bravo, engage the blue, black, and oran...","Turret Bravo, engage the blue, black, and oran..."
346,"Turret Alpha, deploy EMP on blue, white, and b...","Turret Alpha, deploy EMP on blue, white, and b..."
347,"Turret Bravo, engage grey and white fighter pl...","Turret Bravo, engage grey and white fighter pl..."
348,"Control to air defense turrets, prepare to dep...","Control to air defense turrets, prepare to dep..."


In [15]:
wer = jiwer.wer(list(data["hypothesis"]), list(data["reference"]))

print(f"WER: {wer * 100:.2f} %")

WER: 3.03 %


In [17]:
# Display the results
for i, score in enumerate(wer_scores):
    print(f"Reference {i+1}: {references[i]}")
    print(f"Hypothesis {i+1}: {hypothesis[i]}")
    print(f"WER: {score:.2%}\n")

Reference 1: Turret Romeo, heading two six five, engage brown camouflage fighter plane with anti-air artillery. Target confirmed. Strike with precision. Stand by for impact.
Hypothesis 1: Turret Romeo, heading two six five, engage brown camouflage fighter plane with anti-air artillery. Target confirmed. Strike with precision. Stand by for impact.
WER: 0.00%

Reference 2: Control tower to air defense turrets, target is a white drone heading one eight zero, deploy machine gun.
Hypothesis 2: Control tower to air defense turrets, target is a white drone heading one eight zero, deploy machine gun.
WER: 0.00%

Reference 3: Turret Bravo, deploy EMP on the silver, blue, and purple commercial aircraft heading two niner zero. Target locked, awaiting confirmation.
Hypothesis 3: Turret Bravo, deploy EMP on the silver, blue, and purple commercial aircraft heading two niner zero. Target locked, awaiting confirmation.
WER: 0.00%

Reference 4: Control to air defense turrets, deploy EMP towards the gre