In [None]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch
import torch.nn.functional as F
from torch import nn

import safetensors
from accelerate import notebook_launcher

from utils.data_utils import BrainDataset, get_tokenizer
from utils.train_utils import TrainConfig, run_train_model, count_parameters, load_model_weights

from dataclasses import dataclass
from simple_parsing.helpers import Serializable

from safetensors.torch import load_model
import albumentations as A

import matplotlib.pyplot as plt

In [None]:
from models.bert import BrainBert, BertConfig
from models.vq_brain_per_channel import SoundStream, VAEConfig
from models.frankenstein import Franky

## Init brain module

In [None]:
vae_config = VAEConfig(C=256, levels=(8, 8, 6, 5))
vq_vae = SoundStream(**vae_config.to_dict())

model_config = BertConfig(dim=256, 
                 window_size=512, 
                 tokenizer_downsample=int(vq_vae.downsample),
                 n_electrodes=256, 
                 mask_ratio=0.0, 
                 n_layers=12, 
                 n_heads=12, 
                 n_kv_heads=12)

bert = BrainBert(model_config, vq_vae)

## Init language model

In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
gpt2 = GPT2LMHeadModel.from_pretrained('gpt2-medium', add_cross_attention=True)

## Init Franky

In [None]:
@dataclass
class CombinerConfig(Serializable):
    # data params
    n_registers: int = 2
    n_layers: int = 6
    dim: int = 256
    hidden_dim: int = 1024

    head_dim: int = 32
    n_heads: int = 8
    n_kv_heads: int = 8

@dataclass
class CausalModelConfig(Serializable):

    block_size: int = 0
    rope_theta: float = 10000.0

    # data params
    n_layers: int = 6
    dim: int = 256
    hidden_dim: int = 1024
    dropout: float = 0.0
    
    head_dim: int = 32
    n_heads: int = 8
    n_kv_heads: int = 8
    calculate_loss: bool = False

In [None]:
combiner_config = CombinerConfig()
causal_config = CausalModelConfig()

model = Franky(combiner_config, causal_config, bert, gpt2, tokenizer)

bert_weights = "/drive/logs/kovalev/bert/new_vqvae_8x_2000_large_bert_30M_0_25/step_6500_loss_1.4677.safetensors"
model.brain_model = load_model_weights(model.brain_model, bert_weights)

# let's freeeze weights 
for param in model.brain_model.parameters():
    param.requires_grad = False

for param in model.llm_model.parameters():
    param.requires_grad = False

for np, p in model.llm_model.transformer.named_parameters():
    if "cross" in np:
        p.requires_grad = True

# for param in model.llm_model.transformer.wte.parameters():
#     param.requires_grad = False
    
# for param in model.llm_model.transformer.wpe.parameters():
#     param.requires_grad = False

# for param in model.llm_model.lm_head.parameters():
#     param.requires_grad = False

count_parameters(model)
count_parameters(model.llm_model)

In [None]:
# count_parameters(model)

In [None]:
window_size = model_config.window_size
n_electrodes = 256

train_transform = A.Compose([
    
    # A.CoarseDropout(fill_value=0, p=0.5),
    # A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=0.5),
    # A.GaussNoise(var_limit=0.005, mean=0, p=0.5),

    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    # A.RandomCrop(height=window_size, width=n_electrodes, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True),

])

test_transform = A.Compose([
    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True)
])


data_path = Path("/drive/data/competitionData")

train_dataset = BrainDataset(data_path / 'train', tokenize_function=get_tokenizer(tokenizer), transform=train_transform)
test_dataset = BrainDataset(data_path / 'test', tokenize_function=get_tokenizer(tokenizer), transform=test_transform)

In [None]:
# for idx in [0,100, 200, 300, 400, 500, 600]:
#     sample = train_dataset[idx]
#     x, y = sample[0], sample[1]
#     print(y)
#     plt.imshow(x.T)
#     plt.show()
    

In [None]:
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

x, y, date = next(iter(test_dataloader))
loss, logits = model(x, y, date)

print(x.shape, y.shape, date)
plt.imshow(x.detach()[0].T)
plt.show()

# Work with data

In [None]:
project_name = 'frankenstein'
save_folder = Path("/drive/logs/kovalev")

train_config = TrainConfig(exp_name='franky-3-bert-gpt2-medium',
                           mixed_precision=True,
                           batch_size=128, 
                           grad_accum=8,
                           num_workers=3,
                           pin_memory=True, 
                           eval_interval=250, 
                           learning_rate=1e-3,
                           weight_decay=0, 
                           grad_clip=5,
                           lr_decay_iters=20_000, 
                           warmup_iters=1000, 
                           project_name=project_name, 
                           save_folder=save_folder
                          )


args = (model, (train_dataset, test_dataset), train_config, model_config)
notebook_launcher(run_train_model, args, num_processes=1)

In [None]:
save_folder = Path("/drive/logs/kovalev")

safetensors.torch.save_model(model, save_folder / 'model_mae_long_train.safetensors')

In [None]:
model

In [None]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

In [None]:
model = model._orig_mod

In [None]:
mae_weights = "/drive/logs/kovalev/fixed_abs_mae_11M_5M_spikes/step_34000_loss_0.0166.safetensors"

mae_model = SimpleMAE( SimpleEncoderConfig(), SimpleMAEConfig())
mae_model = torch.compile(mae_model)
safetensors.torch.load_model(mae_model, mae_weights)
mae_model = mae_model._orig_mod

brain_model = mae_model.encoder

In [None]:
x = test_dataset[0][0]
device = 'cpu'
x = torch.from_numpy(x[None, :]).to(device)
print(x.shape)
loss, y, binary = model(x, masking_ratio=0.75,  return_preds=True)

x = x.detach().cpu()[0]
y = y.detach().cpu()[0]
binary = binary.detach().cpu()[0]

import matplotlib.pyplot as plt 

# plt.plot(y[1, :])
# plt.show()
# plt.plot(x[1, :])
# plt.show()
loss = F.l1_loss(y, x, reduction='none')
print(torch.mean(loss[loss>0]))

plt.show()
plt.imshow(binary[:64].T, aspect='auto')
# plt.colorbar()
plt.show()
plt.imshow(torch.clip(y[:64].T, 0, 1), aspect='auto')
plt.colorbar()
plt.show()
plt.imshow(x[:64].T, aspect='auto' )
plt.colorbar()


In [None]:
plt.show()
plt.plot(binary[:32, 10])
plt.show()
plt.plot(y[:32, 10])
# plt.show()
plt.plot(x[:32, 10])