# Emotion Recognition and Data Augmentation using Diffusion Models

This notebook implements the methodology described in the paper 'A Generation of Enhanced Data by Variational Autoencoders and Diffusion Modeling' (Electronics 2024, 13, 1314).

**Pipeline:**
1. Data Loading and Preprocessing (EmoDB, RAVDESS) -> Mel-Spectrograms
2. ResNet-50 Emotion Embedding Model Training and Feature Extraction (PyTorch)
3. Diffusion Model for Mel-Spectrogram Generation (PyTorch, U-Net based)
4. Data Augmentation using the Diffusion Model
5. Evaluation of Original vs. Augmented Data (WA, UA, Confusion Matrix, P/R/F1)

**Notes for Google Colab:**
*   Use a GPU runtime (`Runtime` -> `Change runtime type`).
*   Training (especially diffusion) is computationally expensive and time-consuming.
*   Mount Google Drive to save datasets, models, and generated data persistently (see next cell).

# 1. Setup and Imports

In [1]:
!pip install librosa tqdm requests 

import os
import glob
import zipfile
import requests
import shutil
import random
from tqdm.notebook import tqdm  

import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.models as models
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, balanced_accuracy_score

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Using device: cpu


# 2. Configuration and Constants

**Important:** If you mounted Google Drive, modify the `*_DIR` paths below to point to your Drive folder.

In [2]:
# --- Data Parameters ---
TARGET_SAMPLE_RATE = 22050
DURATION_SECONDS = 10  # Pad/truncate audio to this length
TARGET_SAMPLES = TARGET_SAMPLE_RATE * DURATION_SECONDS

# --- Mel-Spectrogram Parameters ---
N_FFT = 1024        # Window size for STFT
HOP_LENGTH = 256    # Hop length for STFT
N_MELS = 80         # Number of Mel bands
EXPECTED_TIME_STEPS = 862 # Fixed size after padding/truncating spectrograms

# --- Emotion Labels ---
EMOTIONS = {
    "neutral": 0,
    "happy": 1,
    "sad": 2,
    "angry": 3,
    "fear": 4,  # EmoDB 'Angst', RAVDESS 'fearful'
    "disgust": 5
}
NUM_EMOTIONS = len(EMOTIONS)

# --- Model Training Parameters ---
# Reduce epochs significantly for initial testing in Colab!
RESNET_EPOCHS = 50 # Paper uses 800, reduce for faster demo
RESNET_BATCH_SIZE = 32
RESNET_LEARNING_RATE = 1e-4

DIFFUSION_EPOCHS = 100 # Paper doesn't specify, GUESS - likely needs more. Reduce for testing.
DIFFUSION_BATCH_SIZE = 16 # Diffusion models often need smaller batches
DIFFUSION_LEARNING_RATE = 1e-4
DIFFUSION_TIMESTEPS = 1000 # Number of noise levels

# --- Paths --- 
# !! MODIFY THESE IF USING GOOGLE DRIVE !!
BASE_DIR = "/content" # Default Colab storage
# Example if using Drive:
# DRIVE_BASE_PATH = '/content/drive/MyDrive/Colab_Emotion_Augmentation'
# BASE_DIR = DRIVE_BASE_PATH 

DATA_DIR = os.path.join(BASE_DIR, "data")
EMODB_URL = "http://emodb.bilderbar.info/download/download.zip" # Official EmoDB download URL might require navigation
EMODB_DIR = os.path.join(DATA_DIR, "emodb")
RAVDESS_URL = "https://zenodo.org/record/1188976/files/Audio_Speech_Actors_01-24.zip?download=1" # Example URL, check for official source
RAVDESS_DIR = os.path.join(DATA_DIR, "ravdess")
PROCESSED_DIR = os.path.join(DATA_DIR, "processed_mels")
AUGMENTED_DIR = os.path.join(DATA_DIR, "augmented_mels")
MODEL_SAVE_DIR = os.path.join(BASE_DIR, "models") # Save models outside data dir

RESNET_MODEL_PATH = os.path.join(MODEL_SAVE_DIR, "resnet50_emotion_embedder.pth")
DIFFUSION_MODEL_PATH = os.path.join(MODEL_SAVE_DIR, "diffusion_mel_generator.pth")

# --- Create Directories ---
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(EMODB_DIR, exist_ok=True) # Create dataset dirs explicitly
os.makedirs(RAVDESS_DIR, exist_ok=True)
os.makedirs(PROCESSED_DIR, exist_ok=True)
os.makedirs(AUGMENTED_DIR, exist_ok=True)
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

print(f"Data Directory: {DATA_DIR}")
print(f"Model Save Directory: {MODEL_SAVE_DIR}")

Data Directory: /content\data
Model Save Directory: /content\models


# 3. Data Loading and Preprocessing

## 3.1. Helper Functions for Downloading and Extracting

In [3]:
def download_file(url, filename):
    """Downloads a file from a URL to a local filename."""
    if os.path.exists(filename):
        print(f"File already exists: {filename}")
        return True
    print(f"Downloading {url} to {filename}...")
    try:
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            total_size = int(r.headers.get('content-length', 0))
            block_size = 8192 # 8KB chunks
            progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
            with open(filename, 'wb') as f:
                for chunk in r.iter_content(chunk_size=block_size):
                    progress_bar.update(len(chunk))
                    f.write(chunk)
            progress_bar.close()
        if total_size != 0 and progress_bar.n != total_size:
             print("ERROR, something went wrong during download")
             return False
        print(f"Download complete: {filename}")
        return True
    except requests.exceptions.RequestException as e:
        print(f"Error downloading {url}: {e}")
        if os.path.exists(filename):
            os.remove(filename) # Clean up incomplete download
        return False
    except Exception as e:
        print(f"An unexpected error occurred during download: {e}")
        if os.path.exists(filename):
            os.remove(filename)
        return False

def extract_zip(zip_path, extract_to):
    """Extracts a zip file."""
    if not os.path.exists(zip_path):
        print(f"Zip file not found: {zip_path}")
        return False
    
    # Check if extraction directory seems populated (simple check based on expected content)
    expected_emodb = os.path.join(extract_to, "wav") # EmoDB has a 'wav' subfolder
    expected_ravdess = os.path.join(extract_to, "Actor_01") # RAVDESS has Actor_ folders
    already_extracted = False
    if os.path.exists(extract_to):
        if zip_path.endswith("emodb.zip") and os.path.exists(expected_emodb) and len(os.listdir(expected_emodb)) > 0:
            already_extracted = True
        elif zip_path.endswith("ravdess.zip") and os.path.exists(expected_ravdess) and len(os.listdir(expected_ravdess)) > 0:
             already_extracted = True
             
    if already_extracted:
         print(f"Directory already exists and seems populated: {extract_to}")
         return True
    else:
        print(f"Directory {extract_to} empty or incomplete. Extracting.")

    print(f"Extracting {zip_path} to {extract_to}...")
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            for member in tqdm(zip_ref.infolist(), desc='Extracting '):
                 try:
                     # Prevent extracting __MACOSX folders if they exist
                     if "__MACOSX" not in member.filename:
                        zip_ref.extract(member, extract_to)
                 except zipfile.error as e:
                     print(f"Error extracting member {member.filename}: {e}")
            # zip_ref.extractall(extract_to) # Less informative progress
        print(f"Extraction complete: {extract_to}")
        # Clean up zip file after successful extraction
        # os.remove(zip_path)
        return True
    except zipfile.BadZipFile:
        print(f"Error: Bad zip file: {zip_path}")
        # Optionally remove the bad zip file
        # os.remove(zip_path)
        return False
    except Exception as e:
        print(f"An unexpected error occurred during extraction: {e}")
        return False

## 3.2. Download and Extract Datasets
**Note:** Direct download URLs might change or require specific permissions. If downloads fail, please download the datasets manually and place them in the `DATA_DIR` (`emodb` and `ravdess` subfolders respectively), or upload them to your mounted Google Drive and adjust paths.

In [4]:
# --- EmoDB Download --- 
# EmoDB download often requires navigating the website.
# Provide instructions for manual download.
emodb_zip_path = os.path.join(DATA_DIR, "emodb.zip")
emodb_wav_dir = os.path.join(EMODB_DIR, "wav")
print("--- EmoDB --- ")
if not os.path.exists(emodb_wav_dir) or not os.listdir(emodb_wav_dir):
     print("EmoDB data not found.")
     print("Please download 'download.zip' manually from http://emodb.bilderbar.info/download/")
     print(f"Extract it, find the 'wav' folder, and place its contents into '{emodb_wav_dir}'.")
     # Create the target directory if it doesn't exist, so the user knows where to put files
     os.makedirs(emodb_wav_dir, exist_ok=True)
     print(f"Directory created: {emodb_wav_dir}")
     # Example: After downloading and extracting, you might have './download/wav'. 
     # You need to move/copy the *.wav files from there into the EMODB_DIR/wav folder defined above.
else:
    print(f"EmoDB data seems to be present in {emodb_wav_dir}.")

