In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import whisper
from transformers import BertTokenizer, BertModel
import numpy as np
from tqdm import tqdm
import wandb
from collections import Counter


In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# --- Load and Unfreeze Whisper‑medium ---
whisper_model = whisper.load_model("base.en").to(device)
# Unfreeze all layers in Whisper
for param in whisper_model.parameters():
    param.requires_grad = True
whisper_model.train()  # Set to train mode so gradients are computed

# Load BERT tokenizer and model (BERT remains frozen here).
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device).eval()

# --- Initialize wandb ---
wandb.init(project="somos-ensemble-cont", name="finetune-whisper-ensemble")
!wandb online



wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: rtfiof (rtfiof-hse-university). Use `wandb login --relogin` to force relogin


W&B online. Running your script from this directory will now sync to the cloud.


Exception in thread IntMsgThr:
Traceback (most recent call last):
  File "D:\ProgramData\anaconda\envs\project\Lib\threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "D:\ProgramData\anaconda\envs\project\Lib\site-packages\ipykernel\ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "D:\ProgramData\anaconda\envs\project\Lib\threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "D:\ProgramData\anaconda\envs\project\Lib\site-packages\wandb\sdk\wandb_run.py", line 325, in check_internal_messages
    self._loop_check_status(
  File "D:\ProgramData\anaconda\envs\project\Lib\site-packages\wandb\sdk\wandb_run.py", line 235, in _loop_check_status
    local_handle = request()
                   ^^^^^^^^^
  File "D:\ProgramData\anaconda\envs\project\Lib\site-packages\wandb\sdk\interface\interface.py", line 914, in deliver_internal_messages
    return self._deliver_internal_messages(internal_message)
           ^^^^^^^^^^^

In [3]:
device


device(type='cuda', index=1)

In [4]:
# --- Utility Functions ---

# Function to compute class weights
# Compute class weights for imbalanced dataset
def compute_class_weights(labels, num_bins=5):
    # Quantize the MOS labels into bins
    min_mos = min(labels)
    max_mos = max(labels)
    bin_width = (max_mos - min_mos) / num_bins
    bins = [min_mos + i * bin_width for i in range(num_bins + 1)]
    
    # Assign each MOS value to a bin
    bin_indices = [min(int((mos - min_mos) / bin_width), num_bins - 1) for mos in labels]
    
    # Initialize class counts for all bins (even if a bin has no samples)
    class_counts = {i: 0 for i in range(num_bins)}
    for bin_idx in bin_indices:
        class_counts[bin_idx] += 1

    total_samples = sum(class_counts.values())
    weights = {cls: total_samples / (num_bins * count) for cls, count in class_counts.items()}

    # Ensure all bins are represented in the weights
    return torch.tensor([weights.get(i, 0) for i in range(num_bins)], dtype=torch.float).to(device)



# Function to compute sample weights for oversampling
def get_sample_weights(dataset, class_weights):
    sample_weights = []

    for _, _, label in dataset:
        sample_weights.append(class_weights[label].item())

    return torch.tensor(sample_weights, dtype=torch.float)

def load_json(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        return json.load(f)

def process_audio_path(clean_path, base_dir="data/somos/audios"):
    return os.path.join(base_dir, clean_path.replace("\\", "/"))

# Earth Mover’s Distance (EMD) Loss for ordinal MOS prediction.
def emd_loss(y_pred, y_true, num_classes):
    y_pred = F.softmax(y_pred, dim=-1)  # Convert logits to probability distribution
    y_true = F.one_hot(y_true, num_classes).float()  # Convert labels to one-hot

    cdf_pred = torch.cumsum(y_pred, dim=-1)  # Compute cumulative sum for predicted distribution
    cdf_true = torch.cumsum(y_true, dim=-1)  # Compute cumulative sum for true distribution

    loss = torch.mean((cdf_pred - cdf_true) ** 2)  # Use squared difference for smoother gradients
    return loss

def entropy_regularization(gate_weights, lambda_reg=0.01):
    # Compute entropy loss to encourage diverse gating weights
    eps = 1e-8
    entropy = -torch.sum(gate_weights * torch.log(gate_weights + eps), dim=1)
    return lambda_reg * torch.mean(entropy)

def save_model(model, epoch, best_acc, save_path="models"):
    os.makedirs(save_path, exist_ok=True)
    model_path = os.path.join(save_path, f"model_epoch_{epoch}.pth")
    torch.save(model.state_dict(), model_path)
    best_model_path = os.path.join(save_path, "best_model.pth")
    if best_acc:
        torch.save(model.state_dict(), best_model_path)

In [5]:
# --- Dataset Class ---
class SOMOSDataset(Dataset):
    def __init__(self, json_file, base_dir="data/somos/audios"):
        self.samples = load_json(json_file)
        self.base_dir = base_dir
        # Store MOS as float for continuous values (not just integer)
        self.labels = [float(sample["mos"]) for sample in self.samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        text = sample["text"]
        label = torch.tensor(float(sample["mos"]), dtype=torch.float)
        audio_path = process_audio_path(sample["clean path"], self.base_dir)
        return audio_path, text, label

def collate_fn(batch):
    audio_paths, texts, labels = zip(*batch)
    audios = [whisper.load_audio(path) for path in audio_paths]
    audios = [whisper.pad_or_trim(audio) for audio in audios]
    mel_spectrograms = [whisper.log_mel_spectrogram(audio).to(device) for audio in audios]
    mel_spectrograms = torch.stack(mel_spectrograms)

    # Compute audio embeddings with gradients enabled
    audio_embeddings = whisper_model.encoder(mel_spectrograms).mean(dim=1)

    # Process texts using BERT (BERT remains frozen)
    inputs = tokenizer(list(texts), return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        text_embeddings = bert_model(**inputs).last_hidden_state[:, 0, :]

    labels = torch.stack(labels).to(device)
    return audio_embeddings, text_embeddings, labels

In [6]:
class ComplexFusionSubModel(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim, dropout_rate=0.05):
        super(ComplexFusionSubModel, self).__init__()
        self.audio_fc = nn.Sequential(
            nn.Linear(audio_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
        )
        self.text_fc = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
        )
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.Tanh(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Softmax(dim=1)
        )
        # Output layer changed to produce a single value for continuous prediction
        self.fusion_fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, 1)  # Output a single value (regression)
        )

    def forward(self, audio_emb, text_emb):
        audio_feat = self.audio_fc(audio_emb)
        text_feat = self.text_fc(text_emb)
        fusion = torch.cat([audio_feat, text_feat], dim=1)
        attn_weights = self.attention(fusion)
        fusion = fusion * attn_weights
        return self.fusion_fc(fusion)  # Output a single value for regression


class EnsembleFusionClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim, dropout_rate=0.05, num_models=3):
        super(EnsembleFusionClassifier, self).__init__()
        self.num_models = num_models
        self.sub_models = nn.ModuleList([
            ComplexFusionSubModel(audio_dim, text_dim, hidden_dim, dropout_rate)
            for _ in range(num_models)
        ])
        # Gate mechanism for selecting the weighted output from each model
        self.gate = nn.Sequential(
            nn.Linear(audio_dim + text_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_models),
            nn.Softmax(dim=1)
        )
        # Residual connection (optional)
        self.residual = nn.Sequential(
            nn.Linear(1, 1),  # Output a single value for regression
            nn.BatchNorm1d(1),
            nn.ReLU()
        )

    def forward(self, audio_emb, text_emb, return_gate=False):
        gate_input = torch.cat([audio_emb, text_emb], dim=1)
        gate_weights = self.gate(gate_input)  # (batch_size, num_models)
        outputs = [model(audio_emb, text_emb) for model in self.sub_models]
        outputs = torch.stack(outputs, dim=1)  # (batch_size, num_models, 1)
        gate_weights_unsq = gate_weights.unsqueeze(2)  # (batch_size, num_models, 1)
        ensemble_output = (gate_weights_unsq * outputs).sum(dim=1)  # Weighted sum of model outputs
        final_output = ensemble_output + self.residual(ensemble_output)  # Apply residual
        if return_gate:
            return final_output, gate_weights
        return final_output


In [7]:
def main():
    train_json = "data/somos/audios/train_new.json"
    test_json = "data/somos/audios/test_new.json"

    # Load datasets
    train_dataset = SOMOSDataset(train_json)
    test_dataset = SOMOSDataset(test_json)

    # Compute sample weights based on MOS values
    class_weights = compute_class_weights(train_dataset.labels, num_bins=5)
    sample_weights = [class_weights[min(int(mos), 4)].item() for mos in train_dataset.labels]  # Quantizing MOS values
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

    # Create DataLoader with sampling
    train_loader = DataLoader(train_dataset, batch_size=4, sampler=sampler, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

    # Check dimensions of a sample batch to set input dimensions for the model
    dummy_audio, dummy_text, _ = next(iter(train_loader))
    audio_dim, text_dim = dummy_audio.shape[1], dummy_text.shape[1]

    # Define the number of classes (regression task, not used directly in the model)
    num_classes = 1  # For continuous MOS prediction, there is only one output value per sample

    # Instantiate the ensemble classifier model
    model = EnsembleFusionClassifier(audio_dim, text_dim, hidden_dim=256, dropout_rate=0.05, num_models=3).to(device)

    # Watch the model with WandB for logging gradients and parameters
    wandb.watch(model, log="all", log_freq=100)

    # Set up gradient scaler for mixed precision training
    scaler = torch.cuda.amp.GradScaler()

    # Define MSE loss for continuous MOS prediction
    criterion = nn.MSELoss()
    
    # Set up optimizer (Adam with a very small learning rate)
    optimizer = optim.Adam(model.parameters(), lr=1e-6)

    num_epochs = 100
    best_mse = float('inf')

    for epoch in range(num_epochs):
        model.train()
        running_loss, total_samples = 0.0, 0

        # Training loop with progress bar
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)
        for audio_emb, text_emb, labels in train_pbar:
            optimizer.zero_grad()

            # Perform forward pass with mixed precision
            with torch.cuda.amp.autocast():
                outputs, gate_weights = model(audio_emb, text_emb, return_gate=True)

                # Ensure the output shape matches that of labels (i.e., batch_size, 1)
                outputs = outputs.squeeze()  # Remove unnecessary dimensions
                labels = labels.squeeze()    # Ensure labels are in the correct shape

                # Check if shapes match before computing loss
                assert outputs.shape == labels.shape, f"Shape mismatch: {outputs.shape} vs {labels.shape}"

                # Calculate loss using MSE
                loss = criterion(outputs, labels)

            # Backward pass with gradient scaling for mixed precision
            scaler.scale(loss).backward()

            # Clip gradients to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            # Accumulate loss for logging
            running_loss += loss.item() * audio_emb.size(0)
            total_samples += labels.size(0)

            # Log the training loss with WandB
            wandb.log({
                "train_loss": loss.item(),
            })
            train_pbar.set_postfix(loss=loss.item())

        # Calculate and log training MSE
        train_mse = running_loss / total_samples
        wandb.log({"train_mse": train_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Train MSE: {train_mse:.4f}")

        # Evaluation phase
        model.eval()
        test_loss, total_samples = 0.0, 0
        test_predictions = []

        with torch.no_grad():
            # Testing loop with progress bar
            test_pbar = tqdm(test_loader, desc=f"Epoch {epoch+1} Validation", leave=False)
            for audio_emb, text_emb, labels in test_pbar:
                audio_emb = audio_emb.to(device)
                text_emb = text_emb.to(device)
                labels = labels.to(device)

                # Get model predictions
                outputs = model(audio_emb, text_emb)

                # Ensure output and labels are in compatible shapes
                outputs = outputs.squeeze()
                labels = labels.squeeze()

                # Calculate test loss
                loss = criterion(outputs, labels)

                test_loss += loss.item() * audio_emb.size(0)
                total_samples += labels.size(0)
                test_predictions.extend(zip(labels.cpu().tolist(), outputs.cpu().tolist()))
                test_pbar.set_postfix(loss=loss.item())

        # Calculate and log test MSE
        test_mse = test_loss / total_samples
        wandb.log({"val_mse": test_mse})
        print(f"Epoch {epoch+1}/{num_epochs} - Val MSE: {test_mse:.4f}")

        # Log some sample predictions
        print("\nSample Predictions (Real MOS vs Predicted MOS):")
        for i, (real_mos, pred_mos) in enumerate(test_predictions[:5]):
            print(f"Example {i+1}: Real MOS = {real_mos}, Predicted MOS = {pred_mos}")
            wandb.log({f"sample_{i}_real_vs_pred": f"{real_mos} vs {pred_mos}"})
        
        # Save the model if the validation MSE improves
        save_model(model, epoch + 1, test_mse < best_mse)

        if test_mse < best_mse:
            best_mse = test_mse

    print("Training complete! Best validation MSE:", best_mse)



In [8]:
main()


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
                                                                                                                       

Epoch 1/100 - Train MSE: 8.7641


                                                                                                                       

Epoch 1/100 - Val MSE: 4.6508

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 1.1607378721237183
Example 2: Real MOS = 4.0, Predicted MOS = 0.812701940536499
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 1.5484755039215088
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 1.0674092769622803
Example 5: Real MOS = 3.0, Predicted MOS = 1.4412553310394287


                                                                                                                       

Epoch 2/100 - Train MSE: 6.5422


                                                                                                                       

Epoch 2/100 - Val MSE: 3.5313

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 1.498292088508606
Example 2: Real MOS = 4.0, Predicted MOS = 1.0315018892288208
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 1.9982125759124756
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 1.2436354160308838
Example 5: Real MOS = 3.0, Predicted MOS = 1.9349894523620605


                                                                                                                       

Epoch 3/100 - Train MSE: 5.8053


                                                                                                                       

Epoch 3/100 - Val MSE: 3.0452

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 1.2202459573745728
Example 2: Real MOS = 4.0, Predicted MOS = 0.9924641251564026
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 2.1842806339263916
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 0.9821459650993347
Example 5: Real MOS = 3.0, Predicted MOS = 1.9845068454742432


                                                                                                                       

Epoch 4/100 - Train MSE: 5.0929


                                                                                                                       

Epoch 4/100 - Val MSE: 2.4831

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 1.469470739364624
Example 2: Real MOS = 4.0, Predicted MOS = 1.3966904878616333
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 2.4737887382507324
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 2.0226187705993652
Example 5: Real MOS = 3.0, Predicted MOS = 2.45470929145813


                                                                                                                       

Epoch 5/100 - Train MSE: 4.4326


                                                                                                                       

Epoch 5/100 - Val MSE: 1.6155

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 1.704763412475586
Example 2: Real MOS = 4.0, Predicted MOS = 1.3491932153701782
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 3.065774440765381
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 2.154963254928589
Example 5: Real MOS = 3.0, Predicted MOS = 2.7340755462646484


                                                                                                                       

Epoch 6/100 - Train MSE: 3.8446


                                                                                                                       

Epoch 6/100 - Val MSE: 1.5237

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 2.044430732727051
Example 2: Real MOS = 4.0, Predicted MOS = 1.816087245941162
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 2.8972599506378174
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 1.826528549194336
Example 5: Real MOS = 3.0, Predicted MOS = 2.3948299884796143


                                                                                                                       

Epoch 7/100 - Train MSE: 3.3145


                                                                                                                       

Epoch 7/100 - Val MSE: 0.9354

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 2.5917716026306152
Example 2: Real MOS = 4.0, Predicted MOS = 1.9791901111602783
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 3.836705207824707
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 2.5647761821746826
Example 5: Real MOS = 3.0, Predicted MOS = 2.9043021202087402


                                                                                                                       

Epoch 8/100 - Train MSE: 2.8326


                                                                                                                       

Epoch 8/100 - Val MSE: 0.9942

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 2.3691368103027344
Example 2: Real MOS = 4.0, Predicted MOS = 2.1023480892181396
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 3.1004490852355957
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 2.516868829727173
Example 5: Real MOS = 3.0, Predicted MOS = 2.7788374423980713


                                                                                                                       

Epoch 9/100 - Train MSE: 2.3719


                                                                                                                       

Epoch 9/100 - Val MSE: 0.8433

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 2.241647720336914
Example 2: Real MOS = 4.0, Predicted MOS = 2.100292682647705
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 3.3095879554748535
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 2.4926095008850098
Example 5: Real MOS = 3.0, Predicted MOS = 2.8073229789733887


                                                                                                                       

Epoch 10/100 - Train MSE: 1.9902


                                                                                                                       

Epoch 10/100 - Val MSE: 0.6261

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 2.3074164390563965
Example 2: Real MOS = 4.0, Predicted MOS = 2.192262887954712
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 3.962413787841797
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 2.6482346057891846
Example 5: Real MOS = 3.0, Predicted MOS = 3.1650524139404297


                                                                                                                       

Epoch 11/100 - Train MSE: 1.6175


                                                                                                                       

Epoch 11/100 - Val MSE: 0.5068

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 2.938669204711914
Example 2: Real MOS = 4.0, Predicted MOS = 2.7256152629852295
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 3.762974739074707
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 3.1050310134887695
Example 5: Real MOS = 3.0, Predicted MOS = 3.3691601753234863


                                                                                                                       

Epoch 12/100 - Train MSE: 1.3780


                                                                                                                       

Epoch 12/100 - Val MSE: 0.4714

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 3.5616965293884277
Example 2: Real MOS = 4.0, Predicted MOS = 3.119319200515747
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 4.048720359802246
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 3.1989569664001465
Example 5: Real MOS = 3.0, Predicted MOS = 3.560316562652588


                                                                                                                       

Epoch 13/100 - Train MSE: 1.1518


                                                                                                                       

Epoch 13/100 - Val MSE: 0.4712

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 2.660867929458618
Example 2: Real MOS = 4.0, Predicted MOS = 2.4787795543670654
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 4.129580497741699
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 3.285508632659912
Example 5: Real MOS = 3.0, Predicted MOS = 3.584831714630127


                                                                                                                       

Epoch 14/100 - Train MSE: 0.9760


                                                                                                                       

Epoch 14/100 - Val MSE: 1.0157

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 3.4898438453674316
Example 2: Real MOS = 4.0, Predicted MOS = 2.904170513153076
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 4.8024468421936035
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 3.456948757171631
Example 5: Real MOS = 3.0, Predicted MOS = 4.263199806213379


                                                                                                                       

Epoch 15/100 - Train MSE: 0.8184


                                                                                                                       

Epoch 15/100 - Val MSE: 0.4823

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 3.295330047607422
Example 2: Real MOS = 4.0, Predicted MOS = 3.0825064182281494
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 3.9900503158569336
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 3.3038344383239746
Example 5: Real MOS = 3.0, Predicted MOS = 3.5901551246643066


                                                                                                                       

Epoch 16/100 - Train MSE: 0.7130


                                                                                                                       

Epoch 16/100 - Val MSE: 0.8254

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 3.533033847808838
Example 2: Real MOS = 4.0, Predicted MOS = 3.09549880027771
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 4.650455474853516
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 3.6377463340759277
Example 5: Real MOS = 3.0, Predicted MOS = 4.042354106903076


                                                                                                                       

Epoch 17/100 - Train MSE: 0.6533


                                                                                                                       

Epoch 17/100 - Val MSE: 0.7455

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 3.6503710746765137
Example 2: Real MOS = 4.0, Predicted MOS = 3.3007984161376953
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 4.227035045623779
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 4.156992435455322
Example 5: Real MOS = 3.0, Predicted MOS = 4.3412041664123535


                                                                                                                       

Epoch 18/100 - Train MSE: 0.6363


                                                                                                                       

Epoch 18/100 - Val MSE: 0.5508

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 3.5612688064575195
Example 2: Real MOS = 4.0, Predicted MOS = 3.291835069656372
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 4.135792255401611
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 3.5392203330993652
Example 5: Real MOS = 3.0, Predicted MOS = 3.681854248046875


                                                                                                                       

Epoch 19/100 - Train MSE: 0.6009


                                                                                                                       

Epoch 19/100 - Val MSE: 0.9191

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 3.679015636444092
Example 2: Real MOS = 4.0, Predicted MOS = 3.241779327392578
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 4.674153804779053
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 4.13768196105957
Example 5: Real MOS = 3.0, Predicted MOS = 4.338153839111328


                                                                                                                       

Epoch 20/100 - Train MSE: 0.5883


                                                                                                                       

Epoch 20/100 - Val MSE: 0.6792

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 3.593503475189209
Example 2: Real MOS = 4.0, Predicted MOS = 3.3273098468780518
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 4.278740882873535
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 3.652090549468994
Example 5: Real MOS = 3.0, Predicted MOS = 3.88718843460083


                                                                                                                       

Epoch 21/100 - Train MSE: 0.5849


                                                                                                                       

Epoch 21/100 - Val MSE: 0.4601

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4.0, Predicted MOS = 3.368781089782715
Example 2: Real MOS = 4.0, Predicted MOS = 3.249788999557495
Example 3: Real MOS = 3.7272727489471436, Predicted MOS = 3.8168792724609375
Example 4: Real MOS = 3.4000000953674316, Predicted MOS = 3.55024790763855
Example 5: Real MOS = 3.0, Predicted MOS = 3.6338706016540527


                                                                                                                       

KeyboardInterrupt: 