In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import Wav2Vec2Model, HubertModel, WavLMModel
from transformers import BertTokenizer, BertModel
from sklearn.linear_model import Ridge
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import numpy as np
import wandb
from collections import Counter
from tqdm import tqdm
import torch
import torchaudio
import torchaudio.transforms as T
from sentence_transformers import SentenceTransformer


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [3]:
text_model = SentenceTransformer("all-MiniLM-L6-v2").to(device)


In [4]:
# --- Initialize wandb ---
wandb.init(project="somos-ensemble2-ssl", name="finetune-only_weak")
!wandb online

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: rtfiof (rtfiof-hse-university). Use `wandb login --relogin` to force relogin


W&B online. Running your script from this directory will now sync to the cloud.


In [5]:
# --- Utility Functions ---
def load_json(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        return json.load(f)

def process_audio_path(clean_path, base_dir="data/somos/audios"):
    return os.path.join(base_dir, clean_path.replace("\\", "/"))

In [6]:
# --- Dataset Class ---
class SOMOSDataset(Dataset):
    def __init__(self, json_file, base_dir="data/somos/audios"):
        self.samples = load_json(json_file)
        self.base_dir = base_dir
        self.labels = [float(sample["mos"]) for sample in self.samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        text = sample["text"]
        label = torch.tensor(float(sample["mos"]), dtype=torch.float)
        audio_path = process_audio_path(sample["clean path"], self.base_dir)
        return audio_path, text, label

def collate_fn(batch):
    audio_paths, texts, labels = zip(*batch)
    
    waveform_list = [torchaudio.load(path)[0] for path in audio_paths]
    resample = T.Resample(orig_freq=16000, new_freq=16000)
    waveform_list = [resample(waveform) for waveform in waveform_list]

    mel_transform = T.MelSpectrogram(sample_rate=16000, n_mels=80)
    mel_spectrograms = [mel_transform(waveform).squeeze(0) for waveform in waveform_list]

    max_length = max(mel.shape[1] for mel in mel_spectrograms)

    def pad_to_max(mel, max_len):
        pad_size = max_len - mel.shape[1]
        return F.pad(mel, (0, pad_size))

    mel_spectrograms = torch.stack([pad_to_max(mel, max_length) for mel in mel_spectrograms]).to(device)

    audio_embeddings = mel_spectrograms.mean(dim=-1)
    
    text_embeddings = text_model.encode(list(texts), convert_to_tensor=True).to(device)

    labels = torch.stack(labels).to(device)
    return audio_embeddings, text_embeddings, labels

In [7]:
class WeakLearners(nn.Module):
    def __init__(self, audio_dim, text_dim, device="cuda"):
        super(WeakLearners, self).__init__()
        self.audio_dim = audio_dim
        self.text_dim = text_dim
        self.device = device

        self.ridge_regressor = Ridge(alpha=1.0)
        self.svr = SVR()
        self.dtr = DecisionTreeRegressor()

        self.fitted = False

    def fit(self, train_loader):
        """ Train weak learners using train dataset embeddings """
        print("Fitting weak learners...")

        all_audio_emb, all_text_emb, all_labels = [], [], []

        for audio_emb, text_emb, labels in tqdm(train_loader, desc="Processing embeddings", unit="batch"):
            all_audio_emb.append(audio_emb.cpu().detach().numpy())
            all_text_emb.append(text_emb.cpu().detach().numpy())
            all_labels.append(labels.cpu().detach().numpy())

        if not all_audio_emb or not all_text_emb or not all_labels:
            raise RuntimeError("No embeddings found in the dataset! Check if the train_loader is correctly loading data.")

        all_audio_emb = np.vstack(all_audio_emb)
        all_text_emb = np.vstack(all_text_emb)
        all_labels = np.hstack(all_labels)

        combined_embeddings = np.hstack((all_audio_emb, all_text_emb))

        print("Training weak learners...")
        for model, name in zip([self.ridge_regressor, self.svr, self.dtr], 
                               ["Ridge Regression", "SVR", "Decision Tree"]):
            with tqdm(total=1, desc=f"Training {name}", unit="step") as pbar:
                model.fit(combined_embeddings, all_labels)
                pbar.update(1)

        self.fitted = True
        print("Weak learners training completed.")

    def forward(self, audio_emb, text_emb):
        if not self.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")

        combined_embeddings = torch.cat([audio_emb, text_emb], dim=1).cpu().detach().numpy()

        with torch.no_grad():
            ridge_pred = self.ridge_regressor.predict(combined_embeddings)
            svr_pred = self.svr.predict(combined_embeddings)
            dtr_pred = self.dtr.predict(combined_embeddings)

        ridge_pred = torch.from_numpy(ridge_pred).float().to(self.device)
        svr_pred = torch.from_numpy(svr_pred).float().to(self.device)
        dtr_pred = torch.from_numpy(dtr_pred).float().to(self.device)

        return ridge_pred, svr_pred, dtr_pred


In [8]:
# --- Stacking Model (Meta-Learner) ---
class StackingMetaLearner(nn.Module):
    def __init__(self, weak_output_dim=3, hidden_dim=256):
        super(StackingMetaLearner, self).__init__()
        self.fc1 = nn.Linear(weak_output_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, weak_outputs):
        x = F.relu(self.fc1(weak_outputs))
        return self.fc2(x)

In [9]:
# --- Main Model ---
class SSLEnsembleModel(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, weak_learners=None):
        super(SSLEnsembleModel, self).__init__()
        if weak_learners is None:
            raise ValueError("Weak learners must be provided and fitted before initializing SSLEnsembleModel.")
        
        self.weak_learners = weak_learners
        self.stacking_meta_learner = StackingMetaLearner(weak_output_dim=3, hidden_dim=hidden_dim)

    def forward(self, audio_emb, text_emb):
        if not self.weak_learners.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")
        
        ridge_pred, svr_pred, dtr_pred = self.weak_learners(audio_emb, text_emb)

        weak_outputs = torch.stack([ridge_pred, svr_pred, dtr_pred], dim=1)

        final_output = self.stacking_meta_learner(weak_outputs)
        return final_output


In [10]:
# --- Main Training Loop ---
def main():
    train_json = "data/somos/audios/train_new.json"
    test_json = "data/somos/audios/test_new.json"

    train_dataset = SOMOSDataset(train_json)
    test_dataset = SOMOSDataset(test_json)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

    dummy_audio, dummy_text, _ = next(iter(train_loader))
    audio_dim, text_dim = dummy_audio.shape[1], dummy_text.shape[1]
    
    weak_learners = WeakLearners(audio_dim, text_dim).to(device)
    weak_learners.fit(train_loader)
    
    model = SSLEnsembleModel(audio_dim, text_dim, hidden_dim=256, weak_learners=weak_learners).to(device)

    wandb.watch(model, log="all", log_freq=100)

    optimizer = optim.Adam(model.parameters(), lr=1e-6)

    criterion = nn.MSELoss()
    
    num_epochs = 20
    best_mse = float('inf')

    for epoch in range(num_epochs):
        model.train()
        running_loss, total_samples = 0.0, 0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)
        for audio_emb, text_emb, labels in train_pbar:
            optimizer.zero_grad()

            outputs = model(audio_emb, text_emb)
            loss = criterion(outputs.squeeze(), labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * audio_emb.size(0)
            total_samples += labels.size(0)
            wandb.log({"train_loss": loss.item()})
            train_pbar.set_postfix(loss=loss.item())

        train_mse = running_loss / total_samples
        wandb.log({"train_mse": train_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Train MSE: {train_mse:.4f}")

        model.eval()
        test_loss, total_samples = 0.0, 0
        with torch.no_grad():
            test_pbar = tqdm(test_loader, desc=f"Epoch {epoch+1} Validation", leave=False)
            for audio_emb, text_emb, labels in test_pbar:
                outputs = model(audio_emb, text_emb)
                loss = criterion(outputs.squeeze(), labels)
                test_loss += loss.item() * audio_emb.size(0)
                total_samples += labels.size(0)
                test_pbar.set_postfix(loss=loss.item())

        test_mse = test_loss / total_samples
        wandb.log({"val_mse": test_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Val MSE: {test_mse:.4f}")

        if test_mse < best_mse:
            best_mse = test_mse
            torch.save(model.state_dict(), "best_model.pth")

    print("Training complete! Best validation MSE:", best_mse)

In [11]:
main()

Fitting weak learners...


Processing embeddings: 100%|████████████████████████████████████████████████████| 3525/3525 [00:45<00:00, 77.78batch/s]


Training weak learners...


Training Ridge Regression: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.75step/s]
Training SVR: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:33<00:00, 33.12s/step]
Training Decision Tree: 100%|██████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.83s/step]


Weak learners training completed.


                                                                                                                       

Epoch 1/10 - Train MSE: 22.1980


                                                                                                                       

Epoch 1/10 - Val MSE: 18.3550


                                                                                                                       

Epoch 2/10 - Train MSE: 15.1797


                                                                                                                       

Epoch 2/10 - Val MSE: 12.1435


                                                                                                                       

Epoch 3/10 - Train MSE: 9.6047


                                                                                                                       

Epoch 3/10 - Val MSE: 7.3068


                                                                                                                       

Epoch 4/10 - Train MSE: 5.3898


                                                                                                                       

Epoch 4/10 - Val MSE: 3.7940


                                                                                                                       

Epoch 5/10 - Train MSE: 2.4978


                                                                                                                       

Epoch 5/10 - Val MSE: 1.5721


                                                                                                                       

Epoch 6/10 - Train MSE: 0.8698


                                                                                                                       

Epoch 6/10 - Val MSE: 0.5328


                                                                                                                       

Epoch 7/10 - Train MSE: 0.2958


                                                                                                                       

Epoch 7/10 - Val MSE: 0.3256


                                                                                                                       

Epoch 8/10 - Train MSE: 0.2292


                                                                                                                       

Epoch 8/10 - Val MSE: 0.3221


                                                                                                                       

Epoch 9/10 - Train MSE: 0.2307


                                                                                                                       

Epoch 9/10 - Val MSE: 0.3230


                                                                                                                       

Epoch 10/10 - Train MSE: 0.2301


                                                                                                                       

Epoch 10/10 - Val MSE: 0.3239
Training complete! Best validation MSE: 0.32213694201037285


In [12]:
from scipy.stats import pearsonr, kendalltau

def load_model(model_path, audio_dim, text_dim, weak_learners):
    model = SSLEnsembleModel(audio_dim, text_dim, hidden_dim=256, weak_learners=weak_learners).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

def evaluate(model, test_loader):
    all_preds, all_labels = [], []

    with torch.no_grad():
        test_pbar = tqdm(test_loader, desc="Evaluating", leave=False)
        for audio_emb, text_emb, labels in test_pbar:
            outputs = model(audio_emb, text_emb)
            preds = outputs.squeeze().cpu().numpy()
            labels = labels.cpu().numpy()

            all_preds.extend(preds)
            all_labels.extend(labels)

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    mse = np.mean((all_preds - all_labels) ** 2)
    rmse = np.sqrt(mse)
    lcc, _ = pearsonr(all_preds, all_labels)
    ktau, _ = kendalltau(all_preds, all_labels)

    accuracy = np.mean(np.abs(all_preds - all_labels) <= 0.5)

    print("\nEvaluation Results:")
    print(f"Accuracy (±0.5): {accuracy * 100:.2f}%")
    print(f"MSE: {mse:.4f}")
    print(f"RMSE: {rmse:.4f}")
    print(f"LCC: {lcc:.4f}")
    print(f"Kendall’s Tau: {ktau:.4f}\n")

    print("Sample Predictions (Predicted MOS vs Ground Truth MOS):")
    for i in range(min(30, len(all_preds))):
        print(f"Predicted: {all_preds[i]:.2f} | Ground Truth: {all_labels[i]:.2f}")

test_json = "data/somos/audios/test_new.json"
test_dataset = SOMOSDataset(test_json)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

dummy_audio, dummy_text, _ = next(iter(test_loader))
audio_dim, text_dim = dummy_audio.shape[1], dummy_text.shape[1]

weak_learners = WeakLearners(audio_dim, text_dim).to(device)
weak_learners.fit(test_loader)

model = load_model("best_model.pth", audio_dim, text_dim, weak_learners)
evaluate(model, test_loader)

Fitting weak learners...


Processing embeddings: 100%|██████████████████████████████████████████████████████| 750/750 [00:09<00:00, 82.19batch/s]


Training weak learners...


Training Ridge Regression: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00, 70.51step/s]
Training SVR: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.79step/s]
Training Decision Tree: 100%|██████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.20s/step]


