In [None]:
import mne
import numpy as np
import matplotlib.pyplot as plt

# --- 1. Define your file paths ---
attention_file_path = r'C:\Users\LENOVO\OneDrive - City University\Desktop\Attention detection FYP\S01\S01_clean\Att_S01_cleaned_raw2.fif'
inattention_file_path = r'C:\Users\LENOVO\OneDrive - City University\Desktop\Attention detection FYP\S01\S01_clean\Inatt_S01_cleaned_raw2.fif'

# --- 2. Load your EEG data from .fif files ---
print(f"Loading attention data from: {attention_file_path}")
raw_attention = mne.io.read_raw_fif(attention_file_path, preload=True, verbose=False)
print(f"Loading inattention data from: {inattention_file_path}")
raw_inattention = mne.io.read_raw_fif(inattention_file_path, preload=True, verbose=False)

# Optional: Set common average reference if your data isn't referenced yet.
# It's generally good practice for EEG.
# raw_attention.set_eeg_reference('average', verbose=False)
# raw_inattention.set_eeg_reference('average', verbose=False)

# Check info for both raw objects (e.g., channel names, sampling frequency)
print("\nAttention Raw Info:")
print(raw_attention.info)
print("\nInattention Raw Info:")
print(raw_inattention.info)

# --- 3. Define Epoching Parameters ---
epoch_length_s = 5  # seconds
overlap_percent = 0.90 # 90% overlap
sfreq = raw_attention.info['sfreq']

# Calculate step size based on overlap
# For 90% overlap, step_size_s = epoch_length_s * (1 - 0.90) = 10 * 0.10 = 1 seconds
step_size_s = epoch_length_s * (1 - overlap_percent)
print(f"\nEpoch length: {epoch_length_s} s, Overlap: {overlap_percent*100}%, Step size: {step_size_s} s")

# --- 4. Create Epochs for Attention and Inattention ---
# mne.make_fixed_length_epochs will create non-overlapping epochs by default if step is None
# We specify 'overlap_percent' as the amount each epoch overlaps, relative to its length
# OR, more explicitly, we can specify 'overlap' in samples:
# overlap_samples = int(epoch_length_s * sfreq * overlap_percent)
# We need to set 'overlap' parameter in make_fixed_length_epochs if we want overlapping epochs.

# Let's use the 'start' and 'stop' parameters for make_fixed_length_epochs
# It's more straightforward to specify 'duration' and 'overlap' directly in seconds for newer MNE versions.

# For MNE versions >= 1.0, the parameter 'overlap' is used for duration-based overlap.
# For older versions, it might be 'tstep' for step duration.
# We'll use the newer `overlap` parameter for clarity as you likely have a recent MNE.

print("Creating epochs for Attention data...")
epochs_attention = mne.make_fixed_length_epochs(
    raw_attention,
    duration=epoch_length_s,
    overlap=epoch_length_s * overlap_percent, # Overlap in seconds
    preload=True,
    verbose=False
)
print(f"Number of Attention epochs: {len(epochs_attention)}")

print("Creating epochs for Inattention data...")
epochs_inattention = mne.make_fixed_length_epochs(
    raw_inattention,
    duration=epoch_length_s,
    overlap=epoch_length_s * overlap_percent, # Overlap in seconds
    preload=True,
    verbose=False
)
print(f"Number of Inattention epochs: {len(epochs_inattention)}")

# --- 5. Calculate PSD for Each Epoch and then Average ---
# We calculate PSD for each epoch, and then average the resulting PSDs.
# MNE's compute_psd on an Epochs object returns a Spectrum object.

# Parameters for Welch's method for epochs
# n_fft is typically set to the epoch length if you want maximum frequency resolution per epoch
# or a bit smaller for smoothing. Since epochs are 20s, let's make n_fft reflect that.
# Parameters for Welch's method for epochs
# n_fft is typically set to the epoch length if you want maximum frequency resolution per epoch
# or a bit smaller for smoothing. Since epochs are 20s, let's make n_fft reflect that.
# Parameters for Welch's method for epochs
# n_fft should be less than or equal to the number of samples in an epoch.
# For 20s epochs, the number of samples per epoch is:
samples_per_epoch = int(epoch_length_s * sfreq)

# Choose n_fft to be the largest power of 2 that is <= samples_per_epoch
n_fft = 2**int(np.log2(samples_per_epoch))

print(f"\nEpoch length in samples: {samples_per_epoch}")
print(f"Using n_fft: {n_fft} samples ({n_fft/sfreq:.2f} seconds) for PSD calculation per epoch")

# No overlap needed here for Welch's *within* an epoch, as we're already averaging over epochs.
# The 'n_overlap' parameter here refers to the overlap of windows *within* each epoch for Welch's method.
# It's common to use 50% overlap for Welch's internally for smoother PSD estimates.
n_overlap_welch = int(0.5 * n_fft)
window = 'hann'

fmin = 1
fmax = 40

