from datasets import load_dataset

# Load the NPSC dataset
dataset = load_dataset("NbAiLab/NPSC",'16K_mp3_bokmaal', split="train")

# Save the dataset as Arrow format to disk
dataset.save_to_disk("npsc_dataset_arrow")


In [1]:
import wandb
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mbjarkinord[0m ([33mbjarkinord-none[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from datasets import load_from_disk
import wandb

# Initialize WandB
wandb.init(project="whisper-nb-bert-semantic-training")

# bert and whisper model names
SEMANTIC_MODEL_NAME = 'NbAiLab/nb-bert-base'
WHISPER_MODEL_NAME = 'openai/whisper-large-v3-turbo'

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load Whisper model and processor
whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
whisper_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_NAME).to(device)
whisper_processor.feature_extractor.n_frames = 3000


# Load Norwegian BERT model and tokenizer
semantic_tokenizer = AutoTokenizer.from_pretrained(SEMANTIC_MODEL_NAME)
semantic_model = AutoModel.from_pretrained(SEMANTIC_MODEL_NAME).to(device)
semantic_model.eval()  # Set to evaluation mode

# Freeze semantic model parameters
for param in semantic_model.parameters():
    param.requires_grad = False

# Load dataset from Arrow format
dataset = load_from_disk("npsc_dataset_arrow")

# Custom Dataset class for batch processing
class NSTSpeechDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, processor):
        self.hf_dataset = hf_dataset
        self.processor = processor

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        sample = self.hf_dataset[idx]
        audio = sample["audio"]["array"]  # Audio array
        transcript = sample["text"]  # Ground truth transcription
        return audio, transcript

# Function to collate and pad the audio samples in a batch
def collate_fn(batch):
    # Extract the audio arrays and transcripts from the batch
    audios = [item[0] for item in batch]          # List of audio arrays
    transcripts = [item[1] for item in batch]     # List of transcripts

    # Get the expected sampling rate from the processor
    sampling_rate = 16000

    # Process the audio data with the sampling rate
    inputs = whisper_processor(
        audios,
        sampling_rate=sampling_rate,  # Pass the sampling rate here
        return_tensors="pt",
        padding="max_length"
    )
    audio_inputs = inputs.input_features

    return audio_inputs, transcripts


# Hyperparameters
BATCH_SIZE = 8  # Adjust based on available memory
LEARNING_RATE = 1e-4
EPOCHS = 5
SEMANTIC_LOSS_WEIGHT = 0.5  # Weighting factor λ for semantic loss
GRAD_CLIP = 1.0  # Gradient clipping threshold
TOTAL_STEPS = 100  # Run for 100 steps

# Prepare the dataset and dataloader
nst_dataset = NSTSpeechDataset(dataset, whisper_processor)
data_loader = DataLoader(nst_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# Define optimizer and scheduler
optimizer = torch.optim.AdamW(whisper_model.parameters(), lr=LEARNING_RATE)
total_steps = len(data_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# Define loss functions
token_level_loss_fn = nn.CrossEntropyLoss()

# Define GradScaler for mixed precision training
scaler = torch.cuda.amp.GradScaler()

# Validation function
def evaluate(model, data_loader, semantic_model, processor, tokenizer, device):
    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for audio_inputs, ground_truth_texts in data_loader:
            audio_inputs = audio_inputs.to(device)

            # Tokenize ground truth texts
            ground_truth_tokens = processor.tokenizer(ground_truth_texts, return_tensors="pt", padding=True).input_ids.to(device)

            # Forward pass through Whisper model
            with torch.cuda.amp.autocast():
                outputs = model(input_features=audio_inputs, labels=ground_truth_tokens)
                logits = outputs.logits
                token_loss = outputs.loss

            # Generate predicted texts
            predicted_tokens = torch.argmax(logits, dim=-1)
            predicted_texts = processor.tokenizer.batch_decode(predicted_tokens, skip_special_tokens=True)

            # Compute semantic embeddings
            gt_encoding = tokenizer(ground_truth_texts, return_tensors='pt', padding=True, truncation=True).to(device)
            pred_encoding = tokenizer(predicted_texts, return_tensors='pt', padding=True, truncation=True).to(device)

            gt_embeddings = semantic_model(**gt_encoding).last_hidden_state.mean(dim=1)
            pred_embeddings = semantic_model(**pred_encoding).last_hidden_state.mean(dim=1)

            # Compute cosine similarity for semantic loss
            cosine_similarity = nn.functional.cosine_similarity(gt_embeddings, pred_embeddings, dim=-1)
            semantic_loss = 1 - cosine_similarity.mean()

            total_loss = token_loss + SEMANTIC_LOSS_WEIGHT * semantic_loss
            total_val_loss += total_loss.item()

    return total_val_loss / len(data_loader)

# Training Loop
step = 0  # Initialize step counter

for epoch in range(EPOCHS):
    whisper_model.train()
    total_loss = 0.0

    for batch_idx, (audio_inputs, ground_truth_texts) in enumerate(data_loader):
        audio_inputs = audio_inputs.to(device)
        step += 1

        # Tokenize ground truth texts
        ground_truth_tokens = whisper_processor.tokenizer(
            ground_truth_texts, return_tensors="pt", padding=True
        ).input_ids.to(device)

        optimizer.zero_grad()

        # Forward pass through Whisper model
        with torch.cuda.amp.autocast():
            outputs = whisper_model(input_features=audio_inputs, labels=ground_truth_tokens)
            logits = outputs.logits
            token_loss = outputs.loss

        # Generate predicted texts
        predicted_tokens = torch.argmax(logits, dim=-1)
        predicted_texts = whisper_processor.tokenizer.batch_decode(predicted_tokens, skip_special_tokens=True) # Decode predicted tokens

        # Compute semantic embeddings
        gt_encoding = semantic_tokenizer(ground_truth_texts, return_tensors='pt', padding=True, truncation=True).to(device) # Tokenize ground truth texts
        pred_encoding = semantic_tokenizer(predicted_texts, return_tensors='pt', padding=True, truncation=True).to(device) # Tokenize predicted texts

        gt_embeddings = semantic_model(**gt_encoding).last_hidden_state.mean(dim=1) # Compute embeddings for ground truth texts
        pred_embeddings = semantic_model(**pred_encoding).last_hidden_state.mean(dim=1) # Compute embeddings for predicted texts

        # Compute cosine similarity for semantic loss and calculate loss 
        cosine_similarity = nn.functional.cosine_similarity(gt_embeddings, pred_embeddings, dim=-1)
        semantic_loss = 1 - cosine_similarity.mean()

        # Combine losses
        total_batch_loss = token_loss + SEMANTIC_LOSS_WEIGHT * semantic_loss

        # Backpropagation with mixed precision
        scaler.scale(total_batch_loss).backward()

        # Gradient Clipping
        torch.nn.utils.clip_grad_norm_(whisper_model.parameters(), GRAD_CLIP)

        # Update parameters
        scaler.step(optimizer)
        scaler.update()

        # Scheduler step (after optimizer.step())
        scheduler.step()

        total_loss += total_batch_loss.item()

        # Logging to WandB
        if batch_idx % 10 == 0:
            current_lr = scheduler.get_last_lr()[0]
            wandb.log({
                "token_loss": token_loss.item(),
                "semantic_loss": semantic_loss.item(),
                "combined_loss": total_batch_loss.item(),
                "learning_rate": current_lr,
                "step": step,
                "epoch": epoch
            })
            print(f"Epoch [{epoch+1}/{EPOCHS}], Batch [{batch_idx}/{len(data_loader)}], Loss: {total_batch_loss.item():.4f}")
        
        # implement best model saving
        if total_batch_loss < best_loss:
            best_loss = total_batch_loss
            torch.save(whisper_model.state_dict(), 'best_model.pth')
            torch.save(whisper_processor, 'best_processor.pth')

        # early stopping
        if total_batch_loss > best_loss:
            patience += 1
        else:
            patience = 0



        # Check if total steps reached
        if step >= TOTAL_STEPS:
            break  # Exit the inner loop

    if step >= TOTAL_STEPS:
        break  # Exit the outer loop
    average_loss = total_loss / len(data_loader)

    # Validation step
    val_loss = evaluate(whisper_model, data_loader, semantic_model, whisper_processor, semantic_tokenizer, device)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Training Loss: {average_loss:.4f}, Validation Loss: {val_loss:.4f}")

    # Log validation loss to WandB
    wandb.log({
        "train_loss": average_loss,
        "val_loss": val_loss,
    })

# Save the fine-tuned model
whisper_model.save_pretrained('fine-tuned-whisper-model')
whisper_processor.save_pretrained('fine-tuned-whisper-processor')
wandb.finish()


  from .autonotebook import tqdm as notebook_tqdm


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  attn_output = torch.nn.functional.scaled_dot_product_attention(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Epoch [1/5], Batch [0/11187], Loss: 4.0318
Epoch [1/5], Batch [10/11187], Loss: 0.9202
Epoch [1/5], Batch [20/11187], Loss: 0.8647
Epoch [1/5], Batch [30/11187], Loss: 0.9394
Epoch [1/5], Batch [40/11187], Loss: 0.8588
Epoch [1/5], Batch [50/11187], Loss: 1.0080
Epoch [1/5], Batch [60/11187], Loss: 0.9931
Epoch [1/5], Batch [70/11187], Loss: 1.2033
Epoch [1/5], Batch [80/11187], Loss: 1.2285
Epoch [1/5], Batch [90/11187], Loss: 1.0457


0,1
combined_loss,█▁▁▁▁▁▁▂▂▁
learning_rate,█▇▆▆▅▄▃▃▂▁
semantic_loss,▅▂▁▆▂█▆▄█▇
token_loss,█▁▁▁▁▁▁▂▂▁

0,1
combined_loss,1.04566
learning_rate,0.0001
semantic_loss,0.42539
token_loss,0.83296
