In [18]:
# multimodal_fusion.py
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoModelForAudioClassification, AutoTokenizer, AutoFeatureExtractor, Trainer, TrainingArguments
from transformers import logging as hf_logging
from datasets import Dataset
from tabulate import tabulate
from collections import Counter
import pandas as pd
import numpy as np
import librosa
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import os
hf_logging.set_verbosity_error()

In [2]:
# -------- CONFIG --------
TEXT_CHECKPOINT = r"models/dabertav3_lyrics_saved"
AUDIO_CHECKPOINT = r"models/ast_lyrics_saved"
TEXT_MODEL_NAME_OR_PATH = TEXT_CHECKPOINT
AUDIO_MODEL_NAME_OR_PATH = AUDIO_CHECKPOINT
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_LABELS = 4
BATCH_SIZE = 8
LR = 1e-3
EPOCHS = 10

In [3]:
# -------- MODELS LOAD --------
# Load text classifier and tokenizer
text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME_OR_PATH, truncation=True, padding=True, max_length=128, return_tensors="pt")
text_model = AutoModelForSequenceClassification.from_pretrained(
    TEXT_CHECKPOINT, 
    num_labels=NUM_LABELS, 
    ignore_mismatched_sizes=True)
text_model.to(DEVICE)
text_model.eval()
for p in text_model.parameters():
    p.requires_grad = False

# Load audio classifier and feature extractor
# Many HF audio models use AutoFeatureExtractor or Wav2Vec2FeatureExtractor
audio_feature_extractor = AutoFeatureExtractor.from_pretrained(AUDIO_MODEL_NAME_OR_PATH)
audio_model = AutoModelForAudioClassification.from_pretrained(
    AUDIO_MODEL_NAME_OR_PATH, 
    attn_implementation="eager", 
    num_labels=NUM_LABELS, 
    ignore_mismatched_sizes=True
)
audio_model.to(DEVICE)
audio_model.eval()
for p in audio_model.parameters():
    p.requires_grad = False

In [4]:
# -------- LOAD DATA --------
df = pd.read_csv('multimodal_dataset_normalized.csv')
display(df.head())

unique_items = df["Quadrant"].unique()
label_list = sorted(set(unique_items))
label2id = {l: i for i, l in enumerate(label_list)}
id2label = {i: l for l, i in label2id.items()}
print("Class to index mapping:", label2id)

audio_dir = "MERGE_Bimodal_Complete/audio_wav"
lyrics_dir = "MERGE_Bimodal_Complete/lyrics"

raw_data = []
for _, row in df.iterrows():
    raw_data.append({
        'lyric_path': f"{lyrics_dir}/{row['Quadrant']}/{row['Lyric_Song']}.txt",
        'audio_path': f"{audio_dir}/{row['Quadrant']}/{row['Audio_Song']}.wav",
        'label': row['Quadrant']
    })

print(raw_data)
print(dict(Counter(item['label'] for item in raw_data)))

Unnamed: 0,Audio_Song,Lyric_Song,Arousal,Valence,Quadrant,Emotion,lyric_id,word_count,unique_word_count,lexical_diversity,...,rms_mean,rms_std,beat_strength,low_energy_ratio,energy_entropy,brightness,warmth,activity,harmonic_energy_ratio,harmonicity
0,A005,L055,0.7875,0.6875,Q1,Surprise,L055,0.583846,-0.032136,-1.136493,...,-1.315662,-1.165455,0.896397,-0.700513,-0.145042,1.086334,-0.092941,[0.06471955],-1.951364,-0.583475
1,A011,L061,0.68125,0.85625,Q1,Happiness,L061,-0.54914,-0.97493,-1.286263,...,0.107522,-0.408386,0.006682,0.068631,0.750471,0.147696,0.04022,[0.18521025],0.472448,-0.078961
2,A014,L064,0.8625,0.725,Q1,Surprise,L064,0.415581,-0.162177,-1.133997,...,0.361195,-0.265328,0.710477,-0.673247,0.854382,1.970159,-0.798261,[0.17745368],-0.949336,-0.454885
3,A019,L069,0.78125,0.81875,Q1,Excitement,L069,-0.229436,-0.308472,-0.401603,...,-0.219121,-0.979168,-0.061622,-0.34051,1.062094,0.474343,0.110013,[0.13742129],0.24476,-0.150335
4,A022,L072,0.76875,0.8375,Q1,Excitement,L072,-0.599619,-0.828634,-0.755787,...,0.916902,0.959027,-0.008507,-0.482841,0.29892,-0.169472,0.269772,[0.22606403],-0.078616,-0.234134


