In [2]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
from scipy import signal
import pandas as pd
import os

def create_output_directory(output_dir):
    """
    Create output directory if it doesn't exist
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output directory: {output_dir}")

def load_data(subject_number):
    """
    Load data for a given subject number
    """
    filename = f's{subject_number:02d}.dat'
    with open(filename, 'rb') as f:
        data = pickle.load(f, encoding='latin1')  # Use latin1 encoding for Python 3
    return data

def preprocess_eeg_data(raw_data):
    """
    Preprocess EEG data by applying additional filtering
    """
    # Apply bandpass filter
    fs = 128  # Sampling frequency
    nyquist = fs/2
    lowcut = 4.0/nyquist
    highcut = 45.0/nyquist
    b, a = signal.butter(4, [lowcut, highcut], btype='band')
    
    cleaned_data = np.zeros_like(raw_data)
    for trial in range(raw_data.shape[0]):
        for channel in range(32):  # Only EEG channels
            cleaned_data[trial, channel] = signal.filtfilt(b, a, raw_data[trial, channel])
    
    return cleaned_data

def plot_raw_vs_clean(raw_data, clean_data, subject_num, output_dir, trial_num=0, channel_num=0):
    """
    Plot raw vs cleaned data for a specific trial and channel and save to file
    """
    plt.figure(figsize=(15, 6))
    
    # Plot raw data
    plt.subplot(2, 1, 1)
    plt.plot(raw_data[trial_num, channel_num], 'b-', linewidth=1)
    plt.title(f'Raw EEG Data - Subject {subject_num}, Trial {trial_num}, Channel {channel_num}')
    plt.ylabel('Amplitude')
    plt.grid(True)
    
    # Plot cleaned data
    plt.subplot(2, 1, 2)
    plt.plot(clean_data[trial_num, channel_num], 'r-', linewidth=1)
    plt.title('Cleaned EEG Data')
    plt.xlabel('Time samples')
    plt.ylabel('Amplitude')
    plt.grid(True)
    
    plt.tight_layout()
    
    # Save the plot
    filename = os.path.join(output_dir, f'raw_vs_clean_subject_{subject_num}_trial_{trial_num}_channel_{channel_num}.png')
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved raw vs clean plot: {filename}")

def plot_channel_wise_data(clean_data, subject_num, output_dir, trial_num=0):
    """
    Plot data from different channels for comparison and save to file
    """
    channels_to_plot = ['Fp1', 'F3', 'C3', 'P3', 'O1']  # Example channels
    channel_indices = [0, 2, 6, 10, 13]  # Corresponding indices
    colors = ['b', 'g', 'r', 'c', 'm']  # Different colors for each channel
    
    plt.figure(figsize=(15, 10))
    for i, (channel, idx, color) in enumerate(zip(channels_to_plot, channel_indices, colors)):
        plt.subplot(5, 1, i+1)
        plt.plot(clean_data[trial_num, idx], color=color, linewidth=1)
        plt.title(f'Channel {channel}')
        plt.ylabel('Amplitude')
        plt.grid(True)
    
    plt.xlabel('Time samples')
    plt.suptitle(f'Channel-wise EEG Data - Subject {subject_num}, Trial {trial_num}')
    plt.tight_layout()
    
    # Save the plot
    filename = os.path.join(output_dir, f'channel_wise_subject_{subject_num}_trial_{trial_num}.png')
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved channel-wise plot: {filename}")

def plot_subject_comparison(subject_list, output_dir, channel_num=0, trial_num=0):
    """
    Plot comparison of data across different subjects and save to file
    """
    plt.figure(figsize=(15, 10))
    colors = ['b', 'g', 'r', 'c', 'm']  # Different colors for each subject
    
    for i, (subject, color) in enumerate(zip(subject_list, colors)):
        data = load_data(subject)
        clean_data = preprocess_eeg_data(data['data'])
        
        plt.subplot(len(subject_list), 1, i+1)
        plt.plot(clean_data[trial_num, channel_num], color=color, linewidth=1)
        plt.title(f'Subject {subject}')
        plt.ylabel('Amplitude')
        plt.grid(True)
    
    plt.xlabel('Time samples')
    plt.suptitle(f'Subject-wise Comparison - Channel {channel_num}, Trial {trial_num}')
    plt.tight_layout()
    
    # Save the plot
    filename = os.path.join(output_dir, f'subject_comparison_channel_{channel_num}_trial_{trial_num}.png')
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved subject comparison plot: {filename}")

def analyze_deap_dataset():
    """
    Main function to analyze DEAP dataset
    """
    # Create output directory
    output_dir = 'deap_analysis_output'
    create_output_directory(output_dir)
    
    # Load and analyze data for subject 1
    subject_num = 1
    data = load_data(subject_num)
    
    # Get raw and cleaned data
    raw_data = data['data']
    clean_data = preprocess_eeg_data(raw_data)
    
    # 1. Plot raw vs clean data
    plot_raw_vs_clean(raw_data, clean_data, subject_num, output_dir)
    
    # 2. Plot channel-wise data
    plot_channel_wise_data(clean_data, subject_num, output_dir)
    
    # 3. Plot subject-wise comparison (for 5 subjects)
    subject_list = [1, 2, 3, 4, 5]  # First 5 subjects
    plot_subject_comparison(subject_list, output_dir)

if __name__ == "__main__":
    analyze_deap_dataset()

Created output directory: deap_analysis_output
Saved raw vs clean plot: deap_analysis_output/raw_vs_clean_subject_1_trial_0_channel_0.png
Saved channel-wise plot: deap_analysis_output/channel_wise_subject_1_trial_0.png
Saved subject comparison plot: deap_analysis_output/subject_comparison_channel_0_trial_0.png


In [None]:
!pip install seaborn matplotlib scipy pandas

Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting matplotlib
  Downloading matplotlib-3.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting scipy
  Downloading scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting pandas
  Downloading pandas-2.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
Collecting numpy!=1.24.0,>=1.20 (from seaborn)
  Downloading numpy-2.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.4 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.56.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (101