In [None]:
import os
import torch
import torchaudio
import numpy as np
import pandas as pd
import shutil
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from tqdm import tqdm
import joblib
import soundfile as sf
import audiomentations
print(audiomentations.__version__)
from audiomentations import Compose, AddGaussianNoise, AddBackgroundNoise, TimeStretch, PitchShift, ClippingDistortion, RoomSimulator

0.41.0


# RAVDESS Dataset Organization and Label Generation

In [None]:
def organize_ravdess(
    data_path="/Audio_Speech_Actors_01-24",
    output_path="/Data/archive_processed"
):
    emotion_map = {
        '01': 'neutral', '02': 'calm', '03': 'happy', '04': 'sad',
        '05': 'angry', '06': 'fear', '07': 'disgust', '08': 'surprise'
    }
    intensity_map = {'01': 'normal', '02': 'strong'}
    statement_map = {
        '01': 'Kids are talking by the door',
        '02': 'Dogs are sitting by the door'
    }
    metadata = []
    os.makedirs(output_path, exist_ok=True)
    for actor_dir in tqdm(sorted(os.listdir(data_path)), desc="Actors"):
        if not actor_dir.startswith('Actor_'):
            continue
        actor_path = os.path.join(data_path, actor_dir)
        if not os.path.isdir(actor_path): continue
        actor_id = actor_dir.split('_')[1]
        gender = 'male' if int(actor_id) % 2 == 1 else 'female'
        for filename in sorted(os.listdir(actor_path)):
            if not filename.endswith('.wav'): continue
            parts = filename.split('-')
            if len(parts) != 7: continue
            modality, vocal_channel, emo_code, inten_code, stmt_code, rep, actor = parts
            if modality != '03' or emo_code not in emotion_map: continue
            emotion = emotion_map[emo_code]
            intensity = intensity_map[inten_code]
            statement = statement_map.get(stmt_code, f"statement_{stmt_code}")
            new_filename = (
                f"{emotion}_{intensity}_{gender}_actor{actor_id}_stmt{stmt_code}_rep{rep}.wav"
            )
            src_path = os.path.join(actor_path, filename)
            dest_path = os.path.join(output_path, new_filename)
            shutil.copy2(src_path, dest_path)
            metadata.append({
                'original_filename': filename,
                'processed_filename': new_filename,
                'emotion': emotion,
                'intensity': intensity,
                'gender': gender,
                'actor_id': actor_id,
                'statement': statement,
                'processed_path': dest_path,
            })
    df = pd.DataFrame(metadata)
    df.to_csv(os.path.join(output_path, "complete_metadata.csv"), index=False)
    print(f"Meta data generated into {os.path.join(output_path, 'complete_metadata.csv')}, number of samples: {len(df)}")
    return df

df = organize_ravdess(
    data_path="/Audio_Speech_Actors_01-24",
    output_path="/Data/archive_processed"
)


# Label Encoding

In [None]:
import os

LE_PATH = "/Data/label_encoders"
os.makedirs(LE_PATH, exist_ok=True)

meta = pd.read_csv("/Data/archive_processed/complete_metadata.csv")
for col in ["emotion", "intensity", "gender"]:
    le = LabelEncoder()
    meta[col] = le.fit_transform(meta[col])
    joblib.dump(le, os.path.join(LE_PATH, f"le_{col}.pkl"))
meta.to_csv("/Data/archive_processed/complete_metadata_encoded.csv", index=False)
print("Label encoding completed. The encoded metadata has been saved.")


# Dataset Splitting

In [None]:
meta = pd.read_csv("/Data/archive_processed/complete_metadata_encoded.csv")
train_df, val_df = train_test_split(meta, stratify=meta["emotion"], test_size=0.15, random_state=42)

# Data Augmentation & Virtual Dataset Construction

In [None]:
from audiomentations import Compose, AddGaussianNoise, AddBackgroundNoise, TimeStretch, PitchShift, \
    ClippingDistortion, RoomSimulator, PolarityInversion, Gain, BandPassFilter, BandStopFilter

NOISE_DIR = "/ESC-50-master"

