In [1]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch
import torch.nn.functional as F
from torch import nn
import safetensors
from accelerate import notebook_launcher

import einops
import gc 


from vector_quantize_pytorch import FSQ

import albumentations as A
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from simple_parsing import Serializable


from utils.data_utils import BrainDataset, get_tokenizer
from utils.train_utils import TrainConfig, run_train_model, count_parameters, simple_train_model
from utils.data_utils import process_file_v2

from models.vq_brain_per_channel import SoundStream, VAEConfig
# from transformers import GPT2Tokenizer
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')

model_config = VAEConfig(C=256, levels=(7, 5, 5, 5, 5), n_features=4, stride_list = (2, 1, 2, 1))
model = SoundStream(**model_config.to_dict())

count_parameters(model)

x = torch.zeros(1, 32, 256 * 4)
loss, pred = model(x)

print(model_config)
print(loss)



window_size = 128
n_electrodes = 256 * 4
max_tokens = 25

train_transform = A.Compose([
    
    # A.CoarseDropout(fill_value=0, p=0.5),
    # A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=0.5),
    # A.GaussNoise(var_limit=0.005, mean=0, p=0.5),

    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    A.RandomCrop(height=window_size, width=n_electrodes, always_apply=True),
    A.Normalize(mean=0.5, std=0.5, max_pixel_value=1, always_apply=True)
    # A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True),

])

test_transform = A.Compose([
    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True),
    A.Normalize(mean=0.5, std=0.5, max_pixel_value=1, always_apply=True)
])


data_path = Path("/drive/data/competitionData")

tokenize_function = None
process_file_function = process_file_v2

train_dataset = BrainDataset(data_path / 'train', 
                             process_file_function=process_file_function, 
                             tokenize_function=tokenize_function, 
                             transform=train_transform, 
                             max_tokens=max_tokens)

gc.collect()
test_dataset = BrainDataset(data_path / 'test', 
                            process_file_function=process_file_function, 
                            tokenize_function=tokenize_function, 
                            transform=test_transform, 
                            max_tokens=max_tokens)

gc.collect()

self.codebook_size 4375
self.downsample 4
Total: 17.34M, Trainable: 17.34M
VAEConfig(C=256, n_features=4, levels=(7, 5, 5, 5, 5), stride_list=(2, 1, 2, 1))
{'total_loss': tensor(0.0297, grad_fn=<MeanBackward0>)}
Runed processing of the  /drive/data/competitionData/train


Processing files: 100%|██████████| 24/24 [01:15<00:00,  3.13s/file]


len of the dataset: 8800
max signal size: 906 | max tokens size: 87
median signal size: 297.0 | median tokens size: 31.0
Runed processing of the  /drive/data/competitionData/test


Processing files: 100%|██████████| 24/24 [00:10<00:00,  2.37file/s]

len of the dataset: 880
max signal size: 919 | max tokens size: 86
median signal size: 283.5 | median tokens size: 30.0





0

In [2]:
# test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
train_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

x, y, date = next(iter(train_dataloader))
# print(x.shape, y.shape, date.shape)
print('input text', y)

loss, logits = model(x, y, date)
loss
# plt.imshow(x.detach()[0].T)
# plt.show()

input text ('theocracy reconsidered',)


{'total_loss': tensor(0.9299, grad_fn=<MeanBackward0>)}

In [None]:
project_name = 'vq_brain'
save_folder = Path("/drive/logs/kovalev")

train_config = TrainConfig(exp_name='4_features_17M_4x_4000_128_ws',
                           mixed_precision=True, 
                           batch_size=32,
                           grad_accum=4,
                           num_workers=3, 
                           pin_memory=True, 
                           eval_interval=1000, 
                           learning_rate=3e-4,
                           weight_decay=0, 
                           grad_clip=10, 
                           lr_decay_iters=500_000, 
                           warmup_iters=2000, 
                           max_steps = 200_000, 
                           project_name=project_name, 
                           save_folder=save_folder
                          )

# model = torch.compile(model)
args = (model, (train_dataset, test_dataset), train_config, model_config)
notebook_launcher(run_train_model, args, num_processes=1)

Launching training on one GPU.


dataloader_config = DataLoaderConfiguration(split_batches=True)
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m. Use [1m`wandb login --relogin`[0m to force relogin


Device for training:  cuda
Num devices:  1
Completed initialization of scheduler
*****************************************

In [None]:

model = model._orig_mod

In [None]:
model 

In [None]:
x = test_dataset[500][0]

x = torch.from_numpy(x[None, :]).to('cuda')

# e = model.encoder(x)
indices, quantized = model.get_quantize_vectors(x)



y = model(x)[1]


print(F.l1_loss(y, x))
print(F.mse_loss(y, x))

x = x.detach().cpu()[0]
y = y.detach().cpu()[0]

import matplotlib.pyplot as plt
plt.plot(y[:, 220], label='recon')
# plt.show()
plt.plot(x[:, 220])
plt.legend()
plt.show()
plt.imshow(y[:, :].T, aspect='auto', vmin=0, vmax=1)
plt.show()
plt.imshow(x[:, :].T, aspect='auto', vmin=0, vmax=1)

In [None]:
quantized = quantized.detach().cpu()
indices = indices.detach().cpu()
print(quantized.shape)
plt.imshow(quantized[64].T)

plt.show()
plt.imshow(F.one_hot(indices[64].to(torch.long), 1000).T, aspect='auto')
plt.show()
plt.plot(indices[64])

In [None]:
plt.imshow(quantized[0][:64].T.cpu().detach(), aspect='auto')

In [None]:
plt.plot(indices[0].cpu())