In [0]:
%pip install torch torchaudio transformers datasets librosa

In [0]:
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor, WhisperConfig
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset

import librosa

In [0]:
def load_whisper_model(
    whisper_model_name = "openai/whisper-small",
) -> tuple[WhisperForConditionalGeneration, WhisperProcessor]:
    # Load pre-trained Whisper model
    model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
    processor = WhisperProcessor.from_pretrained(whisper_model_name)

    # Freeze the base Whisper model to retain pre-trained knowledge
    for param in model.parameters():
        param.requires_grad = False

    return (model, processor)

In [0]:
class WhisperWithClassification(nn.Module):
    def __init__(self, whisper_model, num_classes=2):
        super().__init__()
        self.whisper = whisper_model
        self.fc = nn.Linear(whisper_model.config.d_model, num_classes)  # Classification head

    def forward(self, input_features):
        # Pass through Whisper encoder
        encoder_outputs = self.whisper.model.encoder(input_features).last_hidden_state

        # Take the mean of encoder hidden states (global average pooling)
        pooled_output = encoder_outputs.mean(dim=1)

        # Pass through classification head
        logits = self.fc(pooled_output)
        return logits

In [0]:
def preprocess_audio(file_path, processor):
    audio, sr = librosa.load(file_path, sr=16000)  # Whisper expects 16kHz
    input_features = processor(audio, sampling_rate=sr, return_tensors="pt").input_features
    return input_features

In [0]:
def train(model, dataloader, criterion, optimizer, device, epochs=5):
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0.0
        for batch in dataloader:
            audio_inputs, labels = batch  # Load batch data
            
            # Move to GPU
            audio_inputs, labels = audio_inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(audio_inputs)
            
            # Compute loss
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}")

In [0]:
def predict(model, audio_path, processor, device):
    model.eval()
    
    with torch.no_grad():
        input_features = preprocess_audio(audio_path, processor).to(device)
        logits = model(input_features)
        prediction = torch.argmax(logits, dim=1).item()
        
        return "Spam" if prediction == 1 else "Ham"

In [0]:
if __name__ == "__main__":
    model, processor = load_whisper_model(whisper_model_name="openai/whisper-small")

    # Wrap Whisper model with classification head
    num_classes = 2  # Spam vs Ham
    model_with_classifier = WhisperWithClassification(model, num_classes)

    # Move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_with_classifier.to(device)

    # Example audio file path
    audio_path = "example.wav"
    input_features = preprocess_audio(audio_path, processor)

    # Move input to the correct device
    input_features = input_features.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model_with_classifier.parameters(), lr=1e-4)

    ############################
    ## Model Training
    ############################
    dataloader = DataLoader(None, batch_size=8, shuffle=True)
    train(model_with_classifier, dataloader, criterion, optimizer, device)

    ############################
    ## Inference
    ############################
    # Test on new audio
    test_audio = "new_call.wav"
    print(f"Prediction: {predict(model_with_classifier, test_audio, processor, device)}")