In [1]:
# import requests
# url = "https://dryad-assetstore-merritt-west.s3.us-west-2.amazonaws.com/ark%3A/13030/m53853vd%7C6%7Cproducer/competitionData.tar.gz?response-content-disposition=attachment%3B%20filename%3DcompetitionData.tar.gz&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIA2KERHV5E3OITXZXC%2F20240517%2Fus-west-2%2Fs3%2Faws4_request&X-Amz-Date=20240517T074950Z&X-Amz-Expires=86400&X-Amz-SignedHeaders=host&X-Amz-Signature=89d4fc9165caa9fe8baae163f05912df4836f7418c8e6902dc4ddf99da1cb1ac"
# local_filename = "competitionData.tar.gz"
# response = requests.get(url)
# if response.status_code == 200:
#     with open(local_filename, 'wb') as f:
#         f.write(response.content)
#     print(f"File '{local_filename}' has been downloaded successfully.")
# else:
#     print(f"Failed to download file. Status code: {response.status_code}")


In [2]:
# import tarfile
# filename = "competitionData.tar.gz"
# with tarfile.open(filename, 'r:gz') as tar:
#     tar.extractall()
# print("The.tar.gz file has been uncompressed.")

In [3]:
import os
import numpy as np
import scipy.io
import pandas as pd
import random
from sklearn.preprocessing import RobustScaler
from sklearn.decomposition import PCA
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from transformers.utils import ModelOutput
from transformers import T5ForConditionalGeneration, T5Tokenizer, T5Config, T5EncoderModel
from pytorch_tcn import TCN
import editdistance

In [4]:
def load_and_process_data(directory, partitions):
    partitioned_data = {}

    # Store temporary data for all partitions
    all_features = []
    all_labels = []
    all_lengths = []
    all_partition_info = []

    for partition in partitions:
        partition_dir = os.path.join(directory, partition)
        files = sorted([f for f in os.listdir(partition_dir) if f.endswith('.mat')])

        input_features = []
        transcriptions = []
        frame_lens = []
        all_block_idxs = []

        for file in tqdm(files, desc=f"Loading and processing {partition}"):
            path = os.path.join(partition_dir, file)
            session_data = load_features_and_normalize(path)

            input_features.extend(session_data['inputFeatures'])
            transcriptions.extend(session_data['transcriptions'])
            frame_lens.extend(session_data['frameLens'])
            all_block_idxs.extend(session_data['blockIdx'])

        # Normalize each partition's features
        combined_features, block_means, block_stds = block_normalize(input_features, all_block_idxs)

        # Append partition data for global scaling and PCA
        all_features.append(combined_features)
        all_labels.extend(transcriptions)
        all_lengths.extend(frame_lens)
        all_partition_info.append((partition, len(transcriptions)))

    # Combine features from all partitions
    combined_features = np.vstack(all_features)

    # Global scaling
    scaler = RobustScaler()
    combined_features = scaler.fit_transform(combined_features)

    # PCA for dimensionality reduction
    pca = PCA(n_components=0.95)
    combined_features = pca.fit_transform(combined_features)

    # Distribute features back to partitions
    start_index = 0
    for partition, size in all_partition_info:
        end_index = start_index + size
        partitioned_data[partition] = {
            'features': combined_features[start_index:end_index],
            'label': all_labels[start_index:end_index],
            'sen_len': all_lengths[start_index:end_index]
        }
        start_index = end_index

    return partitioned_data

def block_normalize(input_features, block_idxs):
    unique_blocks = np.unique(block_idxs)
    combined_features = []
    
    for block in unique_blocks:
        block_indices = [i for i, x in enumerate(block_idxs) if x == block]
        block_feats = np.vstack([input_features[i] for i in block_indices])
        
        block_mean = np.mean(block_feats, axis=0)
        block_std = np.std(block_feats, axis=0)
        
        normalized_feats = [(input_features[i] - block_mean) / (block_std + 1e-8) for i in block_indices]
        combined_features.extend(normalized_feats)

    return np.vstack(combined_features), None, None

def load_features_and_normalize(sessionPath):
    dat = scipy.io.loadmat(sessionPath)

    input_features = []
    transcriptions = []
    frame_lens = []
    block_idxs = []
    n_trials = dat['sentenceText'].shape[0]
    blockIdx = np.squeeze(dat['blockIdx'])

    # Collect area 6v tx1 and spikePow features
    for i in range(n_trials):
        features = np.concatenate([dat['tx1'][0, i][:, 0:128], dat['spikePow'][0, i][:, 0:128]], axis=1)
        sentence_len = features.shape[0]
        sentence = dat['sentenceText'][i].strip().lower()

        input_features.append(features)
        transcriptions.append(sentence)
        frame_lens.append(sentence_len)
        block_idxs.append(blockIdx[i])

    session_data = {
        'inputFeatures': input_features,
        'transcriptions': transcriptions,
        'frameLens': frame_lens,
        'blockIdx': block_idxs
    }

    return session_data

def preview_data(results):
    for partition in results:
        print(f"--- {partition.upper()} Partition ---")
        features = results[partition]['features']
        transcriptions = results[partition]['label']
        frame_lengths = results[partition]['sen_len']
        
        # Creating DataFrames for better visual representation
        features_df = pd.DataFrame(features)
        transcriptions_df = pd.DataFrame(transcriptions, columns=['Transcription'])
        frame_lengths_df = pd.DataFrame(frame_lengths, columns=['Sentence Length'])
        
        # Printing shape and size details
        print("Features Details:")
        print(f"Shape: {features_df.shape}, Size: {features_df.size}")
        print("Features Preview:")
        print(features_df.head())
        
        print("Label Details:")
        print(f"Shape: {transcriptions_df.shape}, Size: {transcriptions_df.size}")
        print("Label Preview:")
        print(transcriptions_df.head())
        
        print("Sentence Lengths Details:")
        print(f"Shape: {frame_lengths_df.shape}, Size: {frame_lengths_df.size}")
        print("Sentence Lengths Preview:")
        print(frame_lengths_df.head())
        print("\n")

In [5]:
data_dir = 'competitionData'
partitions = ['train', 'test', 'competitionHoldOut']
partitioned_data = load_and_process_data(data_dir, partitions)

Loading and processing train: 100%|██████████| 24/24 [00:04<00:00,  4.84it/s]
Loading and processing test: 100%|██████████| 24/24 [00:00<00:00, 96.25it/s]
Loading and processing competitionHoldOut: 100%|██████████| 15/15 [00:00<00:00, 46.02it/s]


In [6]:
preview_data(partitioned_data)

--- TRAIN Partition ---
Features Details:
Shape: (8800, 123), Size: 1082400
Features Preview:
         0          1         2         3         4         5         6    \
0  -6.978477   3.589440  9.083821  4.696966  5.081369 -1.942908  2.445167   
1  -6.942856   4.718402  9.565770  5.147847  5.552409 -1.592359  4.563524   
2  -6.162521  15.191476 -6.775521  6.034155 -6.914832 -3.912339  3.322587   
3  -7.094923  -0.951291  7.753854  4.699582  2.685992  5.829111  1.576679   
4  30.155407  -4.286747  7.774179  4.928823  3.108494 -2.292313  3.188529   

        7         8         9    ...       113       114       115       116  \
0  0.618766 -3.260455 -0.226782  ...  1.098188 -0.281921 -0.480892  0.662092   
1  0.616387 -1.932959  1.895795  ... -0.324113  1.787536 -2.580184  0.218130   
2 -1.705997 -1.828976  2.358329  ...  0.966154 -1.191206 -1.623980  0.675462   
3 -0.569604  1.946202 -1.033215  ...  0.523643 -0.538616 -0.255343  0.756173   
4 -0.131116 -4.278669  1.251366  ...  0.441

In [7]:
char_vocab = set()
for partition in partitioned_data:
    for label in partitioned_data[partition]['label']:
        char_vocab.update(label)
char_vocab.add('_') 
char_vocab = sorted(char_vocab)
char_to_idx = {char: idx for idx, char in enumerate(char_vocab)}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

In [8]:
class CharTokenizer:
    def __init__(self, char_to_idx, idx_to_char):
        self.char_to_idx = char_to_idx
        self.idx_to_char = idx_to_char
        self.pad_token_id = char_to_idx['_']

    def encode(self, text):
        encoded = [self.char_to_idx[char] for char in text]
        return encoded

    def decode(self, encoded):
        decoded = ''.join([self.idx_to_char[idx] for idx in encoded])
        return decoded

    def __len__(self):
        return len(self.char_to_idx)

class TCNEncoder(nn.Module):
    def __init__(self, input_size, num_channels, kernel_size, dropout):
        super(TCNEncoder, self).__init__()
        self.tcn = TCN(
            num_inputs=input_size,
            num_channels=num_channels,
            kernel_size=kernel_size,
            dropout=dropout,
            input_shape='NLC'
        )

    def forward(self, x):
        x = x.transpose(1, 2)
        return self.tcn(x)

class BiGRUEncoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(BiGRUEncoder, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True, bidirectional=True)

    def forward(self, x):
        outputs, _ = self.gru(x)
        return outputs

