In [1]:
import os
import sys
import wandb
import torch
import datetime
import collections
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

os.environ['WANDB_NOTEBOOK_NAME'] = 'train_clap.ipynb'

HPARAM_FILE = 'hparams/clap/env_tcn_clap.yaml'

In [2]:
time_stamp = datetime.datetime.now().strftime('%Y-%m-%d+%H-%M-%S')
print(f'Experiment Time Stamp: {time_stamp}')

# Overwrite hparams
argv = [HPARAM_FILE, '--time_stamp', time_stamp]
argv += ['--use_wandb', 'true']
argv += ['--clap_batch_size', '16']
argv += ['--n_epoch', '20']
argv += ['--lr', '1.0e-4']
argv += ['--tau', '0.07']
argv += ['--clap_ch', '16']
argv += ['--patience', '2']
argv += ['--experiment', 'FULL_CLAP_env_tcn_cbs16_cc16_lr1e-4_ep20']

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

with open(HPARAM_FILE) as f:
    hparams = load_hyperpyyaml(f, overrides)
    
run_opts['auto_mix_prec'] = hparams['mix_prec'] # False

if hparams['use_wandb']:
    hparams['logger'] = hparams['wandb_logger']()
    
# sb.utils.distributed.ddp_init_group(run_opts)

Experiment Time Stamp: 2023-12-08+18-20-19


2023-12-08 18:20:19.986633: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-08 18:20:19.986703: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-08 18:20:19.986729: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-08 18:20:19.996199: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-12-08 18:20:23.104496: I tensorflow/core/comm

