# Task 1

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import random
from tabulate import tabulate

base_path = r'C:\Users\prags\Desktop\hackathon\EEG_Data\train_data'
classes = ['Normal', 'Complex_Partial_Seizures', 'Electrographic_Seizures', 'Video_detected_Seizures_with_no_visual_change_over_EEG']
data_points = {cls: None for cls in classes}

# fixed random seed for reproducibility
random.seed(42)

print(f"Base path: {base_path}")

# one data point from each class
for cls in classes:
    cls_folder = os.path.join(base_path, cls)
    print(f"Looking for files in: {cls_folder}")
    if os.path.exists(cls_folder):
        files = os.listdir(cls_folder)
        if files:
            file_path = os.path.join(cls_folder, random.choice(files))
            print(f"Selected file: {file_path}")
            data_points[cls] = np.load(file_path)
        else:
            print(f"No files found in: {cls_folder}")
    else:
        print(f"Directory does not exist: {cls_folder}")

for cls in classes:
    if data_points[cls] is not None:
        data_points[cls] = pd.DataFrame(data_points[cls])

# statistics
def compute_statistics(data):
    stats = {}
    for channel in data.columns:
        signal = data[channel].values
        stats[channel] = {
            "Mean": np.mean(signal),
            "Zero Crossing Rate": np.mean(np.diff(np.sign(signal)) != 0),
            "Range": np.ptp(signal),
            "Energy": np.sum(signal**2),
            "RMS": np.sqrt(np.mean(signal**2)),
            "Variance": np.var(signal),
        }
    return pd.DataFrame(stats).T

# EEG signals
def plot_eeg_signals(data, title):
    channels = data.columns
    num_channels = len(channels)
    max_channels_to_plot = 19

    # individual channels
    for i, channel in enumerate(channels):
        if i >= max_channels_to_plot:
            break
        plt.figure(figsize=(10, 2))
        plt.plot(data[channel], marker='o', linestyle='-', markersize=3)
        plt.title(f'{title} - Channel {channel} Readings')
        plt.xlabel('Time')
        plt.ylabel('Amplitude (µV)')
        plt.show()
        plt.close()

    # channels combined
    plt.figure(figsize=(18, 10))
    for i, channel in enumerate(channels):
        if i >= max_channels_to_plot:
            break
        plt.plot(data[channel], label=channel)
    plt.title(f'{title} - All Channels')
    plt.xlabel('Time')
    plt.ylabel('Amplitude (µV)')
    plt.legend(loc='upper right', bbox_to_anchor=(1.1, 1.05))
    plt.tight_layout()
    plt.show()
    plt.close()

# Print statistics
def print_statistics(data, title):
    stats = compute_statistics(data)
    print(f'Statistics for {title}:')

    # Save statistics to CSV
    save_statistics_to_csv(stats, title)

    # Display only the first 4 and the last 4 rows
    if len(stats) > 8:
        print(tabulate(stats.iloc[:4], headers='keys', tablefmt='grid'))
        print("...")
        print(tabulate(stats.iloc[-4:], headers='keys', tablefmt='grid'))
    else:
        print(tabulate(stats, headers='keys', tablefmt='grid'))

    print('\n')

# Save statistics to CSV
def save_statistics_to_csv(stats, title):
    # Create file path for CSV
    csv_filename = f'{title}_statistics.csv'
    stats.to_csv(csv_filename, index=True)
    print(f'Statistics for {title} saved to {csv_filename}')

# Process each class
for cls, data in data_points.items():
    if data is not None:
        plot_eeg_signals(data, cls)
        print_statistics(data, cls)
    else:
        print(f"No data loaded for class: {cls}")
