In [None]:
import os
import re
import pandas as pd
import numpy as np
import librosa
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    Wav2Vec2Model, Wav2Vec2Processor,
    WhisperModel, WhisperProcessor, 
    BertModel, BertTokenizer,
    AutoModel, AutoTokenizer
)
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.svm import SVC, SVR
from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
from scipy.spatial.distance import cdist
from scipy.stats import ks_2samp
import scipy.linalg
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

In [None]:
@dataclass
class ADReSSConfig:
    """Configuration class for ADReSS dataset processing."""
    # train_path: str = "ADReSS-IS2020-train/ADReSS-IS2020-data/train"
    train_path: str = "/kaggle/input/adress/adress/train"
    # test_path: str = "ADReSS-IS2020-test/ADReSS-IS2020-data/test"
    test_path: str = "/kaggle/input/adress/adress/test"
    audio_sr: int = 16000
    
    # Audio feature extraction parameters
    n_mfcc: int = 13
    n_mel: int = 128
    hop_length: int = 512
    n_fft: int = 2048
    
    # Feature statistics windows
    statistical_windows: List[str] = None
    
    def __post_init__(self):
        if self.statistical_windows is None:
            self.statistical_windows = ['mean', 'std', 'min', 'max', 'median']

In [None]:
class MetadataLoader:
    """Handles loading and processing of metadata files for both training and test datasets."""
    
    def __init__(self, base_path: str):
        self.base_path = Path(base_path)
        self.dataset_type = self._detect_dataset_type()
    
    def _detect_dataset_type(self) -> str:
        """Detect if this is training data or test data based on file structure."""
        cc_path = self.base_path / "cc_meta_data.txt"
        cd_path = self.base_path / "cd_meta_data.txt"
        test_path = self.base_path / "meta_data.txt"
        
        if cc_path.exists() and cd_path.exists():
            return "training"
        elif test_path.exists():
            return "test"
        else:
            raise FileNotFoundError("Could not detect dataset type. Expected either training files (cc_meta_data.txt, cd_meta_data.txt) or test file (meta_data.txt)")
    
    def load_metadata(self) -> pd.DataFrame:
        """Load metadata from either training or test dataset."""
        if self.dataset_type == "training":
            return self._load_training_metadata()
        else:
            return self._load_test_metadata()
    
    def _load_training_metadata(self) -> pd.DataFrame:
        """Load metadata from both control and dementia groups."""
        # print("Loading training metadata...")
        cc_path = self.base_path / "cc_meta_data.txt"
        cd_path = self.base_path / "cd_meta_data.txt"
        
        cc_df = pd.read_csv(cc_path, sep=';', skipinitialspace=True)
        cd_df = pd.read_csv(cd_path, sep=';', skipinitialspace=True)
        
        # Clean column names by stripping whitespace
        cc_df.columns = cc_df.columns.str.strip()
        cd_df.columns = cd_df.columns.str.strip()
        
        cc_df['label'] = 0  # Control
        cd_df['label'] = 1  # Dementia
        cc_df['group'] = 'cc'
        cd_df['group'] = 'cd'
        
        metadata = pd.concat([cc_df, cd_df], ignore_index=True)
        
        # Process IDs and MMSE
        metadata['ID'] = metadata['ID'].str.strip()
        metadata['mmse'] = pd.to_numeric(metadata['mmse'], errors='coerce')
        
        return metadata
    
    def _load_test_metadata(self) -> pd.DataFrame:
        """Load test metadata (no labels, different format)."""
        # print("Loading test metadata...")
        meta_path = self.base_path / "meta_data.txt"
        
        # Read test metadata - format: ID ; age ; gender
        test_df = pd.read_csv(meta_path, sep=';', skipinitialspace=True, header=None)
        test_df.columns = ['ID', 'age', 'gender']
        
        # print("Processing test metadata...")
        # Clean column data
        test_df['ID'] = test_df['ID'].str.strip()
        test_df['age'] = pd.to_numeric(test_df['age'], errors='coerce')
        test_df['gender'] = test_df['gender'].str.strip()
        
        # Test data has no labels or MMSE scores
        test_df['label'] = -1  # Unknown label for test data
        test_df['mmse'] = np.nan  # No MMSE scores in test data
        test_df['group'] = 'test'
        
        return test_df

In [None]:
class TranscriptionProcessor:
    """Processes CHAT format transcription files."""
    
    def __init__(self):
        self.participant_pattern = re.compile(r'\*PAR:\s*(.*?)\s*(\d+)_(\d+)')
        self.investigator_pattern = re.compile(r'\*INV:\s*(.*?)\s*(\d+)_(\d+)')
        
    def extract_speech_with_timing(self, file_path: str) -> Dict:
        """Extract participant and investigator speech with timestamps."""
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        participant_utterances = []
        investigator_utterances = []
        participant_segments = []
        investigator_segments = []
        
        # Find all matches first
        par_matches = list(self.participant_pattern.finditer(content))
        inv_matches = list(self.investigator_pattern.finditer(content))
        
        # Process participant matches
        for match in par_matches:
            utterance = self._clean_utterance(match.group(1))
            start_time = int(match.group(2))
            end_time = int(match.group(3))
            if utterance:
                participant_utterances.append(utterance)
                participant_segments.append((start_time, end_time, utterance))
        
        # Process investigator matches
        for match in inv_matches:
            utterance = self._clean_utterance(match.group(1))
            start_time = int(match.group(2))
            end_time = int(match.group(3))
            if utterance:
                investigator_utterances.append(utterance)
                investigator_segments.append((start_time, end_time, utterance))
        
        return {
            'participant_text': ' '.join(participant_utterances),
            'investigator_text': ' '.join(investigator_utterances),
            'total_utterances': len(participant_utterances),
            'participant_segments': participant_segments,
            'investigator_segments': investigator_segments
        }
    
    def extract_speech(self, file_path: str) -> Dict[str, str]:
        """Backward compatibility method."""
        return self.extract_speech_with_timing(file_path)
    
    def _clean_utterance(self, utterance: str) -> str:
        """Clean CHAT annotations from utterance."""
        cleaned = re.sub(r'\[.*?\]', '', utterance)
        cleaned = re.sub(r'&\w+', '', cleaned)
        cleaned = re.sub(r'<[^>]*>', '', cleaned)
        cleaned = re.sub(r'\([^)]*\)', '', cleaned)
        cleaned = re.sub(r'\s+', ' ', cleaned)
        return cleaned.strip()

