In [1]:
import os
import sys
import wandb
import torch
import datetime
import collections
from hyperpyyaml import load_hyperpyyaml
import speechbrain as sb
from torch.nn import DataParallel

HPARAM_FILE = 'hparams/class/meta_model_random_v2.yaml'

  from .autonotebook import tqdm as notebook_tqdm
torchvision is not available - cannot save figures


In [2]:
time_stamp = datetime.datetime.now().strftime('%Y-%m-%d+%H-%M-%S')
print(f'Experiment Time Stamp: {time_stamp}')

# Overwrite hparams
argv = [HPARAM_FILE, '--time_stamp', time_stamp]
argv += ['--use_wandb', 'false']
argv += ['--n_mismatch', '4']
argv += ['--experiment', 'env_tcn_bs4_lr1e-3_ep200']
#argv += ['--train_windows', 'null']

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

with open(HPARAM_FILE) as f:
    hparams = load_hyperpyyaml(f, overrides)
    
run_opts['auto_mix_prec'] = hparams['mix_prec'] # False

if hparams['use_wandb']:
    hparams['logger'] = hparams['wandb_logger']()
    
# sb.utils.distributed.ddp_init_group(run_opts)

Experiment Time Stamp: 2023-12-10+16-34-04
in_channels {'meg': 64}
in_channels {'meg': 64}
sizes: {'meg': [64, 64, 64, 64, 16]}
channels: (64, 64, 64, 64, 16)
in_channels {'meg': 1024}
in_channels {'meg': 1024}
sizes: {'meg': [1024, 64, 64, 64, 16]}
channels: (1024, 64, 64, 64, 16)


In [9]:
from datetime import datetime
import psutil
import GPUtil
import torch.nn.functional as F
import os
import torch
from transformers import Wav2Vec2Processor, Wav2Vec2Model
from transformers import pipeline
import numpy as np
import torch.nn as nn

wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53-dutch")
wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53-dutch").to('cuda')


