In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import Wav2Vec2Model, HubertModel, WavLMModel
from transformers import BertTokenizer, BertModel
from sklearn.linear_model import Ridge
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import numpy as np
import wandb
from collections import Counter
from tqdm import tqdm
import torch
import torchaudio
import torchaudio.transforms as T
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer
from whisper import load_model
import whisper

In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device


device(type='cuda', index=1)

In [3]:
# --- Load and Unfreeze Whisper‑medium ---
whisper_model = whisper.load_model("base.en").to(device)
for param in whisper_model.parameters():
    param.requires_grad = True
whisper_model.train()

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device).eval()

In [4]:
# --- Initialize wandb ---
wandb.init(project="somos-ensemble2-ssl", name="finetune-whisper+bert")
!wandb online

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: rtfiof (rtfiof-hse-university). Use `wandb login --relogin` to force relogin


W&B online. Running your script from this directory will now sync to the cloud.


In [5]:
# --- Utility Functions ---
def load_json(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        return json.load(f)

def process_audio_path(clean_path, base_dir="data/somos/audios"):
    return os.path.join(base_dir, clean_path.replace("\\", "/"))

In [6]:
# --- Dataset Class ---
class SOMOSDataset(Dataset):
    def __init__(self, json_file, base_dir="data/somos/audios"):
        self.samples = load_json(json_file)
        self.base_dir = base_dir
        self.labels = [float(sample["mos"]) for sample in self.samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        text = sample["text"]
        label = torch.tensor(float(sample["mos"]), dtype=torch.float)
        audio_path = process_audio_path(sample["clean path"], self.base_dir)
        return audio_path, text, label



class AttentionPooling(torch.nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.attention = torch.nn.Linear(embed_dim, 1)

    def forward(self, x):
        weights = torch.nn.functional.softmax(self.attention(x), dim=1)  
        return (weights * x).sum(dim=1)  


attn_pool = AttentionPooling(embed_dim=512).to(device)  # Whisper Base uses 512-dim embeddings



def collate_fn(batch):
    audio_paths, texts, labels = zip(*batch)
    audios = [whisper.load_audio(path) for path in audio_paths]
    audios = [whisper.pad_or_trim(audio) for audio in audios]
    mel_spectrograms = [whisper.log_mel_spectrogram(audio).to(device) for audio in audios]
    mel_spectrograms = torch.stack(mel_spectrograms)

    # Compute audio embeddings with gradients enabled
    audio_embeddings = whisper_model.encoder(mel_spectrograms).mean(dim=1)

    # Process texts using BERT
    inputs = tokenizer(list(texts), return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        text_embeddings = bert_model(**inputs).last_hidden_state[:, 0, :]

    labels = torch.stack(labels).to(device)
    return audio_embeddings, text_embeddings, labels



In [7]:
class WeakLearners(nn.Module):
    def __init__(self, audio_dim, text_dim, device="cuda:1"):
        super(WeakLearners, self).__init__()
        self.audio_dim = audio_dim
        self.text_dim = text_dim
        self.device = device

        self.ridge_regressor = Ridge(alpha=1.0)
        self.svr = SVR()
        self.dtr = DecisionTreeRegressor()

        self.fitted = False

    def fit(self, train_loader):
        """ Train weak learners using train dataset embeddings """
        print("Fitting weak learners...")

        all_audio_emb, all_text_emb, all_labels = [], [], []

        for audio_emb, text_emb, labels in tqdm(train_loader, desc="Processing embeddings", unit="batch"):
            all_audio_emb.append(audio_emb.cpu().detach().numpy())
            all_text_emb.append(text_emb.cpu().detach().numpy())
            all_labels.append(labels.cpu().detach().numpy())

        if not all_audio_emb or not all_text_emb or not all_labels:
            raise RuntimeError("No embeddings found in the dataset! Check if the train_loader is correctly loading data.")

        all_audio_emb = np.vstack(all_audio_emb)
        all_text_emb = np.vstack(all_text_emb)
        all_labels = np.hstack(all_labels)

        combined_embeddings = np.hstack((all_audio_emb, all_text_emb))

        print("Training weak learners...")
        for model, name in zip([self.ridge_regressor, self.svr, self.dtr], 
                               ["Ridge Regression", "SVR", "Decision Tree"]):
            with tqdm(total=1, desc=f"Training {name}", unit="step") as pbar:
                model.fit(combined_embeddings, all_labels)
                pbar.update(1)

        self.fitted = True
        print("Weak learners training completed.")

    def forward(self, audio_emb, text_emb):
        if not self.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")

        combined_embeddings = torch.cat([audio_emb, text_emb], dim=1).cpu().detach().numpy()

        with torch.no_grad():
            ridge_pred = self.ridge_regressor.predict(combined_embeddings)
            svr_pred = self.svr.predict(combined_embeddings)
            dtr_pred = self.dtr.predict(combined_embeddings)

        ridge_pred = torch.from_numpy(ridge_pred).float().to(self.device)
        svr_pred = torch.from_numpy(svr_pred).float().to(self.device)
        dtr_pred = torch.from_numpy(dtr_pred).float().to(self.device)

        return ridge_pred, svr_pred, dtr_pred


In [8]:
# --- Stacking Model (Meta-Learner) ---
class StackingMetaLearner(nn.Module):
    def __init__(self, weak_output_dim=3, hidden_dim=256):
        super(StackingMetaLearner, self).__init__()
        self.fc1 = nn.Linear(weak_output_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, weak_outputs):
        x = F.relu(self.fc1(weak_outputs))
        return self.fc2(x)

In [9]:
# --- Main Model ---
class SSLEnsembleModel(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, weak_learners=None):
        super(SSLEnsembleModel, self).__init__()
        if weak_learners is None:
            raise ValueError("Weak learners must be provided and fitted before initializing SSLEnsembleModel.")
        
        self.weak_learners = weak_learners
        self.stacking_meta_learner = StackingMetaLearner(weak_output_dim=3, hidden_dim=hidden_dim)

    def forward(self, audio_emb, text_emb):
        if not self.weak_learners.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")
        
        ridge_pred, svr_pred, dtr_pred = self.weak_learners(audio_emb, text_emb)

        weak_outputs = torch.stack([ridge_pred, svr_pred, dtr_pred], dim=1)

        final_output = self.stacking_meta_learner(weak_outputs)
        return final_output


In [10]:
import numpy as np
import torch
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr, kendalltau

def evaluate(model, test_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        test_pbar = tqdm(test_loader, desc="Evaluation", leave=False)
        for audio_emb, text_emb, labels in test_pbar:
            audio_emb, text_emb, labels = audio_emb.to(device), text_emb.to(device), labels.to(device)
            outputs = model(audio_emb, text_emb)
            preds = outputs.squeeze()
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            test_pbar.set_postfix({"predicted": preds[:5].cpu().numpy(), "ground_truth": labels[:5].cpu().numpy()})

    # Convert lists to numpy arrays for easier calculation
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Accuracy (up to ±0.5)
    accuracy = np.mean(np.abs(all_preds - all_labels) <= 0.5)

    # MSE and RMSE
    mse = mean_squared_error(all_labels, all_preds)
    rmse = np.sqrt(mse)

    # LCC (Linear Correlation Coefficient)
    lcc = np.corrcoef(all_labels, all_preds)[0, 1]

    # KTAU (Kendall's Tau)
    k_tau, _ = kendalltau(all_labels, all_preds)

    # Print metrics
    print(f"Accuracy (±0.5): {accuracy*100:.2f}%")
    print(f"MSE: {mse:.4f}")
    print(f"RMSE: {rmse:.4f}")
    print(f"LCC: {lcc:.4f}")
    print(f"KTAU: {k_tau:.4f}")

    # Show 30 examples of predicted and ground truth MOS
    print("\n5 Examples of Predicted and Ground Truth MOS:")
    for i in range(5):
        print(f"Pred: {all_preds[i]:.2f}, GT: {all_labels[i]:.2f}")



In [11]:
# --- Main Training Loop ---
def main():
    train_json = "data/somos/audios/train_new.json"
    test_json = "data/somos/audios/test_new.json"

    train_dataset = SOMOSDataset(train_json)
    test_dataset = SOMOSDataset(test_json)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

    dummy_audio, dummy_text, _ = next(iter(train_loader))
    audio_dim, text_dim = dummy_audio.shape[1], dummy_text.shape[1]
    
    weak_learners = WeakLearners(audio_dim, text_dim).to(device)
    weak_learners.fit(train_loader)
    
    model = SSLEnsembleModel(audio_dim, text_dim, hidden_dim=256, weak_learners=weak_learners).to(device)

    wandb.watch(model, log="all", log_freq=100)

    optimizer = optim.Adam(model.parameters(), lr=1e-6)

    criterion = nn.MSELoss()
    
    num_epochs = 20
    best_mse = float('inf')

    for epoch in range(num_epochs):
        model.train()
        running_loss, total_samples = 0.0, 0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)
        for audio_emb, text_emb, labels in train_pbar:
            optimizer.zero_grad()

            outputs = model(audio_emb, text_emb)
            loss = criterion(outputs.squeeze(), labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * audio_emb.size(0)
            total_samples += labels.size(0)
            wandb.log({"train_loss": loss.item()})
            train_pbar.set_postfix(loss=loss.item())

        train_mse = running_loss / total_samples
        wandb.log({"train_mse": train_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Train MSE: {train_mse:.4f}")

        model.eval()
        test_loss, total_samples = 0.0, 0
        with torch.no_grad():
            test_pbar = tqdm(test_loader, desc=f"Epoch {epoch+1} Validation", leave=False)
            for audio_emb, text_emb, labels in test_pbar:
                outputs = model(audio_emb, text_emb)
                loss = criterion(outputs.squeeze(), labels)
                test_loss += loss.item() * audio_emb.size(0)
                total_samples += labels.size(0)
                test_pbar.set_postfix(loss=loss.item())

        test_mse = test_loss / total_samples
        wandb.log({"val_mse": test_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Val MSE: {test_mse:.4f}")

        # Evaluate on validation set
        evaluate(model, test_loader, device)  # Call the evaluation function after each epoch

        if test_mse < best_mse:
            best_mse = test_mse
            torch.save(model.state_dict(), "best_model.pth")

    print("Training complete! Best validation MSE:", best_mse)


In [12]:
main()

Fitting weak learners...


Processing embeddings: 100%|████████████████████████████████████████████████████| 3525/3525 [13:19<00:00,  4.41batch/s]


Training weak learners...


Training Ridge Regression: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.83step/s]
Training SVR: 100%|████████████████████████████████████████████████████████████████████| 1/1 [01:23<00:00, 83.14s/step]
Training Decision Tree: 100%|██████████████████████████████████████████████████████████| 1/1 [00:19<00:00, 19.38s/step]


Weak learners training completed.


                                                                                                                       

Epoch 1/20 - Train MSE: 3.0637


                                                                                                                       

Epoch 1/20 - Val MSE: 1.7922


                                                                                                                       

Accuracy (±0.5): 6.33%
MSE: 1.7922
RMSE: 1.3387
LCC: 0.6218
KTAU: 0.4352

5 Examples of Predicted and Ground Truth MOS:
Pred: 2.28, GT: 4.00
Pred: 2.40, GT: 4.00
Pred: 2.26, GT: 3.73
Pred: 2.17, GT: 3.40
Pred: 2.01, GT: 3.00


                                                                                                                       

Epoch 2/20 - Train MSE: 0.9554


                                                                                                                       

Epoch 2/20 - Val MSE: 0.4727


                                                                                                                       

Accuracy (±0.5): 44.73%
MSE: 0.4727
RMSE: 0.6875
LCC: 0.6089
KTAU: 0.4234

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.15, GT: 4.00
Pred: 3.31, GT: 4.00
Pred: 3.15, GT: 3.73
Pred: 3.03, GT: 3.40
Pred: 2.80, GT: 3.00


                                                                                                                       

Epoch 3/20 - Train MSE: 0.1894


                                                                                                                       

Epoch 3/20 - Val MSE: 0.2090


                                                                                                                       

Accuracy (±0.5): 72.70%
MSE: 0.2090
RMSE: 0.4572
LCC: 0.5992
KTAU: 0.4148

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.70, GT: 4.00
Pred: 3.88, GT: 4.00
Pred: 3.72, GT: 3.73
Pred: 3.59, GT: 3.40
Pred: 3.30, GT: 3.00


                                                                                                                       

Epoch 4/20 - Train MSE: 0.0936


                                                                                                                       

Epoch 4/20 - Val MSE: 0.2113


                                                                                                                       

Accuracy (±0.5): 73.33%
MSE: 0.2113
RMSE: 0.4597
LCC: 0.5883
KTAU: 0.4056

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.80, GT: 4.00
Pred: 3.97, GT: 4.00
Pred: 3.83, GT: 3.73
Pred: 3.70, GT: 3.40
Pred: 3.40, GT: 3.00


                                                                                                                       

KeyboardInterrupt: 