# --- RAVDESS Download ---
ravdess_zip_path = os.path.join(DATA_DIR, "ravdess.zip")
ravdess_actor_dir = os.path.join(RAVDESS_DIR, "Actor_01") # Check for existence of Actor_01
print("\n--- RAVDESS --- ")
if not os.path.exists(ravdess_actor_dir):
    print("RAVDESS data not found. Attempting download...")
    if download_file(RAVDESS_URL, ravdess_zip_path):
        # RAVDESS often extracts into the Actor_ folders directly in the target dir
        extract_zip(ravdess_zip_path, RAVDESS_DIR)
    else:
        print("RAVDESS download failed.")
        print("Please download the audio dataset manually (e.g., search 'RAVDESS audio dataset', often on Zenodo/Kaggle)")
        print(f"Extract the zip file. It should contain folders like 'Actor_01', 'Actor_02', etc.") 
        print(f"Place these 'Actor_XX' folders directly inside '{RAVDESS_DIR}'.")
        # Create the target directory if it doesn't exist
        os.makedirs(RAVDESS_DIR, exist_ok=True)
else:
    print(f"RAVDESS data seems to be present in {RAVDESS_DIR}.")

--- EmoDB --- 
EmoDB data not found.
Please download 'download.zip' manually from http://emodb.bilderbar.info/download/
Extract it, find the 'wav' folder, and place its contents into '/content\data\emodb\wav'.
Directory created: /content\data\emodb\wav

--- RAVDESS --- 
RAVDESS data not found. Attempting download...
Downloading https://zenodo.org/record/1188976/files/Audio_Speech_Actors_01-24.zip?download=1 to /content\data\ravdess.zip...


  0%|          | 0.00/208M [00:00<?, ?iB/s]

Download complete: /content\data\ravdess.zip
Directory /content\data\ravdess empty or incomplete. Extracting.
Extracting /content\data\ravdess.zip to /content\data\ravdess...


Extracting :   0%|          | 0/1464 [00:00<?, ?it/s]

Extraction complete: /content\data\ravdess


## 3.3. Audio Processing Functions

In [5]:
def load_and_preprocess_audio(file_path, target_sr, target_samples):
    """Loads, resamples, and pads/truncates audio."""
    try:
        # Load audio file
        wav, sr = librosa.load(file_path, sr=None) # Load original sample rate

        # Resample if necessary
        if sr != target_sr:
            wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
            sr = target_sr # Update sample rate variable
            
        # Ensure mono
        if wav.ndim > 1:
            wav = librosa.to_mono(wav)

        # Pad or truncate
        if len(wav) < target_samples:
            padding = target_samples - len(wav)
            left_pad = padding // 2
            right_pad = padding - left_pad
            wav = np.pad(wav, (left_pad, right_pad), mode='constant')
        elif len(wav) > target_samples:
            wav = wav[:target_samples]

        # Final length check
        if len(wav) != target_samples:
             # This might happen due to rounding, adjust slightly
             if len(wav) < target_samples:
                 wav = np.pad(wav, (0, target_samples - len(wav)), mode='constant')
             else:
                 wav = wav[:target_samples]

        return wav
    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return None

def wav_to_mel_spectrogram(wav, sr, n_fft, hop_length, n_mels):
    """Converts waveform to mel spectrogram."""
    try:
        mel_spec = librosa.feature.melspectrogram(
            y=wav, 
            sr=sr, 
            n_fft=n_fft, 
            hop_length=hop_length, 
            n_mels=n_mels
        )
        # Convert to dB scale
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        return mel_spec_db
    except Exception as e:
        print(f"Error converting waveform to mel spectrogram: {e}")
        return None

def normalize_spectrogram(mel_spec):
    """Applies Z-score normalization to a spectrogram."""
    if mel_spec is None:
        return None
    mean = np.mean(mel_spec)
    std = np.std(mel_spec)
    if std < 1e-6: # Avoid division by zero/very small numbers for silent spectrograms
        return mel_spec - mean # Just center it
    normalized_spec = (mel_spec - mean) / std
    return normalized_spec

def pad_truncate_spectrogram(spec, target_time_steps):
    """Pads or truncates the time dimension of a spectrogram."""
    if spec is None:
        return None
    current_time_steps = spec.shape[1]
    if current_time_steps == target_time_steps:
        return spec
    elif current_time_steps < target_time_steps:
        padding = target_time_steps - current_time_steps
        # Pad on the right side with a value representing silence (e.g., min value or slightly lower)
        pad_value = np.min(spec) if spec.size > 0 else -80.0 # Use min value or default low dB
        spec = np.pad(spec, ((0, 0), (0, padding)), mode='constant', constant_values=pad_value) 
    elif current_time_steps > target_time_steps:
        spec = spec[:, :target_time_steps] # Truncate from the right
    return spec

## 3.4. Parsing Filenames and Extracting Labels

In [6]:
def get_emotion_label(file_path, dataset_name):
    """Extracts emotion label from filename based on dataset conventions."""
    filename = os.path.basename(file_path)
    try:
        if dataset_name == "emodb":
            # EmoDB: 03a01W.wav -> W = Anger
            # Emotion codes: W(Anger), L(Boredom), E(Disgust), A(Fear), F(Happy), T(Sad), N(Neutral)
            code = filename[5] # 6th character (0-indexed)
            mapping = {
                'W': "angry", 
                'L': "neutral", # Mapping Boredom to Neutral for simplicity / based on target labels
                'E': "disgust", 
                'A': "fear", 
                'F': "happy", 
                'T': "sad", 
                'N': "neutral"
            }
            emotion = mapping.get(code)
            # Return only emotions in our target set
            return emotion if emotion in EMOTIONS else None

        elif dataset_name == "ravdess":
            # RAVDESS: 03-01-01-01-01-01-01.wav -> 3rd part is emotion
            # Codes: 01(neutral), 02(calm), 03(happy), 04(sad), 05(angry), 06(fearful), 07(disgust), 08(surprised)
            parts = filename.split('-')
            if len(parts) < 3:
                return None # Invalid filename format
            code = int(parts[2])
            mapping = {
                1: "neutral", 
                2: "neutral", # Mapping Calm to Neutral
                3: "happy", 
                4: "sad", 
                5: "angry", 
                6: "fear", 
                7: "disgust", 
                8: None # "surprised" - Excluded based on target EMOTIONS dict
            }
            emotion = mapping.get(code)
            # Return only emotions in our target set
            return emotion if emotion in EMOTIONS else None
        else:
            return None
    except (IndexError, ValueError, KeyError) as e:
        print(f"Error parsing filename {filename} for {dataset_name}: {e}")
        return None

## 3.5. Process All Audio Files and Save Mel-Spectrograms

This step iterates through the downloaded audio files, applies the preprocessing steps (load, resample, pad, convert to mel-spectrogram, normalize, fix time steps), and saves the resulting spectrograms as `.npy` files in the `PROCESSED_DIR`. It also creates a metadata list mapping file paths to integer labels.

In [7]:
all_files_metadata = [] # List to store tuples of (mel_path, emotion_label_int)

def process_dataset(dataset_dir, dataset_name, processed_dir):
    """Processes all audio files in a dataset directory."""
    print(f"\nProcessing {dataset_name} dataset from {dataset_dir}...")
    dataset_metadata = []
    audio_files = []
    
    if not os.path.isdir(dataset_dir):
        print(f"Error: Dataset directory not found: {dataset_dir}")
        return []
        
    if dataset_name == "emodb":
        # EmoDB files are directly in the 'wav' subfolder
        wav_dir = os.path.join(dataset_dir, "wav")
        if os.path.exists(wav_dir):
             audio_files = glob.glob(os.path.join(wav_dir, "*.wav"))
        else:
             print(f"Warning: EmoDB 'wav' directory not found at {wav_dir}")
    elif dataset_name == "ravdess":
        # RAVDESS files are in Actor_XX subfolders
        actor_dirs = glob.glob(os.path.join(dataset_dir, "Actor_*"))
        if not actor_dirs:
             print(f"Warning: RAVDESS 'Actor_*' directories not found in {dataset_dir}")
        for actor_dir in actor_dirs:
            # Check if actor_dir is actually a directory before globbing
            if os.path.isdir(actor_dir):
                 audio_files.extend(glob.glob(os.path.join(actor_dir, "*.wav")))
            else:
                 print(f"Warning: Found item '{actor_dir}' which is not a directory.")

    print(f"Found {len(audio_files)} audio files for {dataset_name}.")

    processed_count = 0
    for file_path in tqdm(audio_files, desc=f"Processing {dataset_name}"):
        emotion_label_str = get_emotion_label(file_path, dataset_name)

        # Skip files with emotions not in our target set or parsing errors
        if emotion_label_str is None:
            continue

        # --- Perform Preprocessing ---
        # 1. Load, resample, pad/truncate audio
        wav = load_and_preprocess_audio(file_path, TARGET_SAMPLE_RATE, TARGET_SAMPLES)
        if wav is None: continue

        # 2. Convert to mel spectrogram
        mel_spec = wav_to_mel_spectrogram(wav, TARGET_SAMPLE_RATE, N_FFT, HOP_LENGTH, N_MELS)
        if mel_spec is None: continue

        # 3. Normalize spectrogram
        normalized_spec = normalize_spectrogram(mel_spec)
        if normalized_spec is None: continue

        # 4. Pad/Truncate spectrogram time dimension to fixed size
        final_spec = pad_truncate_spectrogram(normalized_spec, EXPECTED_TIME_STEPS)
        if final_spec is None: continue

        # Check final shape
        if final_spec.shape != (N_MELS, EXPECTED_TIME_STEPS):
             print(f"Warning: Spectrogram shape mismatch for {file_path}. Got {final_spec.shape}, expected {(N_MELS, EXPECTED_TIME_STEPS)}. Skipping.")
             continue

        # --- Save Processed Spectrogram ---
        base_filename = os.path.splitext(os.path.basename(file_path))[0]
        save_filename = f"{dataset_name}_{base_filename}.npy"
        save_path = os.path.join(processed_dir, save_filename)

        try:
            np.save(save_path, final_spec.astype(np.float32)) # Save as float32

            # Store metadata
            emotion_label_int = EMOTIONS[emotion_label_str]
            dataset_metadata.append((save_path, emotion_label_int))
            processed_count += 1
        except Exception as e:
            print(f"Error saving spectrogram {save_path}: {e}")

    print(f"Finished processing {dataset_name}. Saved {processed_count} spectrograms.")
    return dataset_metadata

# --- Run Processing --- 
metadata_path = os.path.join(DATA_DIR, "processed_metadata.npy")

# Check if processing is already done by looking for the metadata file
if not os.path.exists(metadata_path):
    print(f"Processed metadata file not found at {metadata_path}. Processing datasets...")
    # Ensure the processed directory exists and is empty if re-running
    if os.path.exists(PROCESSED_DIR) and os.listdir(PROCESSED_DIR):
         print(f"Warning: Processed directory {PROCESSED_DIR} is not empty. Files might be overwritten.")
         # Optionally clear the directory:
         # shutil.rmtree(PROCESSED_DIR)
         # os.makedirs(PROCESSED_DIR, exist_ok=True)
    else:
         os.makedirs(PROCESSED_DIR, exist_ok=True)
         
    emodb_metadata = process_dataset(EMODB_DIR, "emodb", PROCESSED_DIR)
    ravdess_metadata = process_dataset(RAVDESS_DIR, "ravdess", PROCESSED_DIR)
    
    all_files_metadata = emodb_metadata + ravdess_metadata

    if all_files_metadata: # Only save if processing was successful
        # Save metadata list
        np.save(metadata_path, all_files_metadata)
        print(f"\nTotal processed files: {len(all_files_metadata)}")
        print(f"Saved metadata to {metadata_path}")
    else:
        print("\nError: No files were processed successfully. Metadata not saved.")
else:
    print(f"\nProcessed metadata file found. Loading metadata from {metadata_path}.")
    try:
        all_files_metadata = np.load(metadata_path, allow_pickle=True).tolist()
        print(f"Loaded {len(all_files_metadata)} metadata entries.")
        if not all_files_metadata:
            print("Warning: Loaded metadata is empty. Consider deleting the metadata file and re-running processing.")
    except Exception as e:
        print(f"Error loading metadata file: {e}. Please delete the file and re-run processing.")
        all_files_metadata = []

# --- Verify Data --- 
if all_files_metadata:
    print("\nExample metadata entry:", all_files_metadata[0])
    # Load one spectrogram to check shape
    try:
        example_spec_path = all_files_metadata[0][0]
        if os.path.exists(example_spec_path):
             example_spec = np.load(example_spec_path)
             print("Example spectrogram shape:", example_spec.shape)
             if example_spec.shape != (N_MELS, EXPECTED_TIME_STEPS):
                 print(f"ERROR: Loaded spectrogram shape {example_spec.shape} does not match expected shape {(N_MELS, EXPECTED_TIME_STEPS)}")
                 print("Please check the pad_truncate_spectrogram function and EXPECTED_TIME_STEPS.")
        else:
             print(f"Error: Example spectrogram file not found at {example_spec_path}")
    except Exception as e:
        print(f"Error loading example spectrogram: {e}")
else:
    print("\nMetadata is empty. Cannot proceed with verification or training.")

Processed metadata file not found at /content\data\processed_metadata.npy. Processing datasets...

Processing emodb dataset from /content\data\emodb...
Found 0 audio files for emodb.


Processing emodb: 0it [00:00, ?it/s]

Finished processing emodb. Saved 0 spectrograms.

Processing ravdess dataset from /content\data\ravdess...
Found 1440 audio files for ravdess.


Processing ravdess:   0%|          | 0/1440 [00:00<?, ?it/s]

Finished processing ravdess. Saved 1248 spectrograms.

Total processed files: 1248
Saved metadata to /content\data\processed_metadata.npy

Example metadata entry: ('/content\\data\\processed_mels\\ravdess_03-01-01-01-01-01-01.npy', 0)
Example spectrogram shape: (80, 862)


## 3.6. Create PyTorch Dataset

In [8]:
class MelSpectrogramDataset(Dataset):
    """PyTorch Dataset for loading mel spectrograms and labels."""
    def __init__(self, metadata, transform=None):
        """
        Args:
            metadata (list): List of tuples (spectrogram_path, label_int).
            transform (callable, optional): Optional transform to be applied 
                on a sample.
        """
        # Filter out metadata entries where the file might not exist
        self.metadata = [(p, l) for p, l in metadata if os.path.exists(p)]
        if len(self.metadata) != len(metadata):
             print(f"Warning: Filtered out {len(metadata) - len(self.metadata)} non-existent file paths from metadata.")
             
        self.transform = transform
        self.label_encoder = LabelEncoder()
        # Fit label encoder to the unique labels present in metadata
        all_labels = [item[1] for item in self.metadata]
        if not all_labels:
            print("Warning: No valid labels found in metadata for dataset creation.")
            self.num_classes = 0
        else:
            self.label_encoder.fit(sorted(list(set(all_labels))))
            self.num_classes = len(self.label_encoder.classes_)
            print(f"Dataset initialized with {len(self.metadata)} samples.")
            print(f"Classes found: {self.label_encoder.classes_} -> {self.num_classes} classes")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Handle potential index out of bounds if metadata was empty
        if idx >= len(self.metadata):
             raise IndexError("Index out of bounds for metadata list.")

        mel_path, label_int = self.metadata[idx]

        try:
            # Load spectrogram
            mel_spec = np.load(mel_path) # Should be float32 already

            # Add channel dimension (C, H, W) -> (1, N_MELS, TIME_STEPS)
            mel_spec = np.expand_dims(mel_spec, axis=0)

            # Convert label to tensor
            label = torch.tensor(label_int, dtype=torch.long)

            # Convert spectrogram to tensor
            mel_spec_tensor = torch.from_numpy(mel_spec)

            sample = {'spectrogram': mel_spec_tensor, 'label': label}

            if self.transform:
                sample = self.transform(sample)

            # Return spectrogram, label, and index (useful for diffusion embedding lookup)
            return sample['spectrogram'], sample['label'], idx

        except FileNotFoundError:
             print(f"Error: File not found at {mel_path} for index {idx}. Check metadata consistency.")
             # Handle error: return dummy data, skip, or raise
             # Raising might be best during development
             raise FileNotFoundError(f"File not found: {mel_path}")
        except Exception as e:
            print(f"Error loading or processing item at index {idx} ({mel_path}): {e}")
            # Raise the exception to stop execution and identify the problematic file
            raise e

# Create the full dataset
full_dataset = None
if all_files_metadata:
    full_dataset = MelSpectrogramDataset(all_files_metadata)
    if full_dataset.num_classes > 0:
        NUM_EMOTIONS = full_dataset.num_classes # Update NUM_EMOTIONS based on actual data
        print(f"Number of classes detected in dataset: {NUM_EMOTIONS}")
    else:
        print("Error: Dataset created but no classes found. Check metadata and label parsing.")
        full_dataset = None # Mark as unusable
else:
    print("Cannot create dataset, metadata is empty.")

Dataset initialized with 1248 samples.
Classes found: [0 1 2 3 4 5] -> 6 classes
Number of classes detected in dataset: 6


## 3.7. Split Data into Training and Validation Sets

In [9]:
train_loader, val_loader = None, None
train_dataset, val_dataset = None, None