In [3]:
class CLAP(sb.core.Brain):

    def compute_forward(self, speech, eeg, stage):

        speech_ndim = speech.ndim
        eeg_ndim = eeg.ndim
        
        if speech_ndim != 3:
            B, M1, C1, T1 = speech.shape
            speech = speech.view(B*M1, C1, T1)
        if eeg_ndim != 3:
            B, M2, C2, T2 = eeg.shape
            eeg = eeg.view(B*M2, C2, T2)
        
        speech_proj, eeg_proj, logit_scale_exp = self.modules.clap(speech, eeg)
        
        if speech_ndim != 3:
            speech_proj = speech_proj.view((B, M1) + speech_proj.shape[1:])
        if eeg_ndim != 3:
            eeg_proj = eeg_proj.view((B, M2) + eeg_proj.shape[1:])
        
        return speech_proj, eeg_proj, logit_scale_exp
    

    def compute_clap_objectives(self, speech_proj, eeg_proj, logit_scale_exp):
        similarity = self.modules.clap.compute_similarity_with_scale(
            speech_proj, eeg_proj, logit_scale_exp
        )
        
        eeg_label_for_speech = torch.arange(similarity.shape[0]).long().to(self.device) # number of speech
        speech_label_for_eeg = torch.arange(similarity.shape[1]).long().to(self.device) # number of eeg
        
        predict_eeg_loss = self.hparams.cross_entropy(similarity, eeg_label_for_speech)
        predict_speech_loss = self.hparams.cross_entropy(similarity.T, speech_label_for_eeg)
        loss = 0.5 * (self.hparams.lambda_predict_eeg * predict_eeg_loss \
            + self.hparams.lambda_predict_speech * predict_speech_loss)
    
        with torch.no_grad():
            predict_eeg_correct = (torch.max(similarity, dim=1)[1]==eeg_label_for_speech).sum()
            predict_speech_correct = (torch.max(similarity, dim=0)[1]==speech_label_for_eeg).sum()
            predict_eeg_acc = predict_eeg_correct / len(eeg_label_for_speech)
            predict_speech_acc = predict_speech_correct / len(speech_label_for_eeg)
            clap_acc = 0.5 * (predict_eeg_acc + predict_speech_acc)
        
        loss_dict = {
            'predict_eeg_loss': predict_eeg_loss,
            'predict_speech_loss': predict_speech_loss,
            'predict_eeg_acc': predict_eeg_acc,
            'predict_speech_acc': predict_speech_acc,
            'loss': loss,
            'acc': predict_speech_acc,
            'clap_acc': clap_acc
        }
        
        # Update loss stat
        if not torch.isnan(loss):
            B = eeg_proj.shape[0]
            self.count += B
            for key in self.loss_stat:
                with torch.no_grad():
                    self.loss_stat[key] += B * loss_dict[key]
        
        return loss
    
    
    def compute_class_objectives(self, speech_proj, eeg_proj, label):
        '''
        speech_proj: (B, M, D)
        eeg_proj: (B, D)
        '''
        
        B = eeg_proj.shape[0]
        est_label = self.modules.clap.predict_speech(
            speech_proj, eeg_proj
        )
            
        acc = sum(est_label==label) / B
        
        self.loss_stat['acc'] += float(acc) * B
        self.count += B
    
        return torch.tensor(0)
    
    
    def fit_batch(self, batch):
        # Always use batch_size = 1
        # Number of the windows is the actual batch size for clap
        eeg, speech = batch
        eeg = eeg.squeeze(0).to(self.device)
        speech = speech.squeeze(0).to(self.device)
        
        # Forward
        speech_proj, eeg_proj, logit_scale_exp = self.compute_forward(speech, eeg, sb.Stage.TRAIN)
        loss = self.compute_clap_objectives(speech_proj, eeg_proj, logit_scale_exp)
        loss.backward()
        
        if self.check_gradients(loss):
            self.optimizer.step()
        self.optimizer.zero_grad()
            
        return loss.detach().cpu()
    
    
    def evaluate_batch(self, batch, stage):
        eeg, speech, label = batch
        
        eeg = eeg.to(self.device)
        speech = speech.to(self.device)
        label = label.to(self.device)
        
        # Forward
        with torch.no_grad():
            speech_proj, eeg_proj, _ = self.compute_forward(speech, eeg, stage)
            loss = self.compute_class_objectives(speech_proj, eeg_proj, label)
            
        return loss.detach().cpu()
        

    def make_dataloader(
        self, dataset, stage, ckpt_prefix="dataloader-", **loader_kwargs
    ):
        # Treat pytorch TF wrapper as a dataloader
        # Because create_tf_dataset already batches EEGs and speeches
        return dataset
        

    def on_stage_start(self, stage, epoch=None):
        super().on_stage_start(stage, epoch)
        self.count = 0
        if stage == sb.Stage.TRAIN:
            self.loss_stat = {
                'predict_eeg_loss': 0,
                'predict_speech_loss': 0,
                'predict_eeg_acc': 0,
                'predict_speech_acc': 0,
                'loss': 0,
                'acc': 0,
                'clap_acc': 0
            }
        else:
            self.loss_stat = {
                'acc': 0,
            }
        
        # Reload windows at the start of each epoch
        if stage == sb.Stage.TRAIN:
            self.hparams.train_windows.reload()
        elif stage == sb.Stage.VALID:
            self.hparams.valid_windows.reload()
        
    
    def on_stage_end(self, stage, stage_loss, epoch=None):
        for loss_key in self.loss_stat:
            if self.count != 0:
                self.loss_stat[loss_key] /= self.count
        
        if stage == sb.Stage.TRAIN:
            stage_stats = {'train_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['lr'] = self.optimizer.param_groups[0]['lr']
            stage_stats['tau'] = float(1/self.modules.clap.logit_scale.exp().item())      
            stage_stats['epoch'] = epoch
    
        elif stage == sb.Stage.VALID:
            stage_stats = {'valid_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['epoch'] = epoch
            self.lr_scheduler.step(-self.loss_stat['acc'])

        if self.hparams.use_wandb:
            self.hparams.logger.run.log(
                data=stage_stats,
            )
                        
        print(f'Epoch {epoch}: ', stage, stage_stats)
      
    
    def init_optimizers(self):
        super().init_optimizers()
        self.lr_scheduler = self.hparams.lr_scheduler(self.optimizer)
        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("optimizer", self.optimizer)
            self.checkpointer.add_recoverable("lr_scheduler", self.lr_scheduler)
                        
    

In [4]:
brain = CLAP(
    modules=hparams['modules'],
    opt_class=hparams['optimizer'],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams['checkpointer'],
)

In [None]:
brain.fit(
    epoch_counter=brain.hparams.epoch_counter,
    train_set=hparams['train_windows'],
    valid_set=hparams['valid_windows']
)

  8%|▊         | 43122/510443 [06:24<1:11:29, 108.95it/s, train_loss=2.04]

In [None]:
wandb.finish()