In [1]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch
import torch.nn.functional as F
from torch import nn

import safetensors
from accelerate import notebook_launcher
from dataclasses import dataclass
from simple_parsing.helpers import Serializable

import gc
import albumentations as A
import matplotlib.pyplot as plt
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer

from utils.data_utils import BrainDataset, get_tokenizer, process_file_v2
from utils.train_utils import TrainConfig, run_train_model, count_parameters, load_model_weights, freeze_module

In [2]:
from models.mirasol import Mirasol, MirasolConfig, Franky
from models.vq_brain_per_channel import SoundStream, VAEConfig


In [3]:
IS_AWS_SERVER = True

if IS_AWS_SERVER:
    whisper_model_name = "openai/whisper-small.en"
    vq_weights = Path('/drive/logs/kovalev/vq_brain/4_features_35M_4x_4000_256_ws-new/step_12000_loss_0.0102.safetensors')
    data_path = Path("/drive/data/competitionData")
    save_folder = Path("/drive/logs/kovalev")
else: 
    whisper_model_name = "openai/whisper-large-v3"
    vq_weights = Path('../4_features_35M_4x_4000_256_ws-new.safetensors')
    data_path = Path("/drive/data/competitionData")
    save_folder = Path("/drive/logs/kovalev")

### Init Franky

In [4]:
vae_config = VAEConfig(C=512, 
                       levels=(7, 5, 5, 5, 5), 
                       n_features=4, 
                       stride_list = (2, 2))
vq_vae = SoundStream(**vae_config.to_dict())

model_config = MirasolConfig(window_size=512,
                             n_layers=6,
                             mask_ratio=0.1, 
                             n_registers=4)
brain_model = Mirasol(model_config, vq_vae)

self.codebook_size 4375
self.downsample 4
Shape of the rope cache:  torch.Size([512, 16])
Shape of the causal model:  torch.Size([512, 512])
Full Mirasol model: number of parameters: 114.13M


In [5]:
# load model and processor
tokenizer = WhisperTokenizer.from_pretrained(whisper_model_name, task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name, 
                                                        # apply_spec_augment=0.1,
                                                        decoder_layerdrop=0.1, 
                                                        # encoder_layerdrop=0.1, 
                                                        dropout=0.1,
                                                        activation_dropout=0.1)
model.config.forced_decoder_ids = None

whisper = {'decoder': model.model.decoder, 
           'proj_out': model.proj_out, 
           'tokenizer': tokenizer}



In [6]:
model = Franky(brain_model=brain_model, llm_model=whisper, add_temporal_embeddings=True) 
model.train()

load_model_weights(model.brain_model.vq_model, vq_weights)
freeze_module(model.brain_model.vq_model)
model.brain_model.vq_model.eval()
# # decoder
# freeze_module(model.llm_decoder.embed_tokens)
freeze_module(model.llm_decoder.embed_positions)
# freeze_module(model.proj_out)

count_parameters(model)

Full Franky: number of parameters: 269.70M
load compiled weights
Total: 269.70M, Trainable: 233.64M


(269695264, 233644311)

### Load Datasets

In [7]:
window_size = model_config.window_size
n_electrodes = 256 * 4
max_tokens = 25

train_transform = A.Compose([
    A.CoarseDropout(fill_value=0, p=0.1),
    # A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=0.5),
    A.GaussNoise(var_limit=0.005, mean=0, p=0.5),
    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    # A.RandomCrop(height=window_size, width=n_electrodes, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True),

])

test_transform = A.Compose([
    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True)
])


tokenize_function =  lambda text: tokenizer(text)['input_ids']
process_file_function = process_file_v2

train_dataset = BrainDataset(data_path / 'train', 
                             process_file_function=process_file_function, 
                             tokenize_function=tokenize_function, 
                             transform=train_transform, 
                             max_tokens=max_tokens)

gc.collect()
test_dataset = BrainDataset(data_path / 'test', 
                            process_file_function=process_file_function, 
                            tokenize_function=tokenize_function, 
                            transform=test_transform, 
                            max_tokens=max_tokens)

gc.collect()

Runed processing of the  /drive/data/competitionData/train


Processing files: 100%|██████████| 24/24 [01:21<00:00,  3.39s/file]


len of the dataset: 8800
max signal size: 906 | max tokens size: 23
median signal size: 297.0 | median tokens size: 11.0
Runed processing of the  /drive/data/competitionData/test


Processing files: 100%|██████████| 24/24 [00:07<00:00,  3.17file/s]


len of the dataset: 880
max signal size: 919 | max tokens size: 21
median signal size: 283.5 | median tokens size: 10.0


0

In [8]:
# # test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
# train_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

# x, y, date = next(iter(train_dataloader))
# print(x.shape, y.shape, date.shape)
# print('input text', y)

# loss, logits = model(x, y, date)

# Work with data

In [None]:
project_name = 'franky'

train_config = TrainConfig(exp_name='fixed-mirasol',
                           mixed_precision=True,
                           batch_size=2, 
                           grad_accum=2,
                           num_workers=3,
                           pin_memory=True, 
                           eval_interval=1000, 
                           learning_rate=1e-4,
                           weight_decay=0.001, 
                           grad_clip=1,
                           lr_decay_iters=20_000, 
                           warmup_iters=1000, 
                           project_name=project_name, 
                           save_folder=save_folder
                          )

args = (model, (train_dataset, test_dataset), train_config, model_config)
notebook_launcher(run_train_model, args, num_processes=1)

Launching training on one GPU.


dataloader_config = DataLoaderConfiguration(split_batches=True)
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m. Use [1m`wandb login --relogin`[0m to force relogin


Device for training:  cuda
Num devices:  1
Completed initialization of scheduler
*************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************