if full_dataset and len(full_dataset) > 0:
    # Split metadata first to keep track of original files if needed
    try:
        train_meta, val_meta = train_test_split(
            full_dataset.metadata, # Use the filtered metadata from the dataset object
            test_size=0.2, # 20% for validation
            random_state=SEED,
            stratify=[item[1] for item in full_dataset.metadata] # Stratify by label
        )

        train_dataset = MelSpectrogramDataset(train_meta)
        val_dataset = MelSpectrogramDataset(val_meta)

        print(f"Training set size: {len(train_dataset)}")
        print(f"Validation set size: {len(val_dataset)}")

        # Create DataLoaders
        # Use num_workers=0 if you encounter issues, especially on Windows or with certain Colab setups
        train_loader = DataLoader(train_dataset, batch_size=RESNET_BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=RESNET_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
        
    except ValueError as e:
        print(f"Error during train/val split: {e}")
        print("This might happen if a class has too few samples for stratification.")
        print("Consider using a smaller validation split or checking data balance.")
    except Exception as e:
        print(f"An unexpected error occurred during data splitting: {e}")

else:
    print("Dataset not created or is empty. Skipping data splitting and loader creation.")

Dataset initialized with 998 samples.
Classes found: [0 1 2 3 4 5] -> 6 classes
Dataset initialized with 250 samples.
Classes found: [0 1 2 3 4 5] -> 6 classes
Training set size: 998
Validation set size: 250


# 4. Emotion Embedding Model (ResNet-50)

## 4.1. Model Definition

In [10]:
class EmotionResNet50(nn.Module):
    """ResNet-50 model adapted for 1-channel spectrogram input and emotion classification."""
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        
        weights = models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None
        self.resnet = models.resnet50(weights=weights)

        # Modify the first convolutional layer to accept 1 input channel
        original_conv1 = self.resnet.conv1
        self.resnet.conv1 = nn.Conv2d(
            1, original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=(original_conv1.bias is not None)
        )
        
        # Initialize the new conv1 weights 
        if pretrained:
            # Average the weights of the original 3 channels to initialize the new 1 channel
            original_weights = original_conv1.weight.data
            mean_weights = torch.mean(original_weights, dim=1, keepdim=True)
            self.resnet.conv1.weight.data = mean_weights
            print("Initialized 1-channel conv1 weights by averaging pretrained weights.")
        else:
             # Standard Kaiming initialization if not pretrained
             nn.init.kaiming_normal_(self.resnet.conv1.weight, mode='fan_out', nonlinearity='relu')
             print("Initialized 1-channel conv1 weights using Kaiming normal.")
        
        if self.resnet.conv1.bias is not None:
             nn.init.constant_(self.resnet.conv1.bias, 0)

        # Modify the final fully connected layer for the desired number of emotion classes
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, num_classes)
        # Initialize the new fc layer
        nn.init.xavier_uniform_(self.resnet.fc.weight)
        if self.resnet.fc.bias is not None:
            nn.init.constant_(self.resnet.fc.bias, 0)

        print(f"ResNet-50 adapted for {num_classes} emotion classes and 1 input channel.")


    def forward(self, x):
        # Input x shape: (B, 1, H, W) e.g., (B, 1, 80, 862)
        # ResNet expects H, W >= 32. Our N_MELS (H) is likely ok, Time (W) is large.
        # We might need adaptive pooling if the fixed output size of ResNet layers causes issues
        # before the final FC layer, but standard ResNet50 already includes AdaptiveAvgPool2d.
        return self.resnet(x)

    def get_embedding(self, x):
        """Extracts features before the final classification layer."""
        # Pass input through the ResNet layers up to the average pooling layer
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        # Use the adaptive average pooling layer defined in the original ResNet
        x = self.resnet.avgpool(x)
        embedding = torch.flatten(x, 1) # Flatten the output of avgpool
        return embedding

## 4.2. Training Function

In [12]:
def train_resnet_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, model_save_path):
    """Trains the ResNet-50 model."""
    best_val_accuracy = 0.0
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    print(f"Starting ResNet training for {num_epochs} epochs on {device}...")

    # Check if loaders are valid
    if not train_loader or not val_loader:
        print("Error: Invalid data loaders provided.")
        return None

    for epoch in range(num_epochs):
        # --- Training Phase ---
        model.train()
        running_loss = 0.0
        correct_predictions_train = 0
        total_samples_train = 0

        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        for inputs, labels, _ in train_pbar: # Dataset returns inputs, labels, index
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_samples_train += labels.size(0)
            correct_predictions_train += (predicted == labels).sum().item()

            train_pbar.set_postfix(loss=f"{loss.item():.4f}")
            
        # Check for zero division error
        if total_samples_train == 0:
             print(f"Epoch {epoch+1} Warning: No training samples processed. Skipping epoch stats.")
             continue

        epoch_loss_train = running_loss / total_samples_train
        epoch_acc_train = correct_predictions_train / total_samples_train
        history['train_loss'].append(epoch_loss_train)
        history['train_acc'].append(epoch_acc_train)

        # --- Validation Phase ---
        model.eval()
        running_loss_val = 0.0
        correct_predictions_val = 0
        total_samples_val = 0

        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
        with torch.no_grad():
            for inputs, labels, _ in val_pbar: # Dataset returns inputs, labels, index
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                running_loss_val += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total_samples_val += labels.size(0)
                correct_predictions_val += (predicted == labels).sum().item()
                val_pbar.set_postfix(loss=f"{loss.item():.4f}")
        
        # Check for zero division error
        if total_samples_val == 0:
             print(f"Epoch {epoch+1} Warning: No validation samples processed. Skipping epoch stats.")
             epoch_loss_val = float('inf')
             epoch_acc_val = 0.0
        else:
            epoch_loss_val = running_loss_val / total_samples_val
            epoch_acc_val = correct_predictions_val / total_samples_val
            
        history['val_loss'].append(epoch_loss_val)
        history['val_acc'].append(epoch_acc_val)

        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Train Loss: {epoch_loss_train:.4f}, Train Acc: {epoch_acc_train:.4f} | "
              f"Val Loss: {epoch_loss_val:.4f}, Val Acc: {epoch_acc_val:.4f}")

        # Save the model if validation accuracy improves
        if epoch_acc_val > best_val_accuracy and total_samples_val > 0:
            best_val_accuracy = epoch_acc_val
            try:
                torch.save(model.state_dict(), model_save_path)
                print(f"Best model saved to {model_save_path} (Val Acc: {best_val_accuracy:.4f})")
            except Exception as e:
                 print(f"Error saving model: {e}")

    print("Finished Training ResNet.")
    print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")
    return history

## 4.3. Train or Load ResNet Model

In [13]:
# Initialize model, criterion, optimizer
resnet_model = None
history_resnet = None
if NUM_EMOTIONS > 0:
    resnet_model = EmotionResNet50(num_classes=NUM_EMOTIONS, pretrained=True).to(DEVICE)
    criterion_resnet = nn.CrossEntropyLoss()
    optimizer_resnet = optim.Adam(resnet_model.parameters(), lr=RESNET_LEARNING_RATE)

    # Check if a trained model exists
    if os.path.exists(RESNET_MODEL_PATH):
        print(f"Loading pre-trained ResNet model from {RESNET_MODEL_PATH}...")
        try:
            resnet_model.load_state_dict(torch.load(RESNET_MODEL_PATH, map_location=DEVICE))
            print("ResNet model loaded successfully.")
            # Optionally run evaluation on val set to confirm performance
            # evaluate_model(resnet_model, val_loader, criterion_resnet, DEVICE, NUM_EMOTIONS) # Need evaluate_model defined
        except Exception as e:
            print(f"Error loading ResNet model: {e}. Training from scratch.")
            if train_loader and val_loader:
                 history_resnet = train_resnet_model(resnet_model, train_loader, val_loader, criterion_resnet, optimizer_resnet, RESNET_EPOCHS, DEVICE, RESNET_MODEL_PATH)
            else:
                 print("Cannot train ResNet model, data loaders not available.")
                 resnet_model = None # Mark as None if training cannot proceed
    else:
        print("No pre-trained ResNet model found. Training from scratch...")
        if train_loader and val_loader:
            history_resnet = train_resnet_model(resnet_model, train_loader, val_loader, criterion_resnet, optimizer_resnet, RESNET_EPOCHS, DEVICE, RESNET_MODEL_PATH)
        else:
            print("Cannot train ResNet model, data loaders not available.")
            resnet_model = None # Mark as None if training cannot proceed
else:
    print("Error: Number of emotion classes is zero. Cannot initialize or train ResNet model.")

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to C:\Users\Youssef/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth


100%|█████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [01:34<00:00, 1.09MB/s]


Initialized 1-channel conv1 weights by averaging pretrained weights.
ResNet-50 adapted for 6 emotion classes and 1 input channel.
No pre-trained ResNet model found. Training from scratch...
Starting ResNet training for 50 epochs on cpu...


Epoch 1/50 [Train]:   0%|          | 0/31 [00:00<?, ?it/s]



RuntimeError: DataLoader worker (pid(s) 19596, 16364) exited unexpectedly

## 4.4. Function to Extract Embeddings

In [None]:
def extract_embeddings(model, data_loader, device):
    """Extracts embeddings from the ResNet model for all data in the loader."""
    if model is None or data_loader is None:
        print("Error: Model or data loader is None. Cannot extract embeddings.")
        return None, None, None
        
    model.eval()
    embeddings = []
    labels_list = []
    indices_list = [] # Store indices to ensure correct mapping
    print(f"Extracting embeddings using device: {device}")
    with torch.no_grad():
        for inputs, labels, indices in tqdm(data_loader, desc="Extracting Embeddings"):
            inputs = inputs.to(device)
            # Get embeddings from the model
            batch_embeddings = model.get_embedding(inputs)
            embeddings.append(batch_embeddings.cpu().numpy())
            labels_list.append(labels.cpu().numpy())
            indices_list.append(indices.cpu().numpy()) # Store batch indices
            
    if not embeddings:
        print("Warning: No embeddings were extracted.")
        return None, None, None

    embeddings = np.concatenate(embeddings, axis=0)
    labels_list = np.concatenate(labels_list, axis=0)
    indices_list = np.concatenate(indices_list, axis=0)
    print(f"Extracted embeddings shape: {embeddings.shape}")
    
    # Sort embeddings based on indices to match the original dataset order
    sort_order = np.argsort(indices_list)
    embeddings_sorted = embeddings[sort_order]
    labels_sorted = labels_list[sort_order]
    indices_sorted = indices_list[sort_order]
    
    # Verify sorting
    if not np.all(indices_sorted == np.arange(len(indices_sorted))):
         print("Warning: Index sorting verification failed. Embeddings might not match dataset order.")
    
    return embeddings_sorted, labels_sorted, indices_sorted

## 4.5. Extract Embeddings for Conditioning Diffusion

