# SEED-DV Dataset EEG Data Exploration

This notebook is dedicated to exploring EEG data in the SEED-DV dataset, including:
- Loading and analyzing EEG data from 20 subjects
- Basic shape, channel count, and time dimensions of EEG signals
- Data quality checks and statistical analysis
- EEG channel distribution and signal characteristics
- Comparative analysis between subjects

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import os
import sys
from glob import glob
import warnings
warnings.filterwarnings('ignore')

# Add project root directory to Python path
project_root = Path('.').absolute().parent
sys.path.append(str(project_root))

# Set matplotlib display parameters
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 10
sns.set_style("whitegrid")

# Directly specify SEED-DV dataset path
# Please modify this path according to your actual situation
SEED_DV_DATA_PATH = project_root / 'data' / 'SEED-DV'/'EEG'

print(f"Project root directory: {project_root}")
print(f"Current working directory: {os.getcwd()}")
print(f"SEED-DV data path: {SEED_DV_DATA_PATH}")

## 1. SEED-DV Dataset Path Location

In [None]:
# Check if SEED-DV data path exists
print("=== SEED-DV Dataset Check ===")

if not SEED_DV_DATA_PATH.exists():
    print(f"❌ Data path does not exist: {SEED_DV_DATA_PATH}")
    print("Please confirm:")
    print("1. SEED-DV dataset has been downloaded")
    print("2. Dataset is placed in the correct location")
    print("3. Modify SEED_DV_DATA_PATH in the cell above to the correct path")
else:
    print(f"✅ Found data path: {SEED_DV_DATA_PATH}")
    
    # Search for all .npy files
    npy_files = list(SEED_DV_DATA_PATH.rglob('*.npy'))
    
    print(f"\nFound .npy files: {len(npy_files)} files")
    
    if len(npy_files) > 0:
        print("First 10 files:")
        for i, npy_file in enumerate(npy_files[:10]):
            size_mb = npy_file.stat().st_size / (1024 * 1024)
            print(f"  {i+1:2d}. {npy_file.name} ({size_mb:.2f} MB)")
        
        if len(npy_files) > 10:
            print(f"  ... and {len(npy_files)-10} more files")
            
        # Save file list for subsequent use
        subject_files = sorted(npy_files)
        print(f"\n✅ Ready to analyze {len(subject_files)} data files")
    else:
        print("❌ No .npy data files found")
        subject_files = []

## 3. Load and Analyze Data from the First Subject

In [None]:
# Load the first file as a sample for analysis
if subject_files:
    sample_file = subject_files[0]
    print(f"=== Analyzing sample file: {sample_file.name} ===")
    
    try:
        # Load data
        sample_data = np.load(sample_file)
        
        print(f"Data type: {type(sample_data)}")
        print(f"Data shape: {sample_data.shape}")
        print(f"Data dtype: {sample_data.dtype}")
        print(f"Data size: {sample_data.nbytes / (1024*1024):.2f} MB")
        
        if sample_data.ndim >= 2:
            print(f"\nData dimension analysis:")
            for i, dim_size in enumerate(sample_data.shape):
                print(f"  Dimension {i}: {dim_size}")
                
        print(f"\nData statistics:")
        print(f"  Min value: {sample_data.min():.6f}")
        print(f"  Max value: {sample_data.max():.6f}")
        print(f"  Mean value: {sample_data.mean():.6f}")
        print(f"  Standard deviation: {sample_data.std():.6f}")
        
        # Check for NaN or inf values
        nan_count = np.isnan(sample_data).sum()
        inf_count = np.isinf(sample_data).sum()
        print(f"\nData quality check:")
        print(f"  Number of NaN values: {nan_count}")
        print(f"  Number of Inf values: {inf_count}")
        
        # Save sample data for subsequent analysis
        globals()['sample_eeg_data'] = sample_data
        
    except Exception as e:
        print(f"Failed to load data: {e}")
        sample_data = None
else:
    print("No data files found for analysis")
    sample_data = None

## 4. EEG Data Structure Inference