Weak learners training completed.


                                                                                                                       


Evaluation Results:
Accuracy (±0.5): 85.00%
MSE: 0.1209
RMSE: 0.3477
LCC: 0.9177
Kendall’s Tau: 0.7628

Sample Predictions (Predicted MOS vs Ground Truth MOS):
Predicted: 3.64 | Ground Truth: 4.00
Predicted: 3.58 | Ground Truth: 4.00
Predicted: 3.49 | Ground Truth: 3.73
Predicted: 3.29 | Ground Truth: 3.40
Predicted: 3.07 | Ground Truth: 3.00
Predicted: 3.51 | Ground Truth: 3.89
Predicted: 3.33 | Ground Truth: 3.33
Predicted: 2.97 | Ground Truth: 2.50
Predicted: 3.52 | Ground Truth: 3.93
Predicted: 3.58 | Ground Truth: 3.90
Predicted: 3.52 | Ground Truth: 4.00
Predicted: 3.50 | Ground Truth: 3.90
Predicted: 3.17 | Ground Truth: 3.00
Predicted: 3.35 | Ground Truth: 3.82
Predicted: 3.11 | Ground Truth: 2.67
Predicted: 3.08 | Ground Truth: 2.75
Predicted: 3.33 | Ground Truth: 3.64
Predicted: 3.11 | Ground Truth: 2.38
Predicted: 3.38 | Ground Truth: 3.56
Predicted: 3.47 | Ground Truth: 3.67
Predicted: 3.44 | Ground Truth: 3.78
Predicted: 3.13 | Ground Truth: 3.00
Predicted: 3.39 | Ground 