In [None]:
import mne
import os
import matplotlib.pyplot as plt
import pandas as pd
from config import get_exploration_config
import numpy as np

mne.set_log_level('WARNING')
conf = get_exploration_config()

In [None]:
RAW_PATH, OUTPUT_PATH = conf['input_path'], conf['output_path']

In [None]:
def extract_tuh_version():
    readme = None
    # find AAREADME.txt file
    for root, dirs, files in os.walk(RAW_PATH):
        for file in files:
            if file == 'AAREADME.txt':
                readme = os.path.join(root, file)
                break
    
    if readme is None:
        raise FileNotFoundError('AAREADME.txt not found')
    
    with open(readme, 'r') as f:
        lines = f.readlines()
        version = lines[2].split(': ')[1].strip()
        return version

In [None]:
def extract_events_from_annotations(annotation_file):
    with open(annotation_file, "r") as f:
        annotations = f.readlines()
        events = annotations[6:] 
        
        data = []
        for event in events:
            parts = event.split(",")
            
            start = float(parts[1])
            stop = float(parts[2])

            label = parts[3]
            label_map = {"bckg": 0, "seiz": 1}
            label = label_map[label]
            
            data.append({
                "label": label,
                "start": start,
                "stop": stop,
            })
            
    return pd.DataFrame(data)

In [None]:
def load_events():
    data = []
    
    edf_path  = os.path.join(RAW_PATH, "edf")

    for root, _, files in os.walk(edf_path):
        for file in files:
            if not file.endswith(".edf"):
                continue
            
            rel_path = os.path.relpath(root, edf_path)
            parts = rel_path.split("/")
            
            if len(parts) != 4:
                continue
            
            set_name, patient_id, session_id, configuration = parts
            
            recording_path = os.path.join(root, file)
            recording_id = file.replace(".edf", "").split("_")[-1]
            annotation_path = recording_path.replace(".edf", ".csv_bi")
            
            # check if recording ang corresponding annotation file exists
            if not os.path.exists(recording_path) or not os.path.exists(annotation_path):
                continue

            raw = mne.io.read_raw_edf(recording_path, preload=False)
            raw_info = raw.info

            sfreq, number_of_channels, sex = raw_info["sfreq"], raw_info["nchan"], raw_info["subject_info"]["sex"]
            measurement_date = raw_info["meas_date"] if "meas_date" in raw_info else None
            recording_duration = raw.times[-1]

            events = extract_events_from_annotations(annotation_path)

            for i, event in events.iterrows():
                data.append(
                    {
                        "set": set_name,
                        "patient_id": patient_id,
                        "session_id": session_id,
                        "configuration": configuration,
                        "recording_id": recording_id,
                        "recording_duration": recording_duration,
                        "recording_path": recording_path,
                        "sfreq": sfreq,
                        "number_of_channels": number_of_channels,
                        "measurement_date": measurement_date,
                        "sex": sex,
                        "event_index": i,
                        "start": event["start"],
                        "stop": event["stop"],
                        "label": event["label"]
                    }
                )
            
            raw.close()
            
    
    return pd.DataFrame(data)

In [None]:
events = load_events()
events

In [None]:
version = extract_tuh_version()
output_file = os.path.join(OUTPUT_PATH, f"tusz_{version}.csv")
events.to_csv(output_file, index=False)

### Evaluation of statistics

In [None]:
print(f"TUSZ Version: {version}")

In [None]:
recordings = events.groupby(["recording_path"]).agg(
    {
        "recording_duration": "first",
        "sfreq": "first",
        "number_of_channels": "first",
        "measurement_date": "first",
        "configuration": "first",
    }
)

# Count the occurrences of each configuration
configuration_counts = recordings['configuration'].value_counts()

# Create the bar plot with Viridis colormap
plt.figure(figsize=(10, 6))
bars = plt.bar(configuration_counts.index.astype(str), configuration_counts.values, edgecolor='black', alpha=0.7)

# Apply the Viridis colormap to the bars
viridis = plt.colormaps['viridis']
norm = plt.Normalize(configuration_counts.values.min(), configuration_counts.values.max())

for bar, value in zip(bars, configuration_counts.values):
    bar.set_facecolor(viridis(norm(value)))

# Add labels and title
plt.xlabel("Configuration", fontsize=14)
plt.ylabel("Number of Recordings", fontsize=14)
plt.title("Distribution of Configurations", fontsize=16)

