<a href="https://colab.research.google.com/github/aaderemi/GSoC/blob/main/new_pretrain_multigpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Please set an output directory to store model and other artifacts

In [None]:
#!pip install mne
#!pip install accelerate -U
#!unzip test_code.zip -d test_code
#!pip install wandb

In [None]:
#automatically creates a config file for training
from accelerate.utils import write_basic_config
write_basic_config()

In [None]:
#session needs to be restarted after calling the above
import os
os._exit(00)

In [None]:
#!mkdir output

In [None]:
import wandb
wandb.login()

In [None]:
def main(num_epochs, bs, seed):

    import os
    from rich import print
    import mne
    import numpy as np
    import torch
    import torch.nn as nn
    from accelerate.utils import tqdm
    from accelerate.utils import set_seed
    from torch.utils.data import DataLoader
    from accelerate import DistributedDataParallelKwargs
    import wandb

    from transformers import Wav2Vec2Config, Wav2Vec2ForPreTraining
    from accelerate import Accelerator
    from transformers.models.wav2vec2.modeling_wav2vec2 import (
        ACT2FN,
        Wav2Vec2FeatureEncoder,
        _compute_mask_indices,
        _sample_negative_indices,
        Wav2Vec2GroupNormConvLayer,
        Wav2Vec2GumbelVectorQuantizer,
        Wav2Vec2LayerNormConvLayer,
        Wav2Vec2NoLayerNormConvLayer
    )

    # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration
    configuration = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-large")
    model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large", config=configuration)

    configuration.conv_kernel = [3, 2, 2, 2, 2, 2]
    configuration.conv_stride = [3, 2, 2, 2, 2, 2]
    configuration.conv_dim = [512, 512, 512, 512, 512, 512]
    configuration.num_feat_extract_layers = 6
    configuration.num_channels = 20
    configuration.num_negatives = 20
    configuration.mask_time_prob = 0.065
    configuration.mask_time_length = 10

    # configuration.diversity_loss_weight = 0.0

    # set some training arguments here
    # save steps and log steps control saving and printing output respectively
    save_steps = 10
    logging_steps = 1
    learning_rate = 1e-2
    weight_decay = 0.01
    warmup_ratio = 0.05

    remove_quantization = False  # set to true to remove quantization
    # put output directory for models here
    output_dir = "output"
    # "/data/work/zeydabadi/test/"
    processed_folder = "/content/test_code"

    class Wav2Vec2LayerNormConvLayer(Wav2Vec2LayerNormConvLayer):
        def __init__(self, config, layer_id=0):
            super().__init__(config, layer_id)

            self.in_conv_dim = config.conv_dim[layer_id -
                                               1] if layer_id > 0 else config.num_channels
            self.out_conv_dim = config.conv_dim[layer_id]

            self.conv = nn.Conv1d(
                self.in_conv_dim,
                self.out_conv_dim,
                kernel_size=config.conv_kernel[layer_id],
                stride=config.conv_stride[layer_id],
                bias=config.conv_bias,
                padding=config.conv_kernel[layer_id]//2,
            )
            self.activation = ACT2FN[config.feat_extract_activation]
            self.layer_norm = nn.LayerNorm(
                self.out_conv_dim, elementwise_affine=True)

    class Wav2Vec2NoLayerNormConvLayer(Wav2Vec2NoLayerNormConvLayer):
        def __init__(self, config, layer_id=0):
            super().__init__(config, layer_id)

            self.in_conv_dim = config.conv_dim[layer_id -
                                               1] if layer_id > 0 else config.num_channels
            self.out_conv_dim = config.conv_dim[layer_id]

            self.conv = nn.Conv1d(
                self.in_conv_dim,
                self.out_conv_dim,
                kernel_size=config.conv_kernel[layer_id],
                stride=config.conv_stride[layer_id],
                bias=config.conv_bias,
                padding=config.conv_kernel[layer_id]//2,
            )
            self.activation = ACT2FN[config.feat_extract_activation]

    class Wav2Vec2GroupNormConvLayer(Wav2Vec2GroupNormConvLayer):
        def __init__(self, config, layer_id=0):
            super().__init__(config, layer_id)

            self.in_conv_dim = config.conv_dim[layer_id -
                                               1] if layer_id > 0 else config.num_channels
            self.out_conv_dim = config.conv_dim[layer_id]

            self.conv = nn.Conv1d(
                self.in_conv_dim,
                self.out_conv_dim,
                kernel_size=config.conv_kernel[layer_id],
                stride=config.conv_stride[layer_id],
                bias=config.conv_bias,
                padding=config.conv_kernel[layer_id]//2,
            )

            self.activation = ACT2FN[config.feat_extract_activation]
            self.layer_norm = nn.GroupNorm(
                num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)

    class Wav2Vec2FeatureEncoder(Wav2Vec2FeatureEncoder):

        def __init__(self, config):
            super().__init__(config)

            if config.feat_extract_norm == "group":
                conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
                    Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
                ]
            elif config.feat_extract_norm == "layer":
                conv_layers = [
                    Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
                ]
            else:
                raise ValueError(
                    f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
                )

            self.conv_layers = nn.ModuleList(conv_layers)
            self.gradient_checkpointing = False
            self._requires_grad = True

        def forward(self, input_values):
            hidden_states = input_values

            # make sure hidden_states require grad for gradient_checkpointing
            if self._requires_grad and self.training:
                hidden_states.requires_grad = True

            for conv_layer in self.conv_layers:
                if self._requires_grad and self.gradient_checkpointing and self.training:
                    hidden_states = self._gradient_checkpointing_func(
                        conv_layer.__call__,
                        hidden_states,
                    )
                else:
                    hidden_states = conv_layer(hidden_states)

            return hidden_states

    class Wav2Vec2NoQuantizer(Wav2Vec2GumbelVectorQuantizer):
        def __init__(self, config):
            super().__init__(config)
            self.proj = nn.Linear(config.conv_dim[-1], config.codevector_dim)

        @staticmethod
        def _compute_perplexity(probs, mask=None):
            # should probably make this a tensor
            return 0

        def forward(self, hidden_states, mask_time_indices=None):
            hidden_states = self.proj(hidden_states)
            perplexity = self._compute_perplexity(
                hidden_states,  mask_time_indices)

            return hidden_states, perplexity

    model.wav2vec2.feature_extractor = Wav2Vec2FeatureEncoder(configuration)

    if remove_quantization:
        model.wav2vec2.quantizer = Wav2Vec2NoQuantizer(configuration)
        model.config.diversity_loss_weight = 0.0

    def list_files_with_extension(root_folder, extension):
        """
        List all files with a specific extension in a folder and its subfolders.

        Parameters:
        - root_folder (str): The root folder to start the search.
        - extension (str): The file extension to look for (e.g., '.txt').

        Returns:
        - List of file paths with the specified extension.
        """
        matching_files = []
        walk_list = list(os.walk(root_folder))
        sorted_walk_list = sorted(walk_list, key=lambda x: x[0])

        # Traverse the directory tree
        for dirpath, dirnames, filenames in sorted_walk_list:
            for filename in filenames:
                if filename.endswith(extension):
                    full_path = os.path.join(dirpath, filename)
                    matching_files.append(full_path)

        return matching_files

    # im using dataset class from shared file
    class EEGDataset(torch.utils.data.Dataset):
        def __init__(self, path, extension):
            self.path = path
            self.extension = extension
            self.items = list_files_with_extension(self.path, self.extension)
            # self.items = self.filter_valid_files(all_items)

        def __len__(self):
            return len(self.items)

        def get_filename(self, idx):
            return self.items[idx]

        def __getitem__(self, idx):
            file_id = self.items[idx]
            raw_file = mne.io.read_raw_fif(file_id, preload=True)
            data = raw_file.get_data()
            # model expects dictionary with at least this keyword
            return {"input_values": torch.Tensor(data)}

    mne.set_log_level("ERROR")

    train_dataset = EEGDataset(processed_folder, ".fif")

    ep = train_dataset[0]
    num_samples = ep["input_values"].shape[1]

    for i, s in enumerate(configuration.conv_stride):
        if s > 1:
            num_samples = (num_samples + 2*(configuration.conv_kernel[i]//2) - (
                configuration.conv_kernel[i]-1) - 1)//s + 1

    set_seed(seed)

    accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
    # accelerator.print("Started")
    if accelerator.is_main_process:
      run = wandb.init(
          project = "multigpu_test_kaggle",
          config = {
              "lr": learning_rate,
              "bs": bs,
              "num_epochs": num_epochs
          }
      )
    accelerator.wait_for_everyone()

    train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
    # progress_bar = tqdm(train_dataloader, disable=not accelerator.is_main_process)

    optimizer = torch.optim.Adam(
        params=model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate, epochs=num_epochs, steps_per_epoch=len(train_dataloader), pct_start=warmup_ratio)

    model, optimizer, train_dataloader, scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, scheduler)
    # accelerator.print("prepare done")
    steps = 0
    for epoch in range(num_epochs):
        model.train()
        for batch in tqdm(train_dataloader):
            steps += 1
            bs_data, cin, seq_length = batch["input_values"].shape
            seq_length = int(num_samples)
            mask_time_indices = _compute_mask_indices(shape=(
                bs_data, seq_length), mask_prob=configuration.mask_time_prob, mask_length=configuration.mask_time_length)
            sampled_negative_indices = _sample_negative_indices(features_shape=(
                bs_data,  seq_length), num_negatives=configuration.num_negatives, mask_time_indices=mask_time_indices)

            mask_time_indices = torch.tensor(
                data=mask_time_indices, device=accelerator.device, dtype=torch.long)
            sampled_negative_indices = torch.tensor(
                data=sampled_negative_indices, device=accelerator.device, dtype=torch.long)

            # accelerator.print("before model forward ")

            outputs = model(input_values=batch["input_values"], mask_time_indices=mask_time_indices,
                            sampled_negative_indices=sampled_negative_indices, return_dict=True)
            loss = outputs.loss

            # accelerator.print("shape of loss ", loss)
            # accelerator.print("loss mean ", loss.mean())
            gathered_loss = accelerator.gather(loss)
            # accelerator.print("shape of gathered loss ", gathered_loss)

            if accelerator.is_main_process:
                wandb.log({"loss": gathered_loss.mean().item()})

            if steps % logging_steps == 0:
                accelerator.print(
                    f"Doing step {steps}, loss: %.2f" % (gathered_loss.mean()))
            if steps % save_steps == 0:
                accelerator.wait_for_everyone()
                unwrapped_model = accelerator.unwrap_model(model)
                accelerator.save_model(
                    unwrapped_model, f"{output_dir}/model_{steps}")
                del unwrapped_model

            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
    if accelerator.is_main_process:
      wandb.finish()

In [None]:
from accelerate import notebook_launcher

In [None]:
#Pass number of epochs, batch size and seed here. In order.
args = (5, 2, 42)

In [None]:
notebook_launcher(main, args, num_processes = 1)