In [None]:
if 'sample_eeg_data' in globals() and sample_eeg_data is not None:
    print("=== EEG Data Structure Analysis ===")
    
    data_shape = sample_eeg_data.shape
    
    # Infer structure based on data shape
    if len(data_shape) == 2:
        print("Data format: 2D array")
        print("Possible structures:")
        print(f"  - (time points, channels): {data_shape[0]} × {data_shape[1]}")
        print(f"  - (channels, time points): {data_shape[0]} × {data_shape[1]}")
        
        # Usually EEG data has 8-128 channels
        if 8 <= data_shape[0] <= 128 and data_shape[1] > data_shape[0]:
            channels, time_points = data_shape[0], data_shape[1]
            print(f"\nInferred structure: (channels={channels}, time points={time_points})")
        elif 8 <= data_shape[1] <= 128 and data_shape[0] > data_shape[1]:
            time_points, channels = data_shape[0], data_shape[1]
            print(f"\nInferred structure: (time points={time_points}, channels={channels})")
        else:
            print(f"\nUnable to determine structure, further analysis needed")
            
    elif len(data_shape) == 3:
        print("Data format: 3D array")
        print("Possible structures:")
        print(f"  - (trials, channels, time): {data_shape[0]} × {data_shape[1]} × {data_shape[2]}")
        print(f"  - (trials, time, channels): {data_shape[0]} × {data_shape[1]} × {data_shape[2]}")
        
        # Infer most likely structure
        if 8 <= data_shape[1] <= 128:
            trials, channels, time_points = data_shape
            print(f"\nInferred structure: (trials={trials}, channels={channels}, time points={time_points})")
        elif 8 <= data_shape[2] <= 128:
            trials, time_points, channels = data_shape
            print(f"\nInferred structure: (trials={trials}, time points={time_points}, channels={channels})")
            
    elif len(data_shape) == 4:
        print("Data format: 4D array")
        print(f"Possible structure: (subjects/sessions, trials, channels, time) or similar combinations")
        print(f"Dimension sizes: {data_shape}")
        
    # If this is SEED-DV, estimate sampling rate
    print(f"\n=== Sampling Rate Inference ===")
    print("SEED-DV dataset typically uses 200Hz sampling rate")
    if len(data_shape) >= 2:
        max_dim = max(data_shape)
        min_dim = min(data_shape)
        
        # Assume the largest dimension is time points
        if max_dim > 1000:  # Likely time points
            estimated_duration = max_dim / 200  # Assume 200Hz
            print(f"Assuming 200Hz sampling rate, data duration approximately: {estimated_duration:.1f} seconds")
            
        # Assume the smallest dimension might be channel count
        if 8 <= min_dim <= 128:
            print(f"Estimated number of channels: {min_dim}")
else:
    print("No sample data available for analysis")

## 5. Visualize EEG Data from the First Subject

