In [None]:
%pip install numpy pandas 
%pip install mne
%pip install plotly nbformat>=4.2.0
%pip install tensorflowimport os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import mne

In [None]:
# Silence TF Warnings
tf.get_logger().setLevel('ERROR')

class Config:
    # Dataset
    DATA_PATH = "eeg_data/" # Local path or download location
    CHANNELS = ['Oz', 'T7', 'Cz'] # The 3 optimal channels 
    SFREQ = 160.0
    
    # Segmentation Parameters (Table 1) [cite: 1231]
    T = 160          # Window length (samples) -> 1.0 sec
    ETA = 20         # Number of overlapping segments per input image
    DELTA_STRIDE = 4 # Stride between segments (delta)
    
    # The "Sampling Window" F is the total duration required to build one input
    # F = (eta - 1) * delta + T = (19 * 4) + 160 = 236 samples
    F_SAMPLING_WINDOW = 236 
    
    # Data Augmentation Stride (Big Delta)
    AUGMENTATION_STRIDE = 8 # Stride for creating new inputs [cite: 1231]

    # Training
    BATCH_SIZE = 64     # [cite: 1231]
    EPOCHS = 30         # [cite: 1231]
    LR = 0.0001         # [cite: 1231]
    DROPOUT = 0.25      # [cite: 1231]
    
    # Authentication
    THRESHOLD = 0.15     # Placeholder (Paper suggests finding via EER)

In [None]:
def gram_schmidt_orthogonalization(data):
    orthogonalized = np.zeros_like(data)
    
    # 1. First channel (Oz)
    v0 = data[0]
    orthogonalized[0] = v0
    
    # 2. Second channel (T7)
    if np.dot(v0, v0) == 0: num, den = 0, 1 
    else: num, den = np.dot(data[1], v0), np.dot(v0, v0)
    v1 = data[1] - (num / den) * v0
    orthogonalized[1] = v1
    
    # 3. Third channel (Cz)
    if np.dot(v0, v0) == 0: n1, d1 = 0, 1
    else: n1, d1 = np.dot(data[2], v0), np.dot(v0, v0)
    
    if np.dot(v1, v1) == 0: n2, d2 = 0, 1
    else: n2, d2 = np.dot(data[2], v1), np.dot(v1, v1)
    
    v2 = data[2] - (n1 / d1) * v0 - (n2 / d2) * v1
    orthogonalized[2] = v2
    
    return orthogonalized

def preprocess_signal(raw_data):
    # Min-Max Normalization
    min_vals = np.min(raw_data, axis=1, keepdims=True)
    max_vals = np.max(raw_data, axis=1, keepdims=True)
    denom = (max_vals - min_vals)
    denom[denom == 0] = 1.0
    normalized = (raw_data - min_vals) / denom
    
    return gram_schmidt_orthogonalization(normalized)

def create_inputs(raw_data):
    n_channels, n_total_samples = raw_data.shape
    inputs = []
    
    start = 0
    while start + Config.F_SAMPLING_WINDOW <= n_total_samples:
        block = raw_data[:, start : start + Config.F_SAMPLING_WINDOW]
        img_segments = []
        for i in range(Config.ETA):
            seg_start = i * Config.DELTA_STRIDE
            seg_end = seg_start + Config.T
            segment = block[:, seg_start:seg_end] 
            img_segments.append(segment.T) 
            
        input_matrix = np.array(img_segments) 
        inputs.append(input_matrix)
        start += Config.AUGMENTATION_STRIDE
        
    return np.array(inputs)