In [None]:
class AudioProcessor:
    """Processes audio chunks and extracts features."""
    
    def __init__(self, config: ADReSSConfig):
        self.config = config
    
    def load_audio_chunks(self, chunks_dir: str, subject_id: str) -> List[Dict]:
        """Load all normalized audio chunks for a subject. All chunks are participant speech."""
        chunks_path = Path(chunks_dir)
        pattern = f"{subject_id}-*.wav"
        chunk_files = list(chunks_path.glob(pattern))
        
        participant_chunks = []
        
        for chunk_file in chunk_files:
            filename = chunk_file.name
            audio, sr = librosa.load(str(chunk_file), sr=self.config.audio_sr)
            
            # Parse timing information from filename for sequencing
            # Format: S001-7-4266-13310-1-2400-3470.wav
            parts = filename.replace('.wav', '').split('-')
            start_time = None
            end_time = None
            chunk_num = None
            
            if len(parts) >= 5:
                try:
                    start_time = int(parts[2])
                    end_time = int(parts[3])
                    chunk_num = int(parts[4]) if len(parts) > 4 else 0
                except ValueError:
                    pass
            
            chunk_info = {
                'audio': audio,
                'sr': sr,
                'filename': filename,
                'file_path': str(chunk_file),
                'start_time': start_time,
                'end_time': end_time,
                'chunk_num': chunk_num
            }
            participant_chunks.append(chunk_info)
        
        # Sort chunks by temporal order (start_time, then chunk_num)
        participant_chunks.sort(key=lambda x: (x['start_time'] or 0, x['chunk_num'] or 0))
        
        return participant_chunks
    
    def aggregate_chunk_features(self, chunks: List[Dict]) -> Dict[str, float]:
        """Return minimal info without handcrafted features."""
        if not chunks:
            return {}
        
        total_duration = 0
        valid_chunks = 0
        
        for chunk in chunks:
            audio = chunk['audio']
            if len(audio) > 0:  # Skip empty audio chunks
                total_duration += len(audio) / self.config.audio_sr
                valid_chunks += 1
        
        return {
            'total_duration': total_duration,
            'num_chunks': valid_chunks
        }

In [None]:
class AlignmentProcessor:
    """Handles temporal alignment between transcript and audio chunks."""
    
    def __init__(self):
        pass
    
    def align_transcript_audio(self, participant_segments: List[Tuple], 
                             participant_chunks: List[Dict]) -> List[Dict]:
        """Align transcript segments with audio chunks by timestamp overlap."""
        aligned_pairs = []
        
        # Process segments
        for seg_start, seg_end, utterance in participant_segments:
            for chunk in participant_chunks:
                chunk_start = chunk.get('start_time')
                chunk_end = chunk.get('end_time')
                
                if chunk_start is None or chunk_end is None:
                    continue
                
                # Check for overlap
                if not (seg_end <= chunk_start or seg_start >= chunk_end):
                    aligned_pairs.append({
                        'text': utterance,
                        'audio_file': chunk['filename'],
                        'text_start': seg_start,
                        'text_end': seg_end,
                        'audio_start': chunk_start,
                        'audio_end': chunk_end,
                        'audio_data': chunk['audio']
                    })
        
        return aligned_pairs

In [None]:
class ADReSSDataLoader:
    """Main data loader class that orchestrates all preprocessing components for both training and test data."""
    
    def __init__(self, base_path: str, config: Optional[ADReSSConfig] = None):
        self.base_path = Path(base_path)
        self.config = config or ADReSSConfig()
        
        self.metadata_loader = MetadataLoader(self.base_path)
        self.transcription_processor = TranscriptionProcessor()
        self.audio_processor = AudioProcessor(self.config)
        self.alignment_processor = AlignmentProcessor()
        
        # Detect dataset type
        self.dataset_type = self.metadata_loader.dataset_type
        
    def load_dataset(self, 
                    use_audio_chunks: bool = True,
                    include_audio: bool = True,
                    include_text: bool = True,
                    align_modalities: bool = False) -> pd.DataFrame:
        """Load complete dataset with optional modality alignment."""
        
        metadata = self.metadata_loader.load_metadata()
        dataset_rows = []
        
        pbar_desc = f"Processing {self.dataset_type} subjects"
        for _, row in tqdm(metadata.iterrows(), desc=pbar_desc, total=len(metadata), leave=False):
            subject_id = row['ID']
            group = row['group']
            
            sample_data = {
                'subject_id': subject_id,
                'age': row['age'],
                'gender': row['gender'],
                'mmse': row['mmse'],
                'label': row['label'],
                'group': group
            }
            
            # Process transcription and audio
            if align_modalities and include_text and include_audio:
                aligned_data = self._get_aligned_data(subject_id, group)
                sample_data.update(aligned_data)
            else:
                # Process transcription
                if include_text:
                    transcript_path = self._get_transcript_path(subject_id, group)
                    if transcript_path and transcript_path.exists():
                        text_data = self.transcription_processor.extract_speech_with_timing(str(transcript_path))
                        sample_data.update({k: v for k, v in text_data.items() 
                                          if k not in ['participant_segments', 'investigator_segments']})
                
                # Process audio chunks
                if include_audio and use_audio_chunks:
                    chunks_dir = self._get_chunks_directory(group)
                    if chunks_dir and chunks_dir.exists():
                        try:
                            participant_chunks = self.audio_processor.load_audio_chunks(
                                str(chunks_dir), subject_id
                            )
                            
                            if participant_chunks:
                                basic_audio_info = self.audio_processor.aggregate_chunk_features(participant_chunks)
                                for key, value in basic_audio_info.items():
                                    sample_data[f'participant_{key}'] = value
                            
                            sample_data['total_participant_chunks'] = len(participant_chunks)
                            
                        except Exception as e:
                            print(f"Error processing audio chunks for {subject_id}: {e}")
                            sample_data['total_participant_chunks'] = 0
            
            dataset_rows.append(sample_data)
        
        return pd.DataFrame(dataset_rows)
    
    def _get_aligned_data(self, subject_id: str, group: str) -> Dict:
        """Get aligned text-audio pairs for a subject."""
        # Get transcript segments
        transcript_path = self._get_transcript_path(subject_id, group)
        if not transcript_path or not transcript_path.exists():
            return {}
        
        text_data = self.transcription_processor.extract_speech_with_timing(str(transcript_path))
        participant_segments = text_data['participant_segments']
        
        # Get audio chunks
        chunks_dir = self._get_chunks_directory(group)
        if not chunks_dir or not chunks_dir.exists():
            return {}
        
        participant_chunks = self.audio_processor.load_audio_chunks(str(chunks_dir), subject_id)
        
        # Use AlignmentProcessor to align segments with chunks
        aligned_pairs = self.alignment_processor.align_transcript_audio(
            participant_segments, participant_chunks
        )
        
        return {
            'participant_text': text_data['participant_text'],
            'investigator_text': text_data['investigator_text'],
            'total_utterances': text_data['total_utterances'],
            'aligned_pairs': aligned_pairs,
            'num_aligned_pairs': len(aligned_pairs)
        }
    
    def _get_transcript_path(self, subject_id: str, group: str) -> Optional[Path]:
        """Get transcript path based on dataset type."""
        if self.dataset_type == "training":
            return self.base_path / "transcription" / group / f"{subject_id}.cha"
        else:
            return self.base_path / "transcription" / f"{subject_id}.cha"
    
    def _get_chunks_directory(self, group: str) -> Optional[Path]:
        """Get audio chunks directory based on dataset type."""
        if self.dataset_type == "training":
            return self.base_path / "Normalised_audio-chunks" / group
        else:
            return self.base_path / "Normalised_audio-chunks"
    
    def get_subject_chunks(self, subject_id: str) -> Dict:
        """Get detailed chunk data for a specific subject."""
        metadata = self.metadata_loader.load_metadata()
        subject_meta = metadata[metadata['ID'] == subject_id].iloc[0]
        
        group = subject_meta['group']
        
        # Load transcription with timing
        transcript_path = self._get_transcript_path(subject_id, group)
        transcript_data = {}
        if transcript_path and transcript_path.exists():
            transcript_data = self.transcription_processor.extract_speech_with_timing(str(transcript_path))
        
        # Load audio chunks
        chunks_dir = self._get_chunks_directory(group)
        participant_chunks = []
        if chunks_dir and chunks_dir.exists():
            participant_chunks = self.audio_processor.load_audio_chunks(str(chunks_dir), subject_id)
        
        return {
            'metadata': subject_meta.to_dict(),
            'transcription': transcript_data,
            'participant_chunks': participant_chunks,
            'chunks_dir': str(chunks_dir) if chunks_dir else None,
            'transcript_path': str(transcript_path) if transcript_path else None,
            'dataset_type': self.dataset_type
        }
    
    def get_aligned_pairs(self, subject_id: str) -> List[Dict]:
        """Get aligned text-audio pairs for a specific subject."""
        metadata = self.metadata_loader.load_metadata()
        subject_meta = metadata[metadata['ID'] == subject_id].iloc[0]
        group = subject_meta['group']
        
        aligned_data = self._get_aligned_data(subject_id, group)
        return aligned_data.get('aligned_pairs', [])

