In [1]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch
import torch.nn.functional as F
from torch import nn

import safetensors
from accelerate import notebook_launcher

import einops

from utils.data_utils import BrainDataset, get_tokenizer
from utils.train_utils import TrainConfig, run_train_model, count_parameters, simple_train_model


from models.vq_brain import SoundStream
from transformers import GPT2Tokenizer


In [2]:
config = dict(C=256, D=64, codebook_size=1024, n_electrodes=256)
model = SoundStream(**config)
count_parameters(model)


x = torch.zeros(16, 768, 256)
loss, pred = model(x)
pred.shape

model.get_quantize_vectors(x)[1].shape

Total: 14.17M, Trainable: 14.17M


torch.Size([16, 192, 64])

### Run training pipeline

In [None]:
project_name = 'vq_brain'

train_config = TrainConfig(exp_name='simple_voltage_15M',
                           mixed_precision=True, 
                           batch_size=512, 
                           num_workers=3, 
                           pin_memory=True, 
                           eval_interval=250)
# peter path
# data_path = Path(r'C:\Users\peter\alvi\brain2text\competitionData')

# data_path = Path(r'D:\Work\brain-to-text-competition\data\competitionData')

data_path = Path("/drive/data/competitionData")
save_folder = Path("/drive/logs/kovalev")



tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
train_dataset = BrainDataset(data_path / 'train', tokenize_function=get_tokenizer(tokenizer))
test_dataset = BrainDataset(data_path / 'test', tokenize_function=get_tokenizer(tokenizer))

args = (model, (train_dataset, test_dataset), train_config, project_name, save_folder)
notebook_launcher(run_train_model, args, num_processes=1)



Runed processing of the  /drive/data/competitionData/train
Runed processing of the  /drive/data/competitionData/test
Launching training on one GPU.


dataloader_config = DataLoaderConfiguration(split_batches=True)
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m. Use [1m`wandb login --relogin`[0m to force relogin


Device for training:  cuda
Num devices:  1
Completed initialization of scheduler
**********************************************************************************************************************************************************************************************************************************************************overall_steps 250: 0.6142138838768005
val loss: 0.6136208772659302
saved model:  step_250_loss_0.6136.safetensors
**********************************************************************************************************************************************************************************************************************************************************overall_steps 500: 0.6076383590698242
val loss: 0.610163688659668
saved model:  step_500_loss_0.6102.safetensors
************************************************************************************************************************************************************************************************

In [None]:
x = train_dataset[0][0]

In [None]:
x = torch.from_numpy(x[None, :]).to('cuda')

In [None]:
y = model(x)[1]

x = x.detach().cpu()[0]
y = y.detach().cpu()[0]

In [None]:
import matplotlib.pyplot as plt 

In [None]:
plt.imshow(y.detach().cpu()[0].T, aspect='auto')

In [None]:
plt.imshow(x.detach().cpu()[0].T, aspect='auto')