In [1]:
import os
import sys
import wandb
import torch
import datetime
import collections
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

os.environ['WANDB_NOTEBOOK_NAME'] = 'train_clap_random.ipynb'

HPARAM_FILE = 'hparams/clap/env_tcn_clap_random.yaml'

In [2]:
time_stamp = datetime.datetime.now().strftime('%Y-%m-%d+%H-%M-%S')
print(f'Experiment Time Stamp: {time_stamp}')

# Overwrite hparams
argv = [HPARAM_FILE, '--time_stamp', time_stamp]
argv += ['--use_wandb', 'true']
argv += ['--clap_batch_size', '8']
argv += ['--n_epoch', '1000']
argv += ['--lr', '1.0e-3']
argv += ['--tau', '0.07']
argv += ['--experiment', 'CLAP_env_tcn_cbs8_lr1e-3_ep1000']

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

with open(HPARAM_FILE) as f:
    hparams = load_hyperpyyaml(f, overrides)
    
run_opts['auto_mix_prec'] = hparams['mix_prec'] # False

if hparams['use_wandb']:
    hparams['logger'] = hparams['wandb_logger']()
    
# sb.utils.distributed.ddp_init_group(run_opts)

Experiment Time Stamp: 2023-12-07+15-58-12


[34m[1mwandb[0m: Currently logged in as: [33mxj2289[0m ([33mxj-audio[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
class CLAP(sb.core.Brain):

    def compute_forward(self, speech, eeg, stage):

        speech_ndim = speech.ndim
        eeg_ndim = eeg.ndim
        
        if speech_ndim != 3:
            B, M1, C1, T1 = speech.shape
            speech = speech.view(B*M1, C1, T1)
        if eeg_ndim != 3:
            B, M2, C2, T2 = eeg.shape
            eeg = eeg.view(B*M2, C2, T2)
        
        speech_proj, eeg_proj, logit_scale_exp = self.modules.clap(speech, eeg)
        
        if speech_ndim != 3:
            speech_proj = speech_proj.view((B, M1) + speech_proj.shape[1:])
        if eeg_ndim != 3:
            eeg_proj = eeg_proj.view((B, M2) + eeg_proj.shape[1:])
        
        return speech_proj, eeg_proj, logit_scale_exp
    

    def compute_clap_objectives(self, speech_proj, eeg_proj, logit_scale_exp):
        similarity = self.modules.clap.compute_similarity_with_scale(
            speech_proj, eeg_proj, logit_scale_exp
        )
        
        eeg_label_for_speech = torch.arange(similarity.shape[0]).long().to(self.device) # number of speech
        speech_label_for_eeg = torch.arange(similarity.shape[1]).long().to(self.device) # number of eeg
        
        predict_eeg_loss = self.hparams.compute_cross_sim_cost(similarity, eeg_label_for_speech)
        predict_speech_loss = self.hparams.compute_cross_sim_cost(similarity.T, speech_label_for_eeg)
        loss = 0.5 * (predict_eeg_loss + predict_speech_loss)
        # loss = predict_speech_loss
    
        with torch.no_grad():
            predict_eeg_correct = (torch.max(similarity, dim=1)[1]==eeg_label_for_speech).sum()
            predict_speech_correct = (torch.max(similarity, dim=0)[1]==speech_label_for_eeg).sum()
            predict_eeg_acc = predict_eeg_correct / len(eeg_label_for_speech)
            predict_speech_acc = predict_speech_correct / len(speech_label_for_eeg)
            clap_acc = 0.5 * (predict_eeg_acc + predict_speech_acc)
        
        loss_dict = {
            'predict_eeg_loss': predict_eeg_loss,
            'predict_speech_loss': predict_speech_loss,
            'predict_eeg_acc': predict_eeg_acc,
            'predict_speech_acc': predict_speech_acc,
            'loss': loss,
            'acc': predict_speech_acc,
            'clap_acc': clap_acc
        }
        
        # Update loss stat
        if not torch.isnan(loss):
            B = eeg_proj.shape[0]
            self.count += B
            for key in self.loss_stat:
                with torch.no_grad():
                    self.loss_stat[key] += B * loss_dict[key]
        
        return loss
    
    
    def compute_class_objectives(self, speech_proj, eeg_proj, label):
        '''
        speech_proj: (B, M, D)
        eeg_proj: (B, D)
        '''
        
        B = eeg_proj.shape[0]
        est_label = self.modules.clap.predict_speech(
            speech_proj, eeg_proj
        )
            
        acc = sum(est_label==label) / B
        
        self.loss_stat['acc'] += float(acc) * B
        self.count += B
    
        return torch.tensor(0)
    
    
    def fit_batch(self, batch):
        # Always use batch_size = 1
        # Number of the windows is the actual batch size for clap
        eeg = batch['eeg'].squeeze(0).to(self.device)
        speech = batch[self.hparams.feature].squeeze(0).to(self.device)
        
        # Forward
        speech_proj, eeg_proj, logit_scale_exp = self.compute_forward(speech, eeg, sb.Stage.TRAIN)
        loss = self.compute_clap_objectives(speech_proj, eeg_proj, logit_scale_exp)
        loss.backward()
        
        if self.check_gradients(loss):
            self.optimizer.step()
        self.optimizer.zero_grad()
            
        return loss.detach().cpu()
    
    
    def evaluate_batch(self, batch, stage):
        # Always use batch_size = 1
        eeg = batch['eeg'].to(self.device)
        speech = batch[self.hparams.feature].to(self.device)
        
        # Select one EEG as the target. Keep all M speeches as candidates.
        B, M, C, T = eeg.shape
        label = torch.randint(0, M, (B,), dtype=torch.long).to(self.device)
        eeg = eeg[torch.arange(B), label]     
        
        # Forward
        with torch.no_grad():
            speech_proj, eeg_proj, _ = self.compute_forward(speech, eeg, stage)
            loss = self.compute_class_objectives(speech_proj, eeg_proj, label)
            
        return loss.detach().cpu()
        
        
    def on_stage_start(self, stage, epoch=None):
        super().on_stage_start(stage, epoch)
        self.count = 0
        if stage == sb.Stage.TRAIN:
            self.loss_stat = {
                'predict_eeg_loss': 0,
                'predict_speech_loss': 0,
                'predict_eeg_acc': 0,
                'predict_speech_acc': 0,
                'loss': 0,
                'acc': 0,
                'clap_acc': 0
            }
        else:
            self.loss_stat = {
                'acc': 0,
            }
        
    
    def on_stage_end(self, stage, stage_loss, epoch=None):
        for loss_key in self.loss_stat:
            if self.count != 0:
                self.loss_stat[loss_key] /= self.count
        
        if stage == sb.Stage.TRAIN:
            stage_stats = {'train_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['lr'] = self.optimizer.param_groups[0]['lr']
            stage_stats['tau'] = float(1/self.modules.clap.logit_scale.exp().item())      
            stage_stats['epoch'] = epoch
    
        elif stage == sb.Stage.VALID:
            stage_stats = {'valid_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['epoch'] = epoch
            self.lr_scheduler.step(-self.loss_stat['acc'])

        if self.hparams.use_wandb:
            self.hparams.logger.run.log(
                data=stage_stats,
            )
                        
        print(f'Epoch {epoch}: ', stage, stage_stats)
      
    
    def init_optimizers(self):
        super().init_optimizers()
        self.lr_scheduler = self.hparams.lr_scheduler(self.optimizer)
        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("optimizer", self.optimizer)
            self.checkpointer.add_recoverable("lr_scheduler", self.lr_scheduler)
                        
    

In [4]:
brain = CLAP(
    modules=hparams['modules'],
    opt_class=hparams['optimizer'],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams['checkpointer'],
)

In [None]:
brain.fit(
    epoch_counter=brain.hparams.epoch_counter,
    train_set=hparams['train_windows'],
    valid_set=hparams['valid_windows'],
    train_loader_kwargs=hparams['train_loader_opts'],
    valid_loader_kwargs=hparams['valid_loader_opts'],
)

100%|██████████| 666/666 [00:13<00:00, 50.42it/s, train_loss=1.96]


Epoch 1:  Stage.TRAIN {'train_predict_eeg_loss': 1.9604, 'train_predict_speech_loss': 1.9599, 'train_predict_eeg_acc': 0.2348, 'train_predict_speech_acc': 0.2359, 'train_loss': 1.9601, 'train_acc': 0.2359, 'train_clap_acc': 0.2354, 'lr': 0.001, 'tau': 0.07319261545970077, 'epoch': 1}


100%|██████████| 666/666 [00:05<00:00, 116.68it/s]


Epoch 1:  Stage.VALID {'valid_acc': 0.3814, 'epoch': 1}


100%|██████████| 666/666 [00:12<00:00, 52.99it/s, train_loss=1.85]


Epoch 2:  Stage.TRAIN {'train_predict_eeg_loss': 1.8553, 'train_predict_speech_loss': 1.8545, 'train_predict_eeg_acc': 0.3048, 'train_predict_speech_acc': 0.3101, 'train_loss': 1.8549, 'train_acc': 0.3101, 'train_clap_acc': 0.3074, 'lr': 0.001, 'tau': 0.07044309695257102, 'epoch': 2}


100%|██████████| 666/666 [00:03<00:00, 191.23it/s]


Epoch 2:  Stage.VALID {'valid_acc': 0.4129, 'epoch': 2}


100%|██████████| 666/666 [00:12<00:00, 52.41it/s, train_loss=1.8] 


Epoch 3:  Stage.TRAIN {'train_predict_eeg_loss': 1.8029, 'train_predict_speech_loss': 1.8024, 'train_predict_eeg_acc': 0.3375, 'train_predict_speech_acc': 0.3309, 'train_loss': 1.8026, 'train_acc': 0.3309, 'train_clap_acc': 0.3342, 'lr': 0.001, 'tau': 0.06717539451072342, 'epoch': 3}


100%|██████████| 666/666 [00:03<00:00, 197.63it/s]


Epoch 3:  Stage.VALID {'valid_acc': 0.4279, 'epoch': 3}


100%|██████████| 666/666 [00:12<00:00, 51.99it/s, train_loss=1.77]


Epoch 4:  Stage.TRAIN {'train_predict_eeg_loss': 1.7697, 'train_predict_speech_loss': 1.7668, 'train_predict_eeg_acc': 0.348, 'train_predict_speech_acc': 0.35, 'train_loss': 1.7682, 'train_acc': 0.35, 'train_clap_acc': 0.349, 'lr': 0.001, 'tau': 0.06513852876861198, 'epoch': 4}


100%|██████████| 666/666 [00:03<00:00, 196.40it/s]


Epoch 4:  Stage.VALID {'valid_acc': 0.4084, 'epoch': 4}


100%|██████████| 666/666 [00:12<00:00, 54.46it/s, train_loss=1.74]


Epoch 5:  Stage.TRAIN {'train_predict_eeg_loss': 1.7377, 'train_predict_speech_loss': 1.7371, 'train_predict_eeg_acc': 0.3641, 'train_predict_speech_acc': 0.3722, 'train_loss': 1.7374, 'train_acc': 0.3722, 'train_clap_acc': 0.3681, 'lr': 0.001, 'tau': 0.0629489968782322, 'epoch': 5}


100%|██████████| 666/666 [00:03<00:00, 181.82it/s]


Epoch 5:  Stage.VALID {'valid_acc': 0.4685, 'epoch': 5}


100%|██████████| 666/666 [00:12<00:00, 51.37it/s, train_loss=1.71]


Epoch 6:  Stage.TRAIN {'train_predict_eeg_loss': 1.7125, 'train_predict_speech_loss': 1.7062, 'train_predict_eeg_acc': 0.3731, 'train_predict_speech_acc': 0.3767, 'train_loss': 1.7094, 'train_acc': 0.3767, 'train_clap_acc': 0.3749, 'lr': 0.001, 'tau': 0.06070511713125342, 'epoch': 6}


100%|██████████| 666/666 [00:03<00:00, 171.27it/s]


Epoch 6:  Stage.VALID {'valid_acc': 0.4685, 'epoch': 6}


100%|██████████| 666/666 [00:12<00:00, 54.13it/s, train_loss=1.7] 


Epoch 7:  Stage.TRAIN {'train_predict_eeg_loss': 1.7026, 'train_predict_speech_loss': 1.7028, 'train_predict_eeg_acc': 0.3851, 'train_predict_speech_acc': 0.3788, 'train_loss': 1.7027, 'train_acc': 0.3788, 'train_clap_acc': 0.3819, 'lr': 0.001, 'tau': 0.05994303052282541, 'epoch': 7}


100%|██████████| 666/666 [00:03<00:00, 186.79it/s]


Epoch 7:  Stage.VALID {'valid_acc': 0.458, 'epoch': 7}


100%|██████████| 666/666 [00:13<00:00, 50.12it/s, train_loss=1.67]


Epoch 8:  Stage.TRAIN {'train_predict_eeg_loss': 1.6694, 'train_predict_speech_loss': 1.6673, 'train_predict_eeg_acc': 0.3941, 'train_predict_speech_acc': 0.3964, 'train_loss': 1.6684, 'train_acc': 0.3964, 'train_clap_acc': 0.3953, 'lr': 0.001, 'tau': 0.058994285873746416, 'epoch': 8}


100%|██████████| 666/666 [00:03<00:00, 198.90it/s]


Epoch 8:  Stage.VALID {'valid_acc': 0.479, 'epoch': 8}


100%|██████████| 666/666 [00:12<00:00, 54.18it/s, train_loss=1.66]


Epoch 9:  Stage.TRAIN {'train_predict_eeg_loss': 1.6648, 'train_predict_speech_loss': 1.66, 'train_predict_eeg_acc': 0.4002, 'train_predict_speech_acc': 0.4084, 'train_loss': 1.6624, 'train_acc': 0.4084, 'train_clap_acc': 0.4043, 'lr': 0.001, 'tau': 0.05825036475362298, 'epoch': 9}


100%|██████████| 666/666 [00:03<00:00, 173.63it/s]


Epoch 9:  Stage.VALID {'valid_acc': 0.533, 'epoch': 9}


100%|██████████| 666/666 [00:13<00:00, 50.35it/s, train_loss=1.62]


Epoch 10:  Stage.TRAIN {'train_predict_eeg_loss': 1.6165, 'train_predict_speech_loss': 1.6191, 'train_predict_eeg_acc': 0.4122, 'train_predict_speech_acc': 0.4163, 'train_loss': 1.6178, 'train_acc': 0.4163, 'train_clap_acc': 0.4142, 'lr': 0.001, 'tau': 0.05663773103386588, 'epoch': 10}


100%|██████████| 666/666 [00:03<00:00, 196.99it/s]


Epoch 10:  Stage.VALID {'valid_acc': 0.485, 'epoch': 10}


100%|██████████| 666/666 [00:11<00:00, 55.96it/s, train_loss=1.62]


Epoch 11:  Stage.TRAIN {'train_predict_eeg_loss': 1.6167, 'train_predict_speech_loss': 1.6155, 'train_predict_eeg_acc': 0.42, 'train_predict_speech_acc': 0.4215, 'train_loss': 1.6161, 'train_acc': 0.4215, 'train_clap_acc': 0.4208, 'lr': 0.001, 'tau': 0.05642713054424314, 'epoch': 11}


100%|██████████| 666/666 [00:03<00:00, 184.93it/s]


Epoch 11:  Stage.VALID {'valid_acc': 0.5586, 'epoch': 11}


100%|██████████| 666/666 [00:12<00:00, 53.06it/s, train_loss=1.6] 


Epoch 12:  Stage.TRAIN {'train_predict_eeg_loss': 1.5965, 'train_predict_speech_loss': 1.5999, 'train_predict_eeg_acc': 0.4253, 'train_predict_speech_acc': 0.418, 'train_loss': 1.5982, 'train_acc': 0.418, 'train_clap_acc': 0.4216, 'lr': 0.001, 'tau': 0.055443214295307276, 'epoch': 12}


100%|██████████| 666/666 [00:03<00:00, 191.87it/s]


Epoch 12:  Stage.VALID {'valid_acc': 0.5495, 'epoch': 12}


100%|██████████| 666/666 [00:13<00:00, 50.75it/s, train_loss=1.58]


Epoch 13:  Stage.TRAIN {'train_predict_eeg_loss': 1.5825, 'train_predict_speech_loss': 1.581, 'train_predict_eeg_acc': 0.4396, 'train_predict_speech_acc': 0.4416, 'train_loss': 1.5817, 'train_acc': 0.4416, 'train_clap_acc': 0.4406, 'lr': 0.001, 'tau': 0.05563510684243646, 'epoch': 13}


100%|██████████| 666/666 [00:03<00:00, 176.65it/s]


Epoch 13:  Stage.VALID {'valid_acc': 0.5706, 'epoch': 13}


100%|██████████| 666/666 [00:12<00:00, 52.04it/s, train_loss=1.57]


Epoch 14:  Stage.TRAIN {'train_predict_eeg_loss': 1.5684, 'train_predict_speech_loss': 1.5685, 'train_predict_eeg_acc': 0.4405, 'train_predict_speech_acc': 0.4341, 'train_loss': 1.5685, 'train_acc': 0.4341, 'train_clap_acc': 0.4373, 'lr': 0.001, 'tau': 0.0545713588631411, 'epoch': 14}


100%|██████████| 666/666 [00:03<00:00, 190.59it/s]


Epoch 14:  Stage.VALID {'valid_acc': 0.5541, 'epoch': 14}


100%|██████████| 666/666 [00:12<00:00, 51.38it/s, train_loss=1.57]


Epoch 15:  Stage.TRAIN {'train_predict_eeg_loss': 1.5673, 'train_predict_speech_loss': 1.5649, 'train_predict_eeg_acc': 0.4294, 'train_predict_speech_acc': 0.4356, 'train_loss': 1.5661, 'train_acc': 0.4356, 'train_clap_acc': 0.4325, 'lr': 0.001, 'tau': 0.054268945993865385, 'epoch': 15}


100%|██████████| 666/666 [00:03<00:00, 185.63it/s]


Epoch 15:  Stage.VALID {'valid_acc': 0.536, 'epoch': 15}


100%|██████████| 666/666 [00:12<00:00, 52.33it/s, train_loss=1.55]


Epoch 16:  Stage.TRAIN {'train_predict_eeg_loss': 1.5452, 'train_predict_speech_loss': 1.5566, 'train_predict_eeg_acc': 0.4523, 'train_predict_speech_acc': 0.4414, 'train_loss': 1.5509, 'train_acc': 0.4414, 'train_clap_acc': 0.4469, 'lr': 0.001, 'tau': 0.05302674012702031, 'epoch': 16}


100%|██████████| 666/666 [00:03<00:00, 198.29it/s]


Epoch 16:  Stage.VALID {'valid_acc': 0.5766, 'epoch': 16}


100%|██████████| 666/666 [00:13<00:00, 50.18it/s, train_loss=1.53]


Epoch 17:  Stage.TRAIN {'train_predict_eeg_loss': 1.5342, 'train_predict_speech_loss': 1.5311, 'train_predict_eeg_acc': 0.4512, 'train_predict_speech_acc': 0.4591, 'train_loss': 1.5326, 'train_acc': 0.4591, 'train_clap_acc': 0.4551, 'lr': 0.001, 'tau': 0.05225923967096359, 'epoch': 17}


100%|██████████| 666/666 [00:03<00:00, 177.51it/s]


Epoch 17:  Stage.VALID {'valid_acc': 0.5586, 'epoch': 17}


100%|██████████| 666/666 [00:13<00:00, 50.56it/s, train_loss=1.54]


Epoch 18:  Stage.TRAIN {'train_predict_eeg_loss': 1.5433, 'train_predict_speech_loss': 1.5402, 'train_predict_eeg_acc': 0.4551, 'train_predict_speech_acc': 0.4518, 'train_loss': 1.5418, 'train_acc': 0.4518, 'train_clap_acc': 0.4535, 'lr': 0.001, 'tau': 0.05181764200315796, 'epoch': 18}


100%|██████████| 666/666 [00:03<00:00, 202.14it/s]


Epoch 18:  Stage.VALID {'valid_acc': 0.5931, 'epoch': 18}


100%|██████████| 666/666 [00:13<00:00, 49.23it/s, train_loss=1.52]


Epoch 19:  Stage.TRAIN {'train_predict_eeg_loss': 1.5169, 'train_predict_speech_loss': 1.5162, 'train_predict_eeg_acc': 0.4583, 'train_predict_speech_acc': 0.4574, 'train_loss': 1.5166, 'train_acc': 0.4574, 'train_clap_acc': 0.4579, 'lr': 0.001, 'tau': 0.05081499699929751, 'epoch': 19}


100%|██████████| 666/666 [00:03<00:00, 195.58it/s]


Epoch 19:  Stage.VALID {'valid_acc': 0.5691, 'epoch': 19}


100%|██████████| 666/666 [00:12<00:00, 53.92it/s, train_loss=1.52]


Epoch 20:  Stage.TRAIN {'train_predict_eeg_loss': 1.5199, 'train_predict_speech_loss': 1.515, 'train_predict_eeg_acc': 0.4666, 'train_predict_speech_acc': 0.4638, 'train_loss': 1.5174, 'train_acc': 0.4638, 'train_clap_acc': 0.4652, 'lr': 0.001, 'tau': 0.05011484237168782, 'epoch': 20}


100%|██████████| 666/666 [00:03<00:00, 187.76it/s]


Epoch 20:  Stage.VALID {'valid_acc': 0.5345, 'epoch': 20}


100%|██████████| 666/666 [00:12<00:00, 54.42it/s, train_loss=1.52]


Epoch 21:  Stage.TRAIN {'train_predict_eeg_loss': 1.5125, 'train_predict_speech_loss': 1.5205, 'train_predict_eeg_acc': 0.4632, 'train_predict_speech_acc': 0.4658, 'train_loss': 1.5165, 'train_acc': 0.4658, 'train_clap_acc': 0.4645, 'lr': 0.001, 'tau': 0.04986953534188285, 'epoch': 21}


100%|██████████| 666/666 [00:03<00:00, 185.05it/s]


Epoch 21:  Stage.VALID {'valid_acc': 0.5631, 'epoch': 21}


100%|██████████| 666/666 [00:13<00:00, 50.78it/s, train_loss=1.52]


Epoch 22:  Stage.TRAIN {'train_predict_eeg_loss': 1.5171, 'train_predict_speech_loss': 1.5213, 'train_predict_eeg_acc': 0.4581, 'train_predict_speech_acc': 0.457, 'train_loss': 1.5192, 'train_acc': 0.457, 'train_clap_acc': 0.4576, 'lr': 0.001, 'tau': 0.04946887908861199, 'epoch': 22}


100%|██████████| 666/666 [00:03<00:00, 186.99it/s]


Epoch 22:  Stage.VALID {'valid_acc': 0.6111, 'epoch': 22}


100%|██████████| 666/666 [00:12<00:00, 53.99it/s, train_loss=1.49]


Epoch 23:  Stage.TRAIN {'train_predict_eeg_loss': 1.491, 'train_predict_speech_loss': 1.494, 'train_predict_eeg_acc': 0.4606, 'train_predict_speech_acc': 0.4696, 'train_loss': 1.4925, 'train_acc': 0.4696, 'train_clap_acc': 0.4651, 'lr': 0.001, 'tau': 0.0484318322755159, 'epoch': 23}


100%|██████████| 666/666 [00:03<00:00, 184.40it/s]


Epoch 23:  Stage.VALID {'valid_acc': 0.5586, 'epoch': 23}


100%|██████████| 666/666 [00:11<00:00, 56.50it/s, train_loss=1.51]


Epoch 24:  Stage.TRAIN {'train_predict_eeg_loss': 1.5081, 'train_predict_speech_loss': 1.507, 'train_predict_eeg_acc': 0.4627, 'train_predict_speech_acc': 0.4705, 'train_loss': 1.5076, 'train_acc': 0.4705, 'train_clap_acc': 0.4666, 'lr': 0.001, 'tau': 0.04897860502414468, 'epoch': 24}


100%|██████████| 666/666 [00:03<00:00, 174.73it/s]


Epoch 24:  Stage.VALID {'valid_acc': 0.5961, 'epoch': 24}


100%|██████████| 666/666 [00:11<00:00, 56.05it/s, train_loss=1.49]


Epoch 25:  Stage.TRAIN {'train_predict_eeg_loss': 1.4895, 'train_predict_speech_loss': 1.4907, 'train_predict_eeg_acc': 0.467, 'train_predict_speech_acc': 0.4692, 'train_loss': 1.4901, 'train_acc': 0.4692, 'train_clap_acc': 0.4681, 'lr': 0.001, 'tau': 0.04791621827008714, 'epoch': 25}


100%|██████████| 666/666 [00:03<00:00, 171.19it/s]


Epoch 25:  Stage.VALID {'valid_acc': 0.5631, 'epoch': 25}


100%|██████████| 666/666 [00:11<00:00, 59.54it/s, train_loss=1.47]


Epoch 26:  Stage.TRAIN {'train_predict_eeg_loss': 1.4692, 'train_predict_speech_loss': 1.4738, 'train_predict_eeg_acc': 0.4822, 'train_predict_speech_acc': 0.4777, 'train_loss': 1.4715, 'train_acc': 0.4777, 'train_clap_acc': 0.4799, 'lr': 0.001, 'tau': 0.04712829058370843, 'epoch': 26}


100%|██████████| 666/666 [00:03<00:00, 187.71it/s]


Epoch 26:  Stage.VALID {'valid_acc': 0.5826, 'epoch': 26}


100%|██████████| 666/666 [00:11<00:00, 57.94it/s, train_loss=1.49]


Epoch 27:  Stage.TRAIN {'train_predict_eeg_loss': 1.4845, 'train_predict_speech_loss': 1.4873, 'train_predict_eeg_acc': 0.467, 'train_predict_speech_acc': 0.4809, 'train_loss': 1.4859, 'train_acc': 0.4809, 'train_clap_acc': 0.4739, 'lr': 0.001, 'tau': 0.04702217186956829, 'epoch': 27}


100%|██████████| 666/666 [00:03<00:00, 190.84it/s]


Epoch 27:  Stage.VALID {'valid_acc': 0.5811, 'epoch': 27}


100%|██████████| 666/666 [00:11<00:00, 57.46it/s, train_loss=1.48]


Epoch 28:  Stage.TRAIN {'train_predict_eeg_loss': 1.4821, 'train_predict_speech_loss': 1.484, 'train_predict_eeg_acc': 0.47, 'train_predict_speech_acc': 0.4625, 'train_loss': 1.4831, 'train_acc': 0.4625, 'train_clap_acc': 0.4662, 'lr': 0.001, 'tau': 0.0463313489252834, 'epoch': 28}


100%|██████████| 666/666 [00:03<00:00, 192.76it/s]


Epoch 28:  Stage.VALID {'valid_acc': 0.5961, 'epoch': 28}


100%|██████████| 666/666 [00:12<00:00, 54.79it/s, train_loss=1.47]


Epoch 29:  Stage.TRAIN {'train_predict_eeg_loss': 1.468, 'train_predict_speech_loss': 1.4688, 'train_predict_eeg_acc': 0.4848, 'train_predict_speech_acc': 0.4786, 'train_loss': 1.4684, 'train_acc': 0.4786, 'train_clap_acc': 0.4817, 'lr': 0.001, 'tau': 0.04612871128845463, 'epoch': 29}


100%|██████████| 666/666 [00:03<00:00, 173.40it/s]


Epoch 29:  Stage.VALID {'valid_acc': 0.5721, 'epoch': 29}


100%|██████████| 666/666 [00:12<00:00, 51.33it/s, train_loss=1.49]


Epoch 30:  Stage.TRAIN {'train_predict_eeg_loss': 1.4873, 'train_predict_speech_loss': 1.4893, 'train_predict_eeg_acc': 0.4795, 'train_predict_speech_acc': 0.4698, 'train_loss': 1.4883, 'train_acc': 0.4698, 'train_clap_acc': 0.4747, 'lr': 0.001, 'tau': 0.04607379938447932, 'epoch': 30}


100%|██████████| 666/666 [00:03<00:00, 200.83it/s]


Epoch 30:  Stage.VALID {'valid_acc': 0.5886, 'epoch': 30}


100%|██████████| 666/666 [00:12<00:00, 52.83it/s, train_loss=1.46]


Epoch 31:  Stage.TRAIN {'train_predict_eeg_loss': 1.4551, 'train_predict_speech_loss': 1.4636, 'train_predict_eeg_acc': 0.478, 'train_predict_speech_acc': 0.4741, 'train_loss': 1.4593, 'train_acc': 0.4741, 'train_clap_acc': 0.4761, 'lr': 0.001, 'tau': 0.04488632233740236, 'epoch': 31}


100%|██████████| 666/666 [00:03<00:00, 188.22it/s]


Epoch 31:  Stage.VALID {'valid_acc': 0.5616, 'epoch': 31}


100%|██████████| 666/666 [00:12<00:00, 53.02it/s, train_loss=1.49]


Epoch 32:  Stage.TRAIN {'train_predict_eeg_loss': 1.4873, 'train_predict_speech_loss': 1.4836, 'train_predict_eeg_acc': 0.4694, 'train_predict_speech_acc': 0.466, 'train_loss': 1.4854, 'train_acc': 0.466, 'train_clap_acc': 0.4677, 'lr': 0.001, 'tau': 0.04426744192329931, 'epoch': 32}


100%|██████████| 666/666 [00:03<00:00, 196.26it/s]


Epoch 32:  Stage.VALID {'valid_acc': 0.5946, 'epoch': 32}


100%|██████████| 666/666 [00:12<00:00, 55.13it/s, train_loss=1.5] 


Epoch 33:  Stage.TRAIN {'train_predict_eeg_loss': 1.5029, 'train_predict_speech_loss': 1.5043, 'train_predict_eeg_acc': 0.4685, 'train_predict_speech_acc': 0.4705, 'train_loss': 1.5036, 'train_acc': 0.4705, 'train_clap_acc': 0.4695, 'lr': 0.001, 'tau': 0.04425171199722346, 'epoch': 33}


100%|██████████| 666/666 [00:03<00:00, 181.76it/s]


Epoch 33:  Stage.VALID {'valid_acc': 0.5736, 'epoch': 33}


100%|██████████| 666/666 [00:12<00:00, 52.44it/s, train_loss=1.49]


Epoch 34:  Stage.TRAIN {'train_predict_eeg_loss': 1.4879, 'train_predict_speech_loss': 1.4842, 'train_predict_eeg_acc': 0.4764, 'train_predict_speech_acc': 0.4679, 'train_loss': 1.4861, 'train_acc': 0.4679, 'train_clap_acc': 0.4721, 'lr': 0.001, 'tau': 0.04444315972921858, 'epoch': 34}


100%|██████████| 666/666 [00:03<00:00, 181.46it/s]


Epoch 34:  Stage.VALID {'valid_acc': 0.5961, 'epoch': 34}


100%|██████████| 666/666 [00:11<00:00, 57.35it/s, train_loss=1.45]


Epoch 35:  Stage.TRAIN {'train_predict_eeg_loss': 1.4543, 'train_predict_speech_loss': 1.4549, 'train_predict_eeg_acc': 0.4854, 'train_predict_speech_acc': 0.4786, 'train_loss': 1.4546, 'train_acc': 0.4786, 'train_clap_acc': 0.482, 'lr': 0.001, 'tau': 0.0435708721845014, 'epoch': 35}


100%|██████████| 666/666 [00:03<00:00, 200.22it/s]


Epoch 35:  Stage.VALID {'valid_acc': 0.5766, 'epoch': 35}


100%|██████████| 666/666 [00:11<00:00, 56.67it/s, train_loss=1.47]


Epoch 36:  Stage.TRAIN {'train_predict_eeg_loss': 1.4695, 'train_predict_speech_loss': 1.4667, 'train_predict_eeg_acc': 0.4792, 'train_predict_speech_acc': 0.4814, 'train_loss': 1.4681, 'train_acc': 0.4814, 'train_clap_acc': 0.4803, 'lr': 0.001, 'tau': 0.04356545229585965, 'epoch': 36}


100%|██████████| 666/666 [00:03<00:00, 197.44it/s]


Epoch 36:  Stage.VALID {'valid_acc': 0.5916, 'epoch': 36}


100%|██████████| 666/666 [00:12<00:00, 52.35it/s, train_loss=1.45]


Epoch 37:  Stage.TRAIN {'train_predict_eeg_loss': 1.4448, 'train_predict_speech_loss': 1.4487, 'train_predict_eeg_acc': 0.4803, 'train_predict_speech_acc': 0.4829, 'train_loss': 1.4467, 'train_acc': 0.4829, 'train_clap_acc': 0.4816, 'lr': 0.001, 'tau': 0.04289922599613497, 'epoch': 37}


100%|██████████| 666/666 [00:04<00:00, 163.23it/s]


Epoch 37:  Stage.VALID {'valid_acc': 0.5616, 'epoch': 37}


100%|██████████| 666/666 [00:13<00:00, 50.71it/s, train_loss=1.43]


Epoch 38:  Stage.TRAIN {'train_predict_eeg_loss': 1.4323, 'train_predict_speech_loss': 1.4281, 'train_predict_eeg_acc': 0.4977, 'train_predict_speech_acc': 0.4961, 'train_loss': 1.4302, 'train_acc': 0.4961, 'train_clap_acc': 0.4969, 'lr': 0.001, 'tau': 0.041849765243225895, 'epoch': 38}


100%|██████████| 666/666 [00:03<00:00, 188.62it/s]


Epoch 38:  Stage.VALID {'valid_acc': 0.5871, 'epoch': 38}


100%|██████████| 666/666 [00:12<00:00, 52.00it/s, train_loss=1.47]


Epoch 39:  Stage.TRAIN {'train_predict_eeg_loss': 1.4681, 'train_predict_speech_loss': 1.4707, 'train_predict_eeg_acc': 0.4852, 'train_predict_speech_acc': 0.4895, 'train_loss': 1.4694, 'train_acc': 0.4895, 'train_clap_acc': 0.4873, 'lr': 0.001, 'tau': 0.04236383831139354, 'epoch': 39}


100%|██████████| 666/666 [00:03<00:00, 178.23it/s]


Epoch 39:  Stage.VALID {'valid_acc': 0.5976, 'epoch': 39}


 47%|████▋     | 312/666 [00:07<00:05, 69.62it/s, train_loss=1.46]

In [None]:
wandb.finish()