In [None]:
class AudioEmbedder:
    def __init__(self, model_type='wav2vec2', device='cuda'):
        self.model_type = model_type
        self.device = device
        
        if model_type == 'wav2vec2':
            self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
            self.model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(device)
        elif model_type == 'whisper':
            self.processor = WhisperProcessor.from_pretrained("openai/whisper-small")
            self.model = WhisperModel.from_pretrained("openai/whisper-small").to(device)
        
        self.model.eval()
        
    def extract_embeddings(self, audio_chunks):
        embeddings = []
        
        for i, chunk in enumerate(audio_chunks):
            audio = chunk['audio']
            
            try:
                if self.model_type == 'wav2vec2':
                    inputs = self.processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
                    inputs = {k: v.to(self.device) for k, v in inputs.items()}
                    
                    with torch.no_grad():
                        outputs = self.model(**inputs)
                        embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
                elif self.model_type == 'whisper':
                    inputs = self.processor(audio, sampling_rate=16000, return_tensors="pt")
                    inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
                    
                    with torch.no_grad():
                        outputs = self.model.encoder(**inputs)
                        embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
                        
                embeddings.append(embedding[0])
                
            except Exception as e:
                print(f"Error processing chunk: {e}")
                embeddings.append(np.zeros(768))
                
        return np.array(embeddings)

class TextEmbedder:
    def __init__(self, model_type='clinicalbert', device='cuda'):
        self.model_type = model_type
        self.device = device
        
        if model_type == 'clinicalbert':
            model_name = "emilyalsentzer/Bio_ClinicalBERT"
        elif model_type == 'biobert':
            model_name = "dmis-lab/biobert-v1.1"
        else:
            model_name = "bert-base-uncased"
            
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(device)
        self.model.eval()
        
    def extract_embeddings(self, texts):
        embeddings = []
        
        for i, text in enumerate(texts):
            try:
                inputs = self.tokenizer(text, return_tensors="pt", padding=True, 
                                      truncation=True, max_length=512)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}
                
                with torch.no_grad():
                    outputs = self.model(**inputs)
                    embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
                    
                embeddings.append(embedding[0])
                
            except Exception as e:
                print(f"Error processing text: {e}")
                embeddings.append(np.zeros(768))
                
        return np.array(embeddings)

In [None]:
class MultimodalFusion(nn.Module):
    def __init__(self, audio_dim=768, text_dim=768, fusion_type='concat', output_dim=512, num_heads=8):
        super().__init__()
        self.fusion_type = fusion_type
        self.output_dim = output_dim
        
        if fusion_type == 'concat':
            self.fusion_layer = nn.Linear(audio_dim + text_dim, output_dim)
        elif fusion_type == 'cross_attention':
            self.audio_proj = nn.Linear(audio_dim, output_dim)
            self.text_proj = nn.Linear(text_dim, output_dim)
            self.cross_attn = nn.MultiheadAttention(output_dim, num_heads=num_heads, batch_first=True)
            self.norm = nn.LayerNorm(output_dim)
            self.fusion_layer = nn.Linear(output_dim * 2, output_dim)
        
    def forward(self, audio_emb, text_emb):
        if self.fusion_type == 'concat':
            combined = torch.cat([audio_emb, text_emb], dim=-1)
            return self.fusion_layer(combined)
        
        elif self.fusion_type == 'cross_attention':
            audio_proj = self.audio_proj(audio_emb).unsqueeze(1)
            text_proj = self.text_proj(text_emb).unsqueeze(1)
            
            audio_attended, _ = self.cross_attn(audio_proj, text_proj, text_proj)
            text_attended, _ = self.cross_attn(text_proj, audio_proj, audio_proj)
            
            audio_attended = self.norm(audio_attended.squeeze(1))
            text_attended = self.norm(text_attended.squeeze(1))
            
            combined = torch.cat([audio_attended, text_attended], dim=-1)
            return self.fusion_layer(combined)

