In [1]:
import pandas as pd
import numpy as np
import torch
import jax
import sklearn

# Trying To load T15 Data

In [2]:
import os
import h5py

ls = []
parent_folder = "/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final"
if not os.path.exists(parent_folder):
    print("Path not found!")
else:
    for root, dirs, files in os.walk(parent_folder):
        for file in files:
            if file.endswith(".hdf5"):
                file_path = os.path.join(root,file)
                ls.append(file_path)

In [3]:
ls

['/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.11/data_train.hdf5',
 '/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.13/data_test.hdf5',
 '/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.13/data_train.hdf5',
 '/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.13/data_val.hdf5',
 '/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.18/data_test.hdf5',
 '/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.18/data_train.hdf5',
 '/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.18/data_val.hdf5',
 '/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final/t15.2023.08.20/data_test.hdf5',
 '/mnt/c/Users/Siddh/Datasets/brain-to-text-25/

### Lets check whats within these hdf5 files...|

In [4]:
import h5py

sample_file = ls[0]

with h5py.File(sample_file, 'r') as f:
    print(f"Keys in file: {list(f.keys())}")
    
    # Let's look at the shape of the data inside a key
    for key in f.keys():
        item = f[key]
        if isinstance(item, h5py.Dataset):
            print(f"Key: {key} | Shape: {item.shape} | Type: {item.dtype}")
        else:
            # print(f"Key: {key} is a Group (folder)")
            continue

Keys in file: ['trial_0000', 'trial_0001', 'trial_0002', 'trial_0003', 'trial_0004', 'trial_0005', 'trial_0006', 'trial_0007', 'trial_0008', 'trial_0009', 'trial_0010', 'trial_0011', 'trial_0012', 'trial_0013', 'trial_0014', 'trial_0015', 'trial_0016', 'trial_0017', 'trial_0018', 'trial_0019', 'trial_0020', 'trial_0021', 'trial_0022', 'trial_0023', 'trial_0024', 'trial_0025', 'trial_0026', 'trial_0027', 'trial_0028', 'trial_0029', 'trial_0030', 'trial_0031', 'trial_0032', 'trial_0033', 'trial_0034', 'trial_0035', 'trial_0036', 'trial_0037', 'trial_0038', 'trial_0039', 'trial_0040', 'trial_0041', 'trial_0042', 'trial_0043', 'trial_0044', 'trial_0045', 'trial_0046', 'trial_0047', 'trial_0048', 'trial_0049', 'trial_0050', 'trial_0051', 'trial_0052', 'trial_0053', 'trial_0054', 'trial_0055', 'trial_0056', 'trial_0057', 'trial_0058', 'trial_0059', 'trial_0060', 'trial_0061', 'trial_0062', 'trial_0063', 'trial_0064', 'trial_0065', 'trial_0066', 'trial_0067', 'trial_0068', 'trial_0069', 'tria

### Shows that the HDF5 files are hierarchical i.e. Instead of one big block of data, your file is organized into "Groups" (which act like folders), where each group represents a single "Trial"

##### > File Level: Contains ~288 folders (`trial_0000`, etc.).
##### > Trial Level: Inside each `trial_XXXX` folder, you will find the actual Datasets (the neural arrays, the target text, etc.).

### Let us look into a single trial

In [5]:
import h5py

sample_file = ls[0]

with h5py.File(sample_file, 'r') as f:
    trial_group = f['trial_0000']
    
    print(f"--- Inspecting inside 'trial_0000' ---")
    print(f"Keys: {list(trial_group.keys())}")
    
    # Loop through the items in this trial to see their shapes
    for key in trial_group.keys():
        data_item = trial_group[key]
        
        # Check if it's actual data (Dataset) or another folder
        if isinstance(data_item, h5py.Dataset):
            # We want to see the shape (e.g., [Time, Channels]) and type
            print(f"  [DATASET] Name: {key:<20} | Shape: {data_item.shape} | Type: {data_item.dtype}")
            
            # If it's a small text label, let's print it to see what it says
            if data_item.size < 10 and (data_item.dtype.kind in 'SUa'): # String/Unicode types
                print(f"            Value: {data_item[()]}")
                
        else:
            print(f"  [GROUP]   Name: {key}")

--- Inspecting inside 'trial_0000' ---
Keys: ['input_features', 'seq_class_ids', 'transcription']
  [DATASET] Name: input_features       | Shape: (321, 512) | Type: float32
  [DATASET] Name: seq_class_ids        | Shape: (500,) | Type: int32
  [DATASET] Name: transcription        | Shape: (500,) | Type: int32


## The Challenge: "Trials inside Files"
### You cannot just pass the list of file paths to the generic PyTorch loader because one file contains multiple samples (trials). If you have 100 files and each has 200 trials, you actually have 20,000 samples. To tackle this we create a `Global Index Map`

In [6]:
import h5py

samples_index = []

for file_path in ls:
    try:
        with h5py.File(file_path, 'r') as f:
            # We just read the keys (trial names), we don't load the heavy data
            trial_names = list(f.keys())
            for t_name in trial_names:
                samples_index.append((file_path, t_name))
    except Exception as e:
        print(f"Skipping broken file: {file_path}")

print(f"Indexing complete!")
print(f"Total files: {len(ls)}")
print(f"Total individual trials (samples): {len(samples_index)}")

Indexing complete!
Total files: 127
Total individual trials (samples): 10948


### We load the given data into 3 different dataframes, which then we use later on as our reference dataframes

In [7]:
import h5py
import pandas as pd
import os

# ls = [ ... your list of file paths ... ]

manifest = []

print(f"Scanning {len(ls)} files using FILENAME logic...")

for file_path in ls:
    try:
        # Extract filename and folder info
        file_name = os.path.basename(file_path)
        path_parts = file_path.replace('\\', '/').split('/')
        session_date = next((p for p in path_parts if p.startswith('t15.')), "Unknown")
        
        # 1. DETERMINE TYPE BY FILENAME
        if "train" in file_name:
            split_type = "train"
        elif "val" in file_name:
            split_type = "validation"
        elif "test" in file_name:
            split_type = "test"
        else:
            split_type = "unknown"

        # 2. VERIFY CONTENTS (Do labels exist?)
        with h5py.File(file_path, 'r') as f:
            if len(f.keys()) == 0:
                continue
                
            first_trial = list(f.keys())[0]
            group = f[first_trial]
            
            has_labels = 'transcription' in group
            n_trials = len(f.keys())
            
            manifest.append({
                'file_path': file_path,
                'session_date': session_date,
                'filename': file_name,
                'split': split_type,      # 'train', 'val', 'test'
                'has_labels': has_labels, # True/False
                'n_trials': n_trials
            })

    except Exception as e:
        print(f"Error reading {file_path}: {e}")

# --- SUMMARY REPORT ---
df = pd.DataFrame(manifest)
df.to_csv("t15_split_manifest.csv", index=False)

print("\n--- FINAL SPLIT REPORT ---")
print(df.groupby(['split', 'has_labels'])['n_trials'].sum())

# CRITICAL CHECK: Does validation data have labels?
val_labels = df[df['split'] == 'validation']['has_labels'].all()
if val_labels:
    print("\n[SUCCESS] All Validation files have labels! You can calculate accuracy immediately.")
else:
    print("\n[WARNING] Some Validation files are missing labels. Check the CSV.")

Scanning 127 files using FILENAME logic...

--- FINAL SPLIT REPORT ---
split       has_labels
test        False         1450
train       True          8072
validation  True          1426
Name: n_trials, dtype: int64

[SUCCESS] All Validation files have labels! You can calculate accuracy immediately.


In [8]:
# 1. Load the Master Manifest we created in the previous step
manifest_path = "t15_split_manifest.csv"

if not os.path.exists(manifest_path):
    print(f"Error: '{manifest_path}' not found. Please run the Classification Script first.")
else:
    # Read the full list
    full_df = pd.read_csv(manifest_path)
    
    # 2. Filter into 3 DataFrames
    # We create copies (.copy()) so we can modify them later without warnings
    train_df = full_df[full_df['split'] == 'train'].copy()
    val_df   = full_df[full_df['split'] == 'validation'].copy()
    test_df  = full_df[full_df['split'] == 'test'].copy()

    # 3. Verify the Counts
    print("--- DATASET SPLIT REPORT ---")
    print(f"TRAIN Set : {len(train_df)} files")
    print(f"VAL Set   : {len(val_df)} files")
    print(f"TEST Set  : {len(test_df)} files")
    print("-" * 30)
    print(f"TOTAL     : {len(full_df)} files")

    # 4. Quick Sanity Check
    # Ensure Train/Val actually have labels (should be True)
    train_has_labels = train_df['has_labels'].all()
    val_has_labels = val_df['has_labels'].all()
    
    if train_has_labels and val_has_labels:
        print("\n[OK] Integrity Check: All Train and Validation files have labels.")
    else:
        print("\n[WARNING] Some Train/Val files are missing labels! Check your manifest.")

--- DATASET SPLIT REPORT ---
TRAIN Set : 45 files
VAL Set   : 41 files
TEST Set  : 41 files
------------------------------
TOTAL     : 127 files

[OK] Integrity Check: All Train and Validation files have labels.


In [9]:
train_df

Unnamed: 0,file_path,session_date,filename,split,has_labels,n_trials
0,/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t...,t15.2023.08.11,data_train.hdf5,train,True,288
2,/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t...,t15.2023.08.13,data_train.hdf5,train,True,348
5,/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t...,t15.2023.08.18,data_train.hdf5,train,True,197
8,/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t...,t15.2023.08.20,data_train.hdf5,train,True,278
11,/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t...,t15.2023.08.25,data_train.hdf5,train,True,88
14,/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t...,t15.2023.08.27,data_train.hdf5,train,True,150
17,/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t...,t15.2023.09.01,data_train.hdf5,train,True,297
20,/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t...,t15.2023.09.03,data_train.hdf5,train,True,322
23,/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t...,t15.2023.09.24,data_train.hdf5,train,True,245
26,/mnt/c/Users/Siddh/Datasets/brain-to-text-25/t...,t15.2023.09.29,data_train.hdf5,train,True,153


### Building a character vocabulary set, we'll use the train_df as reference to build it and identify every unique character in it

In [10]:
import pandas as pd
import h5py
import json
import os

train_files = train_df['file_path'].to_list()

print(f"Scanning {len(train_files)} training files to build the keyboard...")
unique_chars = set()

for file_path in train_files:
    try:
        with h5py.File(file_path, 'r') as f:
            for key in f.keys():
                group = f[key]
                
                # Check for the sentence text in attributes
                if 'sentence_label' in group.attrs:
                    sentence = group.attrs['sentence_label']
                    
                    # Convert bytes to string if needed
                    if isinstance(sentence, bytes):
                        sentence = sentence.decode('utf-8')
                        
                    # Add every character to our set
                    for char in sentence:
                        unique_chars.add(char)
                        
    except Exception as e:
        print(f"Skipping file: {e}")
# --- FORMATTING THE VOCAB ---
# Sort the list so 'a' always comes before 'b'
sorted_chars = sorted(list(unique_chars))

# Create the map: Char -> Number
# We start at 1 because 0 is usually reserved for the "Blank" token in CTC Loss
char_to_int = {char: idx + 1 for idx, char in enumerate(sorted_chars)}
int_to_char = {idx + 1: char for idx, char in enumerate(sorted_chars)}

# Add the special CTC Blank Token
char_to_int['<BLANK>'] = 0
int_to_char[0] = '<BLANK>'

# Save it!
vocab_data = {
    'char_to_int': char_to_int,
    'int_to_char': int_to_char,
    'n_classes': len(char_to_int) # This tells us how many output neurons we need
}

with open("t15_vocab.json", "w") as f:
    json.dump(vocab_data, f, indent=4)

print("\n--- KEYBOARD BUILT ---")
print(f"Found {len(sorted_chars)} unique characters.")
print(f"Total Model Output Size: {len(char_to_int)} (including Blank)")
print(f"Characters: {''.join(sorted_chars)}")

Scanning 45 training files to build the keyboard...

--- KEYBOARD BUILT ---
Found 62 unique characters.
Total Model Output Size: 63 (including Blank)
Characters:  !',-.;?ABCDEFGHIJKLMNOPQRSTUVWYZ[]abcdefghijklmnopqrstuvwxyzâ€™


In [12]:
# import torch
# print(torch.cuda.is_available())
# print(torch.cuda.device_count())
# print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")

## Complete Preprocesing Pipeline

In [14]:
# class T15Preprocessor:
#     def __init__(self, bin_size_ms=20, smooth_sigma_ms=80, threshold_rms=-4.5):
#         self.bin_size_ms = bin_size_ms
#         self.sigma = smooth_sigma_ms / bin_size_ms # Convert ms -> bins
#         self.threshold = threshold_rms

#     def __call__(self, raw_spikes):
#         n_bins = raw_spikes.shape[0] // self.bin_size_ms
#         binned = raw_spikes[:n_bins * self.bin_size_ms]
#         binned = binned.reshape(n_bins, self.bin_size_ms, -1).sum(axis=1)
#         mean = np.mean(binned, axis=0)
#         std = np.std(binned, axis=0)
#         std[std < 1e-6] = 1.0
#         normalized = (binned - mean) / std
#         normalized = np.clip(normalized, -10, 10)
#         smoothed = gaussian_filter1d(normalized, sigma=self.sigma, axis=0)

#         return smoothed.astype(np.float32)

In [15]:
import json
from scipy.ndimage import gaussian_filter1d
import torch
from torch.utils.data import Dataset

In [16]:
class T15Dataset(Dataset):
    def __init__(self, df, vocab_path="t15_vocab.json", smooth_sigma=4.0, clip_val=5.0, augment=False):
        self.df = df.reset_index(drop=True)
        self.smooth_sigma = smooth_sigma
        self.clip_val = clip_val
        self.augment = augment

        #Load Vocabulary - 
        with open(vocab_path, 'r') as f:
            data = json.load(f)
            self.char_to_int = data['char_to_int']
            self.blank_token = 0 #CTC requires a blank token at index 0 ,i.e. for unknown characters
    def __len__(self):
        return len(self.df)
        
    #Converts 'Hello' -> [8, 5, 12, 12, 15] based on vocab.
    def text_to_int(self, text):
        result = []
        for char in text:
            if char in self.char_to_int:
                result.append(self.char_to_int[char])
        return torch.LongTensor(result)

    # ARTIFACT REMOVAL (Dead/Saturated Channel Mask), Identify channels that are completely silent(std ~ 0) or screaming(std > 50) BEFORE normalization
    def preprocess_neural(self, neural_data):
        ch_std = np.std(neural_data, axis=0)
        dead_channels = ch_std < 0.01

        #Z-SCORE NORMALIZATION i.e. Normalize time-series to Mean=0, Std=1
        mean = np.mean(neural_data, axis=0)
        std = ch_std
        std[std == 0] = 1.0
        normalized = (neural_data - mean) / std
        
        #Apply the mask: Force dead channels to exactly 0.0
        normalized[:, dead_channels] = 0.0
        #CLIP OUTLIERS
        normalized = np.clip(normalized, -self.clip_val, self.clip_val)
        
        #TEMPORAL SMOOTHING - Convolve with Gaussian kernel to make signals RNN-friendly
        smoothed = gaussian_filter1d(normalized, sigma=self.smooth_sigma, axis=0)
        
        return torch.from_numpy(smoothed).float()

    #Data Augmentation (Training Only) - Adds Gaussian White Noise to make the model robust.    
    def apply_augmentation(self, neural_tensor):
        noise_level = 0.1
        noise = torch.randn_like(neural_tensor) * noise_level
        return neural_tensor + noise\
        
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_path = row['file_path']
        with h5py.File(file_path, 'r') as f:
            #Fix Binning 
            key = list(f.keys())[0]
            group = f[key]

            #Load & Preprocess Input
            raw_neural = group['input_features'][:]
            neural_tensor = self.preprocess_neural(raw_neural)

            #Augmentation (If enabled) - We will be enabling it for training and disabling it for validation
            if self.augment:
                neural_tensor = self.apply_augmentation(neural_tensor)
                
            #Load Target (Labels)
            if 'sentence_label' in group.attrs:
                sentence = group.attrs['sentence_label']
                if isinstance(sentence, bytes):
                    sentence = sentence.decode('utf-8')
                target_tensor = self.text_to_int(sentence)
            else:
                target_tensor = torch.from_numpy(group['transcription'][:]).long()
            # Lengths for CTC Loss
            input_len = neural_tensor.shape[0]
            target_len = len(target_tensor)
        return neural_tensor, target_tensor, input_len, target_len

### Creating two instances of this class. One for training (with noise) and one for validation (pure).

In [17]:
# 1. Training Dataset (With Augmentation enabled)
train_dataset = T15Dataset(
    df=train_df, 
    vocab_path="t15_vocab.json", 
    augment=True  # <--- Step 9: Enabled
)

# 2. Validation Dataset (Clean data only)
val_dataset = T15Dataset(
    df=val_df, 
    vocab_path="t15_vocab.json", 
    augment=False # <--- Disabled
)

print(f"Train Size: {len(train_dataset)}")
print(f"Val Size:   {len(val_dataset)}")

Train Size: 45
Val Size:   41


### `ctc_collate_fn` returns 4 things, not just 2: It essentially just performing padding of the translational sequences

1. `padded_neural`: The rectangular input data (for the GPU).

2. `padded_targets`: The rectangular labels (for the GPU).

3. `input_lens`: A list saying [200, 500, 1000] (The real lengths).

4. `target_lens`: A list saying [2, 11, 35] (The real character counts).

In [18]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
def ctc_collate_fn(batch):
    neural_tensors, target_tensors, input_lens, target_lens = zip(*batch)
    padded_neural = pad_sequence(neural_tensors, batch_first=True, padding_value=0.0)
    padded_targets = pad_sequence(target_tensors, batch_first=True, padding_value=-1)
    input_lens = torch.tensor(input_lens, dtype=torch.long)
    target_lens = torch.tensor(target_lens, dtype=torch.long)
    return padded_neural, padded_targets, input_lens, target_lens
train_dataset = T15Dataset(train_df, vocab_path="t15_vocab.json", augment=True)
val_dataset   = T15Dataset(val_df, vocab_path="t15_vocab.json", augment=False)

In [19]:
BATCH_SIZE = 64
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,       # ALWAYS shuffle training data
    collate_fn=ctc_collate_fn,
    num_workers=2,      # Uses multi-core CPU to load files faster
    pin_memory=True     # Speeds up transfer to GPU
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,      # Never shuffle validation (keep order consistent)
    collate_fn=ctc_collate_fn,
    num_workers=2,
    pin_memory=True
)

In [20]:
# Grab first batch
inputs, targets, in_lens, out_lens = next(iter(train_loader))

print("\n--- BATCH INSPECTION ---")
print(f"Input Shape (Batch, Max_Time, 512): {inputs.shape}")
print(f"Target Shape (Batch, Max_Seq_Len):  {targets.shape}")
print(f"Sample Input Lengths: {in_lens[:5].tolist()}")

# Check for padding (should see zeros at the end of the first sample if it's shorter than max)
if inputs.shape[1] > in_lens[0]:
    print("\n[OK] Padding detected (zeros found at end of sequence).")
else:
    print("\n[NOTE] First sequence was the longest, or batch sizes match exactly.")


--- BATCH INSPECTION ---
Input Shape (Batch, Max_Time, 512): torch.Size([45, 1350, 512])
Target Shape (Batch, Max_Seq_Len):  torch.Size([45, 55])
Sample Input Lengths: [921, 605, 586, 1259, 321]

[OK] Padding detected (zeros found at end of sequence).
