In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import whisper
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
import torch.nn.functional as F



In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
# whisper_model = whisper.load_model("base").to(device).eval()
whisper_model = whisper.load_model("medium").to(device).eval()
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device).eval()


In [4]:
def load_json(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        return json.load(f)

def process_audio_path(clean_path, base_dir="data/somos/audios"):
    return os.path.join(base_dir, clean_path.replace("\\", "/"))

# Dataset Class
class SOMOSDataset(Dataset):
    def __init__(self, json_file, base_dir="data/somos/audios"):
        self.samples = load_json(json_file)
        self.base_dir = base_dir

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        text = sample["text"]
        mos = int(float(sample["mos"]))
        label = torch.tensor(mos - 1, dtype=torch.long)

        audio_path = process_audio_path(sample["clean path"], self.base_dir)

        return audio_path, text, label

def collate_fn(batch):
    audio_paths, texts, labels = zip(*batch)

    audios = [whisper.load_audio(path) for path in audio_paths]
    audios = [whisper.pad_or_trim(audio) for audio in audios]
    mel_spectrograms = [whisper.log_mel_spectrogram(audio).to(device) for audio in audios]
    mel_spectrograms = torch.stack(mel_spectrograms)

    with torch.no_grad():
        audio_embeddings = whisper_model.encoder(mel_spectrograms).mean(dim=1)  # Batch audio embeddings

    inputs = tokenizer(list(texts), return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {key: val.to(device) for key, val in inputs.items()}

    with torch.no_grad():
        text_embeddings = bert_model(**inputs).last_hidden_state[:, 0, :]  # Batch text embeddings

    labels = torch.stack(labels).to(device)

    return audio_embeddings, text_embeddings, labels

class FusionClassifier(nn.Module):
    def __init__(self, audio_dim, text_dim, hidden_dim, num_classes, dropout_rate=0.05):
        super(FusionClassifier, self).__init__()

        self.audio_fc = nn.Sequential(
            nn.Linear(audio_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
        )

        self.text_fc = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
        )

        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.Tanh(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Softmax(dim=1)
        )

        self.fusion_fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, audio_emb, text_emb):
        audio_feat = self.audio_fc(audio_emb)
        text_feat = self.text_fc(text_emb)

        fusion = torch.cat([audio_feat, text_feat], dim=1)
        attn_weights = self.attention(fusion)
        fusion = fusion * attn_weights

        return self.fusion_fc(fusion)



def save_model(model, epoch, best_acc, save_path="models"):
    os.makedirs(save_path, exist_ok=True)
    model_path = os.path.join(save_path, f"model_epoch_{epoch}.pth")
    torch.save(model.state_dict(), model_path)

    best_model_path = os.path.join(save_path, "best_model.pth")
    if best_acc:
        torch.save(model.state_dict(), best_model_path)



In [5]:
from collections import Counter
import numpy as np

def compute_class_weights(dataset, num_classes=5):
    labels = [int(float(sample["mos"])) - 1 for sample in dataset.samples]  # Convert to 0-based index
    class_counts = Counter(labels)
    
    total_samples = len(labels)
    class_weights = {cls: total_samples / (num_classes * count) for cls, count in class_counts.items()}
    
    # Convert to a tensor
    weights = torch.tensor([class_weights[i] for i in range(num_classes)], dtype=torch.float).to(device)
    return weights






In [6]:
from torch.utils.data import WeightedRandomSampler

def create_weighted_sampler(dataset, num_classes=5):
    # Compute frequency of each label
    labels = [int(float(sample["mos"])) - 1 for sample in dataset.samples]
    counts = Counter(labels)
    total = len(labels)
    # Assign weights: lower for frequent classes, higher for rare ones
    sample_weights = [total / counts[label] for label in labels]
    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    return sampler



In [7]:
print("Full model architecture:")
for name, module in whisper_model.named_modules():
    print(name, module)

Full model architecture:
 Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 1024, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(1024, 1024, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-23): 24 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=1024, out_features=1024, bias=True)
          (key): Linear(in_features=1024, out_features=1024, bias=False)
          (value): Linear(in_features=1024, out_features=1024, bias=True)
          (out): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (attn_ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=4096, out_features=1024, bias=True)
        )
        (mlp_ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )


In [8]:
for param in whisper_model.encoder.parameters():
    param.requires_grad = False


In [9]:
encoder_blocks = whisper_model.encoder.blocks
total_blocks = len(encoder_blocks)
print(f"Total encoder blocks: {total_blocks}")


Total encoder blocks: 24


In [10]:
threshold = total_blocks // 2  # unfreeze blocks with index >= threshold
print(f"Unfreezing encoder blocks with index >= {threshold}")


Unfreezing encoder blocks with index >= 12


In [11]:
for idx, block in enumerate(encoder_blocks):
    if idx >= threshold:
        for name, param in block.named_parameters():
            param.requires_grad = True



In [12]:
if hasattr(whisper_model.encoder, "final_layer_norm"):
    for param in whisper_model.encoder.final_layer_norm.parameters():
        param.requires_grad = True


In [13]:
print("\nUnfrozen parameters in the Whisper encoder:")
for name, param in whisper_model.encoder.named_parameters():
    if param.requires_grad:
        print(name)


Unfrozen parameters in the Whisper encoder:
blocks.12.attn.query.weight
blocks.12.attn.query.bias
blocks.12.attn.key.weight
blocks.12.attn.value.weight
blocks.12.attn.value.bias
blocks.12.attn.out.weight
blocks.12.attn.out.bias
blocks.12.attn_ln.weight
blocks.12.attn_ln.bias
blocks.12.mlp.0.weight
blocks.12.mlp.0.bias
blocks.12.mlp.2.weight
blocks.12.mlp.2.bias
blocks.12.mlp_ln.weight
blocks.12.mlp_ln.bias
blocks.13.attn.query.weight
blocks.13.attn.query.bias
blocks.13.attn.key.weight
blocks.13.attn.value.weight
blocks.13.attn.value.bias
blocks.13.attn.out.weight
blocks.13.attn.out.bias
blocks.13.attn_ln.weight
blocks.13.attn_ln.bias
blocks.13.mlp.0.weight
blocks.13.mlp.0.bias
blocks.13.mlp.2.weight
blocks.13.mlp.2.bias
blocks.13.mlp_ln.weight
blocks.13.mlp_ln.bias
blocks.14.attn.query.weight
blocks.14.attn.query.bias
blocks.14.attn.key.weight
blocks.14.attn.value.weight
blocks.14.attn.value.bias
blocks.14.attn.out.weight
blocks.14.attn.out.bias
blocks.14.attn_ln.weight
blocks.14.attn

In [14]:
# New Loss

import torchsort

def emd_loss(y_pred, y_true, num_classes):
    y_pred = F.softmax(y_pred, dim=-1)
    y_true = F.one_hot(y_true, num_classes).float()

    cdf_pred = torch.cumsum(y_pred, dim=-1)
    cdf_true = torch.cumsum(y_true, dim=-1)

    return torch.mean(torch.abs(cdf_pred - cdf_true))