class Classifier(sb.core.Brain):

    def compute_forward(self, eeg, speech, stage, wav2vec_processor, wav2vec_model):
        B, M, C, T = speech.shape
        speech_features = []
        if self.hparams.extract_features_on_the_fly:
            for i in range(B):
                for j in range(M):
                    audio_segment = speech[i, j].squeeze().cpu().numpy()
                    input_values = wav2vec_processor(audio_segment, return_tensors="pt", sampling_rate=16000).input_values
                    with torch.no_grad():
                        features = wav2vec_model(input_values.to('cuda')).last_hidden_state
                        speech_features.append(features)
            
            speech_features = torch.cat(speech_features, dim=0)
            speech_features = speech_features.view(B, M, -1, speech_features.size(-1)).to('cuda')
            speech_features = speech_features.permute(0, 1, 3, 2)  # Reshape to [B, M, F*T, New_T]
        else: 
            speech_features = speech
        
        
        eeg_emb = self.modules.eeg_model(eeg.to('cuda'))
        #print("speech_features.shape",speech_features.shape)

        target_time_dim = speech_features.size(3)  # Target time dimension is the time dimension of speech_features
        adaptive_pool = nn.AdaptiveAvgPool1d(target_time_dim)
        eeg_emb_pooled = adaptive_pool(eeg_emb)
        
        if self.hparams.use_speech_embedding:
            speech_emb = self.modules.speech_model(speech_features)
            pred = self.modules.classifier(eeg_emb_pooled, speech_emb)
        else:
            pred = self.modules.classifier(eeg_emb, speech)
        #print("eeg_emb_pooled.shape",eeg_emb_pooled.shape)
        #print("speech_emb.shape",speech_emb.shape)
        return pred
    
    def compute_objectives(self, pred, label, stage):
        B = label.shape[0]
        
        loss = self.hparams.loss_fn(pred, label)
        est_label = torch.argmax(pred, dim=-1)
        acc = sum(est_label==label) / B
        
        self.loss_stat['loss'] += float(loss) * B
        self.loss_stat['acc'] += float(acc) * B
        self.count += B
        return loss

    
    
    def fit_batch(self, batch):
        eeg = batch['eeg']
        speech = batch['wav']
        B, M, C, T = eeg.shape
        
        label = torch.randint(0, M, (B,), dtype=torch.long)
        eeg = eeg[torch.arange(B), label]
        
        eeg = eeg.to(self.device)
        speech = speech.to(self.device)
        label = label.to(self.device)
        
        # Forward
        pred = self.compute_forward(eeg, speech, sb.Stage.TRAIN, wav2vec_processor, wav2vec_model)
        loss = self.compute_objectives(pred, label, sb.Stage.TRAIN)
        loss.backward()
        
        if self.check_gradients(loss):
            self.optimizer.step()
        self.optimizer.zero_grad()
        
            
        return loss.detach().cpu()
    
    def evaluate_batch(self, batch, stage):
        eeg = batch['eeg']
        speech = batch['wav']
        
        B, M, C, T = eeg.shape
        label = torch.zeros(B, dtype=torch.long)
        eeg = eeg[torch.arange(B), label]
        
        eeg = eeg.to(self.device)
        speech = speech.to(self.device)
        label = label.to(self.device)
        
        # Forward
        with torch.no_grad():
            pred = self.compute_forward(eeg, speech, stage, wav2vec_processor, wav2vec_model)
            loss = self.compute_objectives(pred, label, stage)
            
            
        return loss.detach().cpu()
        
        
    def on_stage_start(self, stage, epoch=None):
        super().on_stage_start(stage, epoch)
        self.count = 0
        self.loss_stat = {
            'loss': 0,
            'acc': 0,
        }

    def on_stage_end(self, stage, stage_loss, epoch=None):
        for loss_key in self.loss_stat:
            self.loss_stat[loss_key] /= self.count
        
        if stage == sb.Stage.TRAIN:
            stage_stats = {'train_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['lr'] = self.optimizer.param_groups[0]['lr']
            stage_stats['epoch'] = epoch
    
        elif stage == sb.Stage.VALID:
            stage_stats = {'valid_'+key: round(float(value), 4) for key, value in self.loss_stat.items()}
            stage_stats['epoch'] = epoch
            self.lr_scheduler.step(self.loss_stat['loss'])
        print("stage_end")
        if self.hparams.use_wandb and hasattr(self.hparams, 'logger') and self.hparams.logger:
            self.hparams.logger.run.log(
                data=stage_stats,
            )
                        
        print(f'Epoch {epoch}: ', stage, stage_stats)
        
    def init_optimizers(self):
        super().init_optimizers()
        self.lr_scheduler = self.hparams.lr_scheduler(self.optimizer)
        if self.checkpointer is not None:
            self.checkpointer.add_recoverable("optimizer", self.optimizer)
            self.checkpointer.add_recoverable("lr_scheduler", self.lr_scheduler)
                        
    

Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize
Ignored unknown kwarg option normalize


Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-large-xlsr-53-dutch and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
brain = Classifier(
    modules=hparams['modules'],
    opt_class=hparams['optimizer'],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams['checkpointer'],
    
)


In [11]:
original_model = brain.module if isinstance(brain, DataParallel) else brain

# Use the hparams from the original model
original_model.fit(
    epoch_counter=original_model.hparams.epoch_counter,
    train_set=hparams['train_windows'],
    valid_set=hparams['valid_windows'],
    train_loader_kwargs={'batch_size':64, 'num_workers':16},
    valid_loader_kwargs=hparams['valid_loader_params']
    
)

100%|██████████| 11/11 [01:25<00:00,  7.76s/it, train_loss=1.61]


stage_end
Epoch 3:  Stage.TRAIN {'train_loss': 1.6074, 'train_acc': 0.2327, 'lr': 0.002, 'epoch': 3}


100%|██████████| 2/2 [01:42<00:00, 51.49s/it]


stage_end
Epoch 3:  Stage.VALID {'valid_loss': 1.6046, 'valid_acc': 0.2676, 'epoch': 3}


100%|██████████| 11/11 [01:33<00:00,  8.52s/it, train_loss=1.58]


stage_end
Epoch 4:  Stage.TRAIN {'train_loss': 1.5827, 'train_acc': 0.3273, 'lr': 0.002, 'epoch': 4}


 50%|█████     | 1/2 [00:32<00:32, 32.16s/it]Exception ignored in: <generator object tqdm.__iter__ at 0x2aac2ec7fed0>
Traceback (most recent call last):
  File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/tqdm/std.py", line 1197, in __iter__
    self.close()
  File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/tqdm/std.py", line 1291, in close
    fp_write('')
  File "/home/ss6928/.conda/envs/myenv/lib/python3.10/site-packages/tqdm/std.py", line 1287, in fp_write
    def fp_write(s):
KeyboardInterrupt: 
Traceback (most recent call last):
  File "/home/ss6928/.conda/envs/myenv/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/home/ss6928/.conda/envs/myenv/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/home/ss6928/.conda/envs/myenv/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor
Traceback (m

KeyboardInterrupt: 

In [6]:
wandb.finish()



0,1
epoch,▁▁█
lr,▁▁
train_acc,█▁
train_loss,▁█
valid_acc,▁
valid_loss,▁

0,1
epoch,2.0
lr,0.006
train_acc,0.1862
train_loss,1.6125
valid_acc,0.168
valid_loss,1.6164
