# Working with AutoReject Logs

This notebook demonstrates how to load and work with AutoReject logs that have been saved during preprocessing, with a focus on the enhanced visualization capabilities.

In [None]:
import sys
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Add the project root to the path
sys.path.insert(0, os.path.abspath(os.path.join('..',)))

# Import our utility functions
from scr.utils.autoreject_utils import find_autoreject_log, load_autoreject_log, plot_autoreject_summary, plot_reject_log

## 1. Finding and Loading AutoReject Logs

First, let's try to find any existing AutoReject logs for our subjects.

In [None]:
# Define the subject and session
subject_id = "01"
session_id = "001"
task_id = "5pt"
run_id = "01"

# Find the AutoReject log
log_path = find_autoreject_log(
    subject_id=subject_id,
    session_id=session_id,
    task_id=task_id,
    run_id=run_id,
    base_dir=os.path.abspath('..')
)

print(f"Found AutoReject log at: {log_path}")

In [None]:
# Load the AutoReject log
reject_log = load_autoreject_log(log_path)

# If the log wasn't found by path, we can try by subject/session
if reject_log is None:
    reject_log = load_autoreject_log(
        subject_id=subject_id,
        session_id=session_id,
        task_id=task_id,
        run_id=run_id,
        base_dir=os.path.abspath('..')
    )

## 2. Exploring the AutoReject Log Structure

Let's look at the attributes and methods of the loaded RejectLog object.

In [None]:
# Check if we successfully loaded a log
if reject_log is None:
    print("No AutoReject log was found. Skipping this section.")
else:
    # Print the main attributes
    print("\nAutoReject Log Attributes:")
    for attr in dir(reject_log):
        if not attr.startswith('__'):
            try:
                value = getattr(reject_log, attr)
                if not callable(value):
                    if isinstance(value, (np.ndarray, list)) and len(str(value)) > 100:
                        print(f"{attr}: {type(value)} - (large array/list)")
                    else:
                        print(f"{attr}: {type(value)} - {value}")
                else:
                    print(f"{attr}: {type(value)} - (method)")
            except Exception as e:
                print(f"{attr}: Error - {e}")
    
    # Print more details about key attributes
    print("\nDetailed information:")
    print(f"Number of epochs: {len(reject_log.bad_epochs)}")
    print(f"Number of bad epochs: {np.sum(reject_log.bad_epochs)}")
    print(f"Percentage of bad epochs: {np.mean(reject_log.bad_epochs) * 100:.1f}%")
    print(f"Number of channels: {len(reject_log.ch_names)}")
    
    # Print the first few bad epochs
    bad_indices = np.where(reject_log.bad_epochs)[0]
    if len(bad_indices) > 0:
        print(f"\nIndices of first 5 bad epochs: {bad_indices[:5]}")
    else:
        print("\nNo bad epochs found")

## 3. Visualizing the AutoReject Results with Enhanced Plots

Let's create different visualizations of the AutoReject results using our enhanced plotting functions.

In [None]:
if reject_log is not None:
    # Create a comprehensive visualization using our enhanced function
    fig = plot_autoreject_summary(reject_log)
    plt.show()
    
    # We can also save the figure
    # plot_autoreject_summary(reject_log, save_to='autoreject_summary.png')
else:
    print("No AutoReject log was found. Cannot create visualization.")

### 3.1 Enhanced RejectLog Plot

This plot is an enhanced version of the original RejectLog's plot method, with better visibility of channel names and improved readability.

In [None]:
if reject_log is not None:
    # Use our enhanced version of the RejectLog's plot method
    fig = plot_reject_log(reject_log, orientation='horizontal')
    plt.show()
    
    # Let's also try the vertical orientation
    fig = plot_reject_log(reject_log, orientation='vertical')
    plt.show()
else:
    print("No AutoReject log was found. Cannot create visualization.")

### 3.2 Compare Visualization Methods

Let's compare the different visualization methods to see which one works best for your data.