In [None]:
if 'sample_eeg_data' in globals() and sample_eeg_data is not None:
    print("=== EEG Data Visualization ===")
    
    # Create multiple subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. Data shape overview
    ax1 = axes[0, 0]
    if sample_eeg_data.ndim == 2:
        # Display heatmap of entire data (downsampled version)
        if sample_eeg_data.shape[0] < sample_eeg_data.shape[1]:
            # Assume (channels, time)
            downsampled = sample_eeg_data[:, ::max(1, sample_eeg_data.shape[1]//1000)]
            im1 = ax1.imshow(downsampled, aspect='auto', cmap='RdBu_r')
            ax1.set_title('EEG Data Overview (Channels × Time)')
            ax1.set_xlabel('Time (samples)')
            ax1.set_ylabel('Channels')
        else:
            # Assume (time, channels)
            downsampled = sample_eeg_data[::max(1, sample_eeg_data.shape[0]//1000), :]
            im1 = ax1.imshow(downsampled.T, aspect='auto', cmap='RdBu_r')
            ax1.set_title('EEG Data Overview (Channels × Time)')
            ax1.set_xlabel('Time (samples)')
            ax1.set_ylabel('Channels')
        plt.colorbar(im1, ax=ax1, shrink=0.6)
    
    # 2. Single channel time series
    ax2 = axes[0, 1]
    if sample_eeg_data.ndim >= 2:
        # Select first channel for display
        if sample_eeg_data.shape[0] < sample_eeg_data.shape[1]:
            # (channels, time)
            channel_data = sample_eeg_data[0, :min(2000, sample_eeg_data.shape[1])]  # Display first 2000 points
            time_axis = np.arange(len(channel_data)) / 200  # Assume 200Hz
        else:
            # (time, channels)
            channel_data = sample_eeg_data[:min(2000, sample_eeg_data.shape[0]), 0]
            time_axis = np.arange(len(channel_data)) / 200
        
        ax2.plot(time_axis, channel_data, 'b-', linewidth=0.8)
        ax2.set_title('Channel 1 EEG Signal')
        ax2.set_xlabel('Time (seconds)')
        ax2.set_ylabel('Amplitude')
        ax2.grid(True, alpha=0.3)
    
    # 3. Data distribution histogram
    ax3 = axes[1, 0]
    # Sample part of data for histogram
    sample_for_hist = sample_eeg_data.flatten()
    if len(sample_for_hist) > 10000:
        sample_for_hist = np.random.choice(sample_for_hist, 10000, replace=False)
    
    ax3.hist(sample_for_hist, bins=50, alpha=0.7, edgecolor='black')
    ax3.set_title('EEG Data Distribution')
    ax3.set_xlabel('Amplitude')
    ax3.set_ylabel('Frequency')
    ax3.grid(True, alpha=0.3)
    
    # 4. Inter-channel correlation (if data is not too large)
    ax4 = axes[1, 1]
    if sample_eeg_data.ndim == 2:
        if sample_eeg_data.shape[0] <= 64 and sample_eeg_data.shape[1] > sample_eeg_data.shape[0]:
            # (channels, time) - calculate inter-channel correlation
            corr_matrix = np.corrcoef(sample_eeg_data)
            im4 = ax4.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
            ax4.set_title('Inter-channel Correlation')
            ax4.set_xlabel('Channels')
            ax4.set_ylabel('Channels')
            plt.colorbar(im4, ax=ax4, shrink=0.6)
        else:
            # Display power spectral density estimate
            if sample_eeg_data.shape[0] > sample_eeg_data.shape[1]:
                # (time, channels)
                signal = sample_eeg_data[:min(1000, sample_eeg_data.shape[0]), 0]
            else:
                # (channels, time)
                signal = sample_eeg_data[0, :min(1000, sample_eeg_data.shape[1])]
            
            from scipy.signal import welch
            freqs, psd = welch(signal, fs=200, nperseg=min(256, len(signal)//4))
            ax4.semilogy(freqs, psd)
            ax4.set_title('Power Spectral Density (Channel 1)')
            ax4.set_xlabel('Frequency (Hz)')
            ax4.set_ylabel('Power Spectral Density')
            ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
else:
    print("No data available for visualization")

## 6. Batch Analysis of Multiple Subject Data

In [None]:
if subject_files:
    print("=== Batch Analysis of Subject Data ===")
    
    # Analyze first few files (avoid memory overflow)
    max_files_to_analyze = min(10, len(subject_files))
    
    analysis_results = []
    
    for i, file_path in enumerate(subject_files[:max_files_to_analyze]):
        try:
            print(f"Analyzing file {i+1}/{max_files_to_analyze}: {file_path.name}", end=" ")
            
            # Load data
            data = np.load(file_path)
            
            # Collect basic information
            result = {
                'subject_id': i+1,
                'filename': file_path.name,
                'shape': data.shape,
                'dtype': str(data.dtype),
                'size_mb': data.nbytes / (1024*1024),
                'min_val': data.min(),
                'max_val': data.max(),
                'mean_val': data.mean(),
                'std_val': data.std(),
                'nan_count': np.isnan(data).sum(),
                'inf_count': np.isinf(data).sum()
            }
            
            analysis_results.append(result)
            print("✅")
            
        except Exception as e:
            print(f"❌ Error: {e}")
            continue
    
    # Create results DataFrame
    if analysis_results:
        results_df = pd.DataFrame(analysis_results)
        
        print(f"\n=== Analysis Results Summary ({len(analysis_results)} files) ===")
        print("\nBasic file information:")
        print(results_df[['subject_id', 'filename', 'shape', 'size_mb']].to_string(index=False))
        
        print("\nData statistics:")
        stats_cols = ['min_val', 'max_val', 'mean_val', 'std_val']
        print(results_df[['subject_id'] + stats_cols].round(4).to_string(index=False))
        
        # Check data consistency
        print(f"\n=== Data Consistency Check ===")
        unique_shapes = results_df['shape'].unique()
        print(f"Different data shapes: {len(unique_shapes)} types")
        for shape in unique_shapes:
            count = (results_df['shape'] == shape).sum()
            print(f"  {shape}: {count} files")
        
        unique_dtypes = results_df['dtype'].unique()
        print(f"\nDifferent data types: {len(unique_dtypes)} types")
        for dtype in unique_dtypes:
            count = (results_df['dtype'] == dtype).sum()
            print(f"  {dtype}: {count} files")
        
        # Data quality check
        nan_files = (results_df['nan_count'] > 0).sum()
        inf_files = (results_df['inf_count'] > 0).sum()
        print(f"\nData quality:")
        print(f"  Files containing NaN values: {nan_files}")
        print(f"  Files containing Inf values: {inf_files}")
        
        # Save results for subsequent analysis
        globals()['batch_analysis_results'] = results_df
        
    else:
        print("No files were successfully analyzed")
        
else:
    print("No subject data files found")

## 7. Data Shape and Structure Comparison Visualization

In [None]:
if 'batch_analysis_results' in globals():
    print("=== Subject Data Comparison Visualization ===")
    
    df = batch_analysis_results
    
    # Create comparison plots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. File size comparison
    ax1 = axes[0, 0]
    ax1.bar(df['subject_id'], df['size_mb'])
    ax1.set_title('Data File Size by Subject')
    ax1.set_xlabel('Subject ID')
    ax1.set_ylabel('File Size (MB)')
    ax1.grid(True, alpha=0.3)
    
    # 2. Data mean comparison
    ax2 = axes[0, 1]
    ax2.bar(df['subject_id'], df['mean_val'])
    ax2.set_title('Data Mean by Subject')
    ax2.set_xlabel('Subject ID')
    ax2.set_ylabel('Data Mean')
    ax2.grid(True, alpha=0.3)
    
    # 3. Data standard deviation comparison
    ax3 = axes[0, 2]
    ax3.bar(df['subject_id'], df['std_val'])
    ax3.set_title('Data Standard Deviation by Subject')
    ax3.set_xlabel('Subject ID')
    ax3.set_ylabel('Standard Deviation')
    ax3.grid(True, alpha=0.3)
    
    # 4. Data range comparison
    ax4 = axes[1, 0]
    width = 0.35
    x = df['subject_id']
    ax4.bar(x - width/2, df['min_val'], width, label='Min Value', alpha=0.7)
    ax4.bar(x + width/2, df['max_val'], width, label='Max Value', alpha=0.7)
    ax4.set_title('Data Range by Subject')
    ax4.set_xlabel('Subject ID')
    ax4.set_ylabel('Value')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    # 5. Data distribution overview
    ax5 = axes[1, 1]
    ax5.boxplot([df['mean_val'], df['std_val'], df['min_val'], df['max_val']], 
                labels=['Mean', 'Std Dev', 'Min Value', 'Max Value'])
    ax5.set_title('Distribution of Statistical Measures')
    ax5.set_ylabel('Value')
    ax5.grid(True, alpha=0.3)
    
    # 6. Correlation heatmap
    ax6 = axes[1, 2]
    numeric_cols = ['size_mb', 'mean_val', 'std_val', 'min_val', 'max_val']
    corr_matrix = df[numeric_cols].corr()
    im = ax6.imshow(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1)
    ax6.set_xticks(range(len(numeric_cols)))
    ax6.set_yticks(range(len(numeric_cols)))
    ax6.set_xticklabels([col.replace('_', '\n') for col in numeric_cols], rotation=45)
    ax6.set_yticklabels([col.replace('_', '\n') for col in numeric_cols])
    ax6.set_title('Correlation Between Statistics')
    
    # Add correlation coefficient text
    for i in range(len(numeric_cols)):
        for j in range(len(numeric_cols)):
            text = ax6.text(j, i, f'{corr_matrix.iloc[i, j]:.2f}',
                           ha="center", va="center", color="black" if abs(corr_matrix.iloc[i, j]) < 0.5 else "white")
    
    plt.tight_layout()
    plt.show()
    
else:
    print("No batch analysis results available for visualization")

## 8. SEED-DV Dataset Summary and Recommendations

In [None]:
print("=== SEED-DV EEG Dataset Analysis Summary ===")

if 'batch_analysis_results' in globals() and len(batch_analysis_results) > 0:
    df = batch_analysis_results
    
    print(f"\n📊 Basic Dataset Information:")
    print(f"  • Number of subjects: {len(df)} (analyzed first {len(df)} files)")
    print(f"  • Total data size: {df['size_mb'].sum():.2f} MB")
    print(f"  • Average file size: {df['size_mb'].mean():.2f} ± {df['size_mb'].std():.2f} MB")
    
    # Data shape analysis
    most_common_shape = df['shape'].mode().iloc[0] if len(df) > 0 else None
    if most_common_shape:
        print(f"\n📐 Data Structure:")
        print(f"  • Most common data shape: {most_common_shape}")
        
        # Infer based on known SEED-DV information
        if len(most_common_shape) == 3:
            print(f"  • Inferred as 3D format: (trials/epochs, channels, timepoints)")
            print(f"    - Number of trials: {most_common_shape[0]}")
            print(f"    - Number of channels: {most_common_shape[1]}")
            print(f"    - Number of timepoints: {most_common_shape[2]}")
            
            if most_common_shape[2] > 1000:  # Likely timepoints
                estimated_duration = most_common_shape[2] / 200  # SEED-DV typically 200Hz
                print(f"    - Estimated sampling rate: 200 Hz")
                print(f"    - Estimated duration per trial: {estimated_duration:.1f} seconds")
        
        elif len(most_common_shape) == 2:
            print(f"  • Inferred as 2D format: (timepoints, channels) or (channels, timepoints)")
    
    print(f"\n📈 Data Quality:")
    clean_files = len(df[(df['nan_count'] == 0) & (df['inf_count'] == 0)])
    print(f"  • Completely clean files: {clean_files}/{len(df)} ({clean_files/len(df)*100:.1f}%)")
    
    if df['nan_count'].sum() > 0:
        print(f"  ⚠️  Found NaN values, preprocessing needed")
    if df['inf_count'].sum() > 0:
        print(f"  ⚠️  Found Inf values, preprocessing needed")
    
    print(f"\n📊 Data Statistics:")
    print(f"  • Value range: [{df['min_val'].min():.3f}, {df['max_val'].max():.3f}]")
    print(f"  • Average amplitude: {df['mean_val'].mean():.3f} ± {df['mean_val'].std():.3f}")
    print(f"  • Average standard deviation: {df['std_val'].mean():.3f} ± {df['std_val'].std():.3f}")

else:
    print("\n❌ Failed to successfully analyze data files")

print(f"\n🔍 Recommendations for Further Analysis:")
print(f"  1. Confirm exact data format (trial × channel × time or other)")
print(f"  2. Verify sampling rate (SEED-DV typically 200Hz)")
print(f"  3. Confirm EEG channel layout and standard (e.g., 10-20 system)")
print(f"  4. Check preprocessing status of data (filtering, denoising, etc.)")
print(f"  5. Analyze stimulus labels and corresponding EEG responses")
print(f"  6. Perform frequency domain analysis (α, β, γ, θ, δ bands)")
print(f"  7. Examine individual differences between subjects")