In [None]:
@dataclass
class GenerationConfig:
    """Configuration for text and audio generation."""
    # Text generation
    text_model_name: str = "gpt2"
    max_text_length: int = 256
    text_temperature: float = 0.8
    
    # Audio generation  
    tts_model_name: str = "facebook/mms-tts-eng"
    sample_rate: int = 16000
    
    # Clinical conditioning
    condition_dim: int = 16
    num_samples_per_condition: int = 50
    
    # Training
    epochs: int = 100
    batch_size: int = 16
    learning_rate: float = 2e-5

class ClinicalTextGenerator(nn.Module):
    def __init__(self, config: GenerationConfig):
        super().__init__()
        from transformers import GPT2LMHeadModel, GPT2Tokenizer
        
        self.config = config
        self.tokenizer = GPT2Tokenizer.from_pretrained(config.text_model_name)
        self.model = GPT2LMHeadModel.from_pretrained(config.text_model_name)
        
        # Add padding token
        self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Clinical condition embedding
        self.condition_proj = nn.Linear(config.condition_dim, self.model.config.n_embd)
        
        # Freeze base model initially
        for param in self.model.parameters():
            param.requires_grad = False
            
        # Only train condition projection and final layer
        for param in self.model.lm_head.parameters():
            param.requires_grad = True
    
    def create_ad_prompts(self):
        """Create AD-specific prompts based on common assessment tasks."""
        prompts = [
            "Tell me about your daily routine.",
            "Describe what you see in this picture.",
            "What did you do yesterday?",
            "Tell me about your family.",
            "Describe how to make a sandwich.",
            "What is your favorite memory?",
            "Tell me about your hometown.",
            "Describe the changing seasons.",
            "What do you like to do for fun?",
            "Tell me about a typical day."
        ]
        return prompts
    
    def add_ad_characteristics(self, text, severity='mild'):
        """Add AD-specific linguistic patterns to generated text."""
        if severity == 'mild':
            # Mild: slight word-finding issues, occasional repetition
            text = re.sub(r'\b(\w+)\b', lambda m: f"{m.group(1)}... {m.group(1)}" if np.random.random() < 0.1 else m.group(1), text)
            text = re.sub(r'\.', '... um...', text) if np.random.random() < 0.2 else text
            
        elif severity == 'moderate':
            # Moderate: more pauses, word substitutions, incomplete sentences
            text = re.sub(r'\b\w{4,}\b', lambda m: "thing" if np.random.random() < 0.15 else m.group(0), text)
            text = re.sub(r'\.', '... what was I saying...', text) if np.random.random() < 0.3 else text
            text = text.replace('.', '... ') if np.random.random() < 0.4 else text
            
        return text
    
    def generate_text(self, condition_vector, severity='mild'):
        """Generate AD-specific text conditioned on clinical markers."""
        prompts = self.create_ad_prompts()
        prompt = np.random.choice(prompts)
        
        # Encode prompt
        inputs = self.tokenizer(prompt, return_tensors='pt', padding=True)
        input_ids = inputs['input_ids']
        
        # Add condition embedding to first token
        with torch.no_grad():
            condition_emb = self.condition_proj(condition_vector.unsqueeze(0))
            
            # Generate text
            output = self.model.generate(
                input_ids,
                max_length=self.config.max_text_length,
                temperature=self.config.text_temperature,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
                no_repeat_ngram_size=2
            )
            
        # Decode and process
        generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
        generated_text = generated_text.replace(prompt, "").strip()
        
        # Add AD characteristics based on severity
        if severity in ['mild', 'moderate']:
            generated_text = self.add_ad_characteristics(generated_text, severity)
            
        return generated_text

class ClinicalSpeechSynthesizer:
    def __init__(self, config: GenerationConfig):
        self.config = config
        
        # Use a simpler TTS approach that's more controllable
        import pyttsx3
        self.tts_engine = pyttsx3.init()
        
        # Set voice properties for clinical realism
        voices = self.tts_engine.getProperty('voices')
        if voices:
            self.tts_engine.setProperty('voice', voices[0].id)
    
    def modify_speech_for_ad(self, text, severity='mild', age_group='elderly'):
        """Add pauses and modify speech rate for AD characteristics."""
        
        # Add more pauses for AD patients
        if severity == 'mild':
            speech_rate = 140  # slightly slower
            text = re.sub(r'\.', '. <break time="0.8s"/>', text)
            text = re.sub(r',', ', <break time="0.3s"/>', text)
            
        elif severity == 'moderate':
            speech_rate = 110  # noticeably slower
            text = re.sub(r'\.', '. <break time="1.2s"/>', text)
            text = re.sub(r',', ', <break time="0.5s"/>', text)
            text = re.sub(r'\s+', ' <break time="0.2s"/>', text)  # More frequent pauses
            
        else:  # healthy
            speech_rate = 160
            
        # Adjust for age
        if age_group == 'elderly':
            speech_rate = max(speech_rate - 20, 100)
            
        return text, speech_rate
    
    def text_to_audio(self, text, condition_vector):
        """Convert text to audio with clinical characteristics."""
        # Extract clinical info from condition vector
        age = condition_vector[0].item()
        mmse = condition_vector[2].item() if len(condition_vector) > 2 else 0.8
        
        # Determine severity and age group
        if mmse > 0.8:
            severity = 'healthy'
        elif mmse > 0.6:
            severity = 'mild'
        else:
            severity = 'moderate'
            
        age_group = 'elderly' if age > 0.7 else 'adult'
        
        # Modify text for AD characteristics
        modified_text, speech_rate = self.modify_speech_for_ad(text, severity, age_group)
        
        # Set speech properties
        self.tts_engine.setProperty('rate', speech_rate)
        
        # Generate audio to temporary file
        import tempfile
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
            self.tts_engine.save_to_file(modified_text, tmp_file.name)
            self.tts_engine.runAndWait()
            
            # Load and return audio
            audio, sr = librosa.load(tmp_file.name, sr=self.config.sample_rate)
            
        return audio

