In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
!pip install mne
!pip install ssqueezepy
# !pip install stockwell
!pip install tensorflow
# !pip install kaggle



In [None]:
import matplotlib.pyplot as plt
import os
os.environ['SSQ_GPU'] = '1'
from ssqueezepy import ssq_cwt
import numpy as np
# from stockwell import st
from sklearn.preprocessing import MinMaxScaler
import mne
from scipy.ndimage import zoom
import tensorflow as tf


import demographic.csv

In [None]:
import pandas as pd
demo = pd.read_csv('/home/zeek/Downloads/demographic.csv')
demo
# demographic.csv determined the class of the .csv files in the dataset

Unnamed: 0,subject,group,gender,age,education
0,1,0,M,44,16.0
1,2,0,M,39,17.0
2,3,0,M,53,18.0
3,4,0,M,52,15.0
4,5,0,M,41,16.0
...,...,...,...,...,...
76,77,1,M,28,13.0
77,78,1,F,32,16.0
78,79,1,M,37,16.0
79,80,1,M,33,13.0


Code to save already processed files

In [None]:
import os
import numpy as np
import pandas as pd
import re
import gc
import psutil
from PIL import Image
import mne
from ssqueezepy import ssq_cwt

count = 1
# Define directories
directory = '/home/zeek/External/Divide-20250317T133000Z-001/Divide/'
output_directory = "/home/zeek/External/Images/I660es_1/"
processed_file_path = "/home/zeek/Downloads/processed_files_part_5.txt"

# Load the list of processed files
if os.path.exists(processed_file_path):
    with open(processed_file_path, "r", encoding="utf-8") as f:
        processed_files = set(f.read().splitlines())
else:
    processed_files = set()

# List all relevant CSV files
csv_files = [f for f in os.listdir(directory) if f.startswith('downsampled_') and f.endswith('.csv')]

def memory_usage():
    """Prints memory usage in MB."""
    process = psutil.Process(os.getpid())
    print(f"Memory Usage: {process.memory_info().rss / (1024 * 1024):.2f} MB")

def save_tfr_image(tfr_data, category, filename, output_directory):
    os.makedirs(os.path.join(output_directory, category), exist_ok=True)

    # Normalize TFR data to [0, 255]
    tfr_data_normalized = (tfr_data - np.min(tfr_data)) / (np.max(tfr_data) - np.min(tfr_data)) * 255
    img = Image.fromarray(tfr_data_normalized.astype(np.uint8))
    img = img.resize((224, 224), Image.LANCZOS)
    img.save(os.path.join(output_directory, category, f"{filename}.png"))

    # Cleanup
    del img, tfr_data_normalized
    gc.collect()

def generate_tfrs(csv_file):
    global count

    # Skip already processed files
    if csv_file in processed_files:
        print(f"Skipping already processed file: {csv_file}")
        return

    print(f"Processing file: {csv_file}")
    print(f"Count: {count}/1660")
    count += 1

    # Extract patient and trial numbers from filename
    match = re.search(r"downsampled_(\d+)_part_(\d+)\.csv", csv_file)
    if match:
        patient_num = int(match.group(1))
        trial_num = int(match.group(2))
    else:
        print(f"Skipping file {csv_file} due to incorrect format.")
        return

    file_path = os.path.join(directory, csv_file)

    try:
        df = pd.read_csv(file_path)
    except Exception as e:
        print(f"Error reading {csv_file}: {e}, skipping.")
        return

    # Ensure 'condition' column exists
    if "condition" not in df.columns:
        print(f"Condition column not found in {csv_file}, skipping file.")
        return

    # Group data by condition
    grouped = df.groupby("condition")

    for condition_num, group in grouped:
        data = group.iloc[:, :-4].values.T.astype(np.float32)  # Transpose to (channels, time)
        print(f"Processing Condition: {condition_num}, Data Shape: {data.shape}")

        channel_names = list(df.columns[:-4])
        info = mne.create_info(ch_names=channel_names, sfreq=128, ch_types='eeg')
        raw = mne.io.RawArray(data, info)
        raw.pick(['Fp1', 'Fp2','F7', 'F3', 'Fz', 'F4', 'F8', 'T7', 'C3', 'Cz', 'C4', 'T8', 'TP7', 'P3', 'Pz', 'P4', 'TP8', 'O1', 'O2'])
        data = raw.get_data()
        print(f"Updated : {data.shape}")

        for k in range(data.shape[0]):  # Iterate over channels
            channel_data = data[k, :]

            # Compute synchrosqueezing transform
            try:
                Twxo, Wxo, *_ = ssq_cwt(channel_data, astensor=True)
                window_tfr = np.abs(Twxo.cpu().data.numpy())
            except Exception as e:
                print(f"Error processing TFR for {csv_file}, channel {k}: {e}")
                continue

            # Generate filename with patient, trial, condition, and channel
            filename = f"sub{patient_num}_trial{trial_num}_cond{condition_num}_ch{k}"
            category = 'schizophrenic' if demo.iloc[patient_num - 1, 1] else 'healthy'

            save_tfr_image(window_tfr, category, filename, output_directory)

            # Clear memory after each channel
            del Twxo, Wxo, window_tfr, channel_data
            gc.collect()

    # Append to processed files list
    with open(processed_file_path, "a") as f:
        f.write(csv_file + "\n")

    processed_files.add(csv_file)

    # Clear memory after each file
    del df
    gc.collect()

# Process each CSV file
for csv_file in csv_files:
    generate_tfrs(csv_file)
    memory_usage()

print(count)
print("Completed")

Processing file: downsampled_10_part_16.csv
Count: 1/1660
Processing Condition: 1.0, Data Shape: (70, 384)
Creating RawArray with float64 data, n_channels=70, n_times=384
    Range : 0 ... 383 =      0.000 ...     2.992 secs
Ready.
Updated : (19, 384)
Processing Condition: 3.0, Data Shape: (70, 384)
Creating RawArray with float64 data, n_channels=70, n_times=384
    Range : 0 ... 383 =      0.000 ...     2.992 secs
Ready.
Updated : (19, 384)
Memory Usage: 1404.36 MB
Processing file: downsampled_10_part_26.csv
Count: 2/1660
Processing Condition: 1.0, Data Shape: (70, 384)
Creating RawArray with float64 data, n_channels=70, n_times=384
    Range : 0 ... 383 =      0.000 ...     2.992 secs
Ready.
Updated : (19, 384)
Processing Condition: 2.0, Data Shape: (70, 384)
Creating RawArray with float64 data, n_channels=70, n_times=384
    Range : 0 ... 383 =      0.000 ...     2.992 secs
Ready.
Updated : (19, 384)
Processing Condition: 3.0, Data Shape: (70, 384)
Creating RawArray with float64 dat