In [1]:
"""
Sample code to generate labels for test dataset of
match-mismatch task. The requested format for submitting the labels is
as follows:
for each subject a json file containing a python dictionary in the
format of  ==> {'sample_id': prediction, ... }.

"""

import os
import sys
import glob
import json
import wandb
import torch
import logging
import collections
import numpy as np
import tensorflow as tf
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

HPARAM_FILE = 'hparams/baseline/twin_tcn.yaml'
# save/baseline/.../1234
CKPT_PATH = 'save/baseline/env_tcn_bs64_lr1e-3_ep10/1234'


2023-12-26 22:48:03.213597: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-26 22:48:03.213667: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-26 22:48:03.213689: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-26 22:48:03.222791: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
### Initialize and load models

# Overwrite hparams
argv = [HPARAM_FILE, '--time_stamp', 'XXX']
argv += ['--use_wandb', 'false']
argv += ['--n_mismatch', '4']
argv += ['--batch_size', '64']
argv += ['--batch_equalizer', 'all']
argv += ['--save_folder', CKPT_PATH]

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

with open(HPARAM_FILE) as f:
    hparams = load_hyperpyyaml(f, overrides)
    
run_opts['auto_mix_prec'] = hparams['mix_prec'] # False

2023-12-26 22:48:08.325363: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1886] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43412 MB memory:  -> device: 0, name: NVIDIA L40, pci bus id: 0000:61:00.0, compute capability: 8.9
Counting batches in TF dataset: 7975it [00:26, 298.29it/s]
Counting batches in TF dataset: 845it [00:04, 177.35it/s]


In [3]:
### SANITY CHECK: Evaluate the model on valid set
# to make sure the weights are properly loaded

