In [6]:
import os
import re
from glob import glob
from typing import List, Tuple

import numpy as np
import pandas as pd

# -----------------------------
# USER CONFIG
# -----------------------------
INPUT_FILES = None
FOLDER      = "/content/data"
FEATURE     = "magnitude"
ROWS_PER_SAMPLE = 16
OUT_X = "X_real_256.csv" if FEATURE == "magnitude" else "X_real_512.csv"
OUT_y = "y_real_256.csv" if FEATURE == "magnitude" else "y_real_512.csv"

# -----------------------------
# HELPERS
# -----------------------------
COMPLEX_REGEX = re.compile(
    r"^\(?\s*-?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?\s*[+-]\s*\d+(?:\.\d+)?(?:[eE][+-]?\d+)?j\s*\)?$"
)

def is_complex_str(x: str) -> bool:
    if not isinstance(x, str):
        return False
    return COMPLEX_REGEX.match(x.strip()) is not None

def parse_complex(x: str) -> complex:
    """Parse '(a+bj)' or 'a-bj' -> complex(a, b)."""
    if isinstance(x, complex):
        return x
    if not isinstance(x, str):
        return complex(np.nan, np.nan)
    s = x.strip()
    if s.startswith("(") and s.endswith(")"):
        s = s[1:-1]
    try:
        return complex(s.replace(" ", ""))
    except Exception:
        return complex(np.nan, np.nan)

def detect_measurement_columns(df: pd.DataFrame) -> List[str]:
    """Columns where most sampled values look like complex strings."""
    candidates = []
    for c in df.columns:
        series = df[c].dropna().astype(str)
        if len(series) == 0:
            continue
        sample = series.sample(n=min(50, len(series)), random_state=42)
        if sample.apply(is_complex_str).mean() > 0.7:
            candidates.append(c)
    return candidates

def coerce_label_binary(r_value) -> int:
    """Map r==0 -> 0 (isotropic), else -> 1 (anisotropic)."""
    try:
        val = float(r_value)
    except Exception:
        return None
    return 0 if abs(val) < 1e-9 else 1

def block_to_features(meas_block: pd.DataFrame, feature: str = "magnitude") -> np.ndarray:
    """
    Convert ~16 rows x ~16 electrode columns into features.
    - If 'magnitude': returns 16x16 magnitudes flattened (256 features).
    - If 'realimag': returns [real; imag] flattened (512 features).
    """
    meas_cols = detect_measurement_columns(meas_block)
    if len(meas_cols) < 16:
        # fallback: take the last 16 columns
        meas_cols = list(meas_block.columns[-16:])

    C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()

    # Pad/trim rows to 16
    if C.shape[0] < 16:
        pad_rows = 16 - C.shape[0]
        C = np.vstack([C, np.full((pad_rows, C.shape[1]), np.nan+1j*np.nan)])
    elif C.shape[0] > 16:
        C = C[:16, :]

    # Pad/trim cols to 16
    if C.shape[1] < 16:
        pad_cols = 16 - C.shape[1]
        C = np.hstack([C, np.full((C.shape[0], pad_cols), np.nan+1j*np.nan)])
    elif C.shape[1] > 16:
        C = C[:, :16]

    if feature == "magnitude":
        M = np.abs(C).astype(float)
        col_meds = np.nanmedian(M, axis=0)
        inds = np.where(np.isnan(M))
        if inds[0].size > 0:
            M[inds] = np.take(col_meds, inds[1])
        return M.flatten().astype(np.float32)

    elif feature == "realimag":
        R = np.real(C).astype(float)
        I = np.imag(C).astype(float)
        for A in (R, I):
            col_meds = np.nanmedian(A, axis=0)
            inds = np.where(np.isnan(A))
            if inds[0].size > 0:
                A[inds] = np.take(col_meds, inds[1])
        return np.concatenate([R.flatten(), I.flatten()]).astype(np.float32)

    else:
        raise ValueError("feature must be 'magnitude' or 'realimag'")

def preprocess_files(file_paths: List[str],
                     feature: str = "magnitude",
                     rows_per_sample: int = 16) -> Tuple[np.ndarray, np.ndarray]:
    X_list, y_list = [], []

    for fp in file_paths:
        df = pd.read_csv(fp, header=0)
        df = df.dropna(axis=1, how="all")

        if "r" not in df.columns:
            raise ValueError(f"'r' column not found in {fp}")

        y_file = coerce_label_binary(df["r"].iloc[0])
        if y_file is None:
            raise ValueError(f"Could not parse label from 'r' in {fp}")

        # Identify a frame/sample grouping column if present
        frame_col = next((c for c in ["frame_id", "frame", "sample_id", "n_sample", "sample_idx"] if c in df.columns), None)

        if frame_col is not None:
            for _, g in df.groupby(frame_col):
                for start in range(0, len(g), rows_per_sample):
                    block = g.iloc[start:start+rows_per_sample]
                    if len(block) == 0:
                        continue
                    x = block_to_features(block, feature=feature)
                    X_list.append(x)
                    y_list.append(y_file)
        else:
            for start in range(0, len(df), rows_per_sample):
                block = df.iloc[start:start+rows_per_sample]
                if len(block) == 0:
                    continue
                x = block_to_features(block, feature=feature)
                X_list.append(x)
                y_list.append(y_file)

    if not X_list:
        nfeat = 256 if feature == "magnitude" else 512
        return np.empty((0, nfeat), dtype=np.float32), np.empty((0,), dtype=np.int64)

    X = np.vstack(X_list).astype(np.float32)
    y = np.array(y_list, dtype=np.int64)
    return X, y

# -----------------------------
# RUN
# -----------------------------
if INPUT_FILES is None:
    file_paths = sorted(glob(os.path.join(FOLDER, "r_*_mm.csv")))
else:
    file_paths = INPUT_FILES

if not file_paths:
    raise FileNotFoundError("No input CSVs found. Set INPUT_FILES or place r_*_mm.csv in FOLDER.")

print("[INFO] Using files:")
for p in file_paths:
    print(" -", p)

X, y = preprocess_files(file_paths, feature=FEATURE, rows_per_sample=ROWS_PER_SAMPLE)

# Save
pd.DataFrame(X).to_csv(OUT_X, header=False, index=False)
pd.Series(y).to_csv(OUT_y, header=False, index=False)

# Report
counts = np.bincount(y) if y.size else np.array([])
print(f"[DONE] X shape: {X.shape} | y shape: {y.shape} | class counts: {counts}")
print(f"[SAVE] Features -> {OUT_X}")
print(f"[SAVE] Labels   -> {OUT_y}")


[INFO] Using files:
 - /content/data/r_0_mm.csv
 - /content/data/r_10_mm.csv
 - /content/data/r_15_mm.csv
 - /content/data/r_20_mm.csv
 - /content/data/r_25_mm.csv
 - /content/data/r_5_mm.csv


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(str).applymap(parse_complex).to_numpy()
  C = meas_block[meas_cols].astype(

[DONE] X shape: (3040, 256) | y shape: (3040,) | class counts: [  40 3000]
[SAVE] Features -> X_real_256_bulk.csv
[SAVE] Labels   -> y_real_256_bulk.csv
