In [2]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import Wav2Vec2Model, HubertModel, WavLMModel
from transformers import BertTokenizer, BertModel
from sklearn.linear_model import Ridge
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import numpy as np
import wandb
from collections import Counter
from tqdm import tqdm
import torch
import torchaudio
import torchaudio.transforms as T
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer
from whisper import load_model
import whisper
import pandas as pd
from sklearn.model_selection import train_test_split



In [3]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device


device(type='cuda', index=1)

In [4]:
# --- Load and Unfreeze Whisper‑medium ---
whisper_model = whisper.load_model("base.en").to(device)
for param in whisper_model.parameters():
    param.requires_grad = True
whisper_model.train()

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device).eval()

In [5]:
# --- Initialize wandb ---
wandb.init(project="somos-ensemble2-ssl", name="finetune-whisper_b+bert+norm")
!wandb online

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: rtfiof (rtfiof-hse-university). Use `wandb login --relogin` to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

W&B online. Running your script from this directory will now sync to the cloud.


In [6]:
# --- Utility Functions ---
def load_json(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        return json.load(f)

def load_mos_values(csv_path):
    """Load MOS values from the new normalized dataset."""
    df = pd.read_csv(csv_path)
    mos_dict = dict(zip(df.iloc[:, 0], df["new_scale"]))  # Mapping: {audio_id: new_mos}
    return mos_dict

def load_transcripts(transcript_path):
    """Load transcripts into a dictionary for quick lookup."""
    transcript_dict = {}
    with open(transcript_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) == 2:
                audio_id, text = parts
                transcript_dict[audio_id] = text
    return transcript_dict

def process_audio_path(audio_id, base_dir="archive/update_SOMOS_v2/update_SOMOS_v2/all_audios/all_wavs"):
    """Construct the full path to the audio file."""
    return os.path.join(base_dir, f"{audio_id}.wav")


In [7]:
# --- Dataset Class ---
class SOMOSDataset(Dataset):
    def __init__(self, csv_file, transcript_file, base_dir="archive/update_SOMOS_v2/update_SOMOS_v2/all_audios/all_wavs", split="train", test_size=0.2, seed=42):
        self.df = pd.read_csv(csv_file)
        self.transcripts = self.load_transcripts(transcript_file)
        self.base_dir = base_dir

        # Use the new MOS scale
        self.df["mos"] = self.df["new_scale"]

        # Split data into train and validation sets
        train_df, val_df = train_test_split(self.df, test_size=test_size, random_state=seed)
        self.df = train_df if split == "train" else val_df

    def load_transcripts(self, transcript_file):
        transcripts = {}
        with open(transcript_file, "r", encoding="utf-8") as f:
            for line in f:
                parts = line.strip().split("\t")
                if len(parts) == 2:
                    transcripts[parts[0]] = parts[1]
        return transcripts

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_name = row.iloc[0]  # First column is the file identifier
        mos = torch.tensor(row["mos"], dtype=torch.float)

        # Get text from transcript
        text = self.transcripts.get(file_name, "")

        # Load audio path
        audio_path = os.path.join(self.base_dir, f"{file_name}.wav")

        return audio_path, text, mos




class AttentionPooling(torch.nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.attention = torch.nn.Linear(embed_dim, 1)

    def forward(self, x):
        weights = torch.nn.functional.softmax(self.attention(x), dim=1)  
        return (weights * x).sum(dim=1)  


attn_pool = AttentionPooling(embed_dim=512).to(device)  # Whisper Base uses 512-dim embeddings



def collate_fn(batch):
    audio_paths, texts, labels = zip(*batch)
    audios = [whisper.load_audio(path) for path in audio_paths]
    audios = [whisper.pad_or_trim(audio) for audio in audios]
    mel_spectrograms = [whisper.log_mel_spectrogram(audio).to(device) for audio in audios]
    mel_spectrograms = torch.stack(mel_spectrograms)

    # Compute audio embeddings with gradients enabled
    audio_embeddings = whisper_model.encoder(mel_spectrograms).mean(dim=1)

    # Process texts using BERT
    inputs = tokenizer(list(texts), return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        text_embeddings = bert_model(**inputs).last_hidden_state[:, 0, :]

    labels = torch.stack(labels).to(device)
    return audio_embeddings, text_embeddings, labels



In [8]:
class WeakLearners(nn.Module):
    def __init__(self, audio_dim, text_dim, device="cuda:1"):
        super(WeakLearners, self).__init__()
        self.audio_dim = audio_dim
        self.text_dim = text_dim
        self.device = device

        self.ridge_regressor = Ridge(alpha=1.0)
        self.svr = SVR()
        self.dtr = DecisionTreeRegressor()

        self.fitted = False

    def fit(self, train_loader):
        """ Train weak learners using train dataset embeddings """
        print("Fitting weak learners...")

        all_audio_emb, all_text_emb, all_labels = [], [], []

        for audio_emb, text_emb, labels in tqdm(train_loader, desc="Processing embeddings", unit="batch"):
            all_audio_emb.append(audio_emb.cpu().detach().numpy())
            all_text_emb.append(text_emb.cpu().detach().numpy())
            all_labels.append(labels.cpu().detach().numpy())

        if not all_audio_emb or not all_text_emb or not all_labels:
            raise RuntimeError("No embeddings found in the dataset! Check if the train_loader is correctly loading data.")

        all_audio_emb = np.vstack(all_audio_emb)
        all_text_emb = np.vstack(all_text_emb)
        all_labels = np.hstack(all_labels)

        combined_embeddings = np.hstack((all_audio_emb, all_text_emb))

        print("Training weak learners...")
        for model, name in zip([self.ridge_regressor, self.svr, self.dtr], 
                               ["Ridge Regression", "SVR", "Decision Tree"]):
            with tqdm(total=1, desc=f"Training {name}", unit="step") as pbar:
                model.fit(combined_embeddings, all_labels)
                pbar.update(1)

        self.fitted = True
        print("Weak learners training completed.")

    def forward(self, audio_emb, text_emb):
        if not self.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")

        combined_embeddings = torch.cat([audio_emb, text_emb], dim=1).cpu().detach().numpy()

        with torch.no_grad():
            ridge_pred = self.ridge_regressor.predict(combined_embeddings)
            svr_pred = self.svr.predict(combined_embeddings)
            dtr_pred = self.dtr.predict(combined_embeddings)

        ridge_pred = torch.from_numpy(ridge_pred).float().to(self.device)
        svr_pred = torch.from_numpy(svr_pred).float().to(self.device)
        dtr_pred = torch.from_numpy(dtr_pred).float().to(self.device)

        return ridge_pred, svr_pred, dtr_pred


In [9]:
# --- Stacking Model (Meta-Learner) ---
class StackingMetaLearner(nn.Module):
    def __init__(self, weak_output_dim=3, hidden_dim=256):
        super(StackingMetaLearner, self).__init__()
        self.fc1 = nn.Linear(weak_output_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, weak_outputs):
        x = F.relu(self.fc1(weak_outputs))
        return self.fc2(x)

In [10]:
# --- Main Model ---
class SSLEnsembleModel(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, weak_learners=None):
        super(SSLEnsembleModel, self).__init__()
        if weak_learners is None:
            raise ValueError("Weak learners must be provided and fitted before initializing SSLEnsembleModel.")
        
        self.weak_learners = weak_learners
        self.stacking_meta_learner = StackingMetaLearner(weak_output_dim=3, hidden_dim=hidden_dim)

    def forward(self, audio_emb, text_emb):
        if not self.weak_learners.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")
        
        ridge_pred, svr_pred, dtr_pred = self.weak_learners(audio_emb, text_emb)

        weak_outputs = torch.stack([ridge_pred, svr_pred, dtr_pred], dim=1)

        final_output = self.stacking_meta_learner(weak_outputs)
        return final_output


In [11]:
import numpy as np
from sklearn.metrics import mean_squared_error
from scipy.stats import kendalltau

def evaluate(model, test_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        test_pbar = tqdm(test_loader, desc="Evaluation", leave=False)
        for audio_emb, text_emb, labels in test_pbar:
            audio_emb, text_emb, labels = audio_emb.to(device), text_emb.to(device), labels.to(device)
            outputs = model(audio_emb, text_emb)
            preds = outputs.squeeze()
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            test_pbar.set_postfix({
                "predicted": [f"{p:.2f}" for p in preds[:5].cpu().numpy()],
                "ground_truth": [f"{l:.2f}" for l in labels[:5].cpu().numpy()]
            })

    # Convert lists to numpy arrays for easier calculation
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Accuracy (up to ±0.5)
    accuracy = np.mean(np.abs(all_preds - all_labels) <= 0.5)

    # MSE and RMSE
    mse = mean_squared_error(all_labels, all_preds)
    rmse = np.sqrt(mse)

    # LCC (Linear Correlation Coefficient)
    lcc = np.corrcoef(all_labels, all_preds)[0, 1]

    # KTAU (Kendall's Tau)
    k_tau, _ = kendalltau(all_labels, all_preds)

    # Print metrics
    print(f"Accuracy (±0.5): {accuracy*100:.2f}%")
    print(f"MSE: {mse:.4f}")
    print(f"RMSE: {rmse:.4f}")
    print(f"LCC: {lcc:.4f}")
    print(f"KTAU: {k_tau:.4f}")

    # Show 5 examples of predicted and ground truth MOS
    print("\n5 Examples of Predicted and Ground Truth MOS:")
    for i in range(5):
        print(f"Pred: {all_preds[i]:.2f}, GT: {all_labels[i]:.2f}")

    return mse, rmse, lcc, k_tau, accuracy


In [12]:
# --- Main Training Loop ---
def main():
    train_csv = "archive/normalised_somos.csv"
    transcript_file = "archive/all_transcripts.txt"
    
    train_dataset = SOMOSDataset(train_csv, transcript_file, split="train")
    val_dataset = SOMOSDataset(train_csv, transcript_file, split="val")
    
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

    first_batch = next(iter(train_loader))

    # Print batch details
    audio_embeddings, text_embeddings, labels = first_batch
    print("Audio Embeddings Shape:", audio_embeddings.shape)  # Should be (batch_size, embed_dim)
    print("Text Embeddings Shape:", text_embeddings.shape)  # Should be (batch_size, embed_dim)
    print("Labels:", labels)  # Check if MOS labels are correctly loaded



    dummy_audio, dummy_text, _ = next(iter(train_loader))
    audio_dim = len(dummy_audio)
    text_dim = len(dummy_text)
    
    weak_learners = WeakLearners(audio_dim, text_dim).to(device)
    weak_learners.fit(train_loader)
    
    model = SSLEnsembleModel(audio_dim, text_dim, hidden_dim=256, weak_learners=weak_learners).to(device)

    wandb.watch(model, log="all", log_freq=100)

    optimizer = optim.Adam(model.parameters(), lr=1e-6)

    criterion = nn.MSELoss()
    
    num_epochs = 20
    best_mse = float('inf')

    for epoch in range(num_epochs):
        # Training Loop
        model.train()
        running_loss, total_samples = 0.0, 0
    
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)
        for audio_emb, text_emb, labels in train_pbar:
            optimizer.zero_grad()
            outputs = model(audio_emb, text_emb)
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item() * audio_emb.size(0)
            total_samples += labels.size(0)
            wandb.log({"train_loss": loss.item()})
            train_pbar.set_postfix(loss=loss.item())
    
        train_mse = running_loss / total_samples
        wandb.log({"train_mse": train_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Train MSE: {train_mse:.4f}")
    
        # Validation
        val_mse, val_rmse, val_lcc, val_k_tau, val_acc = evaluate(model, val_loader, device)
        wandb.log({"val_mse": val_mse, "val_rmse": val_rmse, "val_lcc": val_lcc, "val_k_tau": val_k_tau, "val_accuracy": val_acc})
    
        if val_mse < best_mse:
            best_mse = val_mse
            torch.save(model.state_dict(), "best_model.pth")
    
    print("Training complete! Best validation MSE:", best_mse)




In [13]:
main()

Audio Embeddings Shape: torch.Size([4, 512])
Text Embeddings Shape: torch.Size([4, 768])
Labels: tensor([3.7345, 2.3077, 3.7401, 2.9805], device='cuda:1')
Fitting weak learners...


Processing embeddings: 100%|████████████████████████████████████████████████████| 4020/4020 [16:44<00:00,  4.00batch/s]


Training weak learners...


Training Ridge Regression: 100%|███████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.10s/step]
Training SVR: 100%|███████████████████████████████████████████████████████████████████| 1/1 [02:42<00:00, 162.78s/step]
Training Decision Tree: 100%|██████████████████████████████████████████████████████████| 1/1 [00:32<00:00, 32.67s/step]


Weak learners training completed.


                                                                                                                       

Epoch 1/20 - Train MSE: 2.2564


                                                                                                                       

Accuracy (±0.5): 9.38%
MSE: 1.2943
RMSE: 1.1377
LCC: 0.5228
KTAU: 0.3398

5 Examples of Predicted and Ground Truth MOS:
Pred: 2.35, GT: 3.48
Pred: 2.10, GT: 3.36
Pred: 2.17, GT: 3.40
Pred: 1.88, GT: 3.62
Pred: 2.31, GT: 3.71


                                                                                                                       

Epoch 2/20 - Train MSE: 0.6233


                                                                                                                       

Accuracy (±0.5): 58.96%
MSE: 0.3167
RMSE: 0.5628
LCC: 0.5196
KTAU: 0.3373

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.10, GT: 3.48
Pred: 2.76, GT: 3.36
Pred: 2.86, GT: 3.40
Pred: 2.49, GT: 3.62
Pred: 3.04, GT: 3.71


                                                                                                                       

Epoch 3/20 - Train MSE: 0.1023


                                                                                                                       

Accuracy (±0.5): 76.82%
MSE: 0.1769
RMSE: 0.4206
LCC: 0.5118
KTAU: 0.3314

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.51, GT: 3.48
Pred: 3.12, GT: 3.36
Pred: 3.24, GT: 3.40
Pred: 2.82, GT: 3.62
Pred: 3.45, GT: 3.71


                                                                                                                       

Epoch 4/20 - Train MSE: 0.0579


                                                                                                                       

Accuracy (±0.5): 75.90%
MSE: 0.1827
RMSE: 0.4274
LCC: 0.4960
KTAU: 0.3195

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.53, GT: 3.48
Pred: 3.12, GT: 3.36
Pred: 3.26, GT: 3.40
Pred: 2.83, GT: 3.62
Pred: 3.48, GT: 3.71


                                                                                                                       

Epoch 5/20 - Train MSE: 0.0510


                                                                                                                       

Accuracy (±0.5): 75.27%
MSE: 0.1893
RMSE: 0.4350
LCC: 0.4790
KTAU: 0.3072

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.54, GT: 3.48
Pred: 3.09, GT: 3.36
Pred: 3.26, GT: 3.40
Pred: 2.83, GT: 3.62
Pred: 3.49, GT: 3.71


                                                                                                                       

KeyboardInterrupt: 