def get_stronger_audio_augmentor():
    return Compose([
        PolarityInversion(p=0.2),
        Gain(min_gain_db=-6, max_gain_db=6, p=0.3),
        AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.008, p=0.4),
        AddBackgroundNoise(sounds_path=NOISE_DIR, min_snr_db=3, max_snr_db=18, p=0.4),
        TimeStretch(min_rate=0.95, max_rate=1.05, p=0.20),
        PitchShift(min_semitones=-2, max_semitones=2, p=0.20),
        BandPassFilter(min_center_freq=200.0, max_center_freq=4000.0, p=0.10),
        BandStopFilter(min_center_freq=200.0, max_center_freq=4000.0, p=0.10),
        ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=7, p=0.15),
        RoomSimulator(p=0.15)
    ])

class SERDatasetAugmentedVirtual(Dataset):
    def __init__(self, df, augment=True, sample_rate=16000, target_duration=3, n_aug_per_sample=3):
        self.df = df.reset_index(drop=True)
        self.augment = augment
        self.sample_rate = sample_rate
        self.target_samples = int(target_duration * sample_rate)
        self.n_aug_per_sample = n_aug_per_sample if augment else 1
        self.augmentor = get_stronger_audio_augmentor() if augment else None

        self.virtual_index = []
        for idx in range(len(self.df)):
            for aug_idx in range(self.n_aug_per_sample):
                self.virtual_index.append((idx, aug_idx))

    def __len__(self):
        return len(self.virtual_index)

    def __getitem__(self, vidx):
        idx, aug_idx = self.virtual_index[vidx]
        row = self.df.iloc[idx]
        audio, sr = sf.read(row["processed_path"])
        if sr != self.sample_rate:
            import librosa
            audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)
        if audio.ndim == 2:
            audio = audio.mean(axis=1)
        if len(audio) < self.target_samples:
            audio = np.pad(audio, (0, self.target_samples - len(audio)))
        elif len(audio) > self.target_samples:
            audio = audio[:self.target_samples]
        assert len(audio) == self.target_samples, f"{row['processed_path']} length error: {len(audio)}, expected {self.target_samples}"
        if self.augmentor and aug_idx > 0:
            audio = self.augmentor(samples=audio, sample_rate=self.sample_rate)
            if len(audio) < self.target_samples:
                audio = np.pad(audio, (0, self.target_samples - len(audio)))
            elif len(audio) > self.target_samples:
                audio = audio[:self.target_samples]
            assert len(audio) == self.target_samples, f"{row['processed_path']} length error after augmentation{len(audio)}, expected {self.target_samples}"
        labels = (row["emotion"], row["intensity"], row["gender"])
        return audio.astype(np.float32), labels
SAMPLE_RATE = 16000
train_set = SERDatasetAugmentedVirtual(train_df, augment=True, sample_rate=SAMPLE_RATE, target_duration=3, n_aug_per_sample=4)
val_set   = SERDatasetAugmentedVirtual(val_df,   augment=False, sample_rate=SAMPLE_RATE, target_duration=3, n_aug_per_sample=1)



# Wav2Vec2 Feature Extraction

In [None]:
MODEL_NAME = "facebook/wav2vec2-base"
SAMPLE_RATE = 16000
BATCH_SIZE = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
w2v2_model = Wav2Vec2Model.from_pretrained(MODEL_NAME).to(device).eval()

def collate_fn(batch):
    audios, labels = zip(*batch)
    audios = np.stack(audios)
    labels = np.array(labels, dtype=np.int64)
    return torch.tensor(audios, dtype=torch.float32), torch.tensor(labels, dtype=torch.long)

def extract_w2v2_features_seq(dataset):
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
    features, all_labels = [], []
    with torch.no_grad():
        for batch_audios, batch_labels in loader:
            assert batch_audios.shape[1] == 3 * SAMPLE_RATE, f"Audio length {batch_audios.shape[1]} doesn't equal 3-second target {3 * SAMPLE_RATE}"
            inputs = processor(
                batch_audios.numpy(),
                sampling_rate=SAMPLE_RATE,
                return_tensors="pt",
                padding="longest"
            )
            outs = w2v2_model(input_values=inputs.input_values.to(device))
            emb_seq = outs.last_hidden_state.cpu().numpy()
            features.append(emb_seq)
            all_labels.append(batch_labels.numpy())

    X = np.concatenate(features, axis=0)
    y = np.concatenate(all_labels, axis=0)
    return X, y