# Adding a grid for better readability
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add the counts on top of the bars
for bar, value in zip(bars, configuration_counts.values):
    plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5, str(value), ha='center', va='bottom', fontsize=10, color='black')

plt.tight_layout()
plt.show()

n_recordings = len(recordings)
print(f"Number of recordings: {n_recordings}")

durations = recordings["recording_duration"] / 60
total_duration = durations.sum() / 60
min_duration = durations.min() 
max_duration = durations.max() 
avg_duration = durations.mean() 

print(f"Total recording duration: {total_duration:.2f}h")
print(f"Min recording duration: {min_duration:.2f}min")
print(f"Avg recording duration: {avg_duration:.2f}min")
print(f"Max recording duration: {max_duration:.2f}min")

cutoff = avg_duration + 3 * durations.std()
durations = durations[durations < cutoff]

# Plot histogram with Viridis colormap
plt.figure(figsize=(12, 6))
counts, bins, patches = plt.hist(durations, bins=50, edgecolor='black', alpha=0.7)

# Apply the Viridis colormap
viridis = plt.colormaps['viridis']
norm = plt.Normalize(counts.min(), counts.max())

for count, patch in zip(counts, patches):
    color = viridis(norm(count))
    patch.set_facecolor(color)

# Add labels and title
plt.xlabel("Duration (min)", fontsize=14)
plt.ylabel("Number of recordings", fontsize=14)
plt.title("Distribution of Recording Lengths", fontsize=16)

# Adding a grid for better readability
plt.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

In [None]:
event_durations = (events["stop"] - events["start"])
n_events = len(event_durations)
total_duration = event_durations.sum() / 3600
min_event_duration = event_durations.min()
max_event_duration = event_durations.max()
avg_event_duration = event_durations.mean()

def calculate_metrics(events_subset):
    durations = (events_subset["stop"] - events_subset["start"])
    return {
        "Number of events": len(durations),
        "Total duration (h)": durations.sum() / 3600,
        "Min duration (s)": durations.min(),
        "Avg duration (s)": durations.mean(),
        "Max duration (s)": durations.max()
    }

seizure_metrics = calculate_metrics(events[events["label"] == 1])
non_seizure_metrics = calculate_metrics(events[events["label"] == 0])

data = {
    "Metric": ["Number of events", "Total duration (h)", "Min duration (s)", "Avg duration (s)", "Max duration (s)"],
    "Seizure Events": [seizure_metrics[metric] for metric in seizure_metrics],
    "Non-Seizure Events": [non_seizure_metrics[metric] for metric in non_seizure_metrics]
}

# Create a DataFrame
df = pd.DataFrame(data)
df["Seizure Events"] = df["Seizure Events"].apply(lambda x: f"{x:.2f}" if isinstance(x, (int, float)) else x)
df["Non-Seizure Events"] = df["Non-Seizure Events"].apply(lambda x: f"{x:.2f}" if isinstance(x, (int, float)) else x)

# Plot the table
plt.figure(figsize=(10, 4))
plt.axis('tight')
plt.axis('off')
table = plt.table(cellText=df.values, colLabels=df.columns, cellLoc='center', loc='center')

# Adjust the table properties
table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1.2, 1.2)

plt.show()

plt.figure(figsize=(12, 6))
counts, bins, patches = plt.hist(event_durations, bins=50, edgecolor='black', alpha=0.7)

# Apply the Viridis colormap
viridis = plt.colormaps['viridis']
norm = plt.Normalize(counts.min(), counts.max())

for count, patch in zip(counts, patches):
    color = viridis(norm(count))
    patch.set_facecolor(color)
    
plt.xlabel("Duration (min)", fontsize=14)
plt.ylabel("Number of events", fontsize=14)
plt.title("Distribution of Event Lengths", fontsize=16)
plt.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

In [None]:
patients = events.groupby("patient_id").agg({"sex": "first", "set": "first"})
n_patients = len(patients)
n_male = len(patients[patients["sex"] == 1])
n_female = len(patients[patients["sex"] == 2])
print(f"Number of patients: {n_patients}")
print(f"Number of male patients: {n_male}")
print(f"Number of female patients: {n_female}")