class Classifier(sb.core.Brain):

    def compute_forward(self, eeg, speech, stage):
        eeg_emb = self.modules.eeg_model(eeg)
        speech_emb = self.modules.speech_model(speech)
        pred = self.modules.classifier(eeg_emb, speech_emb)
                
        return pred
    
    def compute_objectives(self, pred, label, stage):
        B = label.shape[0]
        
        loss = self.hparams.loss_fn(pred, label)
        est_label = torch.argmax(pred, dim=-1)
        acc = sum(est_label==label) / B
        
        self.loss_stat['loss'] += float(loss) * B
        self.loss_stat['acc'] += float(acc) * B
        self.count += B
    
        return loss

    def make_dataloader(
        self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs
    ):
        # Treat pytorch TF wrapper as a dataloader
        # Because create_tf_dataset already batches EEGs and speeches
        return dataset
    
    
    def evaluate_batch(self, batch, stage):
        eeg, speech, label = batch
        
        eeg = eeg.to(self.device) # (B*5, 64, 320)
        speech = speech.to(self.device) # (B*5, 5, 1, 320)
        label = label.to(self.device)
        
        # Forward
        with torch.no_grad():
            pred = self.compute_forward(eeg, speech, stage)
            loss = self.compute_objectives(pred, label, stage)
            
        return loss.detach().cpu()
        
        
    def on_stage_start(self, stage, epoch=None):
        super().on_stage_start(stage, epoch)
        self.count = 0
        self.loss_stat = {
            'loss': 0,
            'acc': 0,
        }
        # Reload windows at the start of each epoch
        if stage == sb.Stage.TRAIN:
            self.hparams.train_windows.reload()
        elif stage == sb.Stage.VALID:
            self.hparams.valid_windows.reload()

    def on_stage_end(self, stage, stage_loss, epoch=None):
        for loss_key in self.loss_stat:
            self.loss_stat[loss_key] /= self.count
        
        if stage == sb.Stage.TRAIN:
            stage_stats = {'train_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['lr'] = self.optimizer.param_groups[0]['lr']
            stage_stats['epoch'] = epoch
    
        elif stage == sb.Stage.VALID:
            stage_stats = {'valid_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['epoch'] = epoch
            self.lr_scheduler.step(self.loss_stat['loss'])
            
            self.checkpointer.save_and_keep_only(
                meta=self.loss_stat, max_keys=['acc'],
            )
            
        elif stage == sb.Stage.TEST:
            stage_stats = {'valid_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['epoch'] = epoch
                        
            assert stage_stats['valid_acc'] > 0.5
            print(f'Epoch {epoch}: ', stage, stage_stats)
        
brain = Classifier(
    modules=hparams['modules'],
    opt_class=hparams['optimizer'],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams['checkpointer'],
)

brain.evaluate(
    test_set=hparams['valid_windows'],
    max_key='acc'
)


100%|██████████| 845/845 [00:18<00:00, 45.30it/s]


Epoch None:  Stage.TEST {'valid_loss': 1.1962, 'valid_acc': 0.5246, 'epoch': None}


1.1961566411531896

In [4]:
hparams['eeg_model'].eval()
hparams['speech_model'].eval()
hparams['classifier'].eval()

def predict(eeg, speech):
    
    eeg = eeg.squeeze(0).permute(0, 2, 1).contiguous() # (437, 64, 320)
    speech = speech.permute(1, 0, 3, 2).contiguous() # (437, 5, 1, 320)
    
    eeg_emb = hparams['eeg_model'](eeg)
    speech_emb = hparams['speech_model'](speech)
    pred = hparams['classifier'](eeg_emb, speech_emb)
    return pred

In [None]:
### Evaluation starts
print('Start baseline evaluation...')


# Parameters
# Length of the decision window
window_length_s = 5
fs = 64

window_length = window_length_s * fs  # 5 seconds
number_mismatch = 4 

# Provide the path of the dataset
data_folder = '/engram/naplab/shared/eeg_challenge_data_test/homes.esat.kuleuven.be/~lbollens/sparrkulee/test_set/TASK1_match_mismatch'
eeg_folder = os.path.join(data_folder, 'preprocessed_eeg')
stimulus_folder = os.path.join(data_folder, 'stimulus')

# # stimulus feature which will be used for training the model. Can be either 'envelope' ( dimension 1) or 'mel' (dimension 28)
stimulus_features = ["envelope"]
stimulus_dimension = 1

features = ["eeg"] + stimulus_features

test_eeg_mapping = glob.glob(os.path.join(data_folder, 'sub*mapping.json'))

test_stimuli = glob.glob(os.path.join(stimulus_folder, f'*{stimulus_features[0]}*chunks.npz'))

#load all test stimuli
test_stimuli_data = {}
for stimulus_path in test_stimuli:
    test_stimuli_data = dict(test_stimuli_data, **np.load(stimulus_path))

for sub_stimulus_mapping in test_eeg_mapping:
    subject = os.path.basename(sub_stimulus_mapping).split('_')[0]

    # load stimulus mapping
    sub_stimulus_mapping = json.load(open(sub_stimulus_mapping))

    #load eeg data
    sub_path = os.path.join(eeg_folder, f'{subject}_eeg.npz')
    sub_eeg_data = dict(np.load(sub_path))


    data_eeg =  np.stack([[sub_eeg_data[value['eeg']]]  for key, value in sub_stimulus_mapping.items() ])
    # change dim 0 and 1 of eeg and unstack
    data_eeg = np.swapaxes(data_eeg, 0, 1)
    data_eeg = list(data_eeg)

    data_stimuli = np.stack([[test_stimuli_data[x] for x in value['stimulus']] for key, value in sub_stimulus_mapping.items()])
    # change dim 0 and 1 of stimulus and unstack
    data_stimuli = np.swapaxes(data_stimuli, 0, 1)
    data_stimuli = list(data_stimuli)
    id_list= list(sub_stimulus_mapping.keys()) # 437

    ### Call pytorch model. Other TF stuff unchanged.
    with torch.inference_mode():
        data_eeg = torch.tensor(data_eeg).float().cuda() # (1, 437, 320, 64)
        data_stimuli = torch.tensor(data_stimuli).float().cuda() # (5, 437, 320, 1)
        preds = predict(data_eeg, data_stimuli).cpu().numpy()
    
    # predictions = model.predict(data_eeg + data_stimuli)
    # predictions = predict
    
    labels = np.argmax(preds, axis=1)

    sub = dict(zip(id_list, [int(x) for x in labels]))

    prediction_dir = 'predictions'
    os.makedirs(prediction_dir, exist_ok=True)
    with open(os.path.join(prediction_dir, subject + '.json'), 'w') as f:
        json.dump(sub, f)


Start baseline evaluation...


  data_eeg = torch.tensor(data_eeg).float().cuda() # (1, 437, 320, 64)
