In [1]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))


import safetensors
import torch
import torch.nn.functional as F
from accelerate import notebook_launcher
from einops import rearrange
from einops.layers.torch import Rearrange
from simple_parsing import ArgumentParser
import einops

from models import brainformer
from utils.data_utils import BrainDataset, get_tokenizer
from utils.train_utils import TrainConfig, run_train_model, count_parameters

from torch import nn
from models.brainformer import Encoder, CrossBlock, build_complex_rope_cache, Config


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class BrainFormer(nn.Module): 
    config = Config
    def __init__(self, config: Config):
        super().__init__()

        self.config = config
        self.encoder = Encoder(config.encoder)
        self.n_output_tokens = config.n_output_tokens

        self.learnable_queries = nn.Parameter(torch.zeros(1, config.n_output_tokens, config.dim))
        self.perceiver = nn.ModuleDict(dict(
                h = nn.ModuleList([CrossBlock(config) for _ in range(config.n_layers)]),
                ln_f = nn.LayerNorm(config.dim), 
                to_words = nn.Linear(config.dim, config.output_dim))
        )
        
        self.register_buffer('cross_attn_mask', None)
        self.register_buffer('self_attn_mask', None)

        self.precompute_rope_cash = build_complex_rope_cache(dim=config.head_dim,
                                                             seq_len=config.n_output_tokens,
                                                             theta=config.rope_theta)
        # self.cross_entropy = torch.nn.CrossEntropyLoss()

        print("Full HandFormer: number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params
    
    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device
    
    @property
    def rope_cache(self) -> torch.Tensor:
        # Just to use proper device.
        if self.precompute_rope_cash.device != self.device:
            self.precompute_rope_cash = self.precompute_rope_cash.to(device=self.device)
        return self.precompute_rope_cash                
    
    def forward(self, x, targets=None, date_info=None):
        """
        Get forward pass with loss calculation.
        Inputs: 
        x
            shape b t c 
        targets:
            B, C, T
        """
        b, t, c = x.shape

        emg_context = self.encoder(x) # b, n_tokens, dim
        
        input = self.learnable_queries.expand(b, self.n_output_tokens, -1)
        
        for cross_block in self.perceiver.h:
            input = cross_block(input, emg_context, self.self_attn_mask, 
                                self.cross_attn_mask, sa_rope = self.rope_cache)
        
        logits = self.perceiver.ln_f(input)
        logits = self.perceiver.to_words(logits)
    
        if targets is None:
            return None, logits
        
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss, logits
    
    @torch.no_grad()    
    def inference(self, myo, date_info):
        """
        x (signal) - Time, Channel
        OUTPUTS - Time//8, N_BONES
        """
        x = torch.from_numpy(myo)
        t, c = x.shape
        x = rearrange(x, 't c -> 1 t c', t=t, c=c)
        x = x.to(self.device).to(self.dtype)

        pred = self.forward(x, targets=None)
        pred = pred[0].detach().cpu().numpy()

        return pred.T

In [5]:
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenize_function = get_tokenizer(tokenizer)



In [6]:
project_name = 'brainformer'

train_config = TrainConfig(exp_name='brainformer_simple', 
                           mixed_precision=False, 
                           batch_size=16)
# peter path
data_path = Path(r'C:\Users\peter\alvi\brain2text\competitionData')
# data_path = Path(r"D:\Work\brain-to-text-competition\data\competitionData")


train_dataset = BrainDataset(data_path / 'test', tokenize_function=tokenize_function)
test_dataset = BrainDataset(data_path / 'test', tokenize_function=tokenize_function)

# Init model
mae_config = brainformer.MAEConfig(window_size=768)
config = brainformer.Config(encoder=mae_config, n_output_tokens=25, output_dim=tokenizer.vocab_size)

model = BrainFormer(config)
count_parameters(model)

args = (model, (train_dataset, test_dataset), train_config, project_name)
run_train_model(*args)

Runed processing of the  C:\Users\peter\alvi\brain2text\competitionData\test
bad_samples [15, 17, 18, 22]
Runed processing of the  C:\Users\peter\alvi\brain2text\competitionData\test
bad_samples [15, 17, 18, 22]
Encoder: number of parameters: 8.47M
Shape of casual mask:  torch.Size([6144, 6144])
Shape of the rope cache:  torch.Size([6144, 16])
Full HandFormer: number of parameters: 23.23M
Total: 23.23M, Trainable: 23.23M


dataloader_config = DataLoaderConfiguration(split_batches=True)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpeter_chizhov[0m. Use [1m`wandb login --relogin`[0m to force relogin


Device for training:  cuda
Num devices:  1
Completed initialization of scheduler


  res = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)


OutOfMemoryError: CUDA out of memory. Tried to allocate 768.00 MiB. GPU 

In [None]:
len(train_dataset.targets_tokens[0])

25

In [None]:
btch = next(iter(torch.utils.data.DataLoader(train_dataset, batch_size = 2)))

In [None]:
print(type(train_dataset.__getitem__(0)[0]))
train_dataset.__getitem__(0)[0][:, 0]

<class 'numpy.ndarray'>


array([0.07447474, 0.02272561, 0.03308861, 0.05265393, 0.05520429,
       0.03433828, 0.01415561, 0.0619989 , 0.04668124, 0.0270659 ,
       0.09160464, 0.03165064, 0.10037031, 0.19572735, 0.04219091,
       0.0489976 , 0.04627946, 0.03524334, 0.24202722, 0.06736387,
       0.22391742, 0.05963945, 0.03177392, 0.27899384, 0.08874962,
       0.06798954, 0.07567943, 0.30999067, 0.03322073, 0.22823244,
       0.04652578, 0.21547431, 0.06268316, 0.12725124, 0.04267478,
       0.08866527, 0.19836548, 0.04883863, 0.2253061 , 0.08347657,
       0.06197602, 0.25438076, 0.06890433, 0.09326004, 0.21797293,
       0.0852389 , 0.05124364, 0.21691602, 0.04768499, 0.03850894,
       0.07204329, 0.1576474 , 0.05431377, 0.10416052, 0.16920635,
       0.2675422 , 0.05770864, 0.05367622, 0.09641687, 0.05243768,
       0.03653679, 0.03714343, 0.05742719, 0.04350604, 0.0671173 ,
       0.03214645, 0.07087535, 0.22684342, 0.07689945, 0.06472557,
       0.06716719, 0.0294333 , 0.07210251, 0.05269102, 0.03825

In [None]:
train_dataset.__getitem__(0)

(array([[0.07447474, 0.02260434, 0.25687808, ..., 0.21139844, 0.04803241,
         0.0345922 ],
        [0.02272561, 0.03327404, 0.        , ..., 0.18408735, 0.02911673,
         0.02720368],
        [0.03308861, 0.04182241, 0.24309018, ..., 0.13741817, 0.08524861,
         0.01848135],
        ...,
        [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
         0.        ],
        [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
         0.        ],
        [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
         0.        ]], dtype=float32),
 [50256,
  464,
  17818,
  23898,
  3089,
  13,
  50256,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100],
 0)

In [None]:
btch[0].shape

torch.Size([2, 768, 256])

In [None]:
btch[1]

[tensor([50256, 50256]),
 tensor([  464, 14868]),
 tensor([17818,  8155]),
 tensor([23898,  1811]),
 tensor([3089, 4488]),
 tensor([   13, 24491]),
 tensor([50256, 33492]),
 tensor([-100,   13]),
 tensor([ -100, 50256]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100]),
 tensor([-100, -100])]