In [1]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch
import torch.nn.functional as F
from torch import nn
import safetensors


import einops

from models import brainformer
from utils.data_utils import BrainDataset, get_tokenizer
from utils.train_utils import TrainConfig, run_train_model, count_parameters, simple_train_model


from models.brainformer import Encoder, CrossBlock, build_complex_rope_cache, Config


In [2]:
from transformers import GPT2Tokenizer
from models.gpt2_model import GPT
import tiktoken
from contextlib import nullcontext
from accelerate import notebook_launcher


In [3]:
# from accelerate.utils import write_basic_config

# write_basic_config()  # Write a config file
# os._exit(00)  # Restart the notebook

In [4]:
class BrainEncoder(nn.Module): 
    config = Config
    def __init__(self, config: Config):
        super().__init__()

        self.config = config
        self.encoder = Encoder(config.encoder)
        self.n_output_tokens = config.n_output_tokens

        self.learnable_queries = nn.Parameter(torch.zeros(1, config.n_output_tokens, config.dim))
        self.perceiver = nn.ModuleDict(dict(
                h = nn.ModuleList([CrossBlock(config) for _ in range(config.n_layers)]),
                ln_f = nn.LayerNorm(config.dim), 
                to_words = nn.Linear(config.dim, config.output_dim))
        )
        
        self.register_buffer('cross_attn_mask', None)
        self.register_buffer('self_attn_mask', None)

        self.precompute_rope_cash = build_complex_rope_cache(dim=config.head_dim,
                                                             seq_len=config.n_output_tokens,
                                                             theta=config.rope_theta)
        # self.cross_entropy = torch.nn.CrossEntropyLoss()

        print("Full HandFormer: number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params
    
    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device
    
    @property
    def rope_cache(self) -> torch.Tensor:
        # Just to use proper device.
        if self.precompute_rope_cash.device != self.device:
            self.precompute_rope_cash = self.precompute_rope_cash.to(device=self.device)
        return self.precompute_rope_cash                
    
    def forward(self, x, targets=None, date_info=None):
        """
        Get forward pass with loss calculation.
        Inputs: 
        x
            shape b t c 
        targets:
            B, C, T
        """
        b, t, c = x.shape

        emg_context = self.encoder(x) # b, n_tokens, dim
        
        input = self.learnable_queries.expand(b, self.n_output_tokens, -1)
        
        for cross_block in self.perceiver.h:
            input = cross_block(input, emg_context, self.self_attn_mask, 
                                self.cross_attn_mask, sa_rope = self.rope_cache)
        
        logits = self.perceiver.ln_f(input)
        logits = self.perceiver.to_words(logits)

        return logits

In [5]:
class Franky(nn.Module): 
    """This is first model which incorporate brain features into LLM"""

    def __init__(self, brain_model, llm_model, tokenizer=None):
        super().__init__()

        self.brain_model = brain_model
        self.llm_model= llm_model
        self.tokenizer = tokenizer
        
        print("Full Franky: number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params
    
    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device

    def forward(self, x, targets=None, date_info=None):
        """
        Train model.
        """
        features = self.brain_model(x)

        new_idx = targets.clone()
        new_idx[new_idx == -100] = 50256

        loss, logits = self.llm_model.forward(idx=new_idx, prefix=features, targets=targets)

        return loss, logits
    
    def generate(self, x, date_info=None):
        
        prefix = self.brain_model(x)
        
        start = '<|endoftext|>'
        input_ids = self.tokenizer(start,  return_tensors="pt")['input_ids']
        input_ids = input_ids.to(self.device)
        
        max_new_tokens = 15
        temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
        top_k = 20

        with torch.no_grad():
            y = self.llm_model.generate(x, max_new_tokens, prefix=prefix, temperature=temperature, top_k=top_k)

        return y

In [6]:
device = 'cuda'
dtype = torch.float32

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

llm_model = GPT.from_pretrained('gpt2', dict(dropout=0.0))
llm_model.eval()

for param in llm_model.parameters():
    param.requires_grad = False


mae_config = brainformer.MAEConfig(window_size=768, patch_size=32)
config = brainformer.Config(encoder=mae_config,
                            n_output_tokens=32,
                            output_dim=llm_model.config.n_embd
                            )
brain_model = BrainEncoder(config)


### Create Franky model
model = Franky(brain_model=brain_model, llm_model=llm_model)
model.train().to(torch.float32).to(device)

print('Initing of the Franky completed')

count_parameters(model)



loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
overriding dropout rate to 0.0
number of parameters: 123.65M
Encoder: number of parameters: 4.27M
Shape of casual mask:  torch.Size([6144, 6144])
Shape of the rope cache:  torch.Size([6144, 16])
Full HandFormer: number of parameters: 6.32M
Full Franky: number of parameters: 130.76M
Initing of the Franky completed
Total: 130.76M, Trainable: 6.32M


(130757888, 6318080)

In [7]:
import safetensors

weights = Path("/drive/logs/kovalev/franky_gpt2_retrain/step_5000_loss_3.1739.safetensors")

safetensors.torch.load_model(model, weights)

model.brain_model.state_dict()['learnable_queries']

tensor([[[ 0.0229,  0.0101, -0.0027,  ..., -0.0247,  0.0153,  0.0005],
         [ 0.0670, -0.0147, -0.0068,  ..., -0.0044, -0.0117, -0.0123],
         [ 0.0188, -0.0051, -0.0379,  ...,  0.0185, -0.0410, -0.0402],
         ...,
         [ 0.0058,  0.0146, -0.0130,  ...,  0.0190, -0.0231,  0.0086],
         [-0.0113,  0.0127,  0.0033,  ...,  0.0024,  0.0237,  0.0085],
         [ 0.0214, -0.0044, -0.0212,  ...,  0.0004, -0.0257, -0.0230]]],
       device='cuda:0')

In [8]:
data_path = Path("/drive/data/competitionData")
train_dataset = BrainDataset(data_path / 'train', tokenize_function=get_tokenizer(tokenizer))
test_dataset = BrainDataset(data_path / 'test', tokenize_function=get_tokenizer(tokenizer))

Runed processing of the  /drive/data/competitionData/train
bad_samples [31, 32, 33, 37, 40, 41, 42, 43, 44, 47, 51, 56, 59, 61, 64, 68, 69, 79, 81, 88, 91, 92, 100, 101, 102, 103, 109, 113, 116, 119, 139, 141, 142, 148, 163, 166, 175, 183, 236, 244, 270, 276, 282, 323, 334, 359, 430, 470, 484, 488, 492, 493, 498, 500, 506, 522, 623, 626]
Runed processing of the  /drive/data/competitionData/test
bad_samples [15, 17, 18, 22]


In [13]:
sample = train_dataset[2]

x = sample[0]
x = torch.from_numpy(x[None, ]).to(device)

        
start = '<|endoftext|>'
input_ids = tokenizer(start,  return_tensors="pt")['input_ids']
input_ids = input_ids.to(device)

max_new_tokens = 25
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 10

for k in range(2):
    with torch.no_grad():
        prefix = model.brain_model(x)

        y = model.llm_model.generate_beam_search(input_ids, max_new_tokens, prefix=prefix, temperature=temperature, beam_width=5)

        idxs = (y == 50256).nonzero()

        start, end = idxs[0].item() + 1, idxs[1].item()
        idxs_clean = y[start:end]
        pred = tokenizer.decode(idxs_clean, skip_special_tokens=False)

        print(pred)


labels = sample[1]
labels[labels == -100] = tokenizer.bos_token_id

idxs = (labels == 50256).nonzero()[0]
start, end = idxs[0] + 1, idxs[1]

labels = labels[start: end]

gt = tokenizer.decode(labels)

print('-------')
print(gt)

beam_scores tensor([-2.0941, -2.0941, -2.0941, -2.0941, -2.0941], device='cuda:0')
beam_scores tensor([-6.8935, -6.8935, -7.0580, -7.0580, -7.2391], device='cuda:0')
beam_scores tensor([-8.7491, -8.7639, -8.7639, -8.7710, -8.7710], device='cuda:0')
beam_scores tensor([ -9.3007,  -9.3007, -10.1013, -10.2525, -10.8006], device='cuda:0')
beam_scores tensor([-12.0535, -12.0716, -12.0716, -12.0919, -12.1687], device='cuda:0')
beam_scores tensor([-14.0737, -14.2331, -14.6925, -14.7008, -14.7539], device='cuda:0')
beam_scores tensor([-14.0741, -15.3781, -15.9495, -15.9592, -17.0501], device='cuda:0')
beam_scores tensor([-15.5542, -15.9872, -16.0988, -16.4851, -16.6093], device='cuda:0')
beam_scores tensor([-16.3177, -17.0225, -17.2413, -18.4863, -18.6111], device='cuda:0')
beam_scores tensor([-17.4008, -17.6843, -18.5156, -18.9971, -19.2119], device='cuda:0')
beam_scores tensor([-18.8334, -20.4497, -20.9728, -21.1354, -21.4511], device='cuda:0')
beam_scores tensor([-20.6763, -21.3180, -22.145

torch.Size([26])

In [None]:
# start = '<|endoftext|>i love you so much <|endoftext|>'

# input_ids = tokenizer(start,  return_tensors="pt")['input_ids']
# input_ids = input_ids.to(device)

# brain_activity = torch.randn(1, 768, 256, dtype=torch.float32, device=device)

# loss, _ = model.forward(brain_activity, targets=input_ids)

# print(loss)

### Run training pipeline

In [None]:
project_name = 'frankenstein'

train_config = TrainConfig(exp_name='franky_gpt2_retrain',
                           mixed_precision=True, 
                           batch_size=32, 
                           num_workers=3, 
                           pin_memory=True, 
                           eval_interval=500)
# peter path
# data_path = Path(r'C:\Users\peter\alvi\brain2text\competitionData')
data_path = Path("/drive/data/competitionData")
save_folder = Path("/drive/logs/kovalev")



train_dataset = BrainDataset(data_path / 'train', tokenize_function=get_tokenizer(tokenizer))
test_dataset = BrainDataset(data_path / 'test', tokenize_function=get_tokenizer(tokenizer))


args = (model, (train_dataset, test_dataset), train_config, project_name, save_folder)
notebook_launcher(run_train_model, args, num_processes=1)

# simple_train_model(*args)