In [None]:
# Extract embeddings for the full dataset (needed for conditioning diffusion model)
all_embeddings, all_labels, all_indices = None, None, None
EMBEDDING_DIM = 2048 # Default for ResNet50 avgpool output

if full_dataset and resnet_model:
    # Create a loader for the full dataset 
    # Use num_workers=0 if issues arise
    full_loader = DataLoader(full_dataset, batch_size=RESNET_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    print("Extracting embeddings for the full dataset...")
    all_embeddings, all_labels, all_indices = extract_embeddings(resnet_model, full_loader, DEVICE)
    
    if all_embeddings is not None:
        EMBEDDING_DIM = all_embeddings.shape[1] # Get actual embedding dimension
        print(f"Embedding dimension: {EMBEDDING_DIM}")
        # Create a mapping from label index back to emotion name for clarity
        label_to_emotion = {i: name for name, i in EMOTIONS.items()}
        print("Embeddings extracted successfully.")
    else:
        print("Embedding extraction failed.")
        EMBEDDING_DIM = 2048 # Fallback to default if extraction failed
else:
    print("Full dataset or ResNet model not available, cannot extract embeddings.")

# 5. Diffusion Model for Mel-Spectrogram Generation

Implementing a full diffusion model from scratch is complex. This section provides a structure based on the paper's description (U-Net, ResNet blocks, time/emotion conditioning). It may require further refinement and tuning for optimal performance.

## 5.1. Diffusion Model Components

In [None]:
# --- Helper: Sinusoidal Time Embeddings ---
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :] # Shape: (batch_size, half_dim)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1) # Shape: (batch_size, dim)
        return emb

# --- Helper: ResNet Block with Time/Emotion Conditioning ---
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim=None, emotion_emb_dim=None, use_attention=False):
        super().__init__()
        self.use_attention = use_attention
        groups = 8 # GroupNorm groups, adjust if needed

        # Time embedding projection
        self.time_mlp = None
        if time_emb_dim is not None:
            self.time_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_emb_dim, out_channels * 2) # Project to scale and shift
            )
            
        # Emotion embedding projection
        self.emotion_mlp = None
        if emotion_emb_dim is not None:
             self.emotion_mlp = nn.Sequential(
                 nn.SiLU(),
                 nn.Linear(emotion_emb_dim, out_channels * 2) # Project to scale and shift
             )

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        # Ensure out_channels is divisible by groups for GroupNorm
        self.norm1 = nn.GroupNorm(max(1, groups // (in_channels // out_channels)) if out_channels < groups else groups, out_channels)
        self.act1 = nn.SiLU() # Swish activation

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(groups, out_channels)
        self.act2 = nn.SiLU()

        # Residual connection matching
        self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

        # Optional: Basic Linear Attention (simplified)
        self.attention = None
        if use_attention:
             # Simplified self-attention like mechanism (Conv based)
             attn_hidden_dim = max(16, out_channels // 8) # Ensure reasonable hidden dim
             self.attention = nn.Sequential(
                 nn.Conv2d(out_channels, attn_hidden_dim, 1),
                 nn.SiLU(),
                 nn.Conv2d(attn_hidden_dim, out_channels, 1),
                 nn.Sigmoid() # Scale features based on attention map
             )

    def forward(self, x, t_emb=None, e_emb=None):
        res = x

        # First convolution block
        h = self.conv1(x)
        h = self.norm1(h)

        # Apply time embedding conditioning (FiLM-like)
        if self.time_mlp is not None and t_emb is not None:
             time_cond = self.time_mlp(t_emb)
             time_cond = time_cond.unsqueeze(-1).unsqueeze(-1) # Reshape (B, C*2, 1, 1)
             scale_time, shift_time = time_cond.chunk(2, dim=1) 
             h = h * (scale_time + 1) + shift_time 

        # Apply emotion embedding conditioning (FiLM-like)
        if self.emotion_mlp is not None and e_emb is not None:
             emotion_cond = self.emotion_mlp(e_emb)
             emotion_cond = emotion_cond.unsqueeze(-1).unsqueeze(-1)
             scale_emotion, shift_emotion = emotion_cond.chunk(2, dim=1)
             h = h * (scale_emotion + 1) + shift_emotion
             
        h = self.act1(h)

        # Second convolution block
        h = self.conv2(h)
        h = self.norm2(h)
        h = self.act2(h)

        # Apply attention if enabled
        if self.attention is not None:
            attn_map = self.attention(h)
            h = h * attn_map

        # Add residual connection
        return h + self.residual_conv(res)

# --- Helper: Downsample/Upsample ---
class Downsample(nn.Module):
    def __init__(self, channels):
        super().__init__()
        # Use AvgPool2d for downsampling feature maps might be smoother
        # self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        # Or keep Conv stride 2
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        # return self.pool(x)
        return self.conv(x)

class Upsample(nn.Module):
    def __init__(self, channels):
        super().__init__()
        # Use bilinear interpolation + conv for potentially better quality
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.upsample(x)
        x = self.conv(x)
        return x

# --- U-Net Architecture ---
class UNetDiffusion(nn.Module):
    def __init__(
        self,
        in_channels=1,              # Input spectrogram channels
        model_channels=64,          # Base channel count
        out_channels=1,             # Output channels (usually same as input)
        channel_mult=(1, 2, 4, 8),  # Channel multipliers per level
        num_res_blocks=2,           # ResBlocks per level
        time_emb_dim=256,           # Dimension for time embedding
        emotion_emb_dim=None,       # Dimension for emotion embedding (e.g., 2048 from ResNet)
        use_attention_levels=(False, False, True, True), # Apply attention at deeper levels
        dropout=0.1                 # Dropout rate
    ):
        super().__init__()

        if emotion_emb_dim is None:
            print("Warning: emotion_emb_dim not provided to UNetDiffusion. Emotion conditioning disabled.")

        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.channel_mult = channel_mult
        self.num_res_blocks = num_res_blocks
        self.time_emb_dim = time_emb_dim
        self.emotion_emb_dim = emotion_emb_dim
        self.dropout = dropout

        # --- Time Embedding ---
        time_dim = model_channels * 4 # Internal dimension for time projection
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(model_channels),
            nn.Linear(model_channels, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim)
        )

        # --- Initial Convolution ---
        self.init_conv = nn.Conv2d(in_channels, model_channels, kernel_size=3, padding=1)

        # --- Downsampling Path ---
        self.down_blocks = nn.ModuleList()
        channels = [model_channels]
        now_channels = model_channels
        for i, mult in enumerate(channel_mult):
            out_channels_level = model_channels * mult
            use_attn = use_attention_levels[i]
            for _ in range(num_res_blocks):
                self.down_blocks.append(ResidualBlock(
                    now_channels, out_channels_level, time_dim, emotion_emb_dim, use_attn
                ))
                now_channels = out_channels_level
                channels.append(now_channels)
            if i != len(channel_mult) - 1: # Don't downsample at the last level
                self.down_blocks.append(Downsample(now_channels))
                channels.append(now_channels)

        # --- Bottleneck ---
        self.mid_block1 = ResidualBlock(now_channels, now_channels, time_dim, emotion_emb_dim, use_attention=True)
        # Add dropout maybe?
        # self.mid_attn = AttentionBlock(now_channels)
        self.mid_block2 = ResidualBlock(now_channels, now_channels, time_dim, emotion_emb_dim, use_attention=False)

        # --- Upsampling Path ---
        self.up_blocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(channel_mult))):
            out_channels_level = model_channels * mult
            use_attn = use_attention_levels[i]
            for _ in range(num_res_blocks + 1): # +1 to handle skip connection merging
                skip_channels_in = channels.pop()
                self.up_blocks.append(ResidualBlock(
                    now_channels + skip_channels_in, # Input channels = current + skip
                    out_channels_level,
                    time_dim,
                    emotion_emb_dim,
                    use_attn
                ))
                now_channels = out_channels_level
            if i != 0: # Don't upsample after the first level (closest to output)
                self.up_blocks.append(Upsample(now_channels))

        # --- Final Layers ---
        final_groups = 8 # GroupNorm groups for final layer
        self.final_norm = nn.GroupNorm(max(1, final_groups // (in_channels // model_channels)) if model_channels < final_groups else final_groups, model_channels)
        self.final_act = nn.SiLU()
        self.final_conv = nn.Conv2d(model_channels, self.out_channels, kernel_size=1)

    def forward(self, x, time, emotion_embedding=None):
        # x: Input spectrogram (B, C, H, W) -> (B, 1, N_MELS, TIME)
        # time: Timestep tensor (B,)
        # emotion_embedding: Emotion embedding tensor (B, emotion_emb_dim)

        # 1. Time Embedding
        t_emb = self.time_mlp(time) # (B, time_dim)

        # 2. Initial Convolution
        h = self.init_conv(x) # (B, model_channels, H, W)
        skips = [h] # Store initial skip connection

        # 3. Downsampling Path
        for block in self.down_blocks:
            if isinstance(block, ResidualBlock):
                 h = block(h, t_emb, emotion_embedding)
                 skips.append(h)
            elif isinstance(block, Downsample):
                 h = block(h)
                 # Skip connection comes from *before* downsampling in standard UNets
                 # The current logic stores skip *after* resblock and *after* downsample.
                 # Let's stick to storing after resblock for now.
                 # skips.append(h) # Standard UNet would skip before downsampling
        
        # 4. Bottleneck
        h = self.mid_block1(h, t_emb, emotion_embedding)
        h = self.mid_block2(h, t_emb, emotion_embedding)

        # 5. Upsampling Path
        for block in self.up_blocks:
            if isinstance(block, ResidualBlock):
                # Pop the corresponding skip connection
                skip = skips.pop()
                # Concatenate along channel dim
                # Ensure spatial dimensions match before concat
                if h.shape[-2:] != skip.shape[-2:]:
                     # Simple interpolation if shapes differ slightly
                     h = F.interpolate(h, size=skip.shape[-2:], mode='bilinear', align_corners=False)
                h = torch.cat([h, skip], dim=1) 
                h = block(h, t_emb, emotion_embedding)
            elif isinstance(block, Upsample):
                h = block(h)

        # 6. Final Layers
        h = self.final_norm(h)
        h = self.final_act(h)
        h = self.final_conv(h) # (B, out_channels, H, W)

        return h

## 5.2. Diffusion Process (DDPM)

In [None]:
def linear_beta_schedule(timesteps):
    """Linear schedule from β_start=0.0001 to β_end=0.02."""
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)

# --- Precompute DDPM variables ---
T = DIFFUSION_TIMESTEPS
# Ensure calculations happen on the correct device from the start if possible
# Or move tensors to device later when needed inside functions
betas = linear_beta_schedule(T).to(dtype=torch.float32) # Use float32 for model compatibility
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)

