In [1]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch
import torch.nn.functional as F

import einops

from models import brainformer
from utils.data_utils import BrainDataset, get_tokenizer
from utils.train_utils import TrainConfig, run_train_model, count_parameters

from torch import nn
from models.brainformer import Encoder, CrossBlock, build_complex_rope_cache, Config


In [2]:
from transformers import GPT2Tokenizer
from models.gpt2_model import GPT
import tiktoken
from contextlib import nullcontext


In [3]:
class BrainEncoder(nn.Module): 
    config = Config
    def __init__(self, config: Config):
        super().__init__()

        self.config = config
        self.encoder = Encoder(config.encoder)
        self.n_output_tokens = config.n_output_tokens

        self.learnable_queries = nn.Parameter(torch.zeros(1, config.n_output_tokens, config.dim))
        self.perceiver = nn.ModuleDict(dict(
                h = nn.ModuleList([CrossBlock(config) for _ in range(config.n_layers)]),
                ln_f = nn.LayerNorm(config.dim), 
                to_words = nn.Linear(config.dim, config.output_dim))
        )
        
        self.register_buffer('cross_attn_mask', None)
        self.register_buffer('self_attn_mask', None)

        self.precompute_rope_cash = build_complex_rope_cache(dim=config.head_dim,
                                                             seq_len=config.n_output_tokens,
                                                             theta=config.rope_theta)
        # self.cross_entropy = torch.nn.CrossEntropyLoss()

        print("Full HandFormer: number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params
    
    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device
    
    @property
    def rope_cache(self) -> torch.Tensor:
        # Just to use proper device.
        if self.precompute_rope_cash.device != self.device:
            self.precompute_rope_cash = self.precompute_rope_cash.to(device=self.device)
        return self.precompute_rope_cash                
    
    def forward(self, x, targets=None, date_info=None):
        """
        Get forward pass with loss calculation.
        Inputs: 
        x
            shape b t c 
        targets:
            B, C, T
        """
        b, t, c = x.shape

        emg_context = self.encoder(x) # b, n_tokens, dim
        
        input = self.learnable_queries.expand(b, self.n_output_tokens, -1)
        
        for cross_block in self.perceiver.h:
            input = cross_block(input, emg_context, self.self_attn_mask, 
                                self.cross_attn_mask, sa_rope = self.rope_cache)
        
        logits = self.perceiver.ln_f(input)
        logits = self.perceiver.to_words(logits)

        return logits

In [4]:
class Franky(nn.Module): 
    """This is first model which incorporate brain features into LLM"""

    def __init__(self, brain_model, llm_model, tokenizer=None):
        super().__init__()

        self.brain_model = brain_model
        self.llm_model= llm_model
        self.tokenizer = tokenizer
        
        print("Full Franky: number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params
    
    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device

    def forward(self, x, targets=None, date_info=None):
        """
        Train model.
        """
        features = self.brain_model(x)
        loss, logits = self.llm_model(idx=targets, prefix=features, targets=targets)
        return loss, logits
    
    def generate(self, x, date_info=None):
        
        prefix = self.brain_model(x)
        
        start = '<|endoftext|>'
        input_ids = self.tokenizer(start,  return_tensors="pt")['input_ids']
        input_ids = input_ids.to(self.device)
        
        max_new_tokens = 15
        temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
        top_k = 20

        with torch.no_grad():
            y = self.llm_model.generate(x, max_new_tokens, prefix=prefix, temperature=temperature, top_k=top_k)

        return y

In [5]:
device = 'cuda'
dtype = 'float32'

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

llm_model = GPT.from_pretrained('gpt2', dict(dropout=0.0))

for param in llm_model.parameters():
    param.requires_grad = False


mae_config = brainformer.MAEConfig(window_size=768, patch_size=96)
config = brainformer.Config(encoder=mae_config, 
                            n_output_tokens=32,
                            output_dim=llm_model.config.n_embd
                            )
brain_model = BrainEncoder(config)


### Create Franky model
model = Franky(brain_model=brain_model, llm_model=llm_model)
model.train().to(device)

print('Initing of the Franky completed')

count_parameters(model)

loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
overriding dropout rate to 0.0
number of parameters: 123.65M
Encoder: number of parameters: 4.29M
Shape of casual mask:  torch.Size([2048, 2048])
Shape of the rope cache:  torch.Size([2048, 16])
Full HandFormer: number of parameters: 6.33M
Full Franky: number of parameters: 130.77M
Initing of the Franky completed
Total: 130.77M, Trainable: 6.33M


(130774272, 6334464)

In [6]:
start = '<|endoftext|>i love you so much <|endoftext|>'

input_ids = tokenizer(start,  return_tensors="pt")['input_ids']
input_ids = input_ids.to(device)

brain_activity = torch.randn(1, 768, 256, dtype=torch.float32, device=device)

loss, _ = model.forward(brain_activity, targets=input_ids)

print(loss)

tensor(7.7940, device='cuda:0', grad_fn=<NllLossBackward0>)


### Run training pipeline

In [7]:
project_name = 'frankenstein'

train_config = TrainConfig(exp_name='franky_gpt2',
                           mixed_precision=False, 
                           batch_size=4)
# peter path
# data_path = Path(r'C:\Users\peter\alvi\brain2text\competitionData')
data_path = Path(r"D:\Work\brain-to-text-competition\data\competitionData")


train_dataset = BrainDataset(data_path / 'train', tokenize_function=get_tokenizer(tokenizer))
test_dataset = BrainDataset(data_path / 'test', tokenize_function=get_tokenizer(tokenizer))


args = (model, (train_dataset, test_dataset), train_config, project_name)
run_train_model(*args)

Runed processing of the  D:\Work\brain-to-text-competition\data\competitionData\train
bad_samples [31, 32, 33, 37, 40, 41, 42, 43, 44, 47, 51, 56, 59, 61, 64, 68, 69, 79, 81, 88, 91, 92, 100, 101, 102, 103, 109, 113, 116, 119, 139, 141, 142, 148, 163, 166, 175, 183, 236, 244, 270, 276, 282, 323, 334, 359, 430, 470, 484, 488, 492, 493, 498, 500, 506, 522, 623, 626]
Runed processing of the  D:\Work\brain-to-text-competition\data\competitionData\test
bad_samples [15, 17, 18, 22]


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m. Use [1m`wandb login --relogin`[0m to force relogin


Device for training:  cuda
Num devices:  1
Completed initialization of scheduler
