In [1]:
import os
import sys
import wandb
import torch
import datetime
import collections
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

os.environ['WANDB_NOTEBOOK_NAME'] = 'train_baseline.ipynb'

HPARAM_FILE = 'hparams/baseline/twin_tcn_noenvproj.yaml'

In [2]:
time_stamp = datetime.datetime.now().strftime('%Y-%m-%d+%H-%M-%S')
print(f'Experiment Time Stamp: {time_stamp}')

# Overwrite hparams
argv = [HPARAM_FILE, '--time_stamp', time_stamp]
argv += ['--use_wandb', 'true']
argv += ['--n_mismatch', '4']
argv += ['--batch_size', '64']
argv += ['--batch_equalizer', 'random']
argv += ['--n_epoch', '100']
argv += ['--experiment', 'env_tcn_noenvproj_bs64_lr1e-3_ep100']

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

with open(HPARAM_FILE) as f:
    hparams = load_hyperpyyaml(f, overrides)
    
run_opts['auto_mix_prec'] = hparams['mix_prec'] # False

if hparams['use_wandb']:
    hparams['logger'] = hparams['wandb_logger']()
    
# sb.utils.distributed.ddp_init_group(run_opts)

Experiment Time Stamp: 2023-12-03+21-16-38


2023-12-03 21:16:39.238295: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-03 21:16:39.238370: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-03 21:16:39.238401: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-03 21:16:39.248858: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-12-03 21:16:42.408294: I tensorflow/core/comm

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112867647575008, max=1.0…

In [3]:
class Classifier(sb.core.Brain):

    def compute_forward(self, eeg, speech, stage):
        eeg_emb = self.modules.eeg_model(eeg)
        speech_emb = self.modules.speech_model(speech)
        pred = self.modules.classifier(eeg_emb, speech_emb)
        
        return pred
    
    def compute_objectives(self, pred, label, stage):
        B = label.shape[0]
        
        loss = self.hparams.loss_fn(pred, label)
        est_label = torch.argmax(pred, dim=-1)
        acc = sum(est_label==label) / B
        
        self.loss_stat['loss'] += float(loss) * B
        self.loss_stat['acc'] += float(acc) * B
        self.count += B
    
        return loss

    def make_dataloader(
        self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs
    ):
        # Treat pytorch TF wrapper as a dataloader
        # Because create_tf_dataset already batches EEGs and speeches
        return dataset
    
    def fit_batch(self, batch):
        eeg, speech, label = batch

        eeg = eeg.to(self.device)
        speech = speech.to(self.device)
        label = label.to(self.device)
        
        # Forward
        pred = self.compute_forward(eeg, speech, sb.Stage.TRAIN)
        loss = self.compute_objectives(pred, label, sb.Stage.TRAIN)
        loss.backward()
        
        if self.check_gradients(loss):
            self.optimizer.step()
        self.optimizer.zero_grad()
            
        return loss.detach().cpu()
    
    def evaluate_batch(self, batch, stage):
        eeg, speech, label = batch
        
        eeg = eeg.to(self.device)
        speech = speech.to(self.device)
        label = label.to(self.device)
        
        # Forward
        with torch.no_grad():
            pred = self.compute_forward(eeg, speech, stage)
            loss = self.compute_objectives(pred, label, stage)
            
        return loss.detach().cpu()
        
        
    def on_stage_start(self, stage, epoch=None):
        super().on_stage_start(stage, epoch)
        self.count = 0
        self.loss_stat = {
            'loss': 0,
            'acc': 0,
        }
        # Reload windows at the start of each epoch
        if stage == sb.Stage.TRAIN:
            self.hparams.train_windows.reload()
        elif stage == sb.Stage.VALID:
            self.hparams.valid_windows.reload()

    def on_stage_end(self, stage, stage_loss, epoch=None):
        for loss_key in self.loss_stat:
            self.loss_stat[loss_key] /= self.count
        
        if stage == sb.Stage.TRAIN:
            stage_stats = {'train_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['lr'] = self.optimizer.param_groups[0]['lr']
            stage_stats['epoch'] = epoch
    
        elif stage == sb.Stage.VALID:
            stage_stats = {'valid_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['epoch'] = epoch
            self.lr_scheduler.step(self.loss_stat['loss'])

        if self.hparams.use_wandb:
            self.hparams.logger.run.log(
                data=stage_stats,
            )
                        
        print(f'Epoch {epoch}: ', stage, stage_stats)
        
    def init_optimizers(self):
        super().init_optimizers()
        self.lr_scheduler = self.hparams.lr_scheduler(self.optimizer)
        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("optimizer", self.optimizer)
            self.checkpointer.add_recoverable("lr_scheduler", self.lr_scheduler)
                        
    

In [4]:
brain = Classifier(
    modules=hparams['modules'],
    opt_class=hparams['optimizer'],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams['checkpointer'],
)

In [None]:
brain.fit(
    epoch_counter=brain.hparams.epoch_counter,
    train_set=hparams['train_windows'],
    valid_set=hparams['valid_windows']
)

100%|██████████| 7975/7975 [02:04<00:00, 64.21it/s, train_loss=1.17]


Epoch 1:  Stage.TRAIN {'train_loss': 1.1693, 'train_acc': 0.5347, 'lr': 0.001, 'epoch': 1}


100%|██████████| 845/845 [00:09<00:00, 90.43it/s] 


Epoch 1:  Stage.VALID {'valid_loss': 1.251, 'valid_acc': 0.4964, 'epoch': 1}


100%|██████████| 7975/7975 [02:07<00:00, 62.76it/s, train_loss=1.06]


Epoch 2:  Stage.TRAIN {'train_loss': 1.0564, 'train_acc': 0.5888, 'lr': 0.001, 'epoch': 2}


100%|██████████| 845/845 [00:08<00:00, 95.44it/s] 


Epoch 2:  Stage.VALID {'valid_loss': 1.1642, 'valid_acc': 0.5392, 'epoch': 2}


 17%|█▋        | 1374/7975 [00:21<01:37, 67.90it/s, train_loss=0.97] 

In [None]:
wandb.finish()