In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from transformers import Wav2Vec2Model, HubertModel, WavLMModel
from transformers import BertTokenizer, BertModel
from sklearn.linear_model import Ridge
from sklearn.svm import SVR
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
import numpy as np
import wandb
from collections import Counter
from tqdm import tqdm
import torch
import torchaudio
import torchaudio.transforms as T
from sentence_transformers import SentenceTransformer


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [3]:
text_model = SentenceTransformer("all-MiniLM-L6-v2").to(device)


In [4]:
# --- Initialize wandb ---
wandb.init(project="somos-ensemble2-ssl", name="finetune-only_weak")
!wandb online

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: rtfiof (rtfiof-hse-university). Use `wandb login --relogin` to force relogin


W&B online. Running your script from this directory will now sync to the cloud.


In [5]:
# --- Utility Functions ---
def load_json(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        return json.load(f)

def process_audio_path(clean_path, base_dir="data/somos/audios"):
    return os.path.join(base_dir, clean_path.replace("\\", "/"))

In [6]:
# --- Dataset Class ---
class SOMOSDataset(Dataset):
    def __init__(self, json_file, base_dir="data/somos/audios"):
        self.samples = load_json(json_file)
        self.base_dir = base_dir
        self.labels = [float(sample["mos"]) for sample in self.samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        text = sample["text"]
        label = torch.tensor(float(sample["mos"]), dtype=torch.float)
        audio_path = process_audio_path(sample["clean path"], self.base_dir)
        return audio_path, text, label

def collate_fn(batch):
    audio_paths, texts, labels = zip(*batch)
    
    waveform_list = [torchaudio.load(path)[0] for path in audio_paths]
    resample = T.Resample(orig_freq=16000, new_freq=16000)
    waveform_list = [resample(waveform) for waveform in waveform_list]

    mel_transform = T.MelSpectrogram(sample_rate=16000, n_mels=80)
    mel_spectrograms = [mel_transform(waveform).squeeze(0) for waveform in waveform_list]

    max_length = max(mel.shape[1] for mel in mel_spectrograms)

    def pad_to_max(mel, max_len):
        pad_size = max_len - mel.shape[1]
        return F.pad(mel, (0, pad_size))

    mel_spectrograms = torch.stack([pad_to_max(mel, max_length) for mel in mel_spectrograms]).to(device)

    audio_embeddings = mel_spectrograms.mean(dim=-1)
    
    text_embeddings = text_model.encode(list(texts), convert_to_tensor=True).to(device)

    labels = torch.stack(labels).to(device)
    return audio_embeddings, text_embeddings, labels

In [7]:
class WeakLearners(nn.Module):
    def __init__(self, audio_dim, text_dim, device="cuda"):
        super(WeakLearners, self).__init__()
        self.audio_dim = audio_dim
        self.text_dim = text_dim
        self.device = device

        self.ridge_regressor = Ridge(alpha=1.0)
        self.svr = SVR()
        self.dtr = DecisionTreeRegressor()

        self.fitted = False

    def fit(self, train_loader):
        """ Train weak learners using train dataset embeddings """
        print("Fitting weak learners...")

        all_audio_emb, all_text_emb, all_labels = [], [], []

        for audio_emb, text_emb, labels in tqdm(train_loader, desc="Processing embeddings", unit="batch"):
            all_audio_emb.append(audio_emb.cpu().detach().numpy())
            all_text_emb.append(text_emb.cpu().detach().numpy())
            all_labels.append(labels.cpu().detach().numpy())

        if not all_audio_emb or not all_text_emb or not all_labels:
            raise RuntimeError("No embeddings found in the dataset! Check if the train_loader is correctly loading data.")

        all_audio_emb = np.vstack(all_audio_emb)
        all_text_emb = np.vstack(all_text_emb)
        all_labels = np.hstack(all_labels)

        combined_embeddings = np.hstack((all_audio_emb, all_text_emb))

        print("Training weak learners...")
        for model, name in zip([self.ridge_regressor, self.svr, self.dtr], 
                               ["Ridge Regression", "SVR", "Decision Tree"]):
            with tqdm(total=1, desc=f"Training {name}", unit="step") as pbar:
                model.fit(combined_embeddings, all_labels)
                pbar.update(1)

        self.fitted = True
        print("Weak learners training completed.")

    def forward(self, audio_emb, text_emb):
        if not self.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")

        combined_embeddings = torch.cat([audio_emb, text_emb], dim=1).cpu().detach().numpy()

        with torch.no_grad():
            ridge_pred = self.ridge_regressor.predict(combined_embeddings)
            svr_pred = self.svr.predict(combined_embeddings)
            dtr_pred = self.dtr.predict(combined_embeddings)

        ridge_pred = torch.from_numpy(ridge_pred).float().to(self.device)
        svr_pred = torch.from_numpy(svr_pred).float().to(self.device)
        dtr_pred = torch.from_numpy(dtr_pred).float().to(self.device)

        return ridge_pred, svr_pred, dtr_pred


In [8]:
# --- Stacking Model (Meta-Learner) ---
class StackingMetaLearner(nn.Module):
    def __init__(self, weak_output_dim=3, hidden_dim=256):
        super(StackingMetaLearner, self).__init__()
        self.fc1 = nn.Linear(weak_output_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, weak_outputs):
        x = F.relu(self.fc1(weak_outputs))
        return self.fc2(x)

In [9]:
# --- Main Model ---
class SSLEnsembleModel(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim=256, weak_learners=None):
        super(SSLEnsembleModel, self).__init__()
        if weak_learners is None:
            raise ValueError("Weak learners must be provided and fitted before initializing SSLEnsembleModel.")
        
        self.weak_learners = weak_learners
        self.stacking_meta_learner = StackingMetaLearner(weak_output_dim=3, hidden_dim=hidden_dim)

    def forward(self, audio_emb, text_emb):
        if not self.weak_learners.fitted:
            raise RuntimeError("Weak learners have not been fitted. Call 'fit()' before using the model.")
        
        ridge_pred, svr_pred, dtr_pred = self.weak_learners(audio_emb, text_emb)

        weak_outputs = torch.stack([ridge_pred, svr_pred, dtr_pred], dim=1)

        final_output = self.stacking_meta_learner(weak_outputs)
        return final_output


In [10]:
import numpy as np
import torch
from sklearn.metrics import mean_squared_error
from scipy.stats import spearmanr, kendalltau

def evaluate(model, test_loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        test_pbar = tqdm(test_loader, desc="Evaluation", leave=False)
        for audio_emb, text_emb, labels in test_pbar:
            audio_emb, text_emb, labels = audio_emb.to(device), text_emb.to(device), labels.to(device)
            outputs = model(audio_emb, text_emb)
            preds = outputs.squeeze()
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            test_pbar.set_postfix({"predicted": preds[:5].cpu().numpy(), "ground_truth": labels[:5].cpu().numpy()})

    # Convert lists to numpy arrays for easier calculation
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Accuracy (up to ±0.5)
    accuracy = np.mean(np.abs(all_preds - all_labels) <= 0.5)

    # MSE and RMSE
    mse = mean_squared_error(all_labels, all_preds)
    rmse = np.sqrt(mse)

    # LCC (Linear Correlation Coefficient)
    lcc = np.corrcoef(all_labels, all_preds)[0, 1]

    # KTAU (Kendall's Tau)
    k_tau, _ = kendalltau(all_labels, all_preds)

    # Print metrics
    print(f"Accuracy (±0.5): {accuracy*100:.2f}%")
    print(f"MSE: {mse:.4f}")
    print(f"RMSE: {rmse:.4f}")
    print(f"LCC: {lcc:.4f}")
    print(f"KTAU: {k_tau:.4f}")

    # Show 30 examples of predicted and ground truth MOS
    print("\n5 Examples of Predicted and Ground Truth MOS:")
    for i in range(5):
        print(f"Pred: {all_preds[i]:.2f}, GT: {all_labels[i]:.2f}")



In [11]:
# --- Main Training Loop ---
def main():
    train_json = "data/somos/audios/train_new.json"
    test_json = "data/somos/audios/test_new.json"

    train_dataset = SOMOSDataset(train_json)
    test_dataset = SOMOSDataset(test_json)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

    dummy_audio, dummy_text, _ = next(iter(train_loader))
    audio_dim, text_dim = dummy_audio.shape[1], dummy_text.shape[1]
    
    weak_learners = WeakLearners(audio_dim, text_dim).to(device)
    weak_learners.fit(train_loader)
    
    model = SSLEnsembleModel(audio_dim, text_dim, hidden_dim=256, weak_learners=weak_learners).to(device)

    wandb.watch(model, log="all", log_freq=100)

    optimizer = optim.Adam(model.parameters(), lr=1e-6)

    criterion = nn.MSELoss()
    
    num_epochs = 20
    best_mse = float('inf')

    for epoch in range(num_epochs):
        model.train()
        running_loss, total_samples = 0.0, 0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)
        for audio_emb, text_emb, labels in train_pbar:
            optimizer.zero_grad()

            outputs = model(audio_emb, text_emb)
            loss = criterion(outputs.squeeze(), labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * audio_emb.size(0)
            total_samples += labels.size(0)
            wandb.log({"train_loss": loss.item()})
            train_pbar.set_postfix(loss=loss.item())

        train_mse = running_loss / total_samples
        wandb.log({"train_mse": train_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Train MSE: {train_mse:.4f}")

        model.eval()
        test_loss, total_samples = 0.0, 0
        with torch.no_grad():
            test_pbar = tqdm(test_loader, desc=f"Epoch {epoch+1} Validation", leave=False)
            for audio_emb, text_emb, labels in test_pbar:
                outputs = model(audio_emb, text_emb)
                loss = criterion(outputs.squeeze(), labels)
                test_loss += loss.item() * audio_emb.size(0)
                total_samples += labels.size(0)
                test_pbar.set_postfix(loss=loss.item())

        test_mse = test_loss / total_samples
        wandb.log({"val_mse": test_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Val MSE: {test_mse:.4f}")

        # Evaluate on validation set
        evaluate(model, test_loader, device)  # Call the evaluation function after each epoch

        if test_mse < best_mse:
            best_mse = test_mse
            torch.save(model.state_dict(), "best_model.pth")

    print("Training complete! Best validation MSE:", best_mse)


In [None]:
main()

Fitting weak learners...


Processing embeddings: 100%|████████████████████████████████████████████████████| 3525/3525 [00:46<00:00, 76.60batch/s]


Training weak learners...


Training Ridge Regression: 100%|███████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.20step/s]
Training SVR: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:37<00:00, 37.07s/step]
Training Decision Tree: 100%|██████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.82s/step]


Weak learners training completed.


                                                                                                                       

Epoch 1/20 - Train MSE: 6.7148


                                                                                                                       

Epoch 1/20 - Val MSE: 4.8345


                                                                                                                       

Accuracy (±0.5): 0.33%
MSE: 4.8345
RMSE: 2.1988
LCC: 0.0845
KTAU: 0.0536

5 Examples of Predicted and Ground Truth MOS:
Pred: 1.45, GT: 4.00
Pred: 1.35, GT: 4.00
Pred: 1.01, GT: 3.73
Pred: 1.03, GT: 3.40
Pred: 0.88, GT: 3.00


                                                                                                                       

Epoch 2/20 - Train MSE: 3.3051


                                                                                                                       

Epoch 2/20 - Val MSE: 2.1716


                                                                                                                       

Accuracy (±0.5): 7.60%
MSE: 2.1716
RMSE: 1.4736
LCC: 0.1186
KTAU: 0.0753

5 Examples of Predicted and Ground Truth MOS:
Pred: 2.35, GT: 4.00
Pred: 2.21, GT: 4.00
Pred: 1.75, GT: 3.73
Pred: 1.81, GT: 3.40
Pred: 1.58, GT: 3.00


                                                                                                                       

Epoch 3/20 - Train MSE: 1.2614


                                                                                                                       

Epoch 3/20 - Val MSE: 0.7753


                                                                                                                       

Accuracy (±0.5): 34.63%
MSE: 0.7753
RMSE: 0.8805
LCC: 0.1361
KTAU: 0.0864

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.16, GT: 4.00
Pred: 3.00, GT: 4.00
Pred: 2.43, GT: 3.73
Pred: 2.51, GT: 3.40
Pred: 2.22, GT: 3.00


                                                                                                                       

Epoch 4/20 - Train MSE: 0.3971


                                                                                                                       

Epoch 4/20 - Val MSE: 0.3739


                                                                                                                       

Accuracy (±0.5): 55.63%
MSE: 0.3739
RMSE: 0.6115
LCC: 0.1450
KTAU: 0.0920

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.75, GT: 4.00
Pred: 3.57, GT: 4.00
Pred: 2.91, GT: 3.73
Pred: 3.02, GT: 3.40
Pred: 2.69, GT: 3.00


                                                                                                                       

Epoch 5/20 - Train MSE: 0.2395


                                                                                                                       

Epoch 5/20 - Val MSE: 0.3529


                                                                                                                       

Accuracy (±0.5): 58.73%
MSE: 0.3529
RMSE: 0.5940
LCC: 0.1476
KTAU: 0.0938

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.90, GT: 4.00
Pred: 3.72, GT: 4.00
Pred: 3.05, GT: 3.73
Pred: 3.16, GT: 3.40
Pred: 2.81, GT: 3.00


                                                                                                                       

Epoch 6/20 - Train MSE: 0.2400


                                                                                                                       

Epoch 6/20 - Val MSE: 0.3519


                                                                                                                       

Accuracy (±0.5): 58.80%
MSE: 0.3519
RMSE: 0.5932
LCC: 0.1486
KTAU: 0.0944

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.89, GT: 4.00
Pred: 3.71, GT: 4.00
Pred: 3.05, GT: 3.73
Pred: 3.16, GT: 3.40
Pred: 2.81, GT: 3.00


                                                                                                                       

Epoch 7/20 - Train MSE: 0.2367


                                                                                                                       

Epoch 7/20 - Val MSE: 0.3509


                                                                                                                       

Accuracy (±0.5): 58.83%
MSE: 0.3509
RMSE: 0.5924
LCC: 0.1495
KTAU: 0.0950

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.89, GT: 4.00
Pred: 3.72, GT: 4.00
Pred: 3.05, GT: 3.73
Pred: 3.17, GT: 3.40
Pred: 2.82, GT: 3.00


                                                                                                                       

Epoch 8/20 - Train MSE: 0.2416


                                                                                                                       

Epoch 8/20 - Val MSE: 0.3499


                                                                                                                       

Accuracy (±0.5): 58.73%
MSE: 0.3499
RMSE: 0.5915
LCC: 0.1505
KTAU: 0.0957

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.88, GT: 4.00
Pred: 3.71, GT: 4.00
Pred: 3.05, GT: 3.73
Pred: 3.16, GT: 3.40
Pred: 2.82, GT: 3.00


                                                                                                                       

Epoch 9/20 - Train MSE: 0.2361


                                                                                                                       

Epoch 9/20 - Val MSE: 0.3491


                                                                                                                       

Accuracy (±0.5): 59.13%
MSE: 0.3491
RMSE: 0.5908
LCC: 0.1514
KTAU: 0.0962

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.89, GT: 4.00
Pred: 3.72, GT: 4.00
Pred: 3.06, GT: 3.73
Pred: 3.18, GT: 3.40
Pred: 2.83, GT: 3.00


                                                                                                                       

Epoch 10/20 - Train MSE: 0.2342


                                                                                                                       

Epoch 10/20 - Val MSE: 0.3487


                                                                                                                       

Accuracy (±0.5): 59.03%
MSE: 0.3487
RMSE: 0.5905
LCC: 0.1518
KTAU: 0.0965

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.88, GT: 4.00
Pred: 3.71, GT: 4.00
Pred: 3.06, GT: 3.73
Pred: 3.17, GT: 3.40
Pred: 2.83, GT: 3.00


                                                                                                                       

Epoch 11/20 - Train MSE: 0.2409


                                                                                                                       

Epoch 11/20 - Val MSE: 0.3478


                                                                                                                       

Accuracy (±0.5): 59.17%
MSE: 0.3478
RMSE: 0.5898
LCC: 0.1528
KTAU: 0.0972

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.88, GT: 4.00
Pred: 3.71, GT: 4.00
Pred: 3.06, GT: 3.73
Pred: 3.18, GT: 3.40
Pred: 2.83, GT: 3.00


                                                                                                                       

Epoch 12/20 - Train MSE: 0.2385


                                                                                                                       

Epoch 12/20 - Val MSE: 0.3472


                                                                                                                       

Accuracy (±0.5): 59.27%
MSE: 0.3472
RMSE: 0.5892
LCC: 0.1535
KTAU: 0.0977

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.88, GT: 4.00
Pred: 3.71, GT: 4.00
Pred: 3.06, GT: 3.73
Pred: 3.18, GT: 3.40
Pred: 2.83, GT: 3.00


                                                                                                                       

Epoch 13/20 - Train MSE: 0.2409


                                                                                                                       

Epoch 13/20 - Val MSE: 0.3463


                                                                                                                       

Accuracy (±0.5): 59.33%
MSE: 0.3463
RMSE: 0.5885
LCC: 0.1544
KTAU: 0.0982

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.87, GT: 4.00
Pred: 3.70, GT: 4.00
Pred: 3.06, GT: 3.73
Pred: 3.18, GT: 3.40
Pred: 2.83, GT: 3.00


                                                                                                                       

Epoch 14/20 - Train MSE: 0.2357


                                                                                                                       

Epoch 14/20 - Val MSE: 0.3460


                                                                                                                       

Accuracy (±0.5): 59.23%
MSE: 0.3460
RMSE: 0.5882
LCC: 0.1549
KTAU: 0.0985

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.86, GT: 4.00
Pred: 3.70, GT: 4.00
Pred: 3.06, GT: 3.73
Pred: 3.18, GT: 3.40
Pred: 2.83, GT: 3.00


                                                                                                                       

Epoch 15/20 - Train MSE: 0.2379


                                                                                                                       

Epoch 15/20 - Val MSE: 0.3455


                                                                                                                       

Accuracy (±0.5): 59.27%
MSE: 0.3455
RMSE: 0.5878
LCC: 0.1554
KTAU: 0.0989

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.86, GT: 4.00
Pred: 3.69, GT: 4.00
Pred: 3.06, GT: 3.73
Pred: 3.18, GT: 3.40
Pred: 2.83, GT: 3.00


                                                                                                                       

Epoch 16/20 - Train MSE: 0.2337


                                                                                                                       

Epoch 16/20 - Val MSE: 0.3453


                                                                                                                       

Accuracy (±0.5): 59.40%
MSE: 0.3453
RMSE: 0.5876
LCC: 0.1557
KTAU: 0.0991

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.86, GT: 4.00
Pred: 3.70, GT: 4.00
Pred: 3.06, GT: 3.73
Pred: 3.18, GT: 3.40
Pred: 2.84, GT: 3.00


                                                                                                                       

Epoch 17/20 - Train MSE: 0.2356


                                                                                                                       

Epoch 17/20 - Val MSE: 0.3450


                                                                                                                       

Accuracy (±0.5): 59.80%
MSE: 0.3450
RMSE: 0.5873
LCC: 0.1562
KTAU: 0.0994

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.87, GT: 4.00
Pred: 3.70, GT: 4.00
Pred: 3.07, GT: 3.73
Pred: 3.19, GT: 3.40
Pred: 2.84, GT: 3.00


                                                                                                                       

Epoch 18/20 - Train MSE: 0.2376


                                                                                                                       

Epoch 18/20 - Val MSE: 0.3445


                                                                                                                       

Accuracy (±0.5): 59.53%
MSE: 0.3445
RMSE: 0.5870
LCC: 0.1567
KTAU: 0.0997

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.86, GT: 4.00
Pred: 3.69, GT: 4.00
Pred: 3.06, GT: 3.73
Pred: 3.19, GT: 3.40
Pred: 2.84, GT: 3.00


                                                                                                                       

Epoch 19/20 - Train MSE: 0.2343


                                                                                                                       

Epoch 19/20 - Val MSE: 0.3443


                                                                                                                       

Accuracy (±0.5): 59.60%
MSE: 0.3443
RMSE: 0.5868
LCC: 0.1571
KTAU: 0.0999

5 Examples of Predicted and Ground Truth MOS:
Pred: 3.86, GT: 4.00
Pred: 3.69, GT: 4.00
Pred: 3.06, GT: 3.73
Pred: 3.19, GT: 3.40
Pred: 2.84, GT: 3.00


                                                                                                                       

Epoch 20/20 - Train MSE: 0.2371


                                                                                                                       

Epoch 20/20 - Val MSE: 0.3440


Evaluation:  94%|▉| 707/750 [00:26<00:01, 28.23it/s, predicted=[3.3416011 3.4347908 3.4427114 2.822913 ], ground_truth=