In [None]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch
import torch.nn.functional as F
from torch import nn

import safetensors
from accelerate import notebook_launcher

from utils.data_utils import BrainDataset, get_tokenizer
from utils.train_utils import TrainConfig, run_train_model, count_parameters, load_model_weights

from dataclasses import dataclass
from simple_parsing.helpers import Serializable

from safetensors.torch import load_model
import albumentations as A

import matplotlib.pyplot as plt

In [None]:
from models.brain_mae import EncoderConfig, Encoder, MAE, MAEConfig

## Init brain module

In [None]:
encoder_config = EncoderConfig()
mae_config = MAEConfig()
model = MAE(encoder_config=encoder_config, mae_config=mae_config)


## Init language model

In [None]:
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')

In [None]:
window_size = encoder_config.window_size
n_electrodes = 256

train_transform = A.Compose([
    
    # A.CoarseDropout(fill_value=0, p=0.5),
    # A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=0.5),
    # A.GaussNoise(var_limit=0.005, mean=0, p=0.5),

    # A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
    #               border_mode=0, value=0, always_apply=True),
    A.RandomCrop(height=window_size, width=n_electrodes, always_apply=True),
    # A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True),

])

test_transform = A.Compose([
    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True)
])


data_path = Path("/drive/data/competitionData")

train_dataset = BrainDataset(data_path / 'train', tokenize_function=get_tokenizer(tokenizer), transform=train_transform)
test_dataset = BrainDataset(data_path / 'test', tokenize_function=get_tokenizer(tokenizer), transform=test_transform)

In [None]:
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

x, y, date = next(iter(test_dataloader))
loss, logits = model(x, y, date)

print(x.shape, y.shape, date)
plt.imshow(x.detach()[0].T)
plt.show()

# Work with data

In [None]:
project_name = 'brain-mae'
save_folder = Path("/drive/logs/kovalev")

train_config = TrainConfig(exp_name='mae-test-one',
                           mixed_precision=True,
                           batch_size=128, 
                           grad_accum=8,
                           num_workers=3,
                           pin_memory=True, 
                           eval_interval=250, 
                           learning_rate=1e-3,
                           weight_decay=0, 
                           grad_clip=5,
                           lr_decay_iters=20_000, 
                           warmup_iters=1000, 
                           project_name=project_name, 
                           save_folder=save_folder
                          )


args = (model, (train_dataset, test_dataset), train_config, encoder_config)
notebook_launcher(run_train_model, args, num_processes=1)