# Calculations for diffusion q(x_t | x_0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# Calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# Clamp variance to avoid issues at t=0 where variance is theoretically zero
posterior_variance_clipped = torch.clamp(posterior_variance, min=1e-20)
posterior_log_variance_clipped = torch.log(posterior_variance_clipped)

# Coefficients for posterior mean calculation (used in p_sample)
posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)

# Move precomputed tensors to the target device
betas = betas.to(DEVICE)
alphas = alphas.to(DEVICE)
alphas_cumprod = alphas_cumprod.to(DEVICE)
alphas_cumprod_prev = alphas_cumprod_prev.to(DEVICE)
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(DEVICE)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(DEVICE)
posterior_variance_clipped = posterior_variance_clipped.to(DEVICE)
posterior_log_variance_clipped = posterior_log_variance_clipped.to(DEVICE)
posterior_mean_coef1 = posterior_mean_coef1.to(DEVICE)
posterior_mean_coef2 = posterior_mean_coef2.to(DEVICE)


# --- Helper function to extract alpha/beta values for a batch of timesteps ---
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    # Gather based on timestep indices t. Ensure t is on CPU for gather if a is on CPU initially.
    # However, we moved 'a' to DEVICE, so t should also be on DEVICE.
    out = a.gather(-1, t) 
    # Reshape to broadcast correctly: (batch_size, 1, 1, 1) for image/spectrogram data
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))

# --- Forward Diffusion (Noise addition) ---
def q_sample(x_start, t, noise=None):
    """Applies noise to x_start for timesteps t."""
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

# --- Denoising Loss Calculation ---
def p_losses(denoise_model, x_start, t, emotion_embeddings, noise=None, loss_type="l2"):
    """Calculates the loss for the denoising model."""
    if noise is None:
        noise = torch.randn_like(x_start)

    # Apply noise to the original data
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    
    # Get the model's prediction for the noise
    predicted_noise = denoise_model(x_noisy, t, emotion_embeddings) 

    # Calculate the loss between the actual noise and predicted noise
    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError(f"Loss type '{loss_type}' not implemented.")

    return loss

# --- Sampling Functions (Reverse Process) ---
@torch.no_grad()
def p_sample(model, x, t, t_index, emotion_embeddings):
    """Single step of the reverse diffusion process (Algorithm 2 in DDPM)."""
    # Use coefficients derived for q(x_{t-1} | x_t, x_0)
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
    sqrt_recip_alphas_t = extract(torch.sqrt(1.0 / alphas), t, x.shape)
    
    # Predict the noise using the model
    predicted_noise = model(x, t, emotion_embeddings)
    
    # Calculate the mean of the posterior distribution q(x_{t-1} | x_t, x_0)
    # This uses the predicted noise to estimate x_0 and then plugs into the formula
    # model_mean = posterior_mean_coef1_t * x_0_pred + posterior_mean_coef2_t * x_t
    # A more direct form using predicted_noise (Equation 11 in DDPM paper):
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean # No noise added at the last step
    else:
        # Get the variance of the posterior distribution
        posterior_variance_t = extract(posterior_variance_clipped, t, x.shape)
        noise = torch.randn_like(x)
        # Add noise scaled by the posterior variance
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def p_sample_loop(model, shape, timesteps, emotion_embeddings):
    """Full sampling loop from noise to data."""
    device = next(model.parameters()).device # Get device from model
    b = shape[0] # Batch size from shape
    
    # Start from pure noise (x_T)
    img = torch.randn(shape, device=device)
    imgs = [] # Optional: Store intermediate images

    # Iterate backwards through timesteps
    for i in tqdm(reversed(range(0, timesteps)), desc='Sampling loop time step', total=timesteps):
        # Create tensor for current timestep for all samples in batch
        time_tensor = torch.full((b,), i, device=device, dtype=torch.long)
        # Perform one denoising step
        img = p_sample(model, img, time_tensor, i, emotion_embeddings)
        # Optionally store intermediate results:
        # if i % 50 == 0: imgs.append(img.cpu().numpy())

    # imgs.append(img.cpu().numpy()) # Store final result
    return img # Return the final denoised image x_0

## 5.3. Diffusion Model Training Loop

In [None]:
class DiffusionDatasetWrapper(Dataset):
    """Wraps the MelSpectrogramDataset to return embeddings along with spectrograms."""
    def __init__(self, original_dataset, embeddings):
        self.original_dataset = original_dataset
        # Ensure embeddings are float tensors and match dataset length
        if len(original_dataset) != len(embeddings):
             raise ValueError(f"Dataset length ({len(original_dataset)}) and embeddings length ({len(embeddings)}) mismatch.")
        self.embeddings = torch.from_numpy(embeddings).float() 

    def __len__(self):
        return len(self.original_dataset)

    def __getitem__(self, idx):
        # Get spectrogram, original label, and original index from the underlying dataset
        spectrogram, original_label, original_idx = self.original_dataset[idx]
        
        # Use the original index to fetch the pre-computed embedding
        # This assumes 'all_embeddings' were extracted in the same order as the dataset items
        if original_idx != idx:
             # This check should ideally not fail if extraction was done with shuffle=False
             print(f"Warning: Dataset index {idx} does not match returned original index {original_idx}. Using idx for embedding lookup.")
             lookup_idx = idx
        else:
             lookup_idx = original_idx
             
        # Handle potential index out of bounds for embeddings
        if lookup_idx >= len(self.embeddings):
             raise IndexError(f"Index {lookup_idx} out of bounds for embeddings list (length {len(self.embeddings)}). Check dataset/embedding alignment.")
             
        embedding = self.embeddings[lookup_idx]
        
        return spectrogram, embedding # Return spectrogram and its corresponding embedding

def train_diffusion_model(model, dataset, optimizer, num_epochs, batch_size, device, model_save_path, all_embeddings):
    """Trains the Diffusion model."""
    if model is None or dataset is None or all_embeddings is None:
        print("Error: Model, dataset, or embeddings not provided. Cannot train diffusion model.")
        return None
        
    # Create the wrapped dataset for training
    try:
        diffusion_train_dataset = DiffusionDatasetWrapper(dataset, all_embeddings)
        # Use num_workers=0 if issues arise
        diffusion_loader = DataLoader(diffusion_train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)
    except ValueError as e:
         print(f"Error creating diffusion dataset/loader: {e}")
         return None
    except Exception as e:
         print(f"Unexpected error creating diffusion dataset/loader: {e}")
         return None

    print(f"Starting Diffusion Model training for {num_epochs} epochs on {device}...")

    model.to(device) # Ensure model is on the correct device
    scaler = torch.cuda.amp.GradScaler(enabled=(device == 'cuda')) # Automatic Mixed Precision

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        num_samples = 0

        pbar = tqdm(diffusion_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Diffusion Train]")
        for step, (batch_spectrograms, batch_embeddings) in enumerate(pbar):
            optimizer.zero_grad()

            batch_size_current = batch_spectrograms.shape[0]
            batch_spectrograms = batch_spectrograms.to(device)
            batch_embeddings = batch_embeddings.to(device) # Send embeddings to device

            # Sample random timesteps for this batch
            t = torch.randint(0, DIFFUSION_TIMESTEPS, (batch_size_current,), device=device).long()

            # Calculate loss using mixed precision
            with torch.cuda.amp.autocast(enabled=(device == 'cuda')):
                 loss = p_losses(model, batch_spectrograms, t, batch_embeddings, loss_type="l2") # Use MSE loss

            if torch.isnan(loss) or torch.isinf(loss):
                 print(f"Warning: NaN or Inf loss encountered at epoch {epoch+1}, step {step}. Skipping batch.")
                 # Consider stopping training or reducing learning rate if this happens frequently
                 continue

            # Scales loss. Calls backward() on scaled loss to create scaled gradients.
            scaler.scale(loss).backward()
            
            # Optional: Gradient clipping
            # scaler.unscale_(optimizer) # Unscales gradients before clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # scaler.step() first unscales the gradients of the optimizer's assigned params.
            # If these gradients do not contain infs or NaNs, optimizer.step() is then called.
            # Otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)

            # Updates the scale for next iteration.
            scaler.update()

            running_loss += loss.item() * batch_size_current
            num_samples += batch_size_current
            pbar.set_postfix({'loss': loss.item():.6f})
            
        if num_samples == 0:
             print(f"Epoch {epoch+1} Warning: No samples processed. Skipping epoch stats.")
             continue
             
        epoch_loss = running_loss / num_samples
        print(f"Epoch {epoch+1}/{num_epochs} - Diffusion Loss: {epoch_loss:.6f}")

        # Save the model periodically
        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1: # Save every 10 epochs and at the end
            try:
                torch.save(model.state_dict(), model_save_path)
                print(f"Diffusion model saved to {model_save_path} at epoch {epoch+1}")
            except Exception as e:
                 print(f"Error saving diffusion model: {e}")

    print("Finished Training Diffusion Model.")
    return model # Return trained model