# Calculate PSD for Attention epochs
print(f"Calculating PSD for Attention epochs...")
# compute_psd on Epochs object returns an EpochsSpectrum object
spectrum_att_epochs = epochs_attention.compute_psd(
    method='welch',
    fmin=fmin, fmax=fmax,
    n_fft=n_fft, n_overlap=n_overlap_welch, # This is Welch's internal window overlap
    picks='eeg',
    average='mean', # Average across channels within each epoch
    verbose=False
)
# Get the average PSD across all epochs for attention
# spectrum_att_epochs.get_data() has shape (n_epochs, n_channels, n_frequencies)
# So, we average across epochs (axis=0) and then convert to dB
psds_att_avg_across_epochs, freqs = spectrum_att_epochs.average().get_data(return_freqs=True)
psds_att_avg_across_epochs = 10 * np.log10(psds_att_avg_across_epochs) # Convert to dB

# Calculate PSD for Inattention epochs
print(f"Calculating PSD for Inattention epochs...")
spectrum_inatt_epochs = epochs_inattention.compute_psd(
    method='welch',
    fmin=fmin, fmax=fmax,
    n_fft=n_fft, n_overlap=n_overlap_welch,
    picks='eeg',
    average='mean',
    verbose=False
)
psds_inatt_avg_across_epochs, _ = spectrum_inatt_epochs.average().get_data(return_freqs=True)
psds_inatt_avg_across_epochs = 10 * np.log10(psds_inatt_avg_across_epochs) # Convert to dB

# --- 6. Plotting the Average PSDs ---
plt.figure(figsize=(12, 7))
plt.plot(freqs, psds_att_avg_across_epochs.mean(axis=0), label='Attention (Mean across channels & epochs)', color='blue')
plt.plot(freqs, psds_inatt_avg_across_epochs.mean(axis=0), label='Inattention (Mean across channels & epochs)', color='red')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power Spectral Density (dB/Hz)')
plt.title(f'Average PSD for Attention vs. Inattention (Subject S04 - 20s Epochs, 90% Overlap)')
plt.grid(True)
plt.legend()
plt.axvspan(4, 8, color='gray', alpha=0.2, label='Theta Band')
plt.axvspan(8, 12, color='orange', alpha=0.2, label='Alpha Band')
plt.axvspan(12, 30, color='green', alpha=0.2, label='Beta Band')
plt.xlim(fmin, fmax)
plt.tight_layout()
plt.show()

# --- 7. Extracting Band Power (Features) ---
# Define frequency bands
bands = {
    'delta': (0.5, 4),
    'theta': (4, 8),
    'alpha': (8, 12),
    'beta': (12, 30),
    'gamma': (30, 45)
}

# Calculate average power in each band for each condition
band_powers_att = {}
band_powers_inatt = {}

print("\n--- Calculating Band Powers (averaged across epochs) ---")
for band_name, (f_min_band, f_max_band) in bands.items():
    idx_band = np.where((freqs >= f_min_band) & (freqs <= f_max_band))[0]

    if len(idx_band) == 0:
        print(f"Warning: No frequencies found for {band_name} band ({f_min_band}-{f_max_band} Hz). Check fmin/fmax settings.")
        band_power_att = np.nan
        band_power_inatt = np.nan
    else:
        # psds_att_avg_across_epochs has shape (n_channels, n_frequencies)
        band_power_att = psds_att_avg_across_epochs[:, idx_band].mean()
        band_power_inatt = psds_inatt_avg_across_epochs[:, idx_band].mean()

    band_powers_att[band_name] = band_power_att
    band_powers_inatt[band_name] = band_power_inatt # Corrected line

print("\nAverage Band Powers (dB/Hz):")
print("Attention:", {k: f"{v:.2f}" for k, v in band_powers_att.items()})
print("Inattention:", {k: f"{v:.2f}" for k, v in band_powers_inatt.items()})

# --- 8. Feature Engineering (Corrected Theta/Beta Ratio) ---
theta_beta_ratio_att = np.nan
theta_beta_ratio_inatt = np.nan

# Attention Condition
if 'theta' in band_powers_att and 'beta' in band_powers_att and \
   not np.isnan(band_powers_att['theta']) and not np.isnan(band_powers_att['beta']):

    # Convert dB values back to linear power (microvolts^2/Hz)
    theta_power_linear_att = 10**(band_powers_att['theta'] / 10)
    beta_power_linear_att = 10**(band_powers_att['beta'] / 10)

    # Calculate the ratio of linear powers
    if beta_power_linear_att != 0: # Avoid division by zero
        theta_beta_ratio_att = theta_power_linear_att / beta_power_linear_att
    else:
        print("Warning: Beta power for Attention is zero, cannot calculate Theta/Beta ratio.")

# Inattention Condition
if 'theta' in band_powers_inatt and 'beta' in band_powers_inatt and \
   not np.isnan(band_powers_inatt['theta']) and not np.isnan(band_powers_inatt['beta']):

    # Convert dB values back to linear power (microvolts^2/Hz)
    theta_power_linear_inatt = 10**(band_powers_inatt['theta'] / 10)
    beta_power_linear_inatt = 10**(band_powers_inatt['beta'] / 10)

    # Calculate the ratio of linear powers
    if beta_power_linear_inatt != 0: # Avoid division by zero
        theta_beta_ratio_inatt = theta_power_linear_inatt / beta_power_linear_inatt
    else:
        print("Warning: Beta power for Inattention is zero, cannot calculate Theta/Beta ratio.")


print(f"\nTheta/Beta Ratio (Attention): {theta_beta_ratio_att:.2f}")
print(f"Theta/Beta Ratio (Inattention): {theta_beta_ratio_inatt:.2f}")

print("\nAnalysis Complete.")