In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import Wav2Vec2Model, HubertModel, WavLMModel
from transformers import BertTokenizer, BertModel
from sklearn.linear_model import Ridge
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import numpy as np
import wandb
from collections import Counter
from tqdm import tqdm
import torch
import torchaudio
import torchaudio.transforms as T
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer
from whisper import load_model
import whisper

In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device


device(type='cuda', index=1)

In [3]:
# --- Load and Unfreeze Whisper‑medium ---
whisper_model = whisper.load_model("base.en").to(device)
for param in whisper_model.parameters():
    param.requires_grad = True
whisper_model.train()


Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(512, 512, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-5): 6 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=512, out_features=512, bias=True)
          (key): Linear(in_features=512, out_features=512, bias=False)
          (value): Linear(in_features=512, out_features=512, bias=True)
          (out): Linear(in_features=512, out_features=512, bias=True)
        )
        (attn_ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (mlp_ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm((512,), eps=1e-05,

In [4]:
# --- Initialize wandb ---
wandb.init(project="somos-ensemble2-ssl", name="finetune-whisper_s_no_text")
!wandb online

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: rtfiof (rtfiof-hse-university). Use `wandb login --relogin` to force relogin


W&B online. Running your script from this directory will now sync to the cloud.


In [5]:
# --- Utility Functions ---
def load_json(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        return json.load(f)

def process_audio_path(clean_path, base_dir="data/somos/audios"):
    return os.path.join(base_dir, clean_path.replace("\\", "/"))

In [6]:

# --- Dataset Class (No Text) ---
class SOMOSDataset(Dataset):
    def __init__(self, json_file, base_dir="data/somos/audios"):
        self.samples = load_json(json_file)
        self.base_dir = base_dir
        self.labels = [float(sample["mos"]) for sample in self.samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        label = torch.tensor(float(sample["mos"]), dtype=torch.float)
        audio_path = process_audio_path(sample["clean path"], self.base_dir)
        return audio_path, label


class AttentionPooling(torch.nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.attention = torch.nn.Linear(embed_dim, 1)

    def forward(self, x):
        weights = torch.nn.functional.softmax(self.attention(x), dim=1)  
        return (weights * x).sum(dim=1)  


attn_pool = AttentionPooling(embed_dim=512).to(device)  # Whisper Base uses 512-dim embeddings


def collate_fn(batch):
    audio_paths, labels = zip(*batch)
    audios = [whisper.load_audio(path) for path in audio_paths]
    audios = [whisper.pad_or_trim(audio) for audio in audios]
    mel_spectrograms = [whisper.log_mel_spectrogram(audio).to(device) for audio in audios]
    mel_spectrograms = torch.stack(mel_spectrograms)

    # Compute audio embeddings with gradients enabled
    audio_embeddings = whisper_model.encoder(mel_spectrograms).mean(dim=1)

    labels = torch.stack(labels).to(device)
    return audio_embeddings, labels

In [7]:
class WeakLearners(nn.Module):
    def __init__(self, audio_dim, device="cuda:1"):
        super(WeakLearners, self).__init__()
        self.audio_dim = audio_dim
        self.device = device

        self.ridge_regressor = Ridge(alpha=1.0)
        self.svr = SVR()
        self.dtr = DecisionTreeRegressor()

        self.fitted = False

    def fit(self, train_loader):
        """ Train weak learners using train dataset embeddings """
        print("Fitting weak learners...")

        all_audio_emb, all_labels = [], []

        for audio_emb, labels in tqdm(train_loader, desc="Processing embeddings", unit="batch"):
            all_audio_emb.append(audio_emb.cpu().detach().numpy())
            all_labels.append(labels.cpu().detach().numpy())

        if not all_audio_emb or not all_labels:
            raise RuntimeError("No embeddings found in the dataset! Check if the train_loader is correctly loading data.")

        all_audio_emb = np.vstack(all_audio_emb)
        all_labels = np.hstack(all_labels)

        print("Training weak learners...")
        for model, name in zip([self.ridge_regressor, self.svr, self.dtr], 
                               ["Ridge Regression", "SVR", "Decision Tree"]):
            with tqdm(total=1, desc=f"Training {name}", unit="step") as pbar:
                model.fit(all_audio_emb, all_labels)
                pbar.update(1)

        self.fitted = True
        print("Weak learners training completed.")

    def forward(self, audio_emb):
        if not self.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")

        audio_emb = audio_emb.cpu().detach().numpy()

        with torch.no_grad():
            ridge_pred = self.ridge_regressor.predict(audio_emb)
            svr_pred = self.svr.predict(audio_emb)
            dtr_pred = self.dtr.predict(audio_emb)

        ridge_pred = torch.from_numpy(ridge_pred).float().to(self.device)
        svr_pred = torch.from_numpy(svr_pred).float().to(self.device)
        dtr_pred = torch.from_numpy(dtr_pred).float().to(self.device)

        return ridge_pred, svr_pred, dtr_pred


In [8]:
# --- Stacking Model (Meta-Learner) ---
class StackingMetaLearner(nn.Module):
    def __init__(self, weak_output_dim=3, hidden_dim=256):
        super(StackingMetaLearner, self).__init__()
        self.fc1 = nn.Linear(weak_output_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, weak_outputs):
        x = F.relu(self.fc1(weak_outputs))
        return self.fc2(x)

In [9]:
# --- Main Model ---
class SSLEnsembleModel(nn.Module):
    def __init__(self, audio_dim, hidden_dim=256, weak_learners=None):
        super(SSLEnsembleModel, self).__init__()
        if weak_learners is None:
            raise ValueError("Weak learners must be provided and fitted before initializing SSLEnsembleModel.")
        
        self.weak_learners = weak_learners
        self.stacking_meta_learner = StackingMetaLearner(weak_output_dim=3, hidden_dim=hidden_dim)

    def forward(self, audio_emb):
        if not self.weak_learners.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")
        
        # Only use audio embeddings for weak learners
        ridge_pred, svr_pred, dtr_pred = self.weak_learners(audio_emb)

        weak_outputs = torch.stack([ridge_pred, svr_pred, dtr_pred], dim=1)

        final_output = self.stacking_meta_learner(weak_outputs)
        return final_output


In [10]:
import numpy as np
import torch
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr, kendalltau
from tqdm import tqdm

def evaluate(model, test_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        test_pbar = tqdm(test_loader, desc="Evaluation", leave=False)
        for audio_emb, labels in test_pbar:
            audio_emb, labels = audio_emb.to(device), labels.to(device)
            outputs = model(audio_emb)
            preds = outputs.squeeze()
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            test_pbar.set_postfix({"predicted": preds[:5].cpu().numpy(), "ground_truth": labels[:5].cpu().numpy()})

    # Convert lists to numpy arrays for easier calculation
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Accuracy (up to ±0.5)
    accuracy = np.mean(np.abs(all_preds - all_labels) <= 0.5)

    # MSE and RMSE
    mse = mean_squared_error(all_labels, all_preds)
    rmse = np.sqrt(mse)

    # LCC (Linear Correlation Coefficient)
    lcc = np.corrcoef(all_labels, all_preds)[0, 1]

    # KTAU (Kendall's Tau)
    k_tau, _ = kendalltau(all_labels, all_preds)

    # Print metrics
    print(f"Accuracy (±0.5): {accuracy*100:.2f}%")
    print(f"MSE: {mse:.4f}")
    print(f"RMSE: {rmse:.4f}")
    print(f"LCC: {lcc:.4f}")
    print(f"KTAU: {k_tau:.4f}")

    # Show 5 examples of predicted and ground truth MOS
    print("\n5 Examples of Predicted and Ground Truth MOS:")
    for i in range(5):
        print(f"Pred: {all_preds[i]:.2f}, GT: {all_labels[i]:.2f}")

In [11]:
# --- Main Training Loop ---
def main():
    train_json = "data/somos/audios/train_new.json"
    test_json = "data/somos/audios/test_new.json"

    train_dataset = SOMOSDataset(train_json)
    test_dataset = SOMOSDataset(test_json)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

    dummy_audio, _ = next(iter(train_loader))
    audio_dim = dummy_audio.shape[1]
    
    weak_learners = WeakLearners(audio_dim).to(device)
    weak_learners.fit(train_loader)
    
    model = SSLEnsembleModel(audio_dim, hidden_dim=256, weak_learners=weak_learners).to(device)

    wandb.watch(model, log="all", log_freq=100)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

    criterion = torch.nn.MSELoss()
    
    num_epochs = 20
    best_mse = float('inf')

    for epoch in range(num_epochs):
        model.train()
        running_loss, total_samples = 0.0, 0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)
        for audio_emb, labels in train_pbar:
            optimizer.zero_grad()

            outputs = model(audio_emb)
            loss = criterion(outputs.squeeze(), labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * audio_emb.size(0)
            total_samples += labels.size(0)
            wandb.log({"train_loss": loss.item()})
            train_pbar.set_postfix(loss=loss.item())

        train_mse = running_loss / total_samples
        wandb.log({"train_mse": train_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Train MSE: {train_mse:.4f}")

        model.eval()
        test_loss, total_samples = 0.0, 0
        with torch.no_grad():
            test_pbar = tqdm(test_loader, desc=f"Epoch {epoch+1} Validation", leave=False)
            for audio_emb, labels in test_pbar:
                outputs = model(audio_emb)
                loss = criterion(outputs.squeeze(), labels)
                test_loss += loss.item() * audio_emb.size(0)
                total_samples += labels.size(0)
                test_pbar.set_postfix(loss=loss.item())

        test_mse = test_loss / total_samples
        wandb.log({"val_mse": test_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Val MSE: {test_mse:.4f}")

        # Evaluate on validation set
        evaluate(model, test_loader, device)  # Call the evaluation function after each epoch

        if test_mse < best_mse:
            best_mse = test_mse
            torch.save(model.state_dict(), "best_model.pth")

    print("Training complete! Best validation MSE:", best_mse)

In [12]:
main()

Fitting weak learners...


Processing embeddings: 100%|████████████████████████████████████████████████████| 3525/3525 [14:30<00:00,  4.05batch/s]


Training weak learners...


Training Ridge Regression: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.21step/s]
Training SVR: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:37<00:00, 37.45s/step]
Training Decision Tree: 100%|██████████████████████████████████████████████████████████| 1/1 [00:08<00:00,  8.58s/step]


Weak learners training completed.


                                                                                                                       

Epoch 1/20 - Train MSE: 12.8441


                                                                                                                       

Epoch 1/20 - Val MSE: 10.2577


                                                                                                                       

Accuracy (±0.5): 0.00%
MSE: 10.2577
RMSE: 3.2028
LCC: 0.0297
KTAU: 0.0296

5 Examples of Predicted and Ground Truth MOS:
Pred: 0.09, GT: 4.00
Pred: 0.07, GT: 4.00
Pred: 0.07, GT: 3.73
Pred: 0.09, GT: 3.40
Pred: 0.08, GT: 3.00


                                                                                                                       

Epoch 2/20 - Train MSE: 8.0394


                                                                                                                       

Epoch 2/20 - Val MSE: 6.1099


                                                                                                                       

Accuracy (±0.5): 0.03%
MSE: 6.1099
RMSE: 2.4718
LCC: 0.5064
KTAU: 0.3437

5 Examples of Predicted and Ground Truth MOS:
Pred: 0.90, GT: 4.00
Pred: 0.91, GT: 4.00
Pred: 0.86, GT: 3.73
Pred: 0.90, GT: 3.40
Pred: 0.80, GT: 3.00


                                                                                                                       

Epoch 3/20 - Train MSE: 4.4579


                                                                                                                       

Epoch 3/20 - Val MSE: 3.1233


                                                                                                                       

Accuracy (±0.5): 1.10%
MSE: 3.1233
RMSE: 1.7673
LCC: 0.5469
KTAU: 0.3726

5 Examples of Predicted and Ground Truth MOS:
Pred: 1.69, GT: 4.00
Pred: 1.74, GT: 4.00
Pred: 1.63, GT: 3.73
Pred: 1.68, GT: 3.40
Pred: 1.51, GT: 3.00


                                                                                                                       

Epoch 4/20 - Train MSE: 2.0106


                                                                                                                       

Epoch 4/20 - Val MSE: 1.2452


                                                                                                                       

Accuracy (±0.5): 14.90%
MSE: 1.2452
RMSE: 1.1159
LCC: 0.5591
KTAU: 0.3814

5 Examples of Predicted and Ground Truth MOS:
Pred: 2.44, GT: 4.00
Pred: 2.53, GT: 4.00
Pred: 2.36, GT: 3.73
Pred: 2.44, GT: 3.40
Pred: 2.18, GT: 3.00


                                                                                                                       

Epoch 5/20 - Train MSE: 0.6326


                                                                                                                       

Epoch 5/20 - Val MSE: 0.3804


                                                                                                                       

Accuracy (±0.5): 53.67%
MSE: 0.3804
RMSE: 0.6168
LCC: 0.5624
KTAU: 0.3839

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.11, GT: 4.00
Pred: 3.22, GT: 4.00
Pred: 3.01, GT: 3.73
Pred: 3.11, GT: 3.40
Pred: 2.78, GT: 3.00


                                                                                                                       

Epoch 6/20 - Train MSE: 0.1477


                                                                                                                       

Epoch 6/20 - Val MSE: 0.2243


                                                                                                                       

Accuracy (±0.5): 70.70%
MSE: 0.2243
RMSE: 0.4736
LCC: 0.5547
KTAU: 0.3776

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.52, GT: 4.00
Pred: 3.63, GT: 4.00
Pred: 3.40, GT: 3.73
Pred: 3.53, GT: 3.40
Pred: 3.14, GT: 3.00


                                                                                                                       

Epoch 7/20 - Train MSE: 0.0886


                                                                                                                       

Epoch 7/20 - Val MSE: 0.2307


                                                                                                                       

Accuracy (±0.5): 70.23%
MSE: 0.2307
RMSE: 0.4803
LCC: 0.5368
KTAU: 0.3637

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.57, GT: 4.00
Pred: 3.67, GT: 4.00
Pred: 3.44, GT: 3.73
Pred: 3.60, GT: 3.40
Pred: 3.18, GT: 3.00


                                                                                                                       

Epoch 8/20 - Train MSE: 0.0778


                                                                                                                       

Epoch 8/20 - Val MSE: 0.2393


                                                                                                                       

Accuracy (±0.5): 69.17%
MSE: 0.2393
RMSE: 0.4892
LCC: 0.5166
KTAU: 0.3482

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.56, GT: 4.00
Pred: 3.64, GT: 4.00
Pred: 3.43, GT: 3.73
Pred: 3.63, GT: 3.40
Pred: 3.18, GT: 3.00


                                                                                                                       

Epoch 9/20 - Train MSE: 0.0678


                                                                                                                       

Epoch 9/20 - Val MSE: 0.2488


                                                                                                                       

Accuracy (±0.5): 68.17%
MSE: 0.2488
RMSE: 0.4988
LCC: 0.4961
KTAU: 0.3328

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.56, GT: 4.00
Pred: 3.62, GT: 4.00
Pred: 3.41, GT: 3.73
Pred: 3.66, GT: 3.40
Pred: 3.17, GT: 3.00


                                                                                                                       

Epoch 10/20 - Train MSE: 0.0585


                                                                                                                       

Epoch 10/20 - Val MSE: 0.2598


                                                                                                                       

Accuracy (±0.5): 67.00%
MSE: 0.2598
RMSE: 0.5097
LCC: 0.4755
KTAU: 0.3176

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.57, GT: 4.00
Pred: 3.61, GT: 4.00
Pred: 3.41, GT: 3.73
Pred: 3.70, GT: 3.40
Pred: 3.18, GT: 3.00


                                                                                                                       

Epoch 11/20 - Train MSE: 0.0497


                                                                                                                       

Epoch 11/20 - Val MSE: 0.2716


                                                                                                                       

Accuracy (±0.5): 66.00%
MSE: 0.2716
RMSE: 0.5212
LCC: 0.4547
KTAU: 0.3028

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.57, GT: 4.00
Pred: 3.59, GT: 4.00
Pred: 3.40, GT: 3.73
Pred: 3.73, GT: 3.40
Pred: 3.18, GT: 3.00


                                                                                                                       

Epoch 12/20 - Train MSE: 0.0416


                                                                                                                       

Epoch 12/20 - Val MSE: 0.2844


                                                                                                                       

Accuracy (±0.5): 64.77%
MSE: 0.2844
RMSE: 0.5333
LCC: 0.4345
KTAU: 0.2884

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.57, GT: 4.00
Pred: 3.57, GT: 4.00
Pred: 3.39, GT: 3.73
Pred: 3.77, GT: 3.40
Pred: 3.18, GT: 3.00


                                                                                                                       

Epoch 13/20 - Train MSE: 0.0343


                                                                                                                       

Epoch 13/20 - Val MSE: 0.2982


                                                                                                                       

Accuracy (±0.5): 64.00%
MSE: 0.2982
RMSE: 0.5461
LCC: 0.4153
KTAU: 0.2748

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.58, GT: 4.00
Pred: 3.55, GT: 4.00
Pred: 3.38, GT: 3.73
Pred: 3.81, GT: 3.40
Pred: 3.19, GT: 3.00


                                                                                                                       

Epoch 14/20 - Train MSE: 0.0278


                                                                                                                       

Epoch 14/20 - Val MSE: 0.3128


                                                                                                                       

Accuracy (±0.5): 62.20%
MSE: 0.3128
RMSE: 0.5593
LCC: 0.3964
KTAU: 0.2616

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.58, GT: 4.00
Pred: 3.52, GT: 4.00
Pred: 3.37, GT: 3.73
Pred: 3.84, GT: 3.40
Pred: 3.18, GT: 3.00


                                                                                                                       

Epoch 15/20 - Train MSE: 0.0220


                                                                                                                       

Epoch 15/20 - Val MSE: 0.3282


                                                                                                                       

Accuracy (±0.5): 61.20%
MSE: 0.3282
RMSE: 0.5729
LCC: 0.3784
KTAU: 0.2492

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.57, GT: 4.00
Pred: 3.50, GT: 4.00
Pred: 3.35, GT: 3.73
Pred: 3.86, GT: 3.40
Pred: 3.18, GT: 3.00


                                                                                                                       

Epoch 16/20 - Train MSE: 0.0170


                                                                                                                       

Epoch 16/20 - Val MSE: 0.3445


                                                                                                                       

Accuracy (±0.5): 59.63%
MSE: 0.3445
RMSE: 0.5869
LCC: 0.3619
KTAU: 0.2380

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.58, GT: 4.00
Pred: 3.49, GT: 4.00
Pred: 3.34, GT: 3.73
Pred: 3.90, GT: 3.40
Pred: 3.18, GT: 3.00


                                                                                                                       

Epoch 17/20 - Train MSE: 0.0126


                                                                                                                       

Epoch 17/20 - Val MSE: 0.3616


                                                                                                                       

Accuracy (±0.5): 58.60%
MSE: 0.3616
RMSE: 0.6013
LCC: 0.3464
KTAU: 0.2276

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.58, GT: 4.00
Pred: 3.47, GT: 4.00
Pred: 3.34, GT: 3.73
Pred: 3.94, GT: 3.40
Pred: 3.19, GT: 3.00


                                                                                                                       

Epoch 18/20 - Train MSE: 0.0090


                                                                                                                       

Epoch 18/20 - Val MSE: 0.3791


                                                                                                                       

Accuracy (±0.5): 57.40%
MSE: 0.3791
RMSE: 0.6157
LCC: 0.3319
KTAU: 0.2181

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.58, GT: 4.00
Pred: 3.45, GT: 4.00
Pred: 3.33, GT: 3.73
Pred: 3.98, GT: 3.40
Pred: 3.19, GT: 3.00


                                                                                                                       

Epoch 19/20 - Train MSE: 0.0061


                                                                                                                       

Epoch 19/20 - Val MSE: 0.3971


                                                                                                                       

Accuracy (±0.5): 56.03%
MSE: 0.3971
RMSE: 0.6302
LCC: 0.3184
KTAU: 0.2095

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.58, GT: 4.00
Pred: 3.44, GT: 4.00
Pred: 3.32, GT: 3.73
Pred: 4.01, GT: 3.40
Pred: 3.19, GT: 3.00


                                                                                                                       

Epoch 20/20 - Train MSE: 0.0038


                                                                                                                       

Epoch 20/20 - Val MSE: 0.4156


                                                                                                                       

Accuracy (±0.5): 55.00%
MSE: 0.4156
RMSE: 0.6447
LCC: 0.3061
KTAU: 0.2018

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.59, GT: 4.00
Pred: 3.42, GT: 4.00
Pred: 3.32, GT: 3.73
Pred: 4.05, GT: 3.40
Pred: 3.19, GT: 3.00
Training complete! Best validation MSE: 0.22432248597902557
