In [1]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch
import torch.nn.functional as F
from torch import nn

import safetensors
from accelerate import notebook_launcher

from utils.data_utils import BrainDataset, get_tokenizer
from utils.train_utils import TrainConfig, run_train_model, count_parameters, load_model_weights, freeze_module

from dataclasses import dataclass
from simple_parsing.helpers import Serializable

from safetensors.torch import load_model
import albumentations as A

import matplotlib.pyplot as plt

In [2]:
from models.mirasol import Mirasol, MirasolConfig, Franky
from models.vq_brain_per_channel import SoundStream, VAEConfig


## Init brain module

In [3]:
vae_config = VAEConfig(C=256, levels=(8, 8, 6, 5), n_features=1)
vq_vae = SoundStream(**vae_config.to_dict())

model_config = MirasolConfig(window_size=512, n_layers=6, 
                             w_latent_loss=1, w_recon_loss=1, 
                             mask_ratio=0, n_registers=8)
brain_model = Mirasol(model_config, vq_vae)

self.codebook_size 1920
self.downsample 8
Shape of the rope cache:  torch.Size([512, 16])
Shape of the causal model:  torch.Size([512, 512])
Full Mirasol model: number of parameters: 90.56M


In [4]:
# x = torch.zeros(1, 512, 256)
# loss, logits = brain_model(x) 


## Init language model

In [5]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer

# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-small.en")

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small.en", language="english", task="transcribe")


model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small.en", 
                                                        # apply_spec_augment=0.1,
                                                        decoder_layerdrop=0.0, 
                                                        # encoder_layerdrop=0.1, 
                                                        dropout=0.0,
                                                        activation_dropout=0.0)
model.config.forced_decoder_ids = None

tokenizer = processor.tokenizer
whisper = {'decoder': model.model.decoder, 
           'proj_out': model.proj_out, 
           'tokenizer': tokenizer}



## Init Franky

In [6]:
model = Franky(brain_model=brain_model, llm_model=whisper) 
model.train()

vq_weights = "/drive/logs/kovalev/vq_brain/medium_14M_256ws_8x_2000/step_78000_loss_0.0275.safetensors"
load_model_weights(model.brain_model.vq_model, vq_weights)

## let's freeeze 
# vqvae
freeze_module(model.brain_model.vq_model)

# # decoder
# freeze_module(model.llm_decoder.embed_tokens)
freeze_module(model.llm_decoder.embed_positions)
# freeze_module(model.proj_out)

count_parameters(model)

Full Franky: number of parameters: 244.56M
load compiled weights
Total: 244.56M, Trainable: 230.81M


(244555141, 230813568)

In [7]:
window_size = model_config.window_size
n_electrodes = 256
max_tokens = 25

train_transform = A.Compose([
    
    # A.CoarseDropout(fill_value=0, p=0.5),
    # A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=0.5),
    # A.GaussNoise(var_limit=0.005, mean=0, p=0.5),

    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    # A.RandomCrop(height=window_size, width=n_electrodes, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True),

])

test_transform = A.Compose([
    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True)
])


data_path = Path("/drive/data/competitionData")

tokenize_function =  lambda text: tokenizer(text)['input_ids']

train_dataset = BrainDataset(data_path / 'train', tokenize_function=tokenize_function, transform=train_transform, max_tokens=max_tokens)
test_dataset = BrainDataset(data_path / 'test', tokenize_function=tokenize_function, transform=test_transform, max_tokens=max_tokens)

Runed processing of the  /drive/data/competitionData/train


Processing files: 100%|██████████| 24/24 [00:42<00:00,  1.78s/file]


len of the dataset: 8800
max signal size: 906 | max tokens size: 25
median signal size: 297.0 | median tokens size: 11.0
Runed processing of the  /drive/data/competitionData/test


Processing files: 100%|██████████| 24/24 [00:06<00:00,  3.81file/s]

len of the dataset: 880
max signal size: 919 | max tokens size: 22
median signal size: 283.5 | median tokens size: 10.0





In [8]:
# test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
train_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

x, y, date = next(iter(train_dataloader))
print(x.shape, y.shape, date.shape)
print('input text', y)

loss, logits = model(x, y, date)

# plt.imshow(x.detach()[0].T)
# plt.show()

torch.Size([1, 512, 256]) torch.Size([1, 25]) torch.Size([1, 1])
input text tensor([[50257, 50362,   464, 17818, 23898,  3089,    13, 50256,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100]])


# Work with data

In [None]:
project_name = 'franky'
save_folder = Path("/drive/logs/kovalev")

train_config = TrainConfig(exp_name='mirasol-whisper-small-full-all-loss-embeddings',
                           mixed_precision=True,
                           batch_size=8, 
                           grad_accum=8,
                           num_workers=3,
                           pin_memory=True, 
                           eval_interval=1000, 
                           learning_rate=1e-4,
                           weight_decay=0.0, 
                           grad_clip=10,
                           lr_decay_iters=20_000, 
                           warmup_iters=1000, 
                           project_name=project_name, 
                           save_folder=save_folder
                          )
# indices = torch.arange(16)
# train_dataset = torch.utils.data.Subset(test_dataset, indices)
# test_dataset = torch.utils.data.Subset(test_dataset, indices)

args = (model, (train_dataset, test_dataset), train_config, model_config)
notebook_launcher(run_train_model, args, num_processes=1)

Launching training on one GPU.


dataloader_config = DataLoaderConfiguration(split_batches=True)
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m. Use [1m`wandb login --relogin`[0m to force relogin


Device for training:  cuda
Num devices:  1
Completed initialization of scheduler
******************************************************************************************