In [None]:
sessions_per_patient = events.groupby("patient_id")["session_id"].unique().apply(len)
min_sessions_per_patient = sessions_per_patient.min()
avg_sessions_per_patient = sessions_per_patient.mean()
max_sessions_per_patient = sessions_per_patient.max()
print(f"Min sessions per patient: {min_sessions_per_patient}")
print(f"Avg sessions per patient: {avg_sessions_per_patient}")
print(f"Max sessions per patient: {max_sessions_per_patient}")

### Distribution of Number of Sessions per Patient
session_counts = sessions_per_patient.value_counts().sort_index()

plt.figure(figsize=(12, 6))
bars = plt.bar(session_counts.index, session_counts.values, edgecolor='black', alpha=0.7, color=plt.colormaps['viridis'](np.linspace(0, 1, len(session_counts))))

# Apply Viridis colormap to the bars
viridis = plt.colormaps['viridis']
norm = plt.Normalize(session_counts.values.min(), session_counts.values.max())

for bar, value in zip(bars, session_counts.values):
    bar.set_facecolor(viridis(norm(value)))

# Add labels and title
plt.xlabel("Number of sessions", fontsize=14)
plt.ylabel("Number of patients", fontsize=14)
plt.title("Distribution of Number of Sessions per Patient", fontsize=16)

# Adding a grid for better readability
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add the counts on top of the bars
for bar, value in zip(bars, session_counts.values):
    plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5, str(value), ha='center', va='bottom')

plt.tight_layout()
plt.show()

In [None]:
channels_per_recording = events.groupby("recording_path")["number_of_channels"].agg(lambda x: x.iloc[0])
min_channels_per_recordings = channels_per_recording.min()
avg_channels_per_recordings = channels_per_recording.mean()
max_channels_per_recordings = channels_per_recording.max()

print(f"Min channels per recording: {min_channels_per_recordings}")
print(f"Avg channels per recording: {avg_channels_per_recordings:.2f}")
print(f"Max channels per recording: {max_channels_per_recordings}")

# Count the occurrences of each unique number of channels and sort by channel number
channel_counts = channels_per_recording.value_counts().sort_index()

# Create the bar plot with Viridis colormap
plt.figure(figsize=(10, 6))
bars = plt.bar(channel_counts.index.astype(str), channel_counts.values, edgecolor='black', alpha=0.7)

# Apply the Viridis colormap to the bars
viridis = plt.colormaps['viridis']
norm = plt.Normalize(channel_counts.values.min(), channel_counts.values.max())

for bar, value in zip(bars, channel_counts.values):
    bar.set_facecolor(viridis(norm(value)))

# Add labels and title
plt.xlabel("Number of Channels per Recording", fontsize=14)
plt.ylabel("Number of Recordings", fontsize=14)
plt.title("Distribution of Channels per Recording", fontsize=16)

# Adding a grid for better readability
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add the counts on top of the bars
for bar, value in zip(bars, channel_counts.values):
    plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5, str(value), ha='center', va='bottom', fontsize=10, color='black')

plt.tight_layout()
plt.show()

In [None]:
# Group by recording_path to get the unique sampling frequency for each recording
sampling_frequencies = events.groupby("recording_path")["sfreq"].agg(lambda x: x.iloc[0])
min_sampling_frequency = sampling_frequencies.min()
max_sampling_frequency = sampling_frequencies.max()
sampling_frequency_counts = sampling_frequencies.value_counts().sort_index()

# Print the statistics
print(f"Min sampling frequency: {min_sampling_frequency}")
print(f"Max sampling frequency: {max_sampling_frequency}")
print(f"Sampling frequencies: {sampling_frequency_counts}")

# Create the bar plot with Viridis colormap
plt.figure(figsize=(10, 6))
bars = plt.bar(sampling_frequency_counts.index.astype(str), sampling_frequency_counts.values, edgecolor='black', alpha=0.7)

# Apply the Viridis colormap to the bars
viridis = plt.colormaps['viridis']
norm = plt.Normalize(sampling_frequency_counts.values.min(), sampling_frequency_counts.values.max())

for bar, value in zip(bars, sampling_frequency_counts.values):
    bar.set_facecolor(viridis(norm(value)))

# Add labels and title
plt.xlabel("Sampling Frequency (Hz)", fontsize=14)
plt.ylabel("Number of Recordings", fontsize=14)
plt.title("Distribution of Sampling Frequencies", fontsize=16)

