<a href="https://colab.research.google.com/github/Netdrum/MARIA/blob/main/20241119_fine_tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --upgrade pip
!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio

In [None]:
import torch
from transformers import WhisperFeatureExtractor, WhisperProcessor, WhisperTokenizer, WhisperForConditionalGeneration, Trainer, TrainingArguments
from datasets import load_dataset, DatasetDict
import torchaudio
from torchaudio.transforms import MelSpectrogram
from evaluate import load
import numpy as np

# Load WER metric
wer_metric = load("wer")

# Define global constants
MAX_AUDIO_FRAMES = 3000  # Max length for spectrograms (adjust based on GPU memory)

# Load the dataset
dataset = load_dataset("Netdrum/IRCG_VHF")  # Replace with your dataset name

# Split dataset into train/test if not already split
if "train" not in dataset or "test" not in dataset:
    dataset = dataset.train_test_split(test_size=0.2)

# Initialize Whisper components
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="English", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

# Preprocessing function
def preprocess_data(batch):
    try:
        audio = batch["audio"]
        waveform = torch.tensor(audio["array"], dtype=torch.float32)
        sampling_rate = audio["sampling_rate"]

        # Resample if necessary
        if sampling_rate != 16000:
            resampler = torchaudio.transforms.Resample(sampling_rate, 16000)
            waveform = resampler(waveform)

        # Convert audio to log-Mel spectrogram
        spectrogram_transform = MelSpectrogram(sample_rate=16000, n_mels=80)
        spectrogram = spectrogram_transform(waveform).numpy()

        # Truncate or pad spectrogram to MAX_AUDIO_FRAMES
        if spectrogram.shape[1] > MAX_AUDIO_FRAMES:
            spectrogram = spectrogram[:, :MAX_AUDIO_FRAMES]
        else:
            pad_width = MAX_AUDIO_FRAMES - spectrogram.shape[1]
            spectrogram = np.pad(spectrogram, ((0, 0), (0, pad_width)), mode="constant")

        # Tokenize text
        labels = tokenizer(batch["transcription"], return_tensors="pt", padding="max_length", truncation=True, max_length=256).input_ids

        return {
            "input_features": spectrogram,
            "labels": labels.squeeze(0).tolist(),  # Convert to list for easier processing
        }
    except Exception as e:
        print(f"Error processing entry: {e}")
        return None

# Apply preprocessing and filter invalid samples
dataset = dataset.map(preprocess_data, remove_columns=["audio", "transcription"], batched=False)
dataset = dataset.filter(lambda x: x is not None)

# Define custom data collator
class WhisperDataCollator:
    def __call__(self, features):
        # Extract input features
        input_features = torch.stack([torch.tensor(f["input_features"], dtype=torch.float32) for f in features])

        # Extract and pad labels
        max_label_length = max(len(f["labels"]) for f in features)
        padded_labels = torch.stack([
            torch.nn.functional.pad(
                torch.tensor(f["labels"], dtype=torch.long),
                (0, max_label_length - len(f["labels"])),
                value=-100  # Padding value for cross-entropy loss
            )
            for f in features
        ])

        return {"input_features": input_features, "labels": padded_labels}

# Split dataset into train/test
train_dataset = dataset["train"]
test_dataset = dataset["test"]

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=5e-3,#changed down from 5
    per_device_train_batch_size=16,#changed this from 8 to 16
    per_device_eval_batch_size=4,# changed this from 8 to 4
    num_train_epochs=5,#changed this to 5 but model still only running 3
    save_steps=500,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=50,
    remove_unused_columns=False,
    fp16=True,  # Use mixed precision training if supported by GPU
)

training_args = TrainingArguments(
    logging_steps=10,  # Log training loss every 10 steps
    evaluation_strategy="epoch",
    output_dir="./results",
)


# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=WhisperDataCollator(),
    tokenizer=processor,
)

# Train the model
trainer.train()

# Evaluate the model
def evaluate_model():
    model.eval()  # Set model to evaluation mode
    predictions = []
    references = []

    for batch in test_dataset:
        # Convert input features to tensor and ensure type matches model (float16 if fp16 training)
        inputs = torch.tensor(batch["input_features"], dtype=torch.float32).unsqueeze(0)  # Add batch dimension
        if training_args.fp16:
            inputs = inputs.half()

        with torch.no_grad():
            outputs = model.generate(inputs)

        decoded_preds = processor.batch_decode(outputs, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode([batch["labels"]], skip_special_tokens=True)

        predictions.extend(decoded_preds)
        references.extend(decoded_labels)

    wer = wer_metric.compute(predictions=predictions, references=references)
    print(f"Word Error Rate (WER): {wer}")