# ==========================================
# FIXED LOAD_DATASET FUNCTION
# ==========================================
def load_dataset(num_subjects=10):
    import mne.datasets.eegbci as eegbci
    
    X_all = []
    y_all = []
    
    print(f"Loading {num_subjects} subjects...")
    
    for subject_id in range(1, num_subjects + 1):
        try:
            # Load data
            path_list = eegbci.load_data(subject_id, [1], path=Config.DATA_PATH, update_path=False)
            if not path_list:
                print(f"  Skipping Subject {subject_id}: Download failed.")
                continue
            path = path_list[0]
            
            raw = mne.io.read_raw_edf(path, preload=True, verbose='ERROR')
            
            # --- FIX 1: STRIP DOTS FROM CHANNEL NAMES ---
            # EEGMMIDB channels often come as 'Oz.', 'T7.' etc.
            raw.rename_channels(lambda x: x.strip('.'))
            
            # --- FIX 2: CHECK CHANNELS EXIST ---
            available_channels = set(raw.ch_names)
            missing = [ch for ch in Config.CHANNELS if ch not in available_channels]
            if missing:
                print(f"  Subject {subject_id} missing channels: {missing}. Skipping.")
                continue
                
            # --- FIX 3: USE MODERN PICK METHOD ---
            raw.pick(Config.CHANNELS)
            
            # Resample
            if raw.info['sfreq'] != Config.SFREQ:
                raw.resample(Config.SFREQ, verbose='ERROR')
                
            data = raw.get_data()
            processed_data = preprocess_signal(data)
            inputs = create_inputs(processed_data)
            
            if len(inputs) > 0:
                X_all.append(inputs)
                # Use 0-indexed labels for Sparse Categorical Crossentropy
                y_all.append(np.full(len(inputs), subject_id - 1)) 
                print(f"  Subject {subject_id}: {len(inputs)} samples loaded.")
                
        except Exception as e:
            print(f"  Failed to load Subject {subject_id}: {e}")

    # --- FIX 4: HANDLE EMPTY DATASET ---
    if not X_all: 
        return None, None
    
    return np.concatenate(X_all), np.concatenate(y_all)

In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from sklearn.utils import shuffle
from sklearn.metrics import roc_curve
import mne

# Silence TF Warnings
tf.get_logger().setLevel('ERROR')

# --- CONFIGURATION ---
class Config:
    DATA_PATH = "eeg_data/" 
    CHANNELS = ['Oz', 'T7', 'Cz'] 
    SFREQ = 160.0
    T = 160          
    ETA = 20         
    DELTA_STRIDE = 4 
    F_SAMPLING_WINDOW = 236 
    AUGMENTATION_STRIDE = 8 
    BATCH_SIZE = 64     
    EPOCHS = 30         
    LR = 0.0001         
    DROPOUT = 0.25      

# --- HELPER FUNCTIONS ---
def cosine_distance(v1, v2):
    epsilon = 1e-10 # Prevent divide by zero
    v1_n = v1 / (np.linalg.norm(v1) + epsilon)
    v2_n = v2 / (np.linalg.norm(v2) + epsilon)
    return 1.0 - np.dot(v1_n, v2_n)

def find_optimal_threshold(gen_scores, imp_scores):
    # SAFETY CHECK: Ensure we have both classes
    if len(gen_scores) == 0 or len(imp_scores) == 0:
        print("  [Warning] Not enough data to compute ROC. Using default threshold 0.5")
        return 2.5

    y_true = [1] * len(gen_scores) + [0] * len(imp_scores)
    y_scores = [-s for s in gen_scores] + [-s for s in imp_scores] 
    
    try:
        fpr, tpr, thresholds = roc_curve(y_true, y_scores)
        fnr = 1 - tpr
        eer_index = np.nanargmin(np.abs(fnr - fpr))
        optimal_threshold = -thresholds[eer_index]
        
        print(f"  [Auto-Tuning] Found Optimal Threshold: {optimal_threshold:.4f}")
        print(f"  [Auto-Tuning] Estimated EER: {fpr[eer_index]*100:.2f}%")
        return optimal_threshold
    except Exception as e:
        print(f"  [Warning] ROC Calculation failed ({e}). Using default 0.5")
        return 2.5

# --- MODEL DEFINITION (Must be included) ---
def build_paper_model(num_classes):
    inputs = layers.Input(shape=(Config.ETA, Config.T, len(Config.CHANNELS)))
    x = layers.Conv2D(128, (3, 3), padding='same', activation='relu')(inputs)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(512, (3, 3), padding='same', activation='relu')(x)
    x = layers.Flatten()(x)
    fingerprint = layers.Dense(1024, activation='relu', name="fingerprint_layer")(x)
    x = layers.Dropout(Config.DROPOUT)(fingerprint)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return models.Model(inputs=inputs, outputs=outputs)

