In [1]:
import os
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.linear_model import Ridge
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from transformers import BertTokenizer, BertModel
import whisper
from tqdm import tqdm
import wandb
from sklearn.metrics import mean_squared_error, mean_absolute_error


In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device


device(type='cuda', index=1)

In [3]:
# --- Load and Unfreeze Whisper‑medium and BERT ---
whisper_model = whisper.load_model("base.en").to(device)
for param in whisper_model.parameters():
    param.requires_grad = True
whisper_model.train()

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device).eval()

In [4]:
# --- Initialize wandb ---
wandb.init(project="somos-ensemble2-ssl-sbs", name="finetune-whisper_b+bert+sbs")
!wandb online

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: rtfiof (rtfiof-hse-university). Use `wandb login --relogin` to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01128888888957186, max=1.0)…

W&B online. Running your script from this directory will now sync to the cloud.


In [5]:
# --- SBS Dataset with Optional Subsetting ---
class SBSDataset(Dataset):
    def __init__(self, csv_file, base_dir, subset=False, is_test=False):
        """
        csv_file: Path to train or test CSV file.
        base_dir: Base directory for audio files.
        subset: If True, only 0.1% of the data is used.
        """
        self.df = pd.read_csv(csv_file)

        if subset:
            self.df = self.df.sample(frac=0.05, random_state=42).reset_index(drop=True)
            # self.df = self.df.sample(frac=0.001, random_state=42).reset_index(drop=True)

        self.base_dir = base_dir
        self.is_test = is_test

    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        pair = row["utterance_pairs"]
        audio1_name, audio2_name = pair.split("+")
        audio1_path = os.path.join(self.base_dir, audio1_name)
        audio2_path = os.path.join(self.base_dir, audio2_name)
        
        text_column = "right_text" if not self.is_test else "text"
        text = row[text_column]
        
        sbs1 = float(row["SBS_1"])
        sbs2 = float(row["SBS_2"])
        return audio1_path, audio2_path, text, sbs1, sbs2