In [None]:
if reject_log is not None:
    print("Creating all visualization types for comparison...")
    
    # 1. Original RejectLog's plot method
    plt.figure(figsize=(15, 10))
    fig1 = reject_log.plot(show=False)
    fig1.suptitle('Original RejectLog Plot', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    # 2. Enhanced RejectLog plot
    fig2 = plot_reject_log(reject_log, orientation='horizontal')
    plt.show()
    
    # 3. Comprehensive summary visualization
    fig3 = plot_autoreject_summary(reject_log)
    plt.show()
    
    print("Which visualization style do you prefer? You can use these in your analysis pipeline.")
else:
    print("No AutoReject log was found. Cannot create visualizations.")

## 4. Using AutoReject Results with Raw Data

Now let's demonstrate how to use the AutoReject log with raw data to create epochs that exclude the bad epochs.

In [None]:
import mne

# Load the preprocessed data
data_path = os.path.join('..', 'data', 'processed', f'sub-{subject_id}', f'ses-{session_id}', 
                         f'sub-{subject_id}_ses-{session_id}_task-{task_id}_run-{run_id}_after_autoreject.fif')

# Check if the file exists
if os.path.exists(data_path):
    # Load the data
    raw = mne.io.read_raw_fif(data_path, preload=True)
    print(f"Loaded data from {data_path}")
    
    # Create fixed-length events for epoching
    events = mne.make_fixed_length_events(raw, duration=1.0)
    
    # Create epochs
    epochs = mne.Epochs(raw, events, tmin=0, tmax=1, baseline=None, preload=True)
    print(f"Created {len(epochs)} epochs")
    
    # If we have a reject_log, use it to exclude bad epochs
    if reject_log is not None and len(reject_log.bad_epochs) == len(epochs):
        # Create a mask for good epochs
        good_mask = ~reject_log.bad_epochs
        
        # Select only good epochs
        good_epochs = epochs[good_mask]
        print(f"Selected {len(good_epochs)}/{len(epochs)} good epochs using AutoReject log")
        
        # We can now use these good epochs for further analysis
        # For example, let's compute a PSD
        fig = good_epochs.plot_psd(fmax=40, average=True)
    else:
        print("No AutoReject log available or the number of epochs doesn't match.")
        print(f"Reject log epochs: {None if reject_log is None else len(reject_log.bad_epochs)}")
        print(f"Created epochs: {len(epochs)}")
else:
    print(f"Data file not found at {data_path}")

## 5. Additional Enhanced Visualizations

Let's create a few more specialized visualizations to better understand the AutoReject results.

In [None]:
if reject_log is not None:
    # Extract rejection data
    bad_epochs = reject_log.bad_epochs
    ch_names = reject_log.ch_names
    
    # Get the detailed labels if available
    if hasattr(reject_log, 'labels') and reject_log.labels is not None:
        labels = reject_log.labels
        
        # Calculate statistics per channel
        channel_stats = {}
        for i, ch in enumerate(ch_names):
            good = np.sum(labels[:, i] == 0)
            interpolated = np.sum(labels[:, i] == 1)
            bad = np.sum(labels[:, i] == 2)
            total = len(bad_epochs)
            channel_stats[ch] = {
                'good_percent': good / total * 100,
                'interpolated_percent': interpolated / total * 100,
                'bad_percent': bad / total * 100
            }
        
        # Create a stacked bar chart of channel quality
        plt.figure(figsize=(12, 10))
        channels = list(channel_stats.keys())
        good_vals = [channel_stats[ch]['good_percent'] for ch in channels]
        interp_vals = [channel_stats[ch]['interpolated_percent'] for ch in channels]
        bad_vals = [channel_stats[ch]['bad_percent'] for ch in channels]
        
        # Sort channels by quality (highest percentage of good data first)
        sorted_indices = np.argsort(good_vals)
        channels = [channels[i] for i in sorted_indices]
        good_vals = [good_vals[i] for i in sorted_indices]
        interp_vals = [interp_vals[i] for i in sorted_indices]
        bad_vals = [bad_vals[i] for i in sorted_indices]
        
        # Create stacked bar chart
        plt.barh(channels, good_vals, color='green', alpha=0.7, label='Good')
        plt.barh(channels, interp_vals, left=good_vals, color='blue', alpha=0.7, label='Interpolated')
        plt.barh(channels, bad_vals, left=np.array(good_vals) + np.array(interp_vals), 
                color='red', alpha=0.7, label='Bad')
        
        plt.xlabel('Percentage (%)', fontsize=12)
        plt.ylabel('Channels', fontsize=12)
        plt.title('Channel Data Quality Overview', fontsize=14, fontweight='bold')
        plt.legend(loc='lower right')
        plt.grid(axis='x', linestyle='--', alpha=0.3)
        plt.xlim(0, 100)
        plt.tight_layout()
        plt.show()
        
        # Create a heatmap of epochs x channels showing quality over time
        plt.figure(figsize=(18, 10))
        plt.imshow(labels.T, aspect='auto', cmap='RdYlGn_r', vmin=0, vmax=2)
        plt.colorbar(ticks=[0, 1, 2], label='Status (0=Good, 1=Interpolated, 2=Bad)')
        plt.xlabel('Epochs (Time →)', fontsize=12)
        plt.ylabel('Channels', fontsize=12)
        plt.yticks(range(len(ch_names)), ch_names)
        plt.title('Channel Quality Over Time', fontsize=14, fontweight='bold')
        
        # Add grid for better readability
        plt.grid(False)
        
        # Add horizontal lines between channels
        for i in range(1, len(ch_names)):
            plt.axhline(i - 0.5, color='black', linewidth=0.5, alpha=0.3)
            
        # Add epoch markings every 50 epochs
        for i in range(0, len(bad_epochs), 50):
            plt.axvline(i, color='black', linewidth=0.5, alpha=0.3)
            
        plt.tight_layout()
        plt.show()
    else:
        print("Detailed channel-level information not available in this RejectLog.")
else:
    print("No AutoReject log was found. Cannot create visualizations.")

## Summary

In this notebook, we demonstrated how to:
1. Find and load AutoReject logs saved during preprocessing
2. Explore the structure and attributes of the AutoReject log
3. Create enhanced visualizations of the AutoReject results using multiple approaches
4. Use the AutoReject log to select good epochs from raw data
5. Create specialized visualizations to better understand channel-level and epoch-level quality

These techniques allow you to get a much clearer picture of your data quality and consistently apply the same epoch rejection across different analyses, ensuring reproducibility.