X_train, y_train = extract_w2v2_features_seq(train_set)
X_val, y_val = extract_w2v2_features_seq(val_set)
print("Wav2Vec2 frame sequence feature shape:", X_train.shape, X_val.shape)



# Definition & Training of the Multi-task CNN Model

In [None]:
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import copy
import os
import numpy as np

class DeeperMultiTaskCNNClassifier(nn.Module):
    def __init__(self, emb_dim, n_emotion, n_intensity, n_gender, 
                 cnn_channels=128, kernel_size=5, hidden=256, dropout=0.4):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(emb_dim, cnn_channels, kernel_size=kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv1d(cnn_channels, cnn_channels*2, kernel_size=kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(cnn_channels*2),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv1d(cnn_channels*2, cnn_channels*2, kernel_size=kernel_size, padding=kernel_size//2),
            nn.BatchNorm1d(cnn_channels*2),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv1d(cnn_channels*2, cnn_channels, kernel_size=1),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(cnn_channels, hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.head_emotion   = nn.Linear(hidden, n_emotion)
        self.head_intensity = nn.Linear(hidden, n_intensity)
        self.head_gender    = nn.Linear(hidden, n_gender)

    def forward(self, x):
        x = x.transpose(1, 2)
        cnn_out = self.cnn(x)
        feat = self.fc(cnn_out)
        return self.head_emotion(feat), self.head_intensity(feat), self.head_gender(feat)

n_emotion = 8
n_intensity = 2
n_gender = 2
model = DeeperMultiTaskCNNClassifier(
    emb_dim=X_train.shape[2],
    n_emotion=n_emotion,
    n_intensity=n_intensity,
    n_gender=n_gender
).to(device)


print("Wav2Vec2 frame sequence feature shape:", X_train.shape, X_val.shape)

X_train_torch = torch.tensor(X_train, dtype=torch.float32)
y_train_e = torch.tensor(y_train[:, 0], dtype=torch.long)
y_train_i = torch.tensor(y_train[:, 1], dtype=torch.long)
y_train_g = torch.tensor(y_train[:, 2], dtype=torch.long)

X_val_torch = torch.tensor(X_val, dtype=torch.float32)
y_val_e = torch.tensor(y_val[:, 0], dtype=torch.long)
y_val_i = torch.tensor(y_val[:, 1], dtype=torch.long)
y_val_g = torch.tensor(y_val[:, 2], dtype=torch.long)

loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

EPOCHS = 50
BATCH_SIZE = 32
patience = 7 

best_val_loss = np.inf
best_model_state = None
wait = 0

train_losses = []
val_losses = []
val_acc_emotion = []
val_acc_intensity = []
val_acc_gender = []

for epoch in range(EPOCHS):
    model.train()
    idx = np.random.permutation(len(X_train_torch))
    total_loss = 0

    for i in range(0, len(idx), BATCH_SIZE):
        batch = idx[i:i+BATCH_SIZE]
        xb = X_train_torch[batch].to(device)
        yb_e = y_train_e[batch].to(device)
        yb_i = y_train_i[batch].to(device)
        yb_g = y_train_g[batch].to(device)
        out_e, out_i, out_g = model(xb)
        loss = 2.0 * loss_fn(out_e, yb_e) + 1.0 * loss_fn(out_i, yb_i) + 0.5 * loss_fn(out_g, yb_g) 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * len(batch)
    avg_loss = total_loss / len(X_train_torch)
    train_losses.append(avg_loss)

    model.eval()
    with torch.no_grad():
        ve, vi, vg = model(X_val_torch.to(device))
        pred_e = ve.argmax(1).cpu().numpy()
        pred_i = vi.argmax(1).cpu().numpy()
        pred_g = vg.argmax(1).cpu().numpy()
        
        acc_e = np.mean(pred_e == y_val[:, 0])
        acc_i = np.mean(pred_i == y_val[:, 1])
        acc_g = np.mean(pred_g == y_val[:, 2])
        
        val_loss = (
            2.0 * loss_fn(ve, y_val_e.to(device)) +
            1.0 * loss_fn(vi, y_val_i.to(device)) +
            0.5 * loss_fn(vg, y_val_g.to(device))
        ).item() / 3.5

    val_losses.append(val_loss)
    val_acc_emotion.append(acc_e)
    val_acc_intensity.append(acc_i)
    val_acc_gender.append(acc_g)

    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    
    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"Loss: {avg_loss:.4f} | "
          f"Val_Loss: {val_loss:.4f} | "
          f"Val_Emo: {acc_e*100:.2f}% | "
          f"Val_Inten: {acc_i*100:.2f}% | "
          f"Val_Gend: {acc_g*100:.2f}% | "
          f"LR: {current_lr:.6f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = copy.deepcopy(model.state_dict())
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            print(f"Early stopping at epoch {epoch+1}. Best val_loss: {best_val_loss:.4f}")
            break

model.load_state_dict(best_model_state)
torch.save(model.state_dict(), "w2v2_multitask_cnn_XAI_best.pt")
print("Best model saved to w2v2_multitask_cnn_XAI_best.pt")

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Loss Curve")

plt.subplot(1,2,2)
plt.plot(val_acc_emotion, label="Val Acc Emotion")
plt.plot(val_acc_intensity, label="Val Acc Intensity")
plt.plot(val_acc_gender, label="Val Acc Gender")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Validation Accuracy Curves")
plt.show()

OUTPUT_DIR = "/Data/w2v2/output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

np.save(os.path.join(OUTPUT_DIR, "train_losses.npy"), np.array(train_losses))
np.save(os.path.join(OUTPUT_DIR, "val_losses.npy"), np.array(val_losses))
np.save(os.path.join(OUTPUT_DIR, "val_acc_emotion.npy"), np.array(val_acc_emotion))
np.save(os.path.join(OUTPUT_DIR, "val_acc_intensity.npy"), np.array(val_acc_intensity))
np.save(os.path.join(OUTPUT_DIR, "val_acc_gender.npy"), np.array(val_acc_gender))

print(f"Training log saved to {OUTPUT_DIR}")

# Classification Report & Audio Prediction

In [None]:
from sklearn.metrics import classification_report

model.eval()
with torch.no_grad():
    ve, vi, vg = model(X_val_torch.to(device))
    pred_e = ve.argmax(1).cpu().numpy()
    pred_i = vi.argmax(1).cpu().numpy()
    pred_g = vg.argmax(1).cpu().numpy()

LE_PATH = "/Data/label_encoders"
le_e = joblib.load(os.path.join(LE_PATH, "le_emotion.pkl"))
le_i = joblib.load(os.path.join(LE_PATH, "le_intensity.pkl"))
le_g = joblib.load(os.path.join(LE_PATH, "le_gender.pkl"))

print("Emotion:\n", classification_report(y_val[:, 0], pred_e, target_names=le_e.classes_))
print("Intensity:\n", classification_report(y_val[:, 1], pred_i, target_names=le_i.classes_))
print("Gender:\n", classification_report(y_val[:, 2], pred_g, target_names=le_g.classes_))


def predict_audio(path, le_e, le_i, le_g):
    y, sr = sf.read(path)
    if sr != SAMPLE_RATE:
        import librosa
        y = librosa.resample(y, orig_sr=sr, target_sr=SAMPLE_RATE)
    if y.ndim > 1: y = y.mean(axis=1)

    inputs = processor(y, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True)
    with torch.no_grad():
        emb_seq = w2v2_model(input_values=inputs.input_values.to(device)).last_hidden_state.cpu().numpy()[0]
        emb_seq = (emb_seq - emb_seq.mean()) / (emb_seq.std() + 1e-8)
        single_x_tensor = torch.tensor(emb_seq, dtype=torch.float32).unsqueeze(0).to(device)
        out_e, out_i, out_g = model(single_x_tensor)
        pred_e = out_e.argmax(1).cpu().numpy()[0]
        pred_i = out_i.argmax(1).cpu().numpy()[0]
        pred_g = out_g.argmax(1).cpu().numpy()[0]
    return {
        "emotion": le_e.inverse_transform([pred_e])[0],
        "intensity": le_i.inverse_transform([pred_i])[0],
        "gender": le_g.inverse_transform([pred_g])[0]
    }



test_path = "/Data/test/happy.wav"
result = predict_audio(test_path, le_e, le_i, le_g)
print("Prediction:", result)


# XAI Analysis

In [None]:
import torch
from captum.attr import IntegratedGradients
import matplotlib.pyplot as plt

label_map = {
    0: 'neutral', 1: 'calm', 2: 'happy', 3: 'sad',
    4: 'angry', 5: 'fear', 6: 'disgust', 7: 'surprise'
}


def model_forward_wrapper(input_tensor):
    output_tuple = model(input_tensor)
    return output_tuple[0]

def analyze_with_ig(sample_idx, target_class_index, target_label_name):
    input_tensor = X_val_torch[sample_idx:sample_idx+1].to(device)
    
    ig = IntegratedGradients(model_forward_wrapper)
    
    attributions, delta = ig.attribute(
        input_tensor,
        baselines=torch.zeros_like(input_tensor),
        target=target_class_index,
        return_convergence_delta=True
    )
    
    plt.figure(figsize=(10, 4))
    plt.imshow(attributions[0].detach().cpu().numpy().T, aspect='auto', cmap='hot')
    plt.colorbar()
    plt.title(f'Integrated Gradients for Sample {sample_idx} (Target: {target_label_name})')
    plt.xlabel('Time Step')
    plt.ylabel('Embedding Dimension')
    plt.show()

happy_label = -1
try:
    happy_label = list(label_map.values()).index('happy')
except ValueError:
    print("Error: 'happy' is not in label_map.")

if happy_label != -1:
    print("Using the current model to generate the latest predictions...")
    model.eval()
    with torch.no_grad():
        all_outputs = model_forward_wrapper(X_val_torch.to(device))
        pred_e = torch.argmax(all_outputs, dim=1).cpu().numpy()
    print("Latest predictions generated.")

    happy_wrong_idx = [i for i in range(len(y_val)) if y_val[i, 0] == happy_label and pred_e[i] != happy_label]

    print(f"\nTotal number of 'happy' samples misclassified: {len(happy_wrong_idx)}")
    print(f"Indices of misclassified 'happy' samples: {happy_wrong_idx}")

    if happy_wrong_idx:
        num_to_analyze = 7 
        
        for sample_idx in happy_wrong_idx[:num_to_analyze]:
            print(f"\n==================== Analysis of Sample {sample_idx} ====================")

            true_class_index = int(y_val[sample_idx, 0])
            true_label_name = label_map.get(true_class_index, "Unknown")
            
            predicted_class_index = int(pred_e[sample_idx])
            predicted_label_name = label_map.get(predicted_class_index, "Unknown")

            print(f"True label: {true_label_name} ({true_class_index})")
            print(f"Model prediction: {predicted_label_name} ({predicted_class_index})")

            print("\n -> Generating attribution map (Target: happy)...")
            analyze_with_ig(
                sample_idx=sample_idx,
                target_class_index=true_class_index,
                target_label_name=true_label_name
            )

            print(f"\n -> Generating attribution map (Target: {predicted_label_name})...")
            analyze_with_ig(
                sample_idx=sample_idx,
                target_class_index=predicted_class_index,
                target_label_name=predicted_label_name
            )
            print(f"==================== Analysis of Sample {sample_idx} completed ====================\n")

    else:
        print("\nCongratulations! The model didn't misclassify any 'happy' samples. No analysis needed.")

# Visualization & Performance Reports

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from PIL import Image

VIS_DIR = "/Data/w2v2/visualizations/w2v2_multitask_cnn_XAI"
EMB_DIR = "/Data/XAI/w2v2"
LOG_DIR = "/Data/w2v2/output"


os.makedirs(VIS_DIR, exist_ok=True)
os.makedirs(EMB_DIR, exist_ok=True) 
os.makedirs(LOG_DIR, exist_ok=True) 

assert 'X_train' in globals() and 'y_train' in globals()

X = X_train 
y_all = y_train 

import pandas as pd
df_info = pd.DataFrame({
    "emotion": y_all[:, 0],
    "intensity": y_all[:, 1],
    "gender": y_all[:, 2],
})
y_true = df_info["emotion"].values

np.save(os.path.join(EMB_DIR, "emb_train_X.npy"), X)
df_info.to_csv(os.path.join(EMB_DIR, "emb_train_info.csv"), index=False)
print("Automatically generated emb_train_X.npy and emb_train_info.csv files.")

plt.figure(figsize=(10, 4))
for i, col in enumerate(["emotion", "intensity", "gender"]):
    plt.subplot(1, 3, i+1)
    df_info[col].value_counts().sort_index().plot(kind='bar')
    plt.title(f"{col.capitalize()} Distribution")
    plt.xlabel(col)
    plt.ylabel("Count")
plt.tight_layout()
label_path = os.path.join(VIS_DIR, "label_distribution_summary_XAI.png")
plt.savefig(label_path, dpi=300)
plt.close()

if X.ndim == 3:
    X = X.mean(axis=1)

tsne = TSNE(n_components=2, perplexity=30, random_state=42)
X_emb = tsne.fit_transform(X)
plt.figure(figsize=(8, 6))
scatter = plt.scatter(X_emb[:, 0], X_emb[:, 1], c=y_true, cmap="tab10", s=10, alpha=0.7)
plt.title("t-SNE Projection of Embeddings")
plt.colorbar(scatter, label="Emotion")
tsne_path = os.path.join(VIS_DIR, "tsne_emotion_projection_XAI.png")
plt.savefig(tsne_path, dpi=300)
plt.close()

loss_train = np.load(os.path.join(LOG_DIR, "train_losses.npy"))
loss_val   = np.load(os.path.join(LOG_DIR, "val_losses.npy"))
acc_e = np.load(os.path.join(LOG_DIR, "val_acc_emotion.npy"))
acc_i = np.load(os.path.join(LOG_DIR, "val_acc_intensity.npy"))
acc_g = np.load(os.path.join(LOG_DIR, "val_acc_gender.npy"))

fig, axs = plt.subplots(1, 2, figsize=(14, 5))
axs[0].plot(loss_train, label="Train Loss")
axs[0].plot(loss_val, label="Val Loss")
axs[0].set_title("Loss Curve")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Loss")
axs[0].legend()

axs[1].plot(acc_e, label="Emotion Acc")
axs[1].plot(acc_i, label="Intensity Acc")
axs[1].plot(acc_g, label="Gender Acc")
axs[1].set_title("Validation Accuracy")
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("Accuracy")
axs[1].legend()

plt.tight_layout()
curve_path = os.path.join(VIS_DIR, "training_curves_summary_XAI.png")
plt.savefig(curve_path, dpi=300)
plt.close()

img1 = Image.open(label_path).resize((600, 300))
img2 = Image.open(tsne_path).resize((600, 300))
img3 = Image.open(curve_path).resize((600, 300))
combined_img = Image.new("RGB", (1800, 300), color=(255, 255, 255))
combined_img.paste(img1, (0, 0))
combined_img.paste(img2, (600, 0))
combined_img.paste(img3, (1200, 0))
combined_path = os.path.join(VIS_DIR, "summary_all_combined_XAI.png")
combined_img.save(combined_path)

val_preds_path = os.path.join(LOG_DIR, "val_preds_emotion.npy")
y_pred = np.load(val_preds_path) if os.path.exists(val_preds_path) else y_true
cm = confusion_matrix(y_true, y_pred, labels=np.unique(y_true))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=np.unique(y_true))
plt.figure(figsize=(8, 6))
disp.plot(cmap=plt.cm.Blues, values_format='d')
plt.title("Confusion Matrix - Emotion")
cm_path = os.path.join(VIS_DIR, "summary_confusion_XAI.png")
plt.savefig(cm_path, dpi=300)
plt.close()

if X.ndim == 3:
    sample_idx = 0
    feat_sample = X[sample_idx]
    plt.figure(figsize=(12, 6))
    plt.imshow(feat_sample.T, aspect='auto', origin='lower', cmap="viridis")
    plt.colorbar()
    plt.title("Feature Map of Sample 0")
    plt.xlabel("Time Frames")
    plt.ylabel("Feature Dimensions")
    fmap_path = os.path.join(VIS_DIR, "summary_featmap_XAI.png")
    plt.savefig(fmap_path, dpi=300)
    plt.close()
else:
    fmap_path = None

report_text = classification_report(y_true, y_pred, output_dict=False)
report_path = os.path.join(VIS_DIR, "summary_clsreport_XAI.txt")
with open(report_path, "w") as f:
    f.write(report_text)

print("Visualizations and reports generated:")
print("1. Label distribution:", label_path)
print("2. t-SNE projection:", tsne_path)
print("3. Training curves:", curve_path)
print("4. Combined image:", combined_path)
print("5. Confusion matrix:", cm_path)
print("6. Feature map:", fmap_path)
print("7. Classification report:", report_path)
