# UKW Marine Radio Chatter - Bridge 2 Bridge Communication
This notebook uses pretrained models to transcribe the audio files from the UKW Marine Radio Chatter - Bridge 2 Bridge Communication dataset. <br>
The dataset contains audio files and their corresponding transcriptions. Further we classify the speakers contained in the audio files.

In [3]:
import os
import IPython
import torchaudio
import torch
import wandb
from pydub import AudioSegment
from pytorch_lightning.utilities.types import STEP_OUTPUT
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from src.utils import txt_to_dataframe
import numpy as np

## Configuration - Data Directories

In [4]:
class Config:
    DATA_DIR = '../data/'
    AUDIO_DIR = DATA_DIR + 'audio/'
    TEXT_DIR = DATA_DIR + 'text/'
    DATASET_DIR = 'dataset/'
    
    KAGGLE_DATA_TAG = 'linogova/marine-radio-chatter-bridge-2-bridge-communication/1'
    KAGGLE_DATA_DIR = 'Marine_audio/'

config = Config()

In [5]:
from scipy import signal
import os
import time
import pandas as pd
import torch
import torchaudio
from joblib import Parallel, delayed
from torch.utils.data import Dataset
from src.utils import bcolors

c = bcolors()

def batch_data(data, max_duration=30):
    batches = []
    current_batch = []
    curr_start = 0
    
    if len(data) == 0:
        return []
    
    dur = data[-1]["end_time"] - data[0]["start_time"]
    
    if dur < 10:
        return []

    for entry in data:
        if entry["end_time"] - entry['start_time'] > max_duration:
            
            continue
        if entry['end_time'] - curr_start + 0.2  > max_duration:
            if len(current_batch) == 0:
                curr_start = entry['start_time'] - 0.2
            else:
                batches.append(current_batch)
                current_batch = []
                curr_start = entry['start_time'] - 0.2

        current_batch.append(entry)

    if current_batch:
        batches.append(current_batch)

    return batches

def inner_merge_batches(data):
    return [{
        'text': ' '.join([entry['transcript'] for entry in batch]),
        'start': batch[0]['start_time'],
        'end': batch[-1]['end_time']
    } for i, batch in enumerate(data)]

def lowpass_filter(audio_data, sr):
    # Create a lowpass filter
    b, a = signal.butter(4, 1300, 'low', fs=sr)
    # Apply the lowpass filter
    filtered_audio_data = signal.filtfilt(b, a, audio_data)
    return filtered_audio_data

def apply_rms_normalization(waveform):
    rms_value = waveform.pow(2).mean().sqrt()  # Calculate RMS value of the waveform
    target_rms = 0.1  # Example target RMS value
    normalized_waveform = waveform * (target_rms / rms_value)  # Scale waveform to desired RMS value
    return normalized_waveform

class UKWFunkSprache(Dataset):
    def __init__(self, 
                 file_ids, 
                 root_dir, 
                 proc=None,
                 rms_norm=False,
                 filter_data=False,
                 n_jobs=-1):
        self.feed_ids = file_ids
        self.root_dir = root_dir
        self.processor = proc
        self.rms_norm = rms_norm
        self.filter_data = filter_data

        print(f"\n{c.OKGREEN}Preloading Samples...{c.ENDC}")
        print(f"\n{c.OKCYAN}Audio Files:         {len(self.feed_ids)}{c.ENDC}")
        print(f"{c.OKCYAN}Jobs:                {n_jobs} {c.ENDC}\n")

        start_time = time.time()
        result = Parallel(n_jobs=n_jobs)(
            delayed(self.process_file)(idx) for idx in range(len(self.feed_ids))
        )
        result = [item for sublist in result for item in sublist]
        print(f"\n{c.OKGREEN}Preloading Complete!{c.ENDC}")

        self.audio_samples = [item['audio'] for item in result]
        self.transcriptions = [item['transcript'] for item in result]
        self.groups = [item['group'] for item in result]
        
        print(f"{c.OKCYAN}Number of Samples:   {len(self.audio_samples)} {c.ENDC}\n")

        end_time = time.time()
        t = end_time - start_time
        print(f"\n{c.OKBLUE}Time taken:      {int((t - (t % 60)) / 60)} min {t % 60} sec {c.ENDC}")

    def process_file(self, idx):
        feed_id = self.feed_ids[idx]
        audio_fpath = os.path.join(self.root_dir, f"audio/{feed_id}.wav")
        text_fpath = os.path.join(self.root_dir, f"text/{feed_id}.csv")

        waveform, sample_rate = torchaudio.load(audio_fpath, channels_first=True)
        waveform = waveform.float()
        transcripts_df = pd.read_csv(text_fpath)
        
        if self.rms_norm:
            waveform = apply_rms_normalization(waveform)
        
        if self.filter_data:
            waveform = lowpass_filter(waveform, sample_rate)
        
        batches = batch_data(transcripts_df.to_dict("records"))
        if len(batches) == 0:
            return []
        
        metadata = inner_merge_batches(batches)

        sample_group = str(feed_id)
        samples = []
        for i in range(len(metadata)):
            start_time = metadata[i]['start'] - 0.2
            end_time = metadata[i]['end'] + 0.2
            transcript = metadata[i]['text']

            start_sample = int(start_time * sample_rate)
            end_sample = int(end_time * sample_rate)

            sample = waveform[:, start_sample:end_sample].squeeze()
            
            if self.processor:
                sample = self.processor.feature_extractor(sample, sampling_rate=sample_rate, return_tensors="pt").input_features.squeeze(0)
                transcript = self.processor.tokenizer(transcript, return_tensors="pt").input_ids.squeeze(0)

            samples.append({
                'group': sample_group,
                'audio': sample,
                'transcript': transcript
            })
        return samples

    def __len__(self):
        return len(self.audio_samples)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        audio = self.audio_samples[idx]
        transcript = self.transcriptions[idx]

        return {
            "input_features": audio,
            "labels": transcript
        }


