# TREA Dataset Exploration

This notebook explores the TREA (Temporal Reasoning Evaluation of Audio) dataset.

## Contents
1. Load and inspect the dataset
2. Analyze task distributions
3. Examine audio files
4. Visualize question patterns
5. Listen to sample audio clips

In [None]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import librosa
import librosa.display
from IPython.display import Audio, display

from src.data_loader import load_trea_dataset

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')
%matplotlib inline

## 1. Load Dataset

In [None]:
# Load dataset with 30 samples per task
dataset = load_trea_dataset(
    data_dir='../TREA_dataset',
    tasks=['count', 'order', 'duration'],
    samples_per_task=30,
    random_seed=42
)

print(f"Total samples: {len(dataset)}")
print(f"\nDataset statistics:")
stats = dataset.get_statistics()
for task, task_stats in stats['tasks'].items():
    print(f"  {task}: {task_stats['count']} samples")

## 2. Task Distribution

In [None]:
# Count samples per task
task_counts = pd.Series([sample['task'] for sample in dataset.data]).value_counts()

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
task_counts.plot(kind='bar', ax=ax)
ax.set_title('TREA Dataset: Task Distribution', fontsize=14, fontweight='bold')
ax.set_xlabel('Task', fontsize=12)
ax.set_ylabel('Number of Samples', fontsize=12)
ax.grid(axis='y', alpha=0.3)
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()

## 3. Examine Sample Questions

In [None]:
# Show examples from each task
for task in ['count', 'order', 'duration']:
    print(f"\n{'='*70}")
    print(f"TASK: {task.upper()}")
    print(f"{'='*70}\n")
    
    task_samples = dataset.get_by_task(task)
    for i, sample in enumerate(task_samples[:3], 1):
        print(f"Example {i}:")
        print(f"Question: {sample['question']}")
        print(f"Options:")
        for key, value in sample['options'].items():
            marker = '✓' if key == sample['correct_answer'] else ' '
            print(f"  [{marker}] ({key}) {value}")
        print(f"Correct Answer: {sample['correct_answer']}")
        print()

## 4. Answer Distribution

In [None]:
# Analyze answer distribution
answers = pd.Series([sample['correct_answer'] for sample in dataset.data])
answer_counts = answers.value_counts()

fig, ax = plt.subplots(figsize=(8, 6))
answer_counts.plot(kind='bar', ax=ax, color='steelblue')
ax.set_title('Correct Answer Distribution', fontsize=14, fontweight='bold')
ax.set_xlabel('Answer Option', fontsize=12)
ax.set_ylabel('Count', fontsize=12)
ax.axhline(y=len(dataset)/4, color='red', linestyle='--', label='Uniform distribution')
ax.legend()
ax.grid(axis='y', alpha=0.3)
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()

print(f"\nAnswer distribution:")
for option, count in answer_counts.items():
    print(f"  {option}: {count} ({count/len(dataset)*100:.1f}%)")

## 5. Audio File Analysis

In [None]:
# Analyze audio durations
durations = []
for sample in dataset.data[:30]:  # Sample subset for speed
    audio, sr = librosa.load(sample['audio_path'], sr=None)
    durations.append(len(audio) / sr)

durations = np.array(durations)

print(f"Audio Duration Statistics (seconds):")
print(f"  Mean: {durations.mean():.2f}")
print(f"  Std:  {durations.std():.2f}")
print(f"  Min:  {durations.min():.2f}")
print(f"  Max:  {durations.max():.2f}")

# Plot duration distribution
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(durations, bins=20, edgecolor='black', alpha=0.7)
ax.set_title('Audio Duration Distribution', fontsize=14, fontweight='bold')
ax.set_xlabel('Duration (seconds)', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.axvline(durations.mean(), color='red', linestyle='--', label=f'Mean: {durations.mean():.2f}s')
ax.legend()
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Visualize Sample Audio

In [None]:
# Select a sample
sample = dataset.data[0]

print(f"Task: {sample['task']}")
print(f"Question: {sample['question']}")
print(f"\nOptions:")
for key, value in sample['options'].items():
    marker = '✓' if key == sample['correct_answer'] else ' '
    print(f"  [{marker}] ({key}) {value}")

# Load audio
audio, sr = librosa.load(sample['audio_path'], sr=16000)

# Play audio
print(f"\nPlay audio:")
display(Audio(audio, rate=sr))

# Visualize waveform
fig, axes = plt.subplots(2, 1, figsize=(14, 8))

# Waveform
librosa.display.waveshow(audio, sr=sr, ax=axes[0])
axes[0].set_title('Waveform', fontsize=12, fontweight='bold')
axes[0].set_ylabel('Amplitude')

# Spectrogram
D = librosa.amplitude_to_db(np.abs(librosa.stft(audio)), ref=np.max)
img = librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', ax=axes[1])
axes[1].set_title('Spectrogram', fontsize=12, fontweight='bold')
axes[1].set_ylabel('Frequency (Hz)')
fig.colorbar(img, ax=axes[1], format='%+2.0f dB')

plt.tight_layout()
plt.show()

## 7. Event Detection

In [None]:
# Detect events in the audio
from src.utils import AudioProcessor

processor = AudioProcessor()
events = processor.extract_events(audio, sr, top_db=20)

print(f"Detected {len(events)} events:")
for i, (start, end) in enumerate(events, 1):
    duration = (end - start) / sr
    print(f"  Event {i}: {start/sr:.2f}s - {end/sr:.2f}s (duration: {duration:.2f}s)")

# Visualize events on waveform
fig, ax = plt.subplots(figsize=(14, 4))
librosa.display.waveshow(audio, sr=sr, ax=ax, alpha=0.6)
ax.set_title('Detected Sound Events', fontsize=12, fontweight='bold')
ax.set_ylabel('Amplitude')

# Mark events
for i, (start, end) in enumerate(events):
    ax.axvspan(start/sr, end/sr, alpha=0.3, label=f'Event {i+1}')

if len(events) <= 5:
    ax.legend(loc='upper right')

plt.tight_layout()
plt.show()

## 8. Task-Specific Analysis

In [None]:
# Analyze question patterns by task
task_questions = {}
for task in ['count', 'order', 'duration']:
    task_samples = dataset.get_by_task(task)
    questions = [s['question'] for s in task_samples]
    task_questions[task] = questions
    
    print(f"\n{task.upper()} Task - Question Patterns:")
    print(f"  Total unique questions: {len(set(questions))}")
    
    # Most common patterns
    from collections import Counter
    question_counts = Counter(questions)
    print(f"  Most common:")
    for q, count in question_counts.most_common(3):
        print(f"    ({count}x) {q[:60]}...")

## Summary

This notebook explored the TREA dataset:
- ✅ Loaded 90 samples (30 per task)
- ✅ Analyzed task distribution
- ✅ Examined question patterns
- ✅ Visualized audio files
- ✅ Detected sound events

**Next Steps:**
1. Test model predictions (notebook 02)
2. Run FESTA experiments (notebook 03)