Class to index mapping: {'Q1': 0, 'Q2': 1, 'Q3': 2, 'Q4': 3}
[{'lyric_path': 'MERGE_Bimodal_Complete/lyrics/Q1/L055.txt', 'audio_path': 'MERGE_Bimodal_Complete/audio_wav/Q1/A005.wav', 'label': 'Q1'}, {'lyric_path': 'MERGE_Bimodal_Complete/lyrics/Q1/L061.txt', 'audio_path': 'MERGE_Bimodal_Complete/audio_wav/Q1/A011.wav', 'label': 'Q1'}, {'lyric_path': 'MERGE_Bimodal_Complete/lyrics/Q1/L064.txt', 'audio_path': 'MERGE_Bimodal_Complete/audio_wav/Q1/A014.wav', 'label': 'Q1'}, {'lyric_path': 'MERGE_Bimodal_Complete/lyrics/Q1/L069.txt', 'audio_path': 'MERGE_Bimodal_Complete/audio_wav/Q1/A019.wav', 'label': 'Q1'}, {'lyric_path': 'MERGE_Bimodal_Complete/lyrics/Q1/L072.txt', 'audio_path': 'MERGE_Bimodal_Complete/audio_wav/Q1/A022.wav', 'label': 'Q1'}, {'lyric_path': 'MERGE_Bimodal_Complete/lyrics/Q1/L074.txt', 'audio_path': 'MERGE_Bimodal_Complete/audio_wav/Q1/A024.wav', 'label': 'Q1'}, {'lyric_path': 'MERGE_Bimodal_Complete/lyrics/Q1/L089.txt', 'audio_path': 'MERGE_Bimodal_Complete/audio_wav/Q1

In [5]:
# -------- DEFINE PREPROCESSING --------
def preprocess_lyrics(example):
    with open(example['lyric_path'], "r", encoding="utf-8") as f:
        text = f.read()
    encoded = text_tokenizer(text, truncation=True, padding=True, max_length=128, return_tensors="pt")
    return dict(encoded)


def preprocess_audio(example):
    audio_array, _ = librosa.load(example["audio_path"], sr=16000)
    features = audio_feature_extractor(audio_array, sampling_rate=16000)
    return features

def preprocess_label(example):
    return {'label' : label2id[example['label']]}

In [6]:
# -------- BUILD DATASETS --------
trainval_set, test_set = train_test_split(raw_data, test_size=0.2, random_state=42)
train_set, val_set = train_test_split(trainval_set, test_size=0.25, random_state=42)

train_dataset = Dataset.from_list(train_set)
val_dataset = Dataset.from_list(val_set)
test_dataset = Dataset.from_list(test_set)


train_dataset.set_format('torch')
train_dataset = train_dataset.map(preprocess_label)
train_dataset = train_dataset.map(preprocess_lyrics)
train_dataset = train_dataset.map(preprocess_audio)

val_dataset.set_format('torch')
val_dataset = val_dataset.map(preprocess_label)
val_dataset = val_dataset.map(preprocess_lyrics)
val_dataset = val_dataset.map(preprocess_audio)

test_dataset.set_format('torch')
test_dataset = test_dataset.map(preprocess_label)
test_dataset = test_dataset.map(preprocess_lyrics)
test_dataset = test_dataset.map(preprocess_audio)

Map:   0%|          | 0/1329 [00:00<?, ? examples/s]

Map:   0%|          | 0/1329 [00:00<?, ? examples/s]

Map:   0%|          | 0/1329 [00:00<?, ? examples/s]

Map:   0%|          | 0/443 [00:00<?, ? examples/s]

Map:   0%|          | 0/443 [00:00<?, ? examples/s]

Map:   0%|          | 0/443 [00:00<?, ? examples/s]

Map:   0%|          | 0/444 [00:00<?, ? examples/s]

Map:   0%|          | 0/444 [00:00<?, ? examples/s]

Map:   0%|          | 0/444 [00:00<?, ? examples/s]

In [8]:
from torch.nn.utils.rnn import pad_sequence

class AudioCollator:
    def __call__(self, batch):
        input_values = [torch.tensor(b["input_values"]).squeeze() for b in batch]
        labels = [b["label"] for b in batch]
        return {
            "input_values": torch.stack(input_values),
            "labels": torch.tensor(labels)
        }
    
class TextCollator:
    def __call__(self, batch):
        input_ids = [b["input_ids"].squeeze() for b in batch]
        attention_mask = [b["attention_mask"].squeeze() for b in batch]
        token_type_ids = [b["token_type_ids"].squeeze() for b in batch]
        labels = [b["label"] for b in batch]
        return {
            "input_ids": pad_sequence(input_ids, batch_first=True),
            "attention_mask": pad_sequence(attention_mask, batch_first=True),
            "token_type_ids": pad_sequence(token_type_ids, batch_first=True),
            "labels": torch.tensor(labels)
        }
    
def multimodal_collator(features):
    input_values = [f["input_values"].squeeze() for f in features]
    input_ids = [f["input_ids"].squeeze() for f in features]
    attention_mask = [f["attention_mask"].squeeze() for f in features]
    labels = [f["label"] for f in features]
    
    return {
        "input_values": pad_sequence(input_values, batch_first=True),
        "input_ids": pad_sequence(input_ids, batch_first=True),
        "attention_mask": pad_sequence(attention_mask, batch_first=True),
        "labels": torch.tensor(labels)
    }

In [9]:
class ElementwiseFusionModel(nn.Module):
    def __init__(self, audio_model, text_model, hidden_dim=256, num_labels=2):
        super().__init__()
        self.audio_model = audio_model
        self.text_model = text_model

        self.classifier = nn.Sequential(
            nn.Linear(audio_model.config.hidden_size, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, num_labels)
        )

    def forward(self, input_values=None, input_ids=None, attention_mask=None, labels=None):
        audio_out = self.audio_model(input_values, output_attentions=True)
        text_out = self.text_model(input_ids, attention_mask=attention_mask, output_attentions=True)
        
        audio_attn_weights = audio_out.attentions[-1].mean(dim=1).mean(dim=1)  # [B, seq_len_audio]
        audio_attn_weights = torch.nn.functional.softmax(audio_attn_weights, dim=-1) # Normalize
        audio_hidden_states = audio_out.hidden_states[-1]                      # [B, seq_len_audio, hidden_dim]
        audio_weighted = (audio_hidden_states * audio_attn_weights.unsqueeze(-1)).sum(dim=1)  # [B, hidden_dim]

        text_attn_weights = text_out.attentions[-1].mean(dim=1).mean(dim=1)   # [B, seq_len_text]
        text_attn_weights = torch.nn.functional.softmax(text_attn_weights, dim=-1) # Normalize
        text_hidden_states = text_out.hidden_states[-1]                       # [B, seq_len_text, hidden_dim]
        text_weighted = (text_hidden_states * text_attn_weights.unsqueeze(-1)).sum(dim=1)  # [B, hidden_dim]
        
        fused = audio_weighted * text_weighted  # element-wise fusion

        logits = self.classifier(fused)

        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
            return {"loss": loss, "logits": logits}
        else:
            return {"logits": logits}

In [10]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)  # Single-label prediction
    return {
        "accuracy": accuracy_score(labels, predictions),
        "precision": precision_score(labels, predictions, average="macro", zero_division=1),
        "recall": recall_score(labels, predictions, average="macro", zero_division=1),
        "f1": f1_score(labels, predictions, average="macro", zero_division=1),
    }

In [11]:
fusion_model = ElementwiseFusionModel(audio_model=audio_model, text_model=text_model, num_labels=NUM_LABELS)

text_model.config.output_hidden_states = True
text_model.config.output_attentions = True

audio_model.config.output_hidden_states = True
audio_model.config.output_attentions = True

training_args = TrainingArguments(
    output_dir='bimodal',
    num_train_epochs=5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    disable_tqdm=False,
    report_to="none",
)

trainer = Trainer(
    model=fusion_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    data_collator=multimodal_collator,
)

# Train the model
trainer.train()

print("Training done")

# Check results on the test set
metrics = trainer.evaluate(eval_dataset=test_dataset)
print(tabulate(metrics.items(), headers=["Metric", "Value"], tablefmt="pretty"))

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.9855,0.902299,0.85553,0.849788,0.852037,0.850041
2,0.4747,0.466322,0.857788,0.851425,0.853624,0.851802
3,0.2438,0.480635,0.85553,0.849142,0.851412,0.849435
4,0.2735,0.485387,0.85553,0.849142,0.851412,0.849435
5,0.3594,0.487623,0.85553,0.849142,0.851412,0.849435


Training done


+-------------------------+---------------------+
|         Metric          |        Value        |
+-------------------------+---------------------+
|        eval_loss        | 0.32581987977027893 |
|      eval_accuracy      | 0.9031531531531531  |
|     eval_precision      | 0.9038760766879689  |
|       eval_recall       | 0.8992901012762673  |
|         eval_f1         | 0.8996714057677677  |
|      eval_runtime       |      362.8066       |
| eval_samples_per_second |        1.224        |
|  eval_steps_per_second  |        0.154        |
|          epoch          |         5.0         |
+-------------------------+---------------------+


In [19]:
# Save the whole model to be loaded easily later
text_model_save_path = "models/dabertav3_lyrics_saved_bimodal_optimized"
text_model.save_pretrained(text_model_save_path)

audio_model_save_path = "models/ast_lyrics_saved_bimodal_optimized"
audio_model.save_pretrained(audio_model_save_path)

biomodal_module_save_path = "models/biomodal_lyric_audio_module_saved"
if not os.path.exists(biomodal_module_save_path):
    os.mkdir(biomodal_module_save_path)
torch.save(fusion_model.state_dict(), f"{biomodal_module_save_path}/fusion_model.pt")
trainer.args.save(f"{biomodal_module_save_path}/training_args/")

AttributeError: 'TrainingArguments' object has no attribute 'save'