In [6]:
# --- Collate Function for SBS ---
def collate_fn_sbs(batch):
    audio1_paths, audio2_paths, texts, sbs1_list, sbs2_list = zip(*batch)
    
    # Process first audio of each pair.
    audios1 = [whisper.load_audio(path) for path in audio1_paths]
    audios1 = [whisper.pad_or_trim(audio) for audio in audios1]
    mels1 = [whisper.log_mel_spectrogram(audio).to(device) for audio in audios1]
    mels1 = torch.stack(mels1)
    # Get audio embeddings (mean-pooled over time).
    audio1_emb = whisper_model.encoder(mels1).mean(dim=1)
    
    # Process second audio of each pair.
    audios2 = [whisper.load_audio(path) for path in audio2_paths]
    audios2 = [whisper.pad_or_trim(audio) for audio in audios2]
    mels2 = [whisper.log_mel_spectrogram(audio).to(device) for audio in audios2]
    mels2 = torch.stack(mels2)
    audio2_emb = whisper_model.encoder(mels2).mean(dim=1)
    
    # Process the text once per pair.
    inputs = tokenizer(list(texts), return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        text_emb = bert_model(**inputs).last_hidden_state[:, 0, :]
        
    sbs1_tensor = torch.tensor(sbs1_list, dtype=torch.float).to(device)
    sbs2_tensor = torch.tensor(sbs2_list, dtype=torch.float).to(device)
    
    return audio1_emb, audio2_emb, text_emb, sbs1_tensor, sbs2_tensor


In [7]:
# --- Weak Learners (same as before) ---
class WeakLearners(nn.Module):
    def __init__(self, audio_dim, text_dim, device="cuda"):
        super(WeakLearners, self).__init__()
        self.audio_dim = audio_dim
        self.text_dim = text_dim
        self.device = device

        self.ridge_regressor = Ridge(alpha=1.0)
        self.svr = SVR()
        self.dtr = DecisionTreeRegressor()

        self.fitted = False

    def fit(self, train_loader):
        print("Fitting weak learners on SBS data...")
        all_audio_emb, all_text_emb, all_labels = [], [], []
        # For each pair in the batch, treat the first and second audio separately.
        for audio1_emb, audio2_emb, text_emb, sbs1, sbs2 in tqdm(train_loader, desc="Extracting embeddings", unit="batch"):
            # Convert to numpy arrays.
            audio1_np = audio1_emb.cpu().detach().numpy()
            audio2_np = audio2_emb.cpu().detach().numpy()
            text_np = text_emb.cpu().detach().numpy()
            sbs1_np = sbs1.cpu().detach().numpy()
            sbs2_np = sbs2.cpu().detach().numpy()
            
            # Append first audio example.
            all_audio_emb.append(audio1_np)
            all_text_emb.append(text_np)
            all_labels.append(sbs1_np)
            
            # Append second audio example.
            all_audio_emb.append(audio2_np)
            all_text_emb.append(text_np)
            all_labels.append(sbs2_np)
        
        all_audio_emb = np.vstack(all_audio_emb)
        all_text_emb = np.vstack(all_text_emb)
        all_labels = np.hstack(all_labels)
        
        # Combine audio and text embeddings.
        combined_embeddings = np.hstack((all_audio_emb, all_text_emb))
        
        # Train each weak learner.
        for model, name in zip([self.ridge_regressor, self.svr, self.dtr],
                               ["Ridge Regression", "SVR", "Decision Tree"]):
            print(f"Training {name}...")
            model.fit(combined_embeddings, all_labels)
        self.fitted = True
        print("Weak learners training completed.")

    def forward(self, audio_emb, text_emb):
        if not self.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")
        # Concatenate audio and text embeddings.
        combined = torch.cat([audio_emb, text_emb], dim=1).cpu().detach().numpy()
        with torch.no_grad():
            ridge_pred = self.ridge_regressor.predict(combined)
            svr_pred = self.svr.predict(combined)
            dtr_pred = self.dtr.predict(combined)
        # Convert predictions to tensors.
        ridge_pred = torch.from_numpy(ridge_pred).float().to(self.device)
        svr_pred = torch.from_numpy(svr_pred).float().to(self.device)
        dtr_pred = torch.from_numpy(dtr_pred).float().to(self.device)
        return ridge_pred, svr_pred, dtr_pred


In [8]:
# --- Stacking Meta-Learner ---
class StackingMetaLearner(nn.Module):
    def __init__(self, weak_output_dim=3, hidden_dim=256):
        super(StackingMetaLearner, self).__init__()
        self.fc1 = nn.Linear(weak_output_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, weak_outputs):
        x = F.relu(self.fc1(weak_outputs))
        return self.fc2(x)

In [9]:
# --- SSLEnsembleModel (Ensemble using weak learners and meta-learner) ---
class SSLEnsembleModel(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim, weak_learners):
        super(SSLEnsembleModel, self).__init__()
        if weak_learners is None:
            raise ValueError("Weak learners must be provided and fitted before initializing SSLEnsembleModel.")
        self.weak_learners = weak_learners
        self.stacking_meta_learner = StackingMetaLearner(weak_output_dim=3, hidden_dim=hidden_dim)

    def forward(self, audio_emb, text_emb):
        if not self.weak_learners.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")
        # Get predictions from the weak learners.
        ridge_pred, svr_pred, dtr_pred = self.weak_learners(audio_emb, text_emb)
        # Stack the predictions into one tensor.
        weak_outputs = torch.stack([ridge_pred, svr_pred, dtr_pred], dim=1)
        # Meta-learner produces the final output.
        final_output = self.stacking_meta_learner(weak_outputs)
        return final_output


In [10]:
# --- Pairwise Ranking Loss Function ---
def ranking_loss(pred1, pred2, sbs1, sbs2, margin=1.0):
    # target is 1 if sbs1 > sbs2, and -1 otherwise.
    target = torch.sign(sbs1 - sbs2)
    # The difference between predictions should reflect the sign of the target.
    diff = pred1 - pred2
    # Hinge loss: if the difference is less than the margin in the correct direction, incur a loss.
    loss = torch.mean(torch.clamp(margin - diff * target, min=0))
    return loss

In [11]:
# --- Modified Training Function ---
def train_meta_learner(train_loader, test_loader, ensemble_model, optimizer, mse_criterion, epochs=20, eval_interval=15000, ranking_margin=1.0):
    ensemble_model.train()
    
    for epoch in range(epochs):
        total_loss = 0.0
        total_mse = 0.0
        total_rank_loss = 0.0
        batch_count = 0

        for batch_idx, (audio1_emb, audio2_emb, text_emb, sbs1, sbs2) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
            optimizer.zero_grad()

            # Forward passes for both audios.
            pred1 = ensemble_model(audio1_emb, text_emb).squeeze()
            pred2 = ensemble_model(audio2_emb, text_emb).squeeze()

            # Compute standard MSE losses for each prediction.
            mse_loss1 = mse_criterion(pred1, sbs1)
            mse_loss2 = mse_criterion(pred2, sbs2)
            mse_loss = mse_loss1 + mse_loss2

            # Compute pairwise ranking loss to enforce ordering.
            rnk_loss = ranking_loss(pred1, pred2, sbs1, sbs2, margin=ranking_margin)
            
            # Combine losses (you can adjust the weights as needed).
            loss = mse_loss + rnk_loss
            total_loss += loss.item()
            total_mse += mse_loss.item()
            total_rank_loss += rnk_loss.item()

            # Backpropagation.
            loss.backward()
            optimizer.step()

            batch_count += 1

            # Evaluate periodically.
            if (batch_idx + 1) % eval_interval == 0:
                print(f"Evaluating at batch {batch_idx+1}...")
                evaluate(test_loader, ensemble_model, mse_criterion)

        avg_loss = total_loss / batch_count
        avg_mse = total_mse / batch_count
        avg_rank_loss = total_rank_loss / batch_count

        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, MSE={avg_mse:.4f}, RankingLoss={avg_rank_loss:.4f}")

        wandb.log({"epoch_loss": avg_loss, "epoch_mse": avg_mse, "epoch_ranking_loss": avg_rank_loss})
        evaluate(test_loader, ensemble_model, mse_criterion)


# --- Evaluation Function ---
def evaluate(test_loader, ensemble_model, criterion):
    ensemble_model.eval()
    total_loss = 0.0
    total_mse = 0.0
    total_mae = 0.0
    correct_order = 0
    total_samples = 0
    batch_count = 0

    with torch.no_grad():
        for audio1_emb, audio2_emb, text_emb, sbs1, sbs2 in tqdm(test_loader, desc="Evaluating"):
            pred1 = ensemble_model(audio1_emb, text_emb).squeeze()
            pred2 = ensemble_model(audio2_emb, text_emb).squeeze()

            loss1 = criterion(pred1, sbs1)
            loss2 = criterion(pred2, sbs2)
            total_loss += (loss1.item() + loss2.item())
            batch_count += 1

            # Convert to CPU for logging
            pred1_cpu = pred1.cpu().numpy()
            pred2_cpu = pred2.cpu().numpy()
            sbs1_cpu = sbs1.cpu().numpy()
            sbs2_cpu = sbs2.cpu().numpy()

            total_mse += (mean_squared_error(sbs1_cpu, pred1_cpu) + mean_squared_error(sbs2_cpu, pred2_cpu))
            total_mae += (mean_absolute_error(sbs1_cpu, pred1_cpu) + mean_absolute_error(sbs2_cpu, pred2_cpu))

            # Compute Ranking Accuracy (Did the model preserve the SBS order?)
            correct_order += np.sum((sbs1_cpu > sbs2_cpu) == (pred1_cpu > pred2_cpu))
            total_samples += len(sbs1_cpu)

    avg_loss = total_loss / (2 * batch_count)
    avg_mse = total_mse / (2 * batch_count)
    avg_mae = total_mae / (2 * batch_count)
    accuracy = correct_order / total_samples if total_samples > 0 else 0

    print(f"Test Loss: {avg_loss}, Test MSE: {avg_mse}, Test MAE: {avg_mae}, Ranking Accuracy: {accuracy:.4f}")
    wandb.log({"test_loss": avg_loss, "test_mse": avg_mse, "test_mae": avg_mae, "test_ranking_accuracy": accuracy})

    return avg_loss, accuracy


In [None]:
# --- Main Script ---
if __name__ == "__main__":
    # Define the base directory to prepend to the audio filenames.
    base_audio_dir = "archive/update_SOMOS_v2/update_SOMOS_v2/all_audios/all_wavs"

    subset = True
    # subset = False
    
    # Create datasets and dataloaders for training and testing.
    train_dataset = SBSDataset("archive/train_same_pairs_text.csv", base_dir=base_audio_dir, subset=subset, is_test=False)
    test_dataset = SBSDataset("archive/test_same_pairs_text.csv", base_dir=base_audio_dir, subset=subset, is_test=True)
    
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn_sbs)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn_sbs)
    
    # Initialize and fit weak learners.
    weak_learners = WeakLearners(audio_dim=512, text_dim=768, device=device)
    weak_learners.fit(train_loader)
    
    # Initialize the ensemble model (which uses the fitted weak learners).
    ensemble_model = SSLEnsembleModel(audio_dim=512, text_dim=768, hidden_dim=256, weak_learners=weak_learners).to(device)
    
    # Train the stacking meta-learner.
    optimizer = torch.optim.Adam(ensemble_model.stacking_meta_learner.parameters(), lr=1e-5)
    criterion = nn.MSELoss()
    train_meta_learner(train_loader, test_loader, ensemble_model, optimizer, criterion, epochs=20)

    
    # Evaluate on the test set.
    evaluate(test_loader, ensemble_model, criterion)

  self.df = pd.read_csv(csv_file)


Fitting weak learners on SBS data...


Extracting embeddings: 100%|████████████████████████████████████████████████████| 5516/5516 [44:26<00:00,  2.07batch/s]


Training Ridge Regression...
Training SVR...
Training Decision Tree...
Weak learners training completed.


Epoch 1: 100%|█████████████████████████████████████████████████████████████████████| 5516/5516 [51:06<00:00,  1.80it/s]


Epoch 1: Loss=1.2359, MSE=0.2884, RankingLoss=0.9475


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 279/279 [02:31<00:00,  1.84it/s]


Test Loss: 0.01966735485696658, Test MSE: 0.019667354864685873, Test MAE: 0.11057937748756887, Ranking Accuracy: 0.6457


Epoch 2: 100%|█████████████████████████████████████████████████████████████████████| 5516/5516 [51:28<00:00,  1.79it/s]


Epoch 2: Loss=0.7901, MSE=0.0147, RankingLoss=0.7754


Evaluating: 100%|████████████████████████████████████████████████████████████████████| 279/279 [02:35<00:00,  1.80it/s]


Test Loss: 0.06107231413784494, Test MSE: 0.06107231390939551, Test MAE: 0.19970894189664967, Ranking Accuracy: 0.6386


Epoch 3:  32%|█████████████████████▋                                               | 1738/5516 [16:33<35:58,  1.75it/s]
