In [1]:
import os
import sys
import wandb
import torch
import datetime
import collections
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

os.environ['WANDB_NOTEBOOK_NAME'] = 'train_classification.ipynb'

HPARAM_FILE = 'hparams/class/twin_tcn.yaml'

In [2]:
time_stamp = datetime.datetime.now().strftime('%Y-%m-%d+%H-%M-%S')
print(f'Experiment Time Stamp: {time_stamp}')

# Overwrite hparams
argv = [HPARAM_FILE, '--time_stamp', time_stamp]
argv += ['--use_wandb', 'true']
argv += ['--n_mismatch', '4']
argv += ['--experiment', 'env_tcn_bs4_lr1e-3_ep200']

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

with open(HPARAM_FILE) as f:
    hparams = load_hyperpyyaml(f, overrides)
    
run_opts['auto_mix_prec'] = hparams['mix_prec'] # False

if hparams['use_wandb']:
    hparams['logger'] = hparams['wandb_logger']()
    
# sb.utils.distributed.ddp_init_group(run_opts)

Experiment Time Stamp: 2023-11-26+17-42-06


[34m[1mwandb[0m: Currently logged in as: [33mxj2289[0m ([33mxj-audio[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from tqdm import tqdm
train_windows = hparams['train_windows']
for windows in tqdm(train_windows, total=len(train_windows)):
    pass

In [3]:
class Classifier(sb.core.Brain):

    def compute_forward(self, eeg, speech, stage):
        pred = self.modules.twin_model(eeg, speech)
        
        return pred
    
    def compute_objectives(self, pred, label, stage):
        B = label.shape[0]
        
        loss = self.hparams.loss_fn(pred, label)
        est_label = torch.argmax(pred, dim=-1)
        acc = sum(est_label==label) / B
        
        self.loss_stat['loss'] += float(loss) * B
        self.loss_stat['acc'] += float(acc) * B
        self.count += B
    
        return loss

    def fit_batch(self, batch):
        eeg = batch['eeg'].to(self.device)
        speech = batch[self.hparams.feature].to(self.device)
        
        # Select one EEG as the target. Keep all M speeches as candidates.
        B, M, C, T = eeg.shape
        label = torch.randint(0, M, (B,), dtype=torch.long).to(self.device)
        eeg = eeg[torch.arange(B), label]
        
        # Forward
        pred = self.compute_forward(eeg, speech, sb.Stage.TRAIN)
        loss = self.compute_objectives(pred, label, sb.Stage.TRAIN)
        loss.backward()
        
        if self.check_gradients(loss):
            self.optimizer.step()
        self.optimizer.zero_grad()
            
        return loss.detach().cpu()
    
    def evaluate_batch(self, batch, stage):
        eeg = batch['eeg'].to(self.device)
        speech = batch[self.hparams.feature].to(self.device)

        # Select one EEG as the target. Keep all M speeches as candidates.
        B, M, C, T = eeg.shape
        label = torch.zeros(B, dtype=torch.long).to(self.device)
        eeg = eeg[torch.arange(B), label]
        
        # Forward
        with torch.no_grad():
            pred = self.compute_forward(eeg, speech, stage)
            loss = self.compute_objectives(pred, label, stage)
            
        return loss.detach().cpu()
        
    def on_stage_start(self, stage, epoch=None):
        super().on_stage_start(stage, epoch)
        self.count = 0
        self.loss_stat = {
            'loss': 0,
            'acc': 0,
        }

    def on_stage_end(self, stage, stage_loss, epoch=None):
        for loss_key in self.loss_stat:
            self.loss_stat[loss_key] /= self.count
        
        if stage == sb.Stage.TRAIN:
            stage_stats = {'train_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['epoch'] = epoch
    
        elif stage == sb.Stage.VALID:
            stage_stats = {'valid_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['epoch'] = epoch

        if self.hparams.use_wandb:
            self.hparams.logger.run.log(
                data=stage_stats,
            )
                        
        print(f'Epoch {epoch}: ', stage, stage_stats)
                        
    

In [4]:
brain = Classifier(
    modules=hparams['modules'],
    opt_class=hparams['optimizer'],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams['checkpointer'],
)

In [None]:
brain.fit(
    epoch_counter=brain.hparams.epoch_counter,
    train_set=hparams['train_windows'],
    valid_set=hparams['valid_windows'],
    train_loader_kwargs=hparams['train_loader_opts'],
    valid_loader_kwargs=hparams['valid_loader_opts'],
)

100%|██████████| 166/166 [00:13<00:00, 12.54it/s, train_loss=1.6] 


Epoch 1:  Stage.TRAIN {'train_loss': 1.6024, 'train_acc': 0.2455, 'epoch': 1}


100%|██████████| 167/167 [00:06<00:00, 25.22it/s]


Epoch 1:  Stage.VALID {'valid_loss': 1.5787, 'valid_acc': 0.2898, 'epoch': 1}


100%|██████████| 166/166 [00:12<00:00, 13.42it/s, train_loss=1.54]


Epoch 2:  Stage.TRAIN {'train_loss': 1.543, 'train_acc': 0.3223, 'epoch': 2}


100%|██████████| 167/167 [00:06<00:00, 25.25it/s]


Epoch 2:  Stage.VALID {'valid_loss': 1.5617, 'valid_acc': 0.2868, 'epoch': 2}


100%|██████████| 166/166 [00:12<00:00, 13.27it/s, train_loss=1.5] 


Epoch 3:  Stage.TRAIN {'train_loss': 1.5044, 'train_acc': 0.3389, 'epoch': 3}


100%|██████████| 167/167 [00:06<00:00, 25.20it/s]


Epoch 3:  Stage.VALID {'valid_loss': 1.5296, 'valid_acc': 0.3138, 'epoch': 3}


100%|██████████| 166/166 [00:12<00:00, 13.49it/s, train_loss=1.48]


Epoch 4:  Stage.TRAIN {'train_loss': 1.4825, 'train_acc': 0.3735, 'epoch': 4}


100%|██████████| 167/167 [00:06<00:00, 24.79it/s]


Epoch 4:  Stage.VALID {'valid_loss': 1.4741, 'valid_acc': 0.3619, 'epoch': 4}


100%|██████████| 166/166 [00:12<00:00, 13.54it/s, train_loss=1.45]


Epoch 5:  Stage.TRAIN {'train_loss': 1.4536, 'train_acc': 0.384, 'epoch': 5}


100%|██████████| 167/167 [00:06<00:00, 24.47it/s]


Epoch 5:  Stage.VALID {'valid_loss': 1.4568, 'valid_acc': 0.3724, 'epoch': 5}


100%|██████████| 166/166 [00:12<00:00, 13.80it/s, train_loss=1.46]


Epoch 6:  Stage.TRAIN {'train_loss': 1.4581, 'train_acc': 0.3886, 'epoch': 6}


100%|██████████| 167/167 [00:06<00:00, 24.15it/s]


Epoch 6:  Stage.VALID {'valid_loss': 1.441, 'valid_acc': 0.3784, 'epoch': 6}


100%|██████████| 166/166 [00:12<00:00, 13.51it/s, train_loss=1.4] 


Epoch 7:  Stage.TRAIN {'train_loss': 1.4039, 'train_acc': 0.4367, 'epoch': 7}


100%|██████████| 167/167 [00:06<00:00, 24.35it/s]


Epoch 7:  Stage.VALID {'valid_loss': 1.4706, 'valid_acc': 0.3694, 'epoch': 7}


100%|██████████| 166/166 [00:12<00:00, 13.47it/s, train_loss=1.44]


Epoch 8:  Stage.TRAIN {'train_loss': 1.4389, 'train_acc': 0.3855, 'epoch': 8}


100%|██████████| 167/167 [00:06<00:00, 25.07it/s]


Epoch 8:  Stage.VALID {'valid_loss': 1.4168, 'valid_acc': 0.4009, 'epoch': 8}


100%|██████████| 166/166 [00:12<00:00, 13.43it/s, train_loss=1.44]


Epoch 9:  Stage.TRAIN {'train_loss': 1.4383, 'train_acc': 0.3916, 'epoch': 9}


100%|██████████| 167/167 [00:06<00:00, 25.02it/s]


Epoch 9:  Stage.VALID {'valid_loss': 1.4296, 'valid_acc': 0.3859, 'epoch': 9}


100%|██████████| 166/166 [00:12<00:00, 13.23it/s, train_loss=1.39]


Epoch 10:  Stage.TRAIN {'train_loss': 1.394, 'train_acc': 0.4232, 'epoch': 10}


100%|██████████| 167/167 [00:06<00:00, 24.91it/s]


Epoch 10:  Stage.VALID {'valid_loss': 1.3995, 'valid_acc': 0.4009, 'epoch': 10}


100%|██████████| 166/166 [00:12<00:00, 13.13it/s, train_loss=1.41]


Epoch 11:  Stage.TRAIN {'train_loss': 1.4108, 'train_acc': 0.4051, 'epoch': 11}


100%|██████████| 167/167 [00:06<00:00, 25.01it/s]


Epoch 11:  Stage.VALID {'valid_loss': 1.3989, 'valid_acc': 0.4144, 'epoch': 11}


100%|██████████| 166/166 [00:12<00:00, 13.12it/s, train_loss=1.37]


Epoch 12:  Stage.TRAIN {'train_loss': 1.3749, 'train_acc': 0.4398, 'epoch': 12}


100%|██████████| 167/167 [00:06<00:00, 24.97it/s]


Epoch 12:  Stage.VALID {'valid_loss': 1.4128, 'valid_acc': 0.4009, 'epoch': 12}


100%|██████████| 166/166 [00:12<00:00, 13.41it/s, train_loss=1.4] 


Epoch 13:  Stage.TRAIN {'train_loss': 1.4027, 'train_acc': 0.4337, 'epoch': 13}


  0%|          | 0/167 [00:00<?, ?it/s]

In [None]:
wandb.finish()