# Adding a grid for better readability
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add the counts on top of the bars
for bar, value in zip(bars, sampling_frequency_counts.values):
    plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5, str(value), ha='center', va='bottom', fontsize=10, color='black')

plt.tight_layout()
plt.show()


In [None]:
train_set = events[events["set"] == "train"]
dev_set = events[events["set"] == "dev"]
eval_set = events[events["set"] == "eval"]

# Function to calculate durations and number of patients
def calculate_metrics(dataset):
    patients = dataset["patient_id"].nunique()
    
    seizure_events = dataset[dataset["label"] == 1]
    non_seizure_events = dataset[dataset["label"] == 0]
    
    seizure_duration = (seizure_events["stop"] - seizure_events["start"]).sum() / 3600
    non_seizure_duration = (non_seizure_events["stop"] - non_seizure_events["start"]).sum() / 3600
    
    ratio = seizure_duration / non_seizure_duration if non_seizure_duration > 0 else np.nan
    
    return patients, seizure_duration, non_seizure_duration, ratio

# Calculate metrics for each set
train_metrics = calculate_metrics(train_set)
dev_metrics = calculate_metrics(dev_set)
eval_metrics = calculate_metrics(eval_set)

# Calculate totals
total_patients = len(set(train_set["patient_id"]).union(dev_set["patient_id"]).union(eval_set["patient_id"]))
total_seizure_duration = train_metrics[1] + dev_metrics[1] + eval_metrics[1]
total_non_seizure_duration = train_metrics[2] + dev_metrics[2] + eval_metrics[2]
total_ratio = total_seizure_duration / total_non_seizure_duration if total_non_seizure_duration > 0 else np.nan

# Data preparation for plotting
categories = ['Train', 'Dev', 'Eval']
seizure_durations = [train_metrics[1], dev_metrics[1], eval_metrics[1]]
non_seizure_durations = [train_metrics[2], dev_metrics[2], eval_metrics[2]]

# Set up the bar plot
x = np.arange(len(categories))  # the label locations
width = 0.35  # the width of the bars

# Plotting the bars using the Viridis colormap
plt.figure(figsize=(10, 6))
viridis = plt.colormaps['viridis']
bars1 = plt.bar(x - width/2, seizure_durations, width, label='Seizure Duration', color=viridis(0.7))
bars2 = plt.bar(x + width/2, non_seizure_durations, width, label='Non-Seizure Duration', color=viridis(0.3))

# Adding labels and title
plt.xlabel("Dataset", fontsize=14)
plt.ylabel("Duration (hours)", fontsize=14)
plt.title("Seizure vs Non-Seizure Duration Across Datasets", fontsize=16)
plt.xticks(x, categories)

# Adding a grid for better readability
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add legend
plt.legend()

# Adding the counts on top of the bars
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width() / 2, height + 0.1, f'{height:.2f}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()

# Creating the table and flipping rows and columns
table_data = {
    "Train": [train_metrics[0], f"{train_metrics[1]:.2f}", f"{train_metrics[2]:.2f}", f"{train_metrics[3]:.2f}" if not np.isnan(train_metrics[3]) else "N/A"],
    "Dev": [dev_metrics[0], f"{dev_metrics[1]:.2f}", f"{dev_metrics[2]:.2f}", f"{dev_metrics[3]:.2f}" if not np.isnan(dev_metrics[3]) else "N/A"],
    "Eval": [eval_metrics[0], f"{eval_metrics[1]:.2f}", f"{eval_metrics[2]:.2f}", f"{eval_metrics[3]:.2f}" if not np.isnan(eval_metrics[3]) else "N/A"],
    "Total": [total_patients, f"{total_seizure_duration:.2f}", f"{total_non_seizure_duration:.2f}", f"{total_ratio:.2f}" if not np.isnan(total_ratio) else "N/A"]
}

df = pd.DataFrame(table_data, index=["Number of Patients", "Seizure Duration (h)", "Non-Seizure Duration (h)", "Seizure/Non-Seizure Ratio"]).T

# Plot the table below the bar chart
plt.figure(figsize=(12, 4))
plt.axis('tight')
plt.axis('off')
table = plt.table(cellText=df.values, rowLabels=df.index, colLabels=df.columns, cellLoc='center', loc='center')

# Adjust the table properties
table.auto_set_font_size(False)
table.set_fontsize(12)
table.scale(1.2, 1.2)

plt.show()