## 5.4. Train or Load Diffusion Model

In [None]:
# Initialize Diffusion Model and Optimizer
diffusion_model = None
if EMBEDDING_DIM is not None and NUM_EMOTIONS > 0:
    diffusion_model = UNetDiffusion(
        in_channels=1,
        model_channels=64, # Base channels - adjust complexity
        out_channels=1,
        channel_mult=(1, 2, 3, 4), # Deeper U-Net, e.g., (1, 2, 4, 8) might be better but slower
        num_res_blocks=2,
        time_emb_dim=256, # Should match time_dim in UNet definition if changed
        emotion_emb_dim=EMBEDDING_DIM, # Use dimension from ResNet output
        use_attention_levels=(False, False, True, True) # Use attention in deeper layers
    ).to(DEVICE)

    optimizer_diffusion = optim.AdamW(diffusion_model.parameters(), lr=DIFFUSION_LEARNING_RATE, weight_decay=1e-4)

    # Check if a trained diffusion model exists
    if os.path.exists(DIFFUSION_MODEL_PATH):
        print(f"Loading pre-trained Diffusion model from {DIFFUSION_MODEL_PATH}...")
        try:
            diffusion_model.load_state_dict(torch.load(DIFFUSION_MODEL_PATH, map_location=DEVICE))
            print("Diffusion model loaded successfully.")
        except Exception as e:
            print(f"Error loading Diffusion model: {e}. Training from scratch.")
            if full_dataset and all_embeddings is not None:
                 diffusion_model = train_diffusion_model(diffusion_model, full_dataset, optimizer_diffusion, DIFFUSION_EPOCHS, DIFFUSION_BATCH_SIZE, DEVICE, DIFFUSION_MODEL_PATH, all_embeddings)
            else:
                 print("Cannot train Diffusion model, missing data or embeddings.")
                 diffusion_model = None # Mark as None if training fails
    else:
        print("No pre-trained Diffusion model found. Training from scratch...")
        if full_dataset and all_embeddings is not None:
            diffusion_model = train_diffusion_model(diffusion_model, full_dataset, optimizer_diffusion, DIFFUSION_EPOCHS, DIFFUSION_BATCH_SIZE, DEVICE, DIFFUSION_MODEL_PATH, all_embeddings)
        else:
            print("Cannot train Diffusion model, missing data or embeddings.")
            diffusion_model = None # Mark as None if training fails
else:
     print("Cannot initialize Diffusion model: EMBEDDING_DIM or NUM_EMOTIONS not set correctly.")

# 6. Data Augmentation using the Diffusion Model

## 6.1. Generation Function

In [None]:
def generate_augmented_data(diffusion_model, num_samples_per_emotion, target_emotion_label, all_embeddings, all_labels, device):
    """Generates augmented mel-spectrograms for a specific emotion."""
    if diffusion_model is None or all_embeddings is None or all_labels is None:
        print("Diffusion model or embeddings/labels not available. Skipping augmentation.")
        return []

    diffusion_model.eval()
    generated_specs = []
    print(f"Generating {num_samples_per_emotion} samples for emotion '{target_emotion_label}'...")

    # Find embeddings corresponding to the target emotion
    target_label_int = EMOTIONS.get(target_emotion_label)
    if target_label_int is None:
         print(f"Error: Emotion '{target_emotion_label}' not found in EMOTIONS dictionary.")
         return []
         
    indices = np.where(all_labels == target_label_int)[0]

    if len(indices) == 0:
        print(f"No embeddings found for emotion {target_emotion_label}. Cannot generate.")
        # Option: Use a generic embedding or average embedding? For now, skip.
        return []

    # Select embeddings to condition on (randomly pick from existing ones for that emotion)
    selected_indices = np.random.choice(indices, num_samples_per_emotion, replace=True)
    conditioning_embeddings = torch.from_numpy(all_embeddings[selected_indices]).float().to(device)

    # Define the shape of the output spectrogram (Batch, Channel, Height, Width)
    shape = (num_samples_per_emotion, 1, N_MELS, EXPECTED_TIME_STEPS)

    # Generate using the sampling loop
    with torch.no_grad():
        generated_batch = p_sample_loop(
            diffusion_model,
            shape=shape,
            timesteps=DIFFUSION_TIMESTEPS,
            emotion_embeddings=conditioning_embeddings
        )

    # Move generated specs to CPU and store as numpy arrays
    # Output shape is (B, C, H, W), need to store individual specs
    generated_specs = [spec for spec in generated_batch.cpu().numpy()] 

    print(f"Generated {len(generated_specs)} spectrograms for {target_emotion_label}.")
    return generated_specs

## 6.2. Perform Augmentation and Save Results

Set `num_to_generate_per_emotion` to the desired number of synthetic samples per class.

In [None]:
# --- Perform Augmentation for Each Emotion ---
num_to_generate_per_emotion = 50 # Adjust as needed (e.g., 50-100)
augmented_metadata = []

# Ensure augmented directory exists and is empty if re-running
if os.path.exists(AUGMENTED_DIR):
    print(f"Clearing existing augmented data directory: {AUGMENTED_DIR}")
    # Be careful with rmtree!
    # shutil.rmtree(AUGMENTED_DIR)
    # Instead of deleting, let's just check and maybe skip generation
    if os.listdir(AUGMENTED_DIR):
         print("Warning: Augmented directory is not empty. Files might be overwritten or added.")
else:
     os.makedirs(AUGMENTED_DIR, exist_ok=True)

aug_metadata_path = os.path.join(DATA_DIR, "augmented_metadata.npy")

# Check if metadata already exists to potentially skip generation
if os.path.exists(aug_metadata_path):
     print(f"Augmented metadata file found at {aug_metadata_path}. Skipping generation.")
     try:
          augmented_metadata = np.load(aug_metadata_path, allow_pickle=True).tolist()
          print(f"Loaded {len(augmented_metadata)} augmented metadata entries.")
     except Exception as e:
          print(f"Error loading augmented metadata: {e}. Will proceed with generation.")
          augmented_metadata = [] # Reset if loading failed
else:
     print("Augmented metadata not found. Proceeding with generation...")
     augmented_metadata = [] # Ensure it's empty before generation

# Only generate if metadata list is empty
if not augmented_metadata:
    if diffusion_model is not None and all_embeddings is not None:
        for emotion_name, emotion_int in EMOTIONS.items():
            generated_spectrograms = generate_augmented_data(
                diffusion_model,
                num_to_generate_per_emotion,
                emotion_name,
                all_embeddings,
                all_labels, # Pass all_labels here
                DEVICE
            )

            # Save generated spectrograms and update metadata
            for i, spec_np in enumerate(generated_spectrograms):
                 # spec_np shape is likely (1, N_MELS, TIME), remove channel dim for saving
                 if spec_np.ndim == 3 and spec_np.shape[0] == 1:
                     spec_np = spec_np.squeeze(0)
                 
                 # Check shape before saving
                 if spec_np.shape != (N_MELS, EXPECTED_TIME_STEPS):
                      print(f"Warning: Generated spec {i} for {emotion_name} has wrong shape {spec_np.shape}. Skipping save.")
                      continue
                      
                 save_filename = f"augmented_{emotion_name}_{i:03d}.npy"
                 save_path = os.path.join(AUGMENTED_DIR, save_filename)
                 try:
                     np.save(save_path, spec_np.astype(np.float32))
                     augmented_metadata.append((save_path, emotion_int))
                 except Exception as e:
                     print(f"Error saving augmented spectrogram {save_path}: {e}")

        print(f"\nTotal augmented samples generated: {len(augmented_metadata)}")
        # Save augmented metadata only if generation occurred
        if augmented_metadata:
            np.save(aug_metadata_path, augmented_metadata)
            print(f"Saved augmented metadata to {aug_metadata_path}")
        else:
            print("No augmented data was generated or saved.")
    else:
        print("\nSkipping data augmentation as diffusion model or embeddings are not available.")

# 7. Evaluation

Evaluate the performance of the trained ResNet classifier on:
1. The original validation set.
2. The newly generated augmented dataset.

## 7.1. Evaluation Function