class SpeechBCIModel(nn.Module):
    def __init__(self, input_size, tcn_num_channels, tcn_kernel_size, tcn_dropout, gru_hidden_size, vocab_size):
        super(SpeechBCIModel, self).__init__()
        self.tcn_encoder = TCNEncoder(input_size, tcn_num_channels, tcn_kernel_size, tcn_dropout)
        self.gru_encoder = BiGRUEncoder(tcn_num_channels[-1], gru_hidden_size)
        self.t5_model = T5ForConditionalGeneration.from_pretrained("t5-small")
        self.projection = nn.Linear(2 * gru_hidden_size, self.t5_model.config.d_model)
        self.output_layer = nn.Linear(self.t5_model.config.d_model, vocab_size)

    def forward(self, input_ids, attention_mask=None):
        input_ids = input_ids.unsqueeze(2)
        tcn_outputs = self.tcn_encoder(input_ids)
        gru_outputs = self.gru_encoder(tcn_outputs)
        projected_embeddings = self.projection(gru_outputs)
        attention_mask = (projected_embeddings != 0).any(dim=-1).float()
        encoder_outputs = self.t5_model.encoder(
            inputs_embeds=projected_embeddings,
            attention_mask=attention_mask,
            return_dict=True
        )
        logits = self.output_layer(encoder_outputs.last_hidden_state)
        return logits

class SpeechBCIDataset(Dataset):
    def __init__(self, features, labels, sen_lens):
        self.features = features
        self.labels = labels
        self.sen_lens = sen_lens

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        feature = torch.tensor(self.features[idx], dtype=torch.float32)
        label = torch.tensor(char_tokenizer.encode(self.labels[idx]), dtype=torch.long)
        sen_len = self.sen_lens[idx]
        return feature, label, sen_len

def collate_fn(batch):
    features, labels, sen_lens = zip(*batch)
    max_sen_len = max(sen_lens)
    padded_features = pad_sequence(features, batch_first=True, padding_value=0.0)
    padded_labels = torch.full((len(labels), max_sen_len), char_tokenizer.pad_token_id, dtype=torch.long)
    for i, label in enumerate(labels):
        padded_labels[i, :len(label)] = label
    sen_lens = torch.tensor(sen_lens)
    return padded_features, padded_labels, sen_lens

def train(model, dataloader, optimizer, device):
    model.train()
    train_loss = 0.0
    criterion = nn.CrossEntropyLoss(ignore_index=char_tokenizer.pad_token_id)
    for batch_features, batch_labels, batch_sen_lens in dataloader:
        batch_features = batch_features.to(device)
        batch_labels = batch_labels.to(device)
        logits = model(input_ids=batch_features)
        batch_size, seq_len, vocab_size = logits.size()
        logits = logits.view(batch_size, seq_len, vocab_size)
        batch_labels = batch_labels[:, :seq_len]
        loss = criterion(logits.transpose(1, 2), batch_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(dataloader)
    return train_loss

def evaluate(model, dataloader, device):
    model.eval()
    val_wer = 0.0
    total_chars = 0
    with torch.no_grad():
        for batch_features, batch_labels, batch_sen_lens in dataloader:
            batch_features = batch_features.to(device)
            logits = model(input_ids=batch_features)
            batch_size, seq_len, vocab_size = logits.size()
            logits = logits.view(batch_size, seq_len, vocab_size)
            generated_ids = logits.argmax(dim=-1)
            generated_text = []
            true_text = []
            for i in range(batch_size):
                gen_ids = generated_ids[i].cpu().numpy()
                true_ids = batch_labels[i].cpu().numpy()
                gen_ids = gen_ids[gen_ids != char_tokenizer.pad_token_id]
                true_ids = true_ids[true_ids != char_tokenizer.pad_token_id]
                generated_text.append(char_tokenizer.decode(gen_ids))
                true_text.append(char_tokenizer.decode(true_ids))
            for pred, true in zip(generated_text, true_text):
                val_wer += editdistance.eval(pred, true)
                total_chars += len(true)
    val_wer /= total_chars
    return val_wer

In [9]:
# Model Hyperparameters
input_size = partitioned_data["train"]["features"].shape[1]
tcn_hidden_size = 128
tcn_num_channels = [tcn_hidden_size, tcn_hidden_size, tcn_hidden_size * 2]
tcn_kernel_size = 3
tcn_dropout = 0.1
gru_hidden_size = tcn_hidden_size * 2
vocab_size = len(char_vocab)
learning_rate = 1e-3
num_epochs = 30
batch_size = 16
max_seq_length = 128

# Initialize character tokenizer
char_tokenizer = CharTokenizer(char_to_idx, idx_to_char)

# Initialize model, optimizer, loss function, and device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SpeechBCIModel(
    input_size=input_size,
    tcn_num_channels=tcn_num_channels,
    tcn_kernel_size=tcn_kernel_size,
    tcn_dropout=tcn_dropout,
    gru_hidden_size=gru_hidden_size,
    vocab_size=vocab_size
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)