In [15]:
def main():
    train_json = "data/somos/audios/train.json"
    test_json = "data/somos/audios/test.json"

    train_dataset = SOMOSDataset(train_json)
    test_dataset = SOMOSDataset(test_json)

    # class_weights = compute_class_weights(train_dataset)
    # print(class_weights)

    train_sampler = create_weighted_sampler(train_dataset)
    train_loader = DataLoader(train_dataset, batch_size=4, sampler=train_sampler, collate_fn=collate_fn)
    
    # train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
    
    dummy_audio, dummy_text, _ = next(iter(train_loader))
    audio_dim, text_dim = dummy_audio.shape[1], dummy_text.shape[1]
    num_classes = 5

    model = FusionClassifier(audio_dim, text_dim, hidden_dim=256, num_classes=num_classes).to(device)
    
    scaler = torch.cuda.amp.GradScaler()

    # criterion = nn.CrossEntropyLoss(weight=class_weights)
    # criterion = nn.CrossEntropyLoss()
    criterion = lambda y_pred, y_true: emd_loss(y_pred, y_true, num_classes=5)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    num_epochs = 10
    best_acc = 0.0

    for epoch in range(num_epochs):
        model.train()
        running_loss, correct_preds, total_samples = 0.0, 0, 0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Training", leave=False)
        for audio_emb, text_emb, labels in train_pbar:
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                outputs = model(audio_emb, text_emb)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * audio_emb.size(0)
            preds = torch.argmax(outputs, dim=1)
            correct_preds += (preds == labels).sum().item()
            total_samples += labels.size(0)

            train_pbar.set_postfix(loss=loss.item())

        train_acc = 100 * correct_preds / total_samples
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {running_loss/total_samples:.4f} | Train Acc: {train_acc:.2f}%")

         # Evaluation
        model.eval()
        test_loss, correct_preds, total_samples = 0.0, 0, 0
        test_predictions = []
        
        with torch.no_grad():
            test_pbar = tqdm(test_loader, desc=f"Epoch {epoch+1} Validation", leave=False)
            for audio_emb, text_emb, labels in test_pbar:
                audio_emb = audio_emb.to(device)
                text_emb = text_emb.to(device)
                labels = labels.to(device)

                outputs = model(audio_emb, text_emb)
                loss = criterion(outputs, labels)

                test_loss += loss.item() * audio_emb.size(0)
                preds = torch.argmax(outputs, dim=1)
                correct_preds += (preds == labels).sum().item()
                total_samples += labels.size(0)

                test_predictions.extend(zip(labels.cpu().tolist(), preds.cpu().tolist()))

                test_pbar.set_postfix(loss=loss.item())

        test_acc = 100 * correct_preds / total_samples
        avg_test_loss = test_loss / total_samples
        print(f"Epoch {epoch+1}/{num_epochs} - Val Loss: {avg_test_loss:.4f} | Val Acc: {test_acc:.2f}%")

        print("\nSample Predictions (Real MOS vs Predicted MOS):")
        for i, (real_mos, pred_mos) in enumerate(test_predictions[:20]):
            print(f"Example {i+1}: Real MOS = {real_mos + 1}, Predicted MOS = {pred_mos + 1}")  # Convert back to 1-5 scale

        save_model(model, epoch + 1, test_acc > best_acc)

        if test_acc > best_acc:
            best_acc = test_acc

    print("Training complete! Best validation accuracy:", best_acc)


In [16]:
main()


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
                                                                                                                       

Epoch 1/10 - Train Loss: 0.2493 | Train Acc: 19.87%


                                                                                                                       

Epoch 1/10 - Val Loss: 0.2056 | Val Acc: 23.67%

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4, Predicted MOS = 3
Example 2: Real MOS = 3, Predicted MOS = 3
Example 3: Real MOS = 5, Predicted MOS = 3
Example 4: Real MOS = 2, Predicted MOS = 3
Example 5: Real MOS = 3, Predicted MOS = 3
Example 6: Real MOS = 4, Predicted MOS = 3
Example 7: Real MOS = 3, Predicted MOS = 3
Example 8: Real MOS = 3, Predicted MOS = 3
Example 9: Real MOS = 3, Predicted MOS = 3
Example 10: Real MOS = 3, Predicted MOS = 3
Example 11: Real MOS = 5, Predicted MOS = 3
Example 12: Real MOS = 5, Predicted MOS = 3
Example 13: Real MOS = 4, Predicted MOS = 3
Example 14: Real MOS = 3, Predicted MOS = 3
Example 15: Real MOS = 4, Predicted MOS = 3
Example 16: Real MOS = 5, Predicted MOS = 3
Example 17: Real MOS = 5, Predicted MOS = 3
Example 18: Real MOS = 4, Predicted MOS = 3
Example 19: Real MOS = 5, Predicted MOS = 3
Example 20: Real MOS = 2, Predicted MOS = 3


                                                                                                                       

Epoch 2/10 - Train Loss: 0.2398 | Train Acc: 20.20%


                                                                                                                       

