In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import Wav2Vec2Model, HubertModel, WavLMModel
from transformers import BertTokenizer, BertModel
import whisper
from sklearn.linear_model import Ridge
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import numpy as np
import wandb
from collections import Counter
from tqdm import tqdm


In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device


device(type='cuda', index=1)

In [3]:
# --- Load and Unfreeze Whisper‑medium ---
whisper_model = whisper.load_model("base.en").to(device)
for param in whisper_model.parameters():
    param.requires_grad = True
whisper_model.train()

# Load SSL models (Wav2Vec2, HuBERT, WavLM)
wav2vec2_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-960h").to(device).eval()
hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft").to(device).eval()
wavlm_model = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device).eval()

# Load BERT tokenizer and model (BERT remains frozen here).
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device).eval()


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [6]:
# --- Initialize wandb ---
wandb.init(project="somos-ensemble2-ssl", name="finetune-ssl-ensemble")
!wandb online

W&B online. Running your script from this directory will now sync to the cloud.


In [7]:
# --- Utility Functions ---
def load_json(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        return json.load(f)

def process_audio_path(clean_path, base_dir="data/somos/audios"):
    return os.path.join(base_dir, clean_path.replace("\\", "/"))

In [8]:
# --- Dataset Class ---
class SOMOSDataset(Dataset):
    def __init__(self, json_file, base_dir="data/somos/audios"):
        self.samples = load_json(json_file)
        self.base_dir = base_dir
        self.labels = [float(sample["mos"]) for sample in self.samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        text = sample["text"]
        label = torch.tensor(float(sample["mos"]), dtype=torch.float)
        audio_path = process_audio_path(sample["clean path"], self.base_dir)
        return audio_path, text, label

def collate_fn(batch):
    audio_paths, texts, labels = zip(*batch)
    audios = [whisper.load_audio(path) for path in audio_paths]
    audios = [whisper.pad_or_trim(audio) for audio in audios]
    mel_spectrograms = [whisper.log_mel_spectrogram(audio).to(device) for audio in audios]
    mel_spectrograms = torch.stack(mel_spectrograms)

    # Compute audio embeddings with gradients enabled
    audio_embeddings = whisper_model.encoder(mel_spectrograms).mean(dim=1)

    # Process texts using BERT
    inputs = tokenizer(list(texts), return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        text_embeddings = bert_model(**inputs).last_hidden_state[:, 0, :]

    labels = torch.stack(labels).to(device)
    return audio_embeddings, text_embeddings, labels

In [14]:
# --- Weak Learners (Ridge, SVR, Decision Trees) ---
class WeakLearners(nn.Module):
    def __init__(self, audio_dim, text_dim):
        super(WeakLearners, self).__init__()
        self.audio_dim = audio_dim
        self.text_dim = text_dim

        # Initialize models
        self.ridge_regressor = Ridge(alpha=1.0)
        self.svr = SVR()
        self.dtr = DecisionTreeRegressor()

        # Track if models are fitted
        self.fitted = False

    def fit(self, train_loader):
        """ Train weak learners using train dataset embeddings """
        print("Fitting weak learners...")

        all_audio_emb, all_text_emb, all_labels = [], [], []

        for audio_emb, text_emb, labels in train_loader:
            all_audio_emb.append(audio_emb.cpu().detach().numpy())
            all_text_emb.append(text_emb.cpu().detach().numpy())
            all_labels.append(labels.cpu().detach().numpy())

        all_audio_emb = np.vstack(all_audio_emb)
        all_text_emb = np.vstack(all_text_emb)
        all_labels = np.hstack(all_labels)

        combined_embeddings = np.hstack((all_audio_emb, all_text_emb))

        # Fit models
        self.ridge_regressor.fit(combined_embeddings, all_labels)
        self.svr.fit(combined_embeddings, all_labels)
        self.dtr.fit(combined_embeddings, all_labels)

        self.fitted = True
        print("Weak learners training completed.")

    def forward(self, audio_emb, text_emb):
        if not self.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")

        combined_embeddings = torch.cat([audio_emb, text_emb], dim=1).cpu().detach().numpy()

        # Get predictions
        ridge_pred = self.ridge_regressor.predict(combined_embeddings)
        svr_pred = self.svr.predict(combined_embeddings)
        dtr_pred = self.dtr.predict(combined_embeddings)

        # Convert to tensors
        ridge_pred = torch.tensor(ridge_pred, dtype=torch.float).to(device)
        svr_pred = torch.tensor(svr_pred, dtype=torch.float).to(device)
        dtr_pred = torch.tensor(dtr_pred, dtype=torch.float).to(device)

        return ridge_pred, svr_pred, dtr_pred


In [10]:
# --- Stacking Model (Meta-Learner) ---
class StackingMetaLearner(nn.Module):
    def __init__(self, weak_output_dim=3, hidden_dim=256):
        super(StackingMetaLearner, self).__init__()
        self.fc1 = nn.Linear(weak_output_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)  # Output a single value for regression

    def forward(self, weak_outputs):
        x = F.relu(self.fc1(weak_outputs))
        return self.fc2(x)

In [11]:

# --- Main Model ---
class SSLEnsembleModel(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, num_weak_learners=3):
        super(SSLEnsembleModel, self).__init__()
        self.num_weak_learners = num_weak_learners
        self.weak_learners = WeakLearners(audio_dim, text_dim)
        self.stacking_meta_learner = StackingMetaLearner(weak_output_dim=num_weak_learners, hidden_dim=hidden_dim)

    def forward(self, audio_emb, text_emb):
        # Get weak learner predictions
        ridge_pred, svr_pred, dtr_pred = self.weak_learners(audio_emb, text_emb)

        # Stack the predictions from weak learners
        weak_outputs = torch.stack([torch.tensor(ridge_pred).to(device), 
                                    torch.tensor(svr_pred).to(device), 
                                    torch.tensor(dtr_pred).to(device)], dim=1)

        # Meta-learner combines weak learner outputs
        final_output = self.stacking_meta_learner(weak_outputs)
        return final_output

In [15]:
# --- Main Training Loop ---
def main():
    train_json = "data/somos/audios/train_new.json"
    test_json = "data/somos/audios/test_new.json"

    # Load datasets
    train_dataset = SOMOSDataset(train_json)
    test_dataset = SOMOSDataset(test_json)

    # Create DataLoader
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

    # Get input dimensions for audio and text embeddings
    dummy_audio, dummy_text, _ = next(iter(train_loader))
    audio_dim, text_dim = dummy_audio.shape[1], dummy_text.shape[1]

    # Instantiate model
    model = SSLEnsembleModel(audio_dim, text_dim, hidden_dim=256).to(device)

    # Watch the model with wandb for logging gradients and parameters
    wandb.watch(model, log="all", log_freq=100)

    # Set up optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-6)

    # Loss function
    criterion = nn.MSELoss()
    
    weak_learners = WeakLearners(audio_dim, text_dim).to(device)
    weak_learners.fit(train_loader)
    
    # Train the model
    num_epochs = 100
    best_mse = float('inf')

    for epoch in range(num_epochs):
        model.train()
        running_loss, total_samples = 0.0, 0

        # Training loop
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)
        for audio_emb, text_emb, labels in train_pbar:
            optimizer.zero_grad()

            # Forward pass
            outputs = model(audio_emb, text_emb)
            loss = criterion(outputs.squeeze(), labels)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Log the loss
            running_loss += loss.item() * audio_emb.size(0)
            total_samples += labels.size(0)
            wandb.log({"train_loss": loss.item()})
            train_pbar.set_postfix(loss=loss.item())

        # Calculate and log training MSE
        train_mse = running_loss / total_samples
        wandb.log({"train_mse": train_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Train MSE: {train_mse:.4f}")

        # Evaluation phase
        model.eval()
        test_loss, total_samples = 0.0, 0
        with torch.no_grad():
            test_pbar = tqdm(test_loader, desc=f"Epoch {epoch+1} Validation", leave=False)
            for audio_emb, text_emb, labels in test_pbar:
                outputs = model(audio_emb, text_emb)
                loss = criterion(outputs.squeeze(), labels)
                test_loss += loss.item() * audio_emb.size(0)
                total_samples += labels.size(0)
                test_pbar.set_postfix(loss=loss.item())

        # Calculate and log test MSE
        test_mse = test_loss / total_samples
        wandb.log({"val_mse": test_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Val MSE: {test_mse:.4f}")

        # Save the best model
        if test_mse < best_mse:
            best_mse = test_mse
            torch.save(model.state_dict(), "best_model.pth")

    print("Training complete! Best validation MSE:", best_mse)

In [16]:
main()


OutOfMemoryError: CUDA out of memory. Tried to allocate 12.00 MiB. GPU 1 has a total capacity of 8.00 GiB of which 0 bytes is free. Of the allocated memory 7.23 GiB is allocated by PyTorch, and 85.64 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)