class AlignedDataGenerator:
    def __init__(self, config: GenerationConfig, device='cuda'):
        self.config = config
        self.device = device
        
        self.text_generator = ClinicalTextGenerator(config)
        self.speech_synthesizer = ClinicalSpeechSynthesizer(config)
        
    def create_clinical_condition(self, label, age_group='mixed'):
        """Create condition vector with clinical realism."""
        condition = torch.zeros(self.config.condition_dim)
        
        # Age (0-1 normalized)
        if age_group == 'young':
            age = np.random.uniform(0.2, 0.4)
        elif age_group == 'elderly':
            age = np.random.uniform(0.7, 0.95)
        else:
            age = np.random.uniform(0.4, 0.9)
        condition[0] = age
        
        # Gender (0=F, 1=M)
        condition[1] = np.random.choice([0, 1])
        
        # MMSE (0-1 normalized, realistic based on label)
        if label == 0:  # Healthy
            mmse = np.random.uniform(0.8, 1.0)
        else:  # AD
            mmse = np.random.uniform(0.1, 0.7)
        condition[2] = mmse
        
        # Additional clinical markers
        condition[3] = label  # AD status
        condition[4:8] = torch.randn(4) * 0.1  # Other clinical factors
        
        # Demographics and audio characteristics
        condition[8:] = torch.randn(self.config.condition_dim - 8) * 0.2
        
        return condition
    
    def generate_aligned_sample(self, label):
        """Generate a single aligned text-audio pair."""
        # Create clinical condition
        condition = self.create_clinical_condition(label)
        
        # Determine AD severity from condition
        mmse = condition[2].item()
        if mmse > 0.8:
            severity = 'healthy'
        elif mmse > 0.6:
            severity = 'mild'
        else:
            severity = 'moderate'
        
        # Generate text
        generated_text = self.text_generator.generate_text(condition, severity)
        
        # Generate corresponding audio
        audio = self.speech_synthesizer.text_to_audio(generated_text, condition)
        
        return {
            'text': generated_text,
            'audio': audio,
            'label': label,
            'condition': condition.numpy(),
            'mmse': mmse,
            'severity': severity
        }
    
    def generate_balanced_dataset(self, samples_per_class=100):
        """Generate balanced dataset of aligned text-audio pairs."""
        synthetic_data = []
        
        print(f"Generating {samples_per_class * 2} aligned text-audio pairs...")
        
        with tqdm(total=samples_per_class * 2, desc="Generating samples") as pbar:
            for label in [0, 1]:  # Healthy, AD
                for i in range(samples_per_class):
                    try:
                        sample = self.generate_aligned_sample(label)
                        synthetic_data.append(sample)
                        pbar.update(1)
                    except Exception as e:
                        print(f"Error generating sample: {e}")
                        continue
        
        return synthetic_data

In [None]:
class SyntheticDatasetProcessor:
    def __init__(self, audio_embedder, text_embedder, fusion_model):
        self.audio_embedder = audio_embedder
        self.text_embedder = text_embedder
        self.fusion_model = fusion_model
    
    def process_synthetic_data(self, synthetic_samples):
        """Convert synthetic text-audio pairs to embeddings for training."""
        processed_data = []
        labels = []
        
        print("Processing synthetic samples for training...")
        
        for sample in tqdm(synthetic_samples, desc="Processing"):
            try:
                # Convert audio to chunks format expected by embedder
                audio_chunks = [{'audio': sample['audio']}]
                
                # Extract embeddings
                audio_emb = self.audio_embedder.extract_embeddings(audio_chunks).mean(axis=0)
                text_emb = self.text_embedder.extract_embeddings([sample['text']])[0]
                
                # Create fused embedding
                audio_tensor = torch.tensor(audio_emb, dtype=torch.float32)
                text_tensor = torch.tensor(text_emb, dtype=torch.float32)
                
                with torch.no_grad():
                    fused_emb = self.fusion_model(audio_tensor.unsqueeze(0), text_tensor.unsqueeze(0))
                
                processed_data.append(fused_emb.squeeze(0).numpy())
                labels.append(sample['label'])
                
            except Exception as e:
                print(f"Error processing sample: {e}")
                continue
        
        return np.array(processed_data), np.array(labels)

def evaluate_synthetic_quality(original_data, synthetic_data, labels):
    """Evaluate quality of synthetic data."""
    metrics = {}
    
    # Frechet distance
    mu_real = np.mean(original_data, axis=0)
    mu_synth = np.mean(synthetic_data, axis=0)
    sigma_real = np.cov(original_data.T)
    sigma_synth = np.cov(synthetic_data.T)
    
    diff = mu_real - mu_synth
    covmean = scipy.linalg.sqrtm(sigma_real.dot(sigma_synth))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    fid = diff.dot(diff) + np.trace(sigma_real) + np.trace(sigma_synth) - 2 * np.trace(covmean)
    metrics['frechet_distance'] = fid
    
    # Label distribution
    unique_labels, counts = np.unique(labels, return_counts=True)
    label_dist = dict(zip(unique_labels, counts))
    metrics['label_distribution'] = label_dist
    
    # KS test for distribution similarity
    ks_stats = []
    for dim in range(min(original_data.shape[1], synthetic_data.shape[1])):
        stat, _ = ks_2samp(original_data[:, dim], synthetic_data[:, dim])
        ks_stats.append(stat)
    metrics['ks_statistic'] = np.mean(ks_stats)
    
    return metrics