Epoch 2/10 - Val Loss: 0.2040 | Val Acc: 23.67%

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4, Predicted MOS = 3
Example 2: Real MOS = 3, Predicted MOS = 3
Example 3: Real MOS = 5, Predicted MOS = 3
Example 4: Real MOS = 2, Predicted MOS = 3
Example 5: Real MOS = 3, Predicted MOS = 3
Example 6: Real MOS = 4, Predicted MOS = 3
Example 7: Real MOS = 3, Predicted MOS = 3
Example 8: Real MOS = 3, Predicted MOS = 3
Example 9: Real MOS = 3, Predicted MOS = 3
Example 10: Real MOS = 3, Predicted MOS = 3
Example 11: Real MOS = 5, Predicted MOS = 3
Example 12: Real MOS = 5, Predicted MOS = 3
Example 13: Real MOS = 4, Predicted MOS = 3
Example 14: Real MOS = 3, Predicted MOS = 3
Example 15: Real MOS = 4, Predicted MOS = 3
Example 16: Real MOS = 5, Predicted MOS = 3
Example 17: Real MOS = 5, Predicted MOS = 3
Example 18: Real MOS = 4, Predicted MOS = 3
Example 19: Real MOS = 5, Predicted MOS = 3
Example 20: Real MOS = 2, Predicted MOS = 3


                                                                                                                       

Epoch 3/10 - Train Loss: 0.2406 | Train Acc: 19.73%


                                                                                                                       

Epoch 3/10 - Val Loss: 0.2038 | Val Acc: 23.67%

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4, Predicted MOS = 3
Example 2: Real MOS = 3, Predicted MOS = 3
Example 3: Real MOS = 5, Predicted MOS = 3
Example 4: Real MOS = 2, Predicted MOS = 3
Example 5: Real MOS = 3, Predicted MOS = 3
Example 6: Real MOS = 4, Predicted MOS = 3
Example 7: Real MOS = 3, Predicted MOS = 3
Example 8: Real MOS = 3, Predicted MOS = 3
Example 9: Real MOS = 3, Predicted MOS = 3
Example 10: Real MOS = 3, Predicted MOS = 3
Example 11: Real MOS = 5, Predicted MOS = 3
Example 12: Real MOS = 5, Predicted MOS = 3
Example 13: Real MOS = 4, Predicted MOS = 3
Example 14: Real MOS = 3, Predicted MOS = 3
Example 15: Real MOS = 4, Predicted MOS = 3
Example 16: Real MOS = 5, Predicted MOS = 3
Example 17: Real MOS = 5, Predicted MOS = 3
Example 18: Real MOS = 4, Predicted MOS = 3
Example 19: Real MOS = 5, Predicted MOS = 3
Example 20: Real MOS = 2, Predicted MOS = 3


                                                                                                                       

Epoch 4/10 - Train Loss: 0.2410 | Train Acc: 19.77%


                                                                                                                       

Epoch 4/10 - Val Loss: 0.2037 | Val Acc: 23.67%

Sample Predictions (Real MOS vs Predicted MOS):
Example 1: Real MOS = 4, Predicted MOS = 3
Example 2: Real MOS = 3, Predicted MOS = 3
Example 3: Real MOS = 5, Predicted MOS = 3
Example 4: Real MOS = 2, Predicted MOS = 3
Example 5: Real MOS = 3, Predicted MOS = 3
Example 6: Real MOS = 4, Predicted MOS = 3
Example 7: Real MOS = 3, Predicted MOS = 3
Example 8: Real MOS = 3, Predicted MOS = 3
Example 9: Real MOS = 3, Predicted MOS = 3
Example 10: Real MOS = 3, Predicted MOS = 3
Example 11: Real MOS = 5, Predicted MOS = 3
Example 12: Real MOS = 5, Predicted MOS = 3
Example 13: Real MOS = 4, Predicted MOS = 3
Example 14: Real MOS = 3, Predicted MOS = 3
Example 15: Real MOS = 4, Predicted MOS = 3
Example 16: Real MOS = 5, Predicted MOS = 3
Example 17: Real MOS = 5, Predicted MOS = 3
Example 18: Real MOS = 4, Predicted MOS = 3
Example 19: Real MOS = 5, Predicted MOS = 3
Example 20: Real MOS = 2, Predicted MOS = 3


                                                                                                                       

KeyboardInterrupt: 