# --- MAIN EXECUTION ---
if __name__ == "__main__":
    # REQUESTED SIZES
    REQ_TRAIN = 50 
    REQ_TEST = 10
    
    print("=== Phase 1: Loading Data ===")
    # Ensure load_dataset is defined (from previous steps) 
    # If you need the load_dataset function again, let me know. 
    # Assuming it is already in your notebook memory.
    
    try:
        X_all, y_all = load_dataset(num_subjects=REQ_TRAIN + REQ_TEST)
    except NameError:
        print("CRITICAL: 'load_dataset' function not found. Please run the cell containing 'load_dataset' definition first.")
        exit()

    if X_all is None: 
        print("No data loaded. Exiting.")
        exit()

    # 1. Shuffle immediately
    X_all, y_all = shuffle(X_all, y_all, random_state=42)

    # 2. Check actual number of loaded subjects
    unique_subjects = np.unique(y_all)
    num_loaded = len(unique_subjects)
    print(f"\n--- DATA STATUS ---")
    print(f"Requested: {REQ_TRAIN} Train + {REQ_TEST} Test = {REQ_TRAIN + REQ_TEST}")
    print(f"Actually Loaded: {num_loaded} Subjects")

    if num_loaded < 2:
        print("Error: Need at least 2 subjects to run. Exiting.")
        exit()

    # 3. DYNAMIC SPLIT LOGIC (Prevents Crashing)
    if num_loaded < (REQ_TRAIN + REQ_TEST):
        print("Warning: Fewer subjects loaded than requested.")
        # Reserve at least 2 for testing, use rest for training
        N_TEST = 2
        N_TRAIN = num_loaded - N_TEST
        print(f"Adjusting split to: {N_TRAIN} Train, {N_TEST} Test")
    else:
        N_TRAIN = REQ_TRAIN
        N_TEST = REQ_TEST

    # 4. Apply Split
    # We sort unique_subjects to ensure deterministic split
    unique_subjects.sort()
    train_ids = unique_subjects[:N_TRAIN]
    test_ids = unique_subjects[N_TRAIN : N_TRAIN + N_TEST]
    
    print(f"Training on Subjects: {train_ids}")
    print(f"Testing on Subjects: {test_ids}")

    # Create Masks
    train_mask = np.isin(y_all, train_ids)
    test_mask = np.isin(y_all, test_ids)
    
    X_train, y_train = X_all[train_mask], y_all[train_mask]
    X_test, y_test = X_all[test_mask], y_all[test_mask]

    # Remap training labels to 0..N-1
    map_lbl = {old: new for new, old in enumerate(train_ids)}
    y_train_map = np.array([map_lbl[y] for y in y_train])

    print(f"\n=== Phase 2: Training Proxy Classifier ===")
    model = build_paper_model(num_classes=len(train_ids))
    model.compile(optimizer=optimizers.RMSprop(learning_rate=Config.LR), 
                  loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    
    history = model.fit(X_train, y_train_map, epochs=Config.EPOCHS, batch_size=Config.BATCH_SIZE, validation_split=0.1, verbose=1)

    print("\n=== Phase 3: Extracting Fingerprinter ===")
    fingerprint_model = models.Model(inputs=model.input, outputs=model.get_layer("fingerprint_layer").output)
    
    print("\n=== Phase 4: Universal Authentication & Threshold Tuning ===")
    
    if len(test_ids) >= 2:
        user_a = test_ids[0]
        user_b = test_ids[1]
        
        print(f"Scenario: Genuine User {user_a} vs Impostor User {user_b}")
        
        data_a = X_test[y_test == user_a]
        data_b = X_test[y_test == user_b]
        
        if len(data_a) < 10:
            print(f"Warning: User {user_a} has very little data ({len(data_a)} samples). Results may be unstable.")

        # Split A into Enrollment (50%) and Calibration/Test (50%)
        split = len(data_a) // 2
        if split == 0: 
            print("Not enough data to split for enrollment. Skipping auth test.")
            exit()
            
        enroll_a = data_a[:split]
        probe_a = data_a[split:]
        probe_b = data_b
        
        # 1. Create Template
        enroll_fps = fingerprint_model.predict(enroll_a, verbose=0)
        template_a = np.mean(enroll_fps, axis=0)
        
        # 2. Collect Scores
        gen_fps = fingerprint_model.predict(probe_a, verbose=0)
        imp_fps = fingerprint_model.predict(probe_b, verbose=0)
        
        gen_scores = [cosine_distance(template_a, fp) for fp in gen_fps]
        imp_scores = [cosine_distance(template_a, fp) for fp in imp_fps]
        
        # 3. Find Best Threshold
        best_threshold = find_optimal_threshold(gen_scores, imp_scores)
        
        # 4. Apply
        accepted_gen = sum(1 for s in gen_scores if s < best_threshold)
        rejected_imp = sum(1 for s in imp_scores if s > best_threshold)
        
        print(f"\n--- Final Results (Threshold {best_threshold:.4f}) ---")
        print(f"Genuine Acceptance Rate (GAR): {accepted_gen/len(gen_scores)*100:.1f}%")
        print(f"Impostor Rejection Rate (GRR): {rejected_imp/len(imp_scores)*100:.1f}%")
    else:
        print("Not enough test subjects found for authentication.")