#Downloads

In [None]:
!pip install torchcodec --quiet
!pip install torchinfo --quiet
!pip install transformers accelerate sentencepiece torchaudio diffusers datasets soundfile pillow --quiet

#Imports

In [None]:
import os
import shutil
import soundfile as sf
from IPython.display import display

import torch
import torchaudio
import torch.nn as nn
from torchinfo import summary
from torchvision import datasets
from torchvision.models import vit_b_16
from torch.utils.data import DataLoader, random_split, TensorDataset, Dataset

from PIL import Image
from diffusers import FluxPipeline
from transformers import WhisperProcessor, WhisperModel, WhisperForConditionalGeneration, pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, GenerationConfig


import gc
import json
import uuid
import time
import random
import librosa
import logging
import kagglehub
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from dataclasses import dataclass
from sklearn.svm import LinearSVC
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

from huggingface_hub import login
login()
os.environ["HF_TOKEN"] = ""

#Dataset Preparation (just for inference we need the audio files)

Download the mixed dataset (we merged) From Kaggle, unzip it and copy it to /content/.

Check and delete corrupted files.
Setup dataloaders and preprocess data, make it compatible with Whisper.

In [None]:
path = kagglehub.dataset_download("kamilhanna/emotion-dataset", force_download=True)
print("Path to dataset files:", path)

In [None]:
!mv "{path}" "/content"
!mv "/content/1/content/Emotion" "/content/"
!rm -rf "/content/1"

# #Dropping this class cause low data
!rm -rf "/content/Emotion/calm"

In [None]:
#Detete corrupted files
def clean_corrupted_audio(root):
    removed = 0
    checked = 0

    print(f"Scanning audio under: {root}\n")

    for folder, _, files in os.walk(root):
        for f in files:
            if not f.lower().endswith((".wav", ".m4a")):
                continue

            path = os.path.join(folder, f)
            checked += 1

            try:
                audio, sr = torchaudio.load(path)
            except Exception as e:
                print(f"[CORRUPTED] Removing: {path}   -->   {e}")
                try:
                    os.remove(path)
                    removed += 1
                except Exception as re:
                    print(f"[ERROR] Could not delete {path}: {re}")
                continue

    print("\n===== SUMMARY =====")
    print(f"Checked:  {checked} files")
    print(f"Removed:  {removed} corrupted files\n")
    return removed

clean_corrupted_audio("/content/Emotion")

#FULL SPEECH 2 SCENE

Get models .pt file for MLP

Setup class for all different modules

In [None]:
!cp  "/content/drive/MyDrive/GenAI/Project/best_model.pt" "/content/"

In [None]:
MODEL_HIDDEN_DIMENSION = 1280
num_classes = 8
class EmotionClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        hidden_1 = 1024
        hidden_2 = 512
        hidden_3 = 256
        hidden_4 = 128
        hidden_5 = 64
        self.net = nn.Sequential(
            nn.Linear(MODEL_HIDDEN_DIMENSION, MODEL_HIDDEN_DIMENSION * 2),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.BatchNorm1d(MODEL_HIDDEN_DIMENSION * 2),

            nn.Linear(MODEL_HIDDEN_DIMENSION * 2, hidden_1),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.BatchNorm1d(hidden_1),

            nn.Linear(hidden_1, hidden_2),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.BatchNorm1d(hidden_2),

            nn.Linear(hidden_2, hidden_3),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.BatchNorm1d(hidden_3),

            nn.Linear(hidden_3, hidden_3),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.BatchNorm1d(hidden_3),

            nn.Linear(hidden_3, hidden_4),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.BatchNorm1d(hidden_4),

            nn.Linear(hidden_4, hidden_5),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.BatchNorm1d(hidden_5),

            nn.Linear(hidden_5, num_classes)
        )

    def forward(self, x):
        return self.net(x)

#Put in Cuda later
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model_mlp = EmotionClassifier(num_classes).to(device)
model_mlp.load_state_dict(torch.load("/content/best_model.pt"))

summary(model_mlp, input_size=(1, MODEL_HIDDEN_DIMENSION))

In [None]:
#Map MLP output to emotion
IDX_TO_CLASS = {
    0: "angry",
    1: "contempt",
    2: "disgust",
    3: "fear",
    4: "happy",
    5: "neutral",
    6: "sad",
    7: "surprised"
}

#LLM prompt
system_prompt = (
"""
Your task is to modify the given prompt by inserting emotion-related adjectives
while preserving every original word.

Rules:
1. Do NOT delete, replace, or reorder any original words.
2. Only INSERT 1–2 adjectives related to the given emotion.
3. Adjectives must be inserted directly before nouns.
4. Do NOT add new sentences, explanations, or details.
5. Output only the modified prompt — nothing else.

Format:
Original prompt: <prompt>
Emotion: <emotion>
Output: <modified prompt>
"""
)

user_template = (
"""
Original prompt: {prompt}
Emotion: {emotion}
Output:
"""
)


#Gpu timing
def measure_gpu_time(fn, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = fn(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()

    elapsed_ms = start.elapsed_time(end)
    return output, elapsed_ms

#Whole pipeline module
class Speech2SceneModule(nn.Module):
    def __init__(
        self,
        whisper_model_name="openai/whisper-large-v3",
        flux_model_name="black-forest-labs/FLUX.1-schnell",
        llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
        mlp_classifier=None,
        device="cuda",
    ):
        super().__init__()

        self.device = device

        # Whipser transcription
        self.processor = WhisperProcessor.from_pretrained(whisper_model_name)
        self.whisper = WhisperForConditionalGeneration.from_pretrained(
            whisper_model_name
        ).to(device)
        self.whisper.eval()

        # FLux image generation
        self.flux = FluxPipeline.from_pretrained(
            flux_model_name,
            torch_dtype=torch.float16,
            device_map=device
        )
        self.flux.set_progress_bar_config(disable=True)

        # Llama LLM rewriting
        self.llm = AutoModelForCausalLM.from_pretrained(
            llm_model_name,
            torch_dtype=torch.bfloat16,
            device_map={"": 0}
        )

        self.tokenizer = AutoTokenizer.from_pretrained(llm_model_name, use_fast=True)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # MLP emotion Classifier
        self.mlp = mlp_classifier.to(device)
        self.mlp.eval()

    def forward(
        self,
        audio_path,
        emotion_to_prompt=True,
        flux_height=512,
        flux_width=512,
        flux_steps=40,
        flux_guidance=3.0,
        flux_seed=0,
        llm_temperature=0.2,
        llm_max_new_tokens=100,
        llm_top_p=0.9,
        llm_top_k=50,
        llm_do_sample=True,
        stop_strings=None
    ):

        # load audio
        audio, sr = torchaudio.load(audio_path)
        if audio.shape[0] > 1:
            audio = audio.mean(dim=0)
        else:
            audio = audio[0]
        if sr != 16000:
            audio = torchaudio.functional.resample(audio, sr, 16000)

        audio = audio.float()

        # preprocess audio to become compatible with Whisper
        feats = self.processor(
            audio.cpu().numpy(),
            sampling_rate=16000,
            padding="max_length",
            return_attention_mask=True,
            return_tensors="pt"
        )

        input_features = feats.input_features.to(self.device)
        attention_mask = feats.attention_mask.to(self.device)

        # Emotion classification, Encode Audio, Transcription Generation

        # Whisper Encoder timing
        (enc_out_dict, whisper_enc_ms) = measure_gpu_time(
            self.whisper.model.encoder,
            input_features=input_features,
            attention_mask=attention_mask,
            return_dict=True
        )
        enc_out = enc_out_dict.last_hidden_state
        pooled = enc_out.mean(dim=1)

        # MLP timing
        (logits, mlp_ms) = measure_gpu_time(self.mlp, pooled)
        probs = torch.softmax(logits, dim=-1)

        conf, idx = torch.max(probs, dim=-1)
        idx = idx.item()
        conf = conf.item()
        emotion_label = IDX_TO_CLASS[idx]

        forced_ids = self.processor.get_decoder_prompt_ids(language="en", task="transcribe")

        (generated_ids, whisper_dec_ms) = measure_gpu_time(
            self.whisper.generate,
            input_features=input_features,
            attention_mask=attention_mask,
            forced_decoder_ids=forced_ids
        )

        transcript = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]


        #LLM rewriting
        if emotion_to_prompt:

            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_template.format(prompt=transcript, emotion=emotion_label).strip()},
            ]

            input_ids = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                return_tensors="pt"
            ).to(self.llm.device)

            attention_mask = (input_ids != self.tokenizer.pad_token_id).long()

            # LLM timing
            (output_ids, llm_ms) = measure_gpu_time(
                self.llm.generate,
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=llm_max_new_tokens,
                temperature=llm_temperature,
                top_p=llm_top_p,
                top_k=llm_top_k,
                do_sample=llm_do_sample,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

            gen_ids = output_ids[0, input_ids.shape[1]:]
            text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()

            # Cleanup mostly for other reasoning and non reasoning models that we used
            if "</think>" in text:
                text = text.split("</think>", 1)[-1].strip()

            if stop_strings:
                for s in stop_strings:
                    if s in text:
                        text = text.split(s)[0].strip()
                        break

            bad_prefixes = [
                "Original prompt:", "Output:", "Modified prompt:", "modified prompt:",
                "Rewritten Prompt:", "Prompt:", "Result:"
            ]
            for p in bad_prefixes:
                if p in text:
                    text = text.split(p)[-1].strip()

            if "Original prompt:" in text and "Emotion:" in text and "Output:" in text:
                text = text.split("Output:", 1)[-1].strip()

            prompt = text.strip().strip('"').strip()

        else:
            prompt = transcript
            llm_ms = 0.0

        # Flux image generation
        generator = torch.Generator("cpu").manual_seed(flux_seed)

        (flux_output, flux_ms) = measure_gpu_time(
            self.flux,
            prompt,
            height=flux_height,
            width=flux_width,
            guidance_scale=flux_guidance,
            num_inference_steps=flux_steps,
            generator=generator
        )
        img = flux_output.images[0]

        #Time of whole pipeline
        total_time_ms = whisper_enc_ms + mlp_ms + whisper_dec_ms + llm_ms + flux_ms

        return {
            "transcript": transcript,
            "emotion": emotion_label,
            "emotion_confidence": conf,
            "prompt": prompt,
            "image": img,

            "timings_ms": {
                "whisper_encoder": whisper_enc_ms,
                "mlp": mlp_ms,
                "whisper_decoder": whisper_dec_ms,
                "llm": llm_ms,
                "flux": flux_ms,
                "total": total_time_ms
            },

            "emotion_probs": probs.detach().cpu(),
            "embedding": pooled.detach().cpu(),
            "unfiltered_prompt": transcript,
        }


In [None]:
s2s = Speech2SceneModule(mlp_classifier=model_mlp)

In [None]:
MAX_SAMPLES = 50

timings_list = []

count = 0
path = "/content/emotion-dataset/content/Emotion"
for folder in os.listdir(path):
    folder_path = os.path.join(path, folder)

    for file in os.listdir(folder_path):
        if count >= MAX_SAMPLES:
            break

        audio_path = os.path.join(folder_path, file)

        result = s2s(audio_path, llm_temperature=0.2)


        t = result["timings_ms"]
        timings_list.append(t)


        print(f"\n===== Sample {count+1} =====")
        print("Transcript:", result["transcript"])
        print(f"Emotion: {result['emotion']}  (conf={result['emotion_confidence']:.3f})")
        print("Prompt:", result["prompt"])

        print("\n--- Timings (ms) ---")
        print(f"Whisper Encoder:  {t['whisper_encoder']:.2f}")
        print(f"MLP:              {t['mlp']:.2f}")
        print(f"Whisper Decoder:  {t['whisper_decoder']:.2f}")
        print(f"LLM:              {t['llm']:.2f}")
        print(f"Flux:             {t['flux']:.2f}")
        print(f"Total:            {t['total']:.2f}")

        display(result["image"])

        count += 1

    if count >= MAX_SAMPLES:
        break



#compute averages
avg_whisper_enc = np.mean([x["whisper_encoder"] for x in timings_list])
avg_mlp         = np.mean([x["mlp"] for x in timings_list])
avg_whisper_dec = np.mean([x["whisper_decoder"] for x in timings_list])
avg_llm         = np.mean([x["llm"] for x in timings_list])
avg_flux        = np.mean([x["flux"] for x in timings_list])
avg_total       = np.mean([x["total"] for x in timings_list])

print("\n====================================")
print("          AVERAGE TIMINGS (ms)      ")
print("====================================")
print(f"Avg Whisper Encoder:  {avg_whisper_enc:.2f} ms")
print(f"Avg MLP:              {avg_mlp:.2f} ms")
print(f"Avg Whisper Decoder:  {avg_whisper_dec:.2f} ms")
print(f"Avg LLM:              {avg_llm:.2f} ms")
print(f"Avg Flux:             {avg_flux:.2f} ms")
print(f"Avg TOTAL:            {avg_total:.2f} ms")
print("====================================")