class ExperimentPipeline:
    def __init__(self, train_loader, test_loader, device='cuda'):
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        
    def run_generation_experiment(self, audio_model='wav2vec2', text_model='clinicalbert', 
                                fusion_type='cross_attention', samples_per_class=100):
        """Run complete text→audio generation experiment."""
        
        config_name = f"{audio_model}_{text_model}_{fusion_type}_generation"
        print(f"\\nRunning generation experiment: {config_name}")
        
        # Initialize models
        audio_embedder = AudioEmbedder(audio_model, self.device)
        text_embedder = TextEmbedder(text_model, self.device)
        fusion_model = MultimodalFusion(
            fusion_type=fusion_type, 
            output_dim=512
        ).to(self.device)
        
        # Create baseline dataset
        train_dataset = self._create_baseline_dataset(audio_embedder, text_embedder, fusion_model)
        
        # Baseline performance
        baseline_results = self._evaluate_baseline(train_dataset)
        
        # Generate synthetic data
        generation_config = GenerationConfig(num_samples_per_condition=samples_per_class)
        data_generator = AlignedDataGenerator(generation_config, self.device)
        synthetic_samples = data_generator.generate_balanced_dataset(samples_per_class)
        
        # Process synthetic data
        processor = SyntheticDatasetProcessor(audio_embedder, text_embedder, fusion_model)
        synthetic_embeddings, synthetic_labels = processor.process_synthetic_data(synthetic_samples)
        
        # Evaluate synthetic data quality
        quality_metrics = evaluate_synthetic_quality(train_dataset.data, synthetic_embeddings, synthetic_labels)
        
        # Train with augmented data
        augmented_results = self._evaluate_with_augmentation(
            train_dataset, synthetic_embeddings, synthetic_labels, baseline_results['split_data']
        )
        
        # Generate test predictions
        test_results = self._generate_test_predictions(
            audio_embedder, text_embedder, fusion_model,
            augmented_results['model'], augmented_results['scaler'], config_name
        )
        
        # Compile results
        improvement = augmented_results['accuracy'] - baseline_results['accuracy']
        
        result = {
            'config': config_name,
            'audio_model': audio_model,
            'text_model': text_model,
            'fusion_type': fusion_type,
            'baseline_accuracy': baseline_results['accuracy'],
            'baseline_f1': baseline_results['f1_score'],
            'baseline_cv': f"{baseline_results['cv_mean']:.3f}±{baseline_results['cv_std']:.3f}",
            'augmented_accuracy': augmented_results['accuracy'],
            'augmented_f1': augmented_results['f1_score'],
            'augmented_cv': f"{augmented_results['cv_mean']:.3f}±{augmented_results['cv_std']:.3f}",
            'improvement': improvement,
            'num_synthetic_samples': len(synthetic_samples),
            **quality_metrics
        }
        
        print(f"Results for {config_name}:")
        print(f"  Generated {len(synthetic_samples)} aligned text-audio pairs")
        print(f"  Baseline Acc: {baseline_results['accuracy']:.3f}")
        print(f"  Augmented Acc: {augmented_results['accuracy']:.3f}")
        print(f"  Improvement: {improvement:+.3f}")
        print(f"  Synthetic Quality (FID): {quality_metrics['frechet_distance']:.3f}")
        
        return result
    
    def _create_baseline_dataset(self, audio_embedder, text_embedder, fusion_model):
        """Create baseline dataset from real data."""
        embeddings = []
        labels = []
        
        dataset = self.train_loader.load_dataset(align_modalities=True)
        
        for _, row in tqdm(dataset.iterrows(), desc="Creating baseline dataset", total=len(dataset), leave=False):
            try:
                subject_chunks = self.train_loader.get_subject_chunks(row['subject_id'])
                audio_chunks = subject_chunks['participant_chunks']
                text = row['participant_text']
                
                if audio_chunks and text and len(text.strip()) > 0:
                    audio_emb = audio_embedder.extract_embeddings(audio_chunks).mean(axis=0)
                    text_emb = text_embedder.extract_embeddings([text])[0]
                    
                    audio_tensor = torch.tensor(audio_emb, dtype=torch.float32).to(self.device)
                    text_tensor = torch.tensor(text_emb, dtype=torch.float32).to(self.device)
                    
                    with torch.no_grad():
                        fused_emb = fusion_model(audio_tensor.unsqueeze(0), text_tensor.unsqueeze(0))
                    
                    embeddings.append(fused_emb.squeeze(0).cpu().numpy())
                    labels.append(int(row['label']))
                    
            except Exception:
                continue
        
        class SimpleDataset:
            def __init__(self, data, labels):
                self.data = np.array(data)
                self.labels = np.array(labels)
        
        return SimpleDataset(embeddings, labels)
    
    def _evaluate_baseline(self, dataset):
        """Evaluate baseline performance."""
        X_train, X_val, y_train, y_val = train_test_split(
            dataset.data, dataset.labels, test_size=0.2, 
            stratify=dataset.labels, random_state=42
        )
        
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_val_scaled = scaler.transform(X_val)
        
        svm_clf = SVC(C=2.0, kernel='rbf', probability=True, random_state=42)
        svm_clf.fit(X_train_scaled, y_train)
        
        y_pred = svm_clf.predict(X_val_scaled)
        accuracy = accuracy_score(y_val, y_pred)
        f1 = f1_score(y_val, y_pred)
        
        # Cross-validation
        skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
        cv_scores = []
        
        for train_idx, val_idx in skf.split(X_train, y_train):
            X_cv_train, X_cv_val = X_train[train_idx], X_train[val_idx]
            y_cv_train, y_cv_val = y_train[train_idx], y_train[val_idx]
            
            X_cv_train_scaled = scaler.fit_transform(X_cv_train)
            X_cv_val_scaled = scaler.transform(X_cv_val)
            
            svm_cv = SVC(C=2.0, kernel='rbf', random_state=42)
            svm_cv.fit(X_cv_train_scaled, y_cv_train)
            cv_scores.append(accuracy_score(y_cv_val, svm_cv.predict(X_cv_val_scaled)))
        
        return {
            'accuracy': accuracy,
            'f1_score': f1,
            'cv_mean': np.mean(cv_scores),
            'cv_std': np.std(cv_scores),
            'split_data': (X_train, X_val, y_train, y_val),
            'scaler': scaler,
            'model': svm_clf
        }
    
    def _evaluate_with_augmentation(self, dataset, synthetic_embeddings, synthetic_labels, baseline_split):
        """Evaluate with synthetic data augmentation."""
        X_train, X_val, y_train, y_val = baseline_split
        
        # Augment training data
        X_aug = np.concatenate([X_train, synthetic_embeddings])
        y_aug = np.concatenate([y_train, synthetic_labels])
        
        scaler = StandardScaler()
        X_aug_scaled = scaler.fit_transform(X_aug)
        X_val_scaled = scaler.transform(X_val)
        
        svm_aug = SVC(C=2.0, kernel='rbf', probability=True, random_state=42)
        svm_aug.fit(X_aug_scaled, y_aug)
        
        y_pred = svm_aug.predict(X_val_scaled)
        accuracy = accuracy_score(y_val, y_pred)
        f1 = f1_score(y_val, y_pred)
        
        # Cross-validation
        skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)
        cv_scores = []
        
        syn_subset_size = min(len(synthetic_embeddings), len(X_train) // 2)
        syn_indices = np.random.choice(len(synthetic_embeddings), syn_subset_size, replace=False)
        syn_subset = synthetic_embeddings[syn_indices]
        syn_labels_subset = synthetic_labels[syn_indices]
        
        for train_idx, val_idx in skf.split(X_train, y_train):
            X_cv_train, X_cv_val = X_train[train_idx], X_train[val_idx]
            y_cv_train, y_cv_val = y_train[train_idx], y_train[val_idx]
            
            X_cv_aug = np.concatenate([X_cv_train, syn_subset])
            y_cv_aug = np.concatenate([y_cv_train, syn_labels_subset])
            
            X_cv_aug_scaled = scaler.fit_transform(X_cv_aug)
            X_cv_val_scaled = scaler.transform(X_cv_val)
            
            svm_cv = SVC(C=2.0, kernel='rbf', random_state=42)
            svm_cv.fit(X_cv_aug_scaled, y_cv_aug)
            cv_scores.append(accuracy_score(y_cv_val, svm_cv.predict(X_cv_val_scaled)))
        
        return {
            'accuracy': accuracy,
            'f1_score': f1,
            'cv_mean': np.mean(cv_scores),
            'cv_std': np.std(cv_scores),
            'model': svm_aug,
            'scaler': scaler
        }
    
    def _generate_test_predictions(self, audio_embedder, text_embedder, fusion_model, 
                                 trained_svm, scaler, config_name):
        """Generate test predictions."""
        print(f"Generating test predictions for {config_name}...")
        
        # Create test embeddings
        test_embeddings = []
        test_dataset = self.test_loader.load_dataset(align_modalities=True)
        
        for _, row in tqdm(test_dataset.iterrows(), desc="Processing test data", total=len(test_dataset), leave=False):
            try:
                subject_chunks = self.test_loader.get_subject_chunks(row['subject_id'])
                audio_chunks = subject_chunks['participant_chunks']
                text = row['participant_text'] if pd.notna(row.get('participant_text')) else ""
                
                if audio_chunks and len(audio_chunks) > 0:
                    audio_emb = audio_embedder.extract_embeddings(audio_chunks).mean(axis=0)
                    if text and len(text.strip()) > 0:
                        text_emb = text_embedder.extract_embeddings([text])[0]
                    else:
                        text_emb = np.zeros(768)  # Default if no text
                    
                    audio_tensor = torch.tensor(audio_emb, dtype=torch.float32).to(self.device)
                    text_tensor = torch.tensor(text_emb, dtype=torch.float32).to(self.device)
                    
                    with torch.no_grad():
                        fused_emb = fusion_model(audio_tensor.unsqueeze(0), text_tensor.unsqueeze(0))
                    
                    test_embeddings.append(fused_emb.squeeze(0).cpu().numpy())
                else:
                    # Fallback for missing audio
                    test_embeddings.append(np.zeros(512))
                    
            except Exception:
                test_embeddings.append(np.zeros(512))
        
        # Generate predictions
        X_test = np.array(test_embeddings)
        X_test_scaled = scaler.transform(X_test)
        test_preds = trained_svm.predict(X_test_scaled)
        
        # Save in ADReSS format
        filename = f'test_results_generation_{config_name}.txt'
        with open(filename, 'w') as f:
            f.write('ID   ; Prediction\n')
            for i, pred in enumerate(test_preds):
                subject_id = f'S{i+160:03d}'
                f.write(f'{subject_id} ; {pred}\n')
        
        print(f"Test predictions saved to {filename}")
        
        return {
            'predictions': test_preds,
            'filename': filename
        }

In [None]:
def plot_generation_results(result, synthetic_samples):
    """Create plots for text-audio generation results."""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle(f'Text→Audio Generation Results: {result["config"]}', fontsize=16)
    
    # Performance comparison
    metrics = ['baseline_accuracy', 'augmented_accuracy', 'baseline_f1', 'augmented_f1']
    values = [result[m] for m in metrics]
    colors = ['lightblue', 'darkblue', 'lightgreen', 'darkgreen']
    
    bars = axes[0, 0].bar(range(len(metrics)), values, color=colors, alpha=0.7)
    axes[0, 0].set_title('Performance Comparison')
    axes[0, 0].set_xticks(range(len(metrics)))
    axes[0, 0].set_xticklabels([m.replace('_', ' ').title() for m in metrics], rotation=45)
    axes[0, 0].set_ylabel('Score')
    
    for i, (bar, value) in enumerate(zip(bars, values)):
        axes[0, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                       f'{value:.3f}', ha='center', va='bottom', fontsize=10)
    
    # Sample text lengths by severity
    text_lengths = {}
    severities = ['healthy', 'mild', 'moderate']
    
    for severity in severities:
        lengths = [len(s['text'].split()) for s in synthetic_samples if s['severity'] == severity]
        text_lengths[severity] = lengths
    
    axes[0, 1].boxplot([text_lengths[s] for s in severities], labels=severities)
    axes[0, 1].set_title('Generated Text Lengths by Severity')
    axes[0, 1].set_ylabel('Words per Sample')
    
    # MMSE distribution
    mmse_healthy = [s['mmse'] for s in synthetic_samples if s['label'] == 0]
    mmse_ad = [s['mmse'] for s in synthetic_samples if s['label'] == 1]
    
    axes[0, 2].hist(mmse_healthy, alpha=0.7, label='Healthy', bins=15, color='green')
    axes[0, 2].hist(mmse_ad, alpha=0.7, label='AD', bins=15, color='red')
    axes[0, 2].set_title('MMSE Distribution in Generated Samples')
    axes[0, 2].set_xlabel('MMSE Score')
    axes[0, 2].set_ylabel('Count')
    axes[0, 2].legend()
    
    # Audio duration distribution
    audio_durations = [len(s['audio']) / 16000 for s in synthetic_samples]  # Convert to seconds
    axes[1, 0].hist(audio_durations, bins=20, alpha=0.7, color='purple')
    axes[1, 0].set_title('Generated Audio Durations')
    axes[1, 0].set_xlabel('Duration (seconds)')
    axes[1, 0].set_ylabel('Count')
    
    # Quality metrics
    quality_keys = ['frechet_distance', 'ks_statistic']
    quality_values = [result[k] for k in quality_keys if k in result]
    if quality_values:
        axes[1, 1].bar(range(len(quality_keys)), quality_values, alpha=0.7, color='orange')
        axes[1, 1].set_title('Synthetic Data Quality')
        axes[1, 1].set_xticks(range(len(quality_keys)))
        axes[1, 1].set_xticklabels([k.replace('_', ' ').title() for k in quality_keys])
        axes[1, 1].set_ylabel('Score (lower is better)')
    
    # Improvement visualization
    improvement = result['improvement']
    color = 'green' if improvement > 0 else 'red'
    axes[1, 2].bar(['Improvement'], [improvement], color=color, alpha=0.7)
    axes[1, 2].set_title('Performance Improvement')
    axes[1, 2].set_ylabel('Accuracy Change')
    axes[1, 2].axhline(y=0, color='black', linestyle='-', alpha=0.5)
    axes[1, 2].text(0, improvement + 0.01 if improvement > 0 else improvement - 0.01,
                   f'{improvement:+.3f}', ha='center', va='bottom' if improvement > 0 else 'top',
                   fontweight='bold', fontsize=12)
    
    plt.tight_layout()
    plt.savefig(f'generation_results_{result["config"]}.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    return fig

def display_sample_generations(synthetic_samples, num_samples=3):
    """Display sample generated text-audio pairs."""
    print("\n" + "="*60)
    print("SAMPLE GENERATED TEXT-AUDIO PAIRS")
    print("="*60)
    
    for i, sample in enumerate(synthetic_samples[:num_samples]):
        print(f"\nSample {i+1}:")
        print(f"Label: {'AD' if sample['label'] == 1 else 'Healthy'}")
        print(f"MMSE: {sample['mmse']:.2f}")
        print(f"Severity: {sample['severity']}")
        print(f"Text: \"{sample['text']}\"")
        print(f"Audio duration: {len(sample['audio'])/16000:.2f} seconds")
        print(f"Text length: {len(sample['text'].split())} words")
        print("-" * 50)

In [None]:
# Main execution for text→audio generation pipeline
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Set seeds
    torch.manual_seed(42)
    np.random.seed(42)
    
    print("="*60)
    print("TEXT→AUDIO GENERATION PIPELINE FOR AD DETECTION")
    print("="*60)
    
    # Load datasets
    print("Loading ADReSS datasets...")
    
    # Paths (adjust for your environment)
    train_path = "ADReSS-IS2020-train/ADReSS-IS2020-data/train"  # Local
    test_path = "ADReSS-IS2020-test/ADReSS-IS2020-data/test"
    # train_path = "/kaggle/input/adress/adress/train"  # Kaggle
    # test_path = "/kaggle/input/adress/adress/test"
    
    train_config = ADReSSConfig(train_path=train_path)
    test_config = ADReSSConfig(train_path=test_path)
    
    train_loader = ADReSSDataLoader(train_path, train_config)
    test_loader = ADReSSDataLoader(test_path, test_config)
    
    # Initialize pipeline
    pipeline = ExperimentPipeline(train_loader, test_loader, device)
    
    # Single configuration: ClinicalBERT + Whisper + Cross Attention
    config = {'audio_model': 'whisper', 'text_model': 'clinicalbert', 'fusion_type': 'cross_attention'}
    
    print("Running ClinicalBERT + Whisper + Cross Attention generation experiment...")
    print(f"Configuration: {config}")
    
    try:
        # Run experiment
        result = pipeline.run_generation_experiment(
            audio_model=config['audio_model'],
            text_model=config['text_model'], 
            fusion_type=config['fusion_type'],
            samples_per_class=100  # Generate more samples for single config
        )
        
        # Generate sample data for visualization
        generation_config = GenerationConfig(num_samples_per_condition=10)
        data_generator = AlignedDataGenerator(generation_config, device)
        sample_synthetic = data_generator.generate_balanced_dataset(10)
        
        # Store results
        results = [result]
        all_synthetic_samples = {result['config']: sample_synthetic}
        
        # Create plots
        plot_generation_results(result, sample_synthetic)
        
        # Display sample generations
        display_sample_generations(sample_synthetic)
        
    except Exception as e:
        print(f"Error in experiment: {e}")
        results = []
    
    print("\\n" + "="*60)
    print("EXPERIMENT SUMMARY")
    print("="*60)
    
    if results:
        # Single result
        result = results[0]
        
        # Save result
        results_df = pd.DataFrame([result])
        results_df.to_csv('text_audio_generation_results.csv', index=False)
        
        print(f"\nExperiment completed successfully!")
        print("Results saved to text_audio_generation_results.csv")
        
        # Display summary
        print("\nRESULT SUMMARY:")
        print("-" * 40)
        print(f"Configuration: {result['config']}")
        print(f"  Baseline: {result['baseline_accuracy']:.3f}")
        print(f"  Augmented: {result['augmented_accuracy']:.3f}")
        print(f"  Improvement: {result['improvement']:+.3f}")
        print(f"  Synthetic samples: {result['num_synthetic_samples']}")
        print(f"  Fréchet Distance: {result['frechet_distance']:.3f}")
        
        # Create simple summary plot
        plt.figure(figsize=(10, 6))
        
        plt.subplot(1, 2, 1)
        metrics = ['Baseline', 'Augmented']
        values = [result['baseline_accuracy'], result['augmented_accuracy']]
        colors = ['lightblue', 'darkblue']
        bars = plt.bar(metrics, values, color=colors, alpha=0.7)
        plt.title('Performance Comparison')
        plt.ylabel('Accuracy')
        for bar, value in zip(bars, values):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{value:.3f}', ha='center', va='bottom', fontsize=12)
        
        plt.subplot(1, 2, 2)
        improvement = result['improvement']
        color = 'green' if improvement > 0 else 'red'
        plt.bar(['Improvement'], [improvement], color=color, alpha=0.7)
        plt.title('Performance Improvement')
        plt.ylabel('Accuracy Change')
        plt.axhline(y=0, color='black', linestyle='-', alpha=0.5)
        plt.text(0, improvement + 0.01 if improvement > 0 else improvement - 0.01,
                f'{improvement:+.3f}', ha='center', va='bottom' if improvement > 0 else 'top',
                fontweight='bold', fontsize=12)
        
        plt.tight_layout()
        plt.savefig('whisper_clinicalbert_generation_results.png', dpi=150, bbox_inches='tight')
        plt.show()
        
        print("\nGenerated plots and CSV file contain detailed results.")
        
    else:
        print("Experiment failed.")
        print("Check error messages above for debugging.")