# Initialize datasets and dataloaders
train_dataset = SpeechBCIDataset(partitioned_data["train"]["features"], partitioned_data["train"]["label"], partitioned_data["train"]["sen_len"])
val_dataset = SpeechBCIDataset(partitioned_data["test"]["features"], partitioned_data["test"]["label"], partitioned_data["test"]["sen_len"])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)

# Training loop with early stopping
best_loss = float("inf")
patience = 4
patience_counter = 0

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, device)
    val_wer = evaluate(model, val_loader, device)
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val WER = {val_wer:.4f}")

    if train_loss < best_loss:
        best_loss = train_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break
    scheduler.step(train_loss)        
torch.save(model.state_dict(), 'best_model.pth')            

Epoch 1: Train Loss = 2.5097, Val WER = 0.9728
Epoch 2: Train Loss = 2.4803, Val WER = 0.9730
Epoch 3: Train Loss = 2.4769, Val WER = 0.9730
Epoch 4: Train Loss = 2.4739, Val WER = 0.9730
Epoch 5: Train Loss = 2.4730, Val WER = 0.9730
Epoch 6: Train Loss = 2.4693, Val WER = 0.9710
Epoch 7: Train Loss = 2.4627, Val WER = 0.9730
Epoch 8: Train Loss = 2.4477, Val WER = 0.9730
Epoch 9: Train Loss = 2.4262, Val WER = 0.9729
Epoch 10: Train Loss = 2.3826, Val WER = 0.9726
Epoch 11: Train Loss = 2.3377, Val WER = 0.9729
Epoch 12: Train Loss = 2.2706, Val WER = 0.9732
Epoch 13: Train Loss = 2.1925, Val WER = 0.9738
Epoch 14: Train Loss = 2.1093, Val WER = 0.9743
Epoch 15: Train Loss = 2.0255, Val WER = 0.9754
Epoch 16: Train Loss = 1.9276, Val WER = 0.9740
Epoch 17: Train Loss = 1.8356, Val WER = 0.9769
Epoch 18: Train Loss = 1.7266, Val WER = 0.9756
Epoch 19: Train Loss = 1.6336, Val WER = 0.9767
Epoch 20: Train Loss = 1.5387, Val WER = 0.9766
Epoch 21: Train Loss = 1.4470, Val WER = 0.9749
E

In [10]:
def calculate_wer(decoded_sentences, true_sentences):
    total_word_errors = 0
    total_words = 0

    for decoded_sent, true_sent in zip(decoded_sentences, true_sentences):
        decoded_words = decoded_sent.split(" ")
        true_words = true_sent.split(" ")
        word_errors = editdistance.eval(decoded_words, true_words)
        total_word_errors += word_errors
        total_words += len(true_words)

    wer = total_word_errors / total_words
    return wer

def inference(model, dataset, device):
    model.eval()
    with torch.no_grad():
        idx = random.randint(0, len(dataset) - 1)
        feature, label, sen_len = dataset[idx]
        feature = feature.unsqueeze(0).to(device)
        logits = model(input_ids=feature)
        generated_ids = logits.argmax(dim=-1)
        generated_text = char_tokenizer.decode(generated_ids[0].cpu().numpy())
        true_text = char_tokenizer.decode(label.numpy())
        print("Generated Text:", true_text)
        print("Ground Truth:", true_text)

def calculate_wer_on_val_set(model, val_dataset, device):
    model.eval()
    decoded_sentences = []
    true_sentences = []
    with torch.no_grad():
        for feature, label, sen_len in val_dataset:
            feature = feature.unsqueeze(0).to(device)
            logits = model(input_ids=feature)
            generated_ids = logits.argmax(dim=-1)
            generated_text = char_tokenizer.decode(generated_ids[0].cpu().numpy())
            true_text = char_tokenizer.decode(label.numpy())
            decoded_sentences.append(generated_text)
            true_sentences.append(true_text)
    wer = calculate_wer(decoded_sentences, true_sentences)
    return wer

In [11]:
# Load the best model checkpoint
best_model_path = 'best_model.pth'
model.load_state_dict(torch.load(best_model_path))
model.to(device)

# Calculate WER on the entire validation set
val_wer = calculate_wer_on_val_set(model, val_dataset, device)
print("Validation WER:", val_wer)

Validation WER: 0.9909074377159484


In [12]:
# Perform inference on a random sample from the validation set
inference(model, val_dataset, device)  

Generated Text: i don't see much tv mostly when i'm in school.
Ground Truth: i don't see much tv mostly when i'm in school.