In [None]:
def evaluate_model(model, data_loader, criterion, device, num_classes):
    """Evaluates a classification model on a given dataset."""
    if model is None or data_loader is None or criterion is None:
        print("Error: Model, dataloader, or criterion is None. Cannot evaluate.")
        return 0, 0, 0, 0, 0, np.zeros((num_classes, num_classes)), {}
        
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, labels, _ in tqdm(data_loader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

    # Check if any samples were processed
    dataset_size = len(data_loader.dataset)
    if dataset_size == 0:
         print("Error: Evaluation dataset is empty.")
         return 0, 0, 0, 0, 0, np.zeros((num_classes, num_classes)), {}
         
    avg_loss = total_loss / dataset_size
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)

    # Calculate Metrics
    wa = accuracy_score(all_targets, all_preds) # Weighted Accuracy (standard accuracy)
    ua = balanced_accuracy_score(all_targets, all_preds) # Unweighted Accuracy
    conf_matrix = confusion_matrix(all_targets, all_preds, labels=range(num_classes))
    
    # Get class names sorted by label index for the report
    target_names_sorted = [name for name, idx in sorted(EMOTIONS.items(), key=lambda item: item[1])]
    
    class_report_dict = classification_report(all_targets, all_preds, labels=range(num_classes),
                                        target_names=target_names_sorted,
                                        zero_division=0, output_dict=True)
    class_report_str = classification_report(all_targets, all_preds, labels=range(num_classes),
                                        target_names=target_names_sorted,
                                        zero_division=0)

    print(f"Evaluation Loss: {avg_loss:.4f}")
    print(f"Weighted Accuracy (WA): {wa:.4f}")
    print(f"Unweighted Accuracy (UA): {ua:.4f}")
    print("\nClassification Report:")
    print(class_report_str)
    # print("\nConfusion Matrix:")
    # print(conf_matrix)

    # Extract overall precision, recall, f1 (macro average)
    precision = class_report_dict['macro avg']['precision']
    recall = class_report_dict['macro avg']['recall']
    f1 = class_report_dict['macro avg']['f1-score']

    return wa, ua, precision, recall, f1, conf_matrix, class_report_dict

## 7.2. Plot Confusion Matrix Function

In [None]:
def plot_confusion_matrix(cm, class_names, title='Confusion Matrix'):
    """Plots a confusion matrix using seaborn."""
    if cm is None or not isinstance(cm, np.ndarray):
        print("Invalid confusion matrix provided for plotting.")
        return
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                annot_kws={"size": 10}) # Adjust font size if needed
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title(title)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

# Get class names in the correct order for plotting
class_names_sorted = [name for name, idx in sorted(EMOTIONS.items(), key=lambda item: item[1])]

## 7.3. Evaluate Original Data (Validation Set)

In [None]:
print("\n--- Evaluating on Original Validation Data ---")
wa_orig, ua_orig, p_orig, r_orig, f1_orig, cm_orig, cr_orig = (None,) * 7

if resnet_model and val_loader and criterion_resnet:
     wa_orig, ua_orig, p_orig, r_orig, f1_orig, cm_orig, cr_orig = evaluate_model(
         resnet_model, val_loader, criterion_resnet, DEVICE, NUM_EMOTIONS
     )
     if cm_orig is not None:
         print("\nOriginal Data Evaluation Summary:")
         print(f"WA: {wa_orig:.4f}, UA: {ua_orig:.4f}")
         print(f"Precision (Macro): {p_orig:.4f}, Recall (Macro): {r_orig:.4f}, F1 (Macro): {f1_orig:.4f}")
         plot_confusion_matrix(cm_orig, class_names_sorted, title='Confusion Matrix (Original Validation Data)')
     else:
          print("Evaluation returned invalid confusion matrix.")
else:
     print("ResNet model, validation loader, or criterion not available. Skipping original data evaluation.")

## 7.4. Evaluate Generated Data

In [None]:
print("\n--- Evaluating on Generated (Augmented) Data ---")
wa_aug, ua_aug, p_aug, r_aug, f1_aug, cm_aug, cr_aug = (None,) * 7

# Load augmented metadata again (ensure it exists and is up-to-date)
aug_metadata_path = os.path.join(DATA_DIR, "augmented_metadata.npy")
augmented_metadata_loaded = []
if os.path.exists(aug_metadata_path):
    try:
        augmented_metadata_loaded = np.load(aug_metadata_path, allow_pickle=True).tolist()
        print(f"Loaded {len(augmented_metadata_loaded)} augmented metadata entries for evaluation.")
    except Exception as e:
        print(f"Error loading augmented metadata for evaluation: {e}")
else:
    print(f"Augmented metadata file not found at {aug_metadata_path}. Cannot evaluate generated data.")

if augmented_metadata_loaded:
    # Create dataset and dataloader for augmented data
    aug_dataset = MelSpectrogramDataset(augmented_metadata_loaded)
    if aug_dataset.num_classes != NUM_EMOTIONS and aug_dataset.num_classes > 0:
         print(f"Warning: Number of classes in augmented data ({aug_dataset.num_classes}) differs from original ({NUM_EMOTIONS}). Using {aug_dataset.num_classes} for evaluation.")
         eval_num_classes = aug_dataset.num_classes
    elif aug_dataset.num_classes == 0:
         print("Error: Augmented dataset has no classes. Cannot evaluate.")
         aug_dataset = None
    else:
         eval_num_classes = NUM_EMOTIONS
         
    if aug_dataset and len(aug_dataset) > 0:
        # Use num_workers=0 if issues arise
        aug_loader = DataLoader(aug_dataset, batch_size=RESNET_BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

        # Evaluate using the trained ResNet model
        if resnet_model and criterion_resnet:
            wa_aug, ua_aug, p_aug, r_aug, f1_aug, cm_aug, cr_aug = evaluate_model(
                resnet_model, aug_loader, criterion_resnet, DEVICE, eval_num_classes
            )
            if cm_aug is not None:
                 print("\nGenerated Data Evaluation Summary:")
                 print(f"WA: {wa_aug:.4f}, UA: {ua_aug:.4f}")
                 print(f"Precision (Macro): {p_aug:.4f}, Recall (Macro): {r_aug:.4f}, F1 (Macro): {f1_aug:.4f}")
                 # Ensure class names match the number of classes evaluated
                 plot_class_names = class_names_sorted[:eval_num_classes]
                 plot_confusion_matrix(cm_aug, plot_class_names, title='Confusion Matrix (Generated Augmented Data)')
            else:
                 print("Evaluation returned invalid confusion matrix for augmented data.")
        else:
            print("ResNet model or criterion not available. Skipping generated data evaluation.")
    else:
         print("Augmented dataset is empty or invalid. Skipping evaluation.")
else:
    print("No augmented data found to evaluate.")

## 7.5. Compare Results (Original vs. Augmented)
This compares the classifier's performance on the held-out *original validation data* versus its performance on the *purely synthetic generated data*. This helps assess how well the generated data preserves recognizable emotional features according to the classifier.

In [None]:
print("\n\n--- Evaluation Comparison ---")
print("-" * 30)
print("| Metric          | Original(Val) | Generated |")
print("|-----------------|---------------|-----------|")
metric_available = wa_orig is not None and wa_aug is not None
if metric_available:
    print(f"| WA              | {wa_orig:^13.4f} | {wa_aug:^9.4f} |")
    print(f"| UA              | {ua_orig:^13.4f} | {ua_aug:^9.4f} |")
    print(f"| Precision (Mac) | {p_orig:^13.4f} | {p_aug:^9.4f} |")
    print(f"| Recall (Mac)    | {r_orig:^13.4f} | {r_aug:^9.4f} |")
    print(f"| F1 Score (Mac)  | {f1_orig:^13.4f} | {f1_aug:^9.4f} |")
else:
    print("| Results not available for comparison. |")
print("-" * 30)

print("\nNote: Evaluation performed on original validation data vs. purely generated data.")
print("The paper's results might involve different evaluation splits or combining datasets.")

# 8. Conclusion

This notebook implemented the core steps described in the paper "A Generation of Enhanced Data by Variational Autoencoders and Diffusion Modeling" for audio emotion data augmentation.

1.  **Data Preparation:** Loaded and preprocessed EmoDB and RAVDESS datasets into standardized mel-spectrograms.
2.  **Emotion Embedding:** Trained a ResNet-50 model for emotion classification and used it to extract emotion embeddings.
3.  **Diffusion Model:** Implemented and trained a U-Net based diffusion model conditioned on time and emotion embeddings to generate mel-spectrograms.
4.  **Data Augmentation:** Used the trained diffusion model to generate new mel-spectrograms for each target emotion.
5.  **Evaluation:** Assessed the quality of the original and generated data using the trained ResNet classifier, comparing metrics like WA, UA, Precision, Recall, F1-score, and confusion matrices.

The results can be compared to those in the paper (Table 5, Table 6, etc.), keeping in mind potential differences in dataset splits, specific model architectures, hyperparameters, and training duration. The diffusion model, in particular, often requires extensive training to generate high-fidelity samples.

**Potential next steps:**
*   Hyperparameter tuning for both ResNet and Diffusion models.
*   Using more advanced attention mechanisms in the U-Net.
*   Implementing more sophisticated diffusion sampling techniques (e.g., DDIM).
*   Evaluating the impact of augmented data by adding it to the *training set* of the classifier and re-training/re-evaluating.
*   Converting generated mel-spectrograms back to audio using a vocoder (e.g., HiFi-GAN, MelGAN - requires separate implementation/model) for auditory evaluation.