In [13]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))


import safetensors
import torch
import torch.nn.functional as F
from accelerate import notebook_launcher
from einops import rearrange
from einops.layers.torch import Rearrange
from simple_parsing import ArgumentParser

from models import brainformer
from utils.data_utils import BrainDataset, get_tokenizer
from utils.train_utils import TrainConfig, run_train_model, count_parameters

In [5]:
from torch import nn
from models.brainformer import Encoder, CrossBlock, build_complex_rope_cache, Config

In [None]:
import torch.nn.functional as F


In [9]:
class BrainFormer(nn.Module): 
    config = Config
    def __init__(self, config: Config):
        super().__init__()

        self.config = config
        self.encoder = Encoder(config.encoder)
        self.n_output_tokens = config.n_output_tokens

        self.learnable_queries = nn.Parameter(torch.zeros(1, config.n_output_tokens, config.dim))
        self.perceiver = nn.ModuleDict(dict(
                h = nn.ModuleList([CrossBlock(config) for _ in range(config.n_layers)]),
                ln_f = nn.LayerNorm(config.dim), 
                to_words = nn.Linear(config.dim, config.output_dim))
        )
        
        self.register_buffer('cross_attn_mask', None)
        self.register_buffer('self_attn_mask', None)

        self.precompute_rope_cash = build_complex_rope_cache(dim=config.head_dim,
                                                             seq_len=config.n_output_tokens,
                                                             theta=config.rope_theta)

        print("Full HandFormer: number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params
    
    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device
    
    @property
    def rope_cache(self) -> torch.Tensor:
        # Just to use proper device.
        if self.precompute_rope_cash.device != self.device:
            self.precompute_rope_cash = self.precompute_rope_cash.to(device=self.device)
        return self.precompute_rope_cash                
    
    def forward(self, x, targets=None, date_info=None):
        """
        Get forward pass with loss calculation.
        Inputs: 
        x
            shape b t c 
        targets:
            B, C, T
        """
        b, t, c = x.shape

        emg_context = self.encoder(x) # b, n_tokens, dim
        
        input = self.learnable_queries.expand(b, self.n_output_tokens, -1)
        
        for cross_block in self.perceiver.h:
            input = cross_block(input, emg_context, self.self_attn_mask, 
                                self.cross_attn_mask, sa_rope = self.rope_cache)
        
        pred = self.perceiver.ln_f(input)
        pred = self.perceiver.to_words(pred)

        if targets is None:
            return None, pred
        
        loss = F.cross_entropy(pred[:, :-1], targets[:, 1:])
        return loss, pred
    
    @torch.no_grad()    
    def inference(self, myo, date_info):
        """
        x (signal) - Time, Channel
        OUTPUTS - Time//8, N_BONES
        """
        x = torch.from_numpy(myo)
        t, c = x.shape
        x = rearrange(x, 't c -> 1 t c', t=t, c=c)
        x = x.to(self.device).to(self.dtype)

        pred = self.forward(x, targets=None)
        pred = pred[0].detach().cpu().numpy()

        return pred.T

In [14]:
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenize_function = get_tokenizer(tokenizer)

In [17]:
project_name = 'brainformer'

train_config = TrainConfig(exp_name='brainformer_simple', 
                           mixed_precision=False, 
                           batch_size=1)
data_path = Path(r"D:\Work\brain-to-text-competition\data\competitionData")


train_dataset = BrainDataset(data_path / 'test', tokenize_function=tokenize_function)
test_dataset = BrainDataset(data_path / 'test', tokenize_function=tokenize_function)

# Init model
mae_config = brainformer.MAEConfig(window_size=768)
config = brainformer.Config(encoder=mae_config, n_output_tokens=25, output_dim=tokenizer.vocab_size)

model = BrainFormer(config)
count_parameters(model)

args = (model, (train_dataset, test_dataset), train_config, project_name)
run_train_model(*args)

Runed processing of the  D:\Work\brain-to-text-competition\data\competitionData\test
bad_samples [15, 17, 18, 22]
Runed processing of the  D:\Work\brain-to-text-competition\data\competitionData\test
bad_samples [15, 17, 18, 22]
Encoder: number of parameters: 8.47M
Shape of casual mask:  torch.Size([6144, 6144])
Shape of the rope cache:  torch.Size([6144, 16])
Full HandFormer: number of parameters: 23.17M
Total: 23.17M, Trainable: 23.17M


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m. Use [1m`wandb login --relogin`[0m to force relogin


Device for training:  cuda
Num devices:  1
Completed initialization of scheduler


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.12 GiB (GPU 0; 4.00 GiB total capacity; 10.30 GiB already allocated; 0 bytes free; 10.49 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [3]:
config

Config(encoder=MAEConfig(window_size=768, n_electrodes=256, patch_size=32, dim=256, n_layers=8, head_dim=32, hidden_dim=1024, n_heads=8, n_kv_heads=8, rope_theta=10000, n_dec_layers=4, decoder_dim=256), n_output_tokens=25, output_dim=50000, dim=256, n_layers=2, head_dim=16, hidden_dim=512, n_heads=4, n_kv_heads=4, rope_theta=10000)