# Model

In [6]:
import torchmetrics
from transformers import get_linear_schedule_with_warmup
import pytorch_lightning as pl
import evaluate


class WhisperLightningModule(pl.LightningModule):
    def __init__(self, model_name: str, processor, learning_rate: float, weight_decay: float, warmup_steps: int, num_jobs: int = 8):
        super().__init__()
        self.save_hyperparameters()
        
        self.processor = processor
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.warmup_steps = warmup_steps
        self.num_jobs = num_jobs
        
        self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
        self.wer = torchmetrics.text.wer.WordErrorRate()
        self.val_preds = []
        self.val_true = []

    def forward(self, input_features, labels):
        return self.model(input_features=input_features, labels=labels)

    def training_step(self, batch, batch_idx):
        outputs = self(batch["input_features"], batch["labels"])
        loss = outputs.loss
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(batch["input_features"], batch["labels"])
        loss = outputs.loss
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        
        self.val_preds.append(outputs.logits.argmax(-1))
        self.val_true.append(batch["labels"])
        
    def on_validation_epoch_end(self) -> None:
        self.val_preds = Parallel(n_jobs=self.num_jobs)(
            delayed(self.processor.decode)(pred, skip_special_tokens=True) for b in self.val_preds for pred in b
        )
        self.val_true = Parallel(n_jobs=self.num_jobs)(
            delayed(self.processor.decode)(true, skip_special_tokens=True) for b in self.val_true for true in b
        )
        
        wer = self.wer(self.val_preds, self.val_true)
        self.log("val_wer", wer, prog_bar=True)
        
        self.val_preds = []
        self.val_true = []

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

In [7]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader


class SpeechDataModule(pl.LightningDataModule):
    def __init__(self, train_dataset, val_dataset, processor, batch_size: int, num_workers: int = 8):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.processor = processor
        self.batch_size = batch_size
        self.num_workers = num_workers

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn, num_workers=self.num_workers)

    def collate_fn(self, features):
        input_features = [feature["input_features"] for feature in features]
        batch = self.processor.feature_extractor.pad(
            [{"input_features": input_feature} for input_feature in input_features],
            return_tensors="pt"
        )

        labels = [feature["labels"] for feature in features]
        labels_batch = self.processor.tokenizer.pad(
            [{"input_ids": label} for label in labels],
            return_tensors="pt"
        )
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        if (labels[:, 0] == self.processor.tokenizer.pad_token_id).all():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch


# Hyperparameter Finetuning

In [8]:
from transformers import WhisperProcessor

model_config = {
    "model_name": "openai/whisper-tiny",
}

# Initialize the processor
processor = WhisperProcessor.from_pretrained(
    model_config["model_name"], 
    language='en', 
    task="transcribe", 
    do_normalize=True, 
    sampling_rate=16000, 
    return_tensors="pt", 
    device="cpu",
    local_files_only=True
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
model_config["filter_data"] = False
model_config["rms_norm"] = False

# Create the Datasets
feed_ids = [f.replace(".wav", "") for f in os.listdir(config.DATASET_DIR + "audio")]

ds_train = UKWFunkSprache(
    feed_ids[:1700], 
    config.DATASET_DIR, 
    proc=processor, 
    filter_data=model_config["filter_data"], 
    rms_norm=model_config["rms_norm"]
)
ds_val = UKWFunkSprache(
    feed_ids[1700:], 
    config.DATASET_DIR, 
    proc=processor, 
    filter_data=model_config["filter_data"], 
    rms_norm=model_config["rms_norm"]
)

model_config["num_train_samples"] = len(ds_train)
model_config["num_val_samples"] = len(ds_val)


[92mPreloading Samples...[0m

[96mAudio Files:         1700[0m
[96mJobs:                -1 [0m


[92mPreloading Complete![0m
[96mNumber of Samples:   3422 [0m


[94mTime taken:      1 min 45.1122784614563 sec [0m

[92mPreloading Samples...[0m

[96mAudio Files:         300[0m
[96mJobs:                -1 [0m


[92mPreloading Complete![0m
[96mNumber of Samples:   586 [0m


[94mTime taken:      0 min 14.88439154624939 sec [0m


In [None]:
import gc
from transformers import WhisperForConditionalGeneration
import optuna
from pytorch_lightning.loggers import WandbLogger
import wandb

torch.manual_seed(42)

def objective(trial):
    # Define the search space for hyperparameters
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3)
    weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-2)
    
    parameters = {
        "n_epochs": 2,
        "batch_size": 8,
        "learning_rate": learning_rate,
        "warmup_steps": 200,
        "weight_decay": weight_decay,
        "unfreeze_encoder": False,
        "unfreeze_decoder": True,
        "unfreeze_linear": False
    }
    
    # Initialize DataModule with the suggested batch size
    data_module = SpeechDataModule(ds_train, ds_val, processor, parameters["batch_size"])
    
    # Initialize the model with suggested hyperparameters
    model_train = WhisperLightningModule(model_config["model_name"], processor, parameters["learning_rate"], parameters["weight_decay"], parameters["warmup_steps"])
    model_train.model = WhisperForConditionalGeneration.from_pretrained(model_config["model_name"])
    model_train.model.generation_config.language = "en"
    model_train.model.generation_config.task = "transcribe"
    model_train.model.generation_config.is_multilingual = False
    
    # Freeze or unfreeze layers based on the original configuration
    for param in model_train.model.parameters():
        param.requires_grad = False
        
    # Freeze layers in the decoder
    for param in model_train.model.model.decoder.parameters():
        param.requires_grad = parameters["unfreeze_decoder"]
        
    # Freeze layers in the encoder
    for param in model_train.model.model.encoder.parameters():
        param.requires_grad = parameters["unfreeze_encoder"]
        
    # Freeze layers in the linear layer
    model_train.model.proj_out.weight.requires_grad = parameters["unfreeze_linear"]
    
    
    # Initialize the WandbLogger
    wandb_logger = WandbLogger(
        project="ukw-radio-trans_" + model_config["model_name"].split("/")[-1], 
        name=f"lr_{parameters["learning_rate"]:.6f}_wd_{parameters["weight_decay"]:.6f}",
        log_model=False)
    wandb_logger.log_hyperparams(parameters)
    
    # Initialize the Trainer with WandbLogger
    trainer = pl.Trainer(
        max_epochs=parameters["n_epochs"],
        logger=wandb_logger,
        accelerator="auto",
        log_every_n_steps=5,
        num_sanity_val_steps=5,
    )
    
    # Train the model
    trainer.fit(model_train, data_module)
    val_loss = trainer.callback_metrics["val_loss"].item()
    
    # Finish the WandbLogger run
    wandb.finish()
    
    # Free up memory
    del model_train
    del data_module
    del trainer
    torch.cuda.empty_cache()
    gc.collect()
    
    # Return the validation loss
    return val_loss

# Create a study object and optimize the objective function
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=10)

# Print the best hyperparameters found
print("Best trial:")
best_trial = study.best_trial
print(f"  Value: {best_trial.value}")
print("  Params: ")
for key, value in best_trial.params.items():
    print(f"    {key}: {value}")


[I 2024-05-25 18:40:37,746] A new study created in memory with name: no-name-b60596cc-f1a4-425b-ace5-b9c3686d4a5a
[34m[1mwandb[0m: Currently logged in as: [33mtobias-ettling[0m ([33mtobias-ettling-wandb[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name  | Type                            | Params
----------------------------------------------------------
0 | model | WhisperForConditionalGeneration | 37.8 M
1 | wer   | WordErrorRate                   | 0     
----------------------------------------------------------
9.6 M     Trainable params
28.1 M    Non-trainable params
37.8 M    Total params
151.043   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))