In [1]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from einops.layers.torch import Rearrange

from typing import Optional

from dataclasses import dataclass
from simple_parsing.helpers import Serializable
import time
import numpy as np
## Functions
import matplotlib.pyplot as plt
import albumentations as A


In [2]:
from models.bert import BrainBert, BertConfig
from models.vq_brain_per_channel import SoundStream, VAEConfig
from models.blocks import Block, build_complex_rope_cache, RMSNorm, build_advanced_causal_mask

In [57]:
class Franky(nn.Module): 
    """This is first model which incorporate brain features into LLM"""

    def __init__(self, combiner_config, causal_config, brain_model, llm_model, tokenizer=None):
        super().__init__()
        self.brain_model = brain_model
        self.llm_model = llm_model
        self.tokenizer = tokenizer


        self.combiner_config = combiner_config
        self.causal_config = causal_config

        self.n_electrodes = brain_model.config.n_electrodes
        self.window_size = self.brain_model.config.window_size
        self.block_size = int(self.window_size / self.brain_model.tokenizer.downsample * self.combiner_config.n_registers)
        self.dim = combiner_config.dim
        
        
        self.combiner_model = nn.Sequential(*[Block(combiner_config) for _ in range(combiner_config.n_layers)])
        
        self.combiner_pos_embeddings = nn.Parameter(torch.randn(1, self.n_electrodes, combiner_config.dim))
        
        # Causal model
        # init new rope cache for working with several registers and overwriting old one
        # we have to repeat values, because n_registers have same time step
        causal_config.block_size = self.block_size
        
        self.causal_model = CausalModel(causal_config)
        self.causal_model.attn_mask = build_advanced_causal_mask(self.block_size, self.combiner_config.n_registers)
        old_rope = self.causal_model.precompute_rope_cash
        self.causal_model.precompute_rope_cash = old_rope.repeat_interleave(self.combiner_config.n_registers, dim=0)

        
        self.projector = nn.Linear(brain_model.config.dim, llm_model.config.n_embd)
        

        self.date_embeddings = nn.Embedding(num_embeddings=25, embedding_dim=llm_model.config.n_embd)
        
        print("Full Franky: number of parameters: %.2fM" % (self.get_num_params()/1e6,))


    def combine_features(self, x):
        """
        Combine features from different channels into several vectors.
        x: b, t, c, d
        """
        print(x.shape)
        b, t, c, d = x.size()
        x = rearrange(x, 'b t c d -> (b t) c d', b=b, t=t, c=self.n_electrodes, d=self.dim)

        x = x + self.combiner_pos_embeddings
        tokens = self.combiner_model(x)

        tokens = tokens[:, :self.combiner_config.n_registers]
        tokens = rearrange(tokens, '(b t) c d -> b (c t) d', b=b, t=t, c=self.combiner_config.n_registers, d=self.brain_model.dim)
        return tokens 
    
    
    def forward(self, x, targets=None, date_info=None):
        """
        Train model.
        x: B, T, C
        """
        is_padded = (x==0).all(dim=-1) # B, T
        is_padded = is_padded[:, ::4]

        _, x = self.brain_model(x) # b, t, c, d

        x = self.combine_features(x)
        pred_latent = self.causal_model(x)

        # Also we have to add padded tokens here. and do not calculate metrics on them.
        # future_loss = F.mse_loss(pred_latent[:, :-self.combiner_config.n_registers], x[:, :-self.combiner_config.n_registers])
        
        features = self.projector(x)
        print('brain features shape', features.shape)
        print('is_padded', is_padded.shape)
        

        # date_embedding = self.date_embeddings(date_info)
        # x = torch.cat([x, date_embedding], dim=-1)
        
        new_idx = targets.clone()
        new_idx[new_idx == -100] = self.tokenizer.eos_token_id


        outputs = self.llm_model(input_ids=new_idx[:, :-1], 
                                 labels=targets[:, 1:], 
                                 encoder_hidden_states=features, 
                                 encoder_attention_mask=~is_padded)
        
        return outputs.loss, outputs.logits

    
    @torch.no_grad()
    def generate(self, x, date_info=None, tokenizer=None):
        device = self.device
        
        x = torch.from_numpy(x[None, ]).to(device).to(self.dtype)
        features = self.brain_model(x)
        
        ### Text part
        start = tokenizer.bos_token
        input_ids = tokenizer(start,  return_tensors="pt")['input_ids'].to(self.device)

        # max_new_tokens = 25
        # temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
        # top_k = 10

        res = self.llm_model.generate(input_ids=input_ids, encoder_hidden_states=features)
        pred = self.tokenizer.batch_decode(res)
        
        return pred
    

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params
    
    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device

In [58]:
window_size = 128

In [59]:
vae_config = VAEConfig(C=256, levels=(8, 8, 6, 5))
vq_vae = SoundStream(**vae_config.to_dict())

self.codebook_size 1920
self.downsample 8


In [60]:
bert_config = BertConfig(dim=64, 
                        window_size=window_size, 
                        tokenizer_downsample=int(vq_vae.downsample),
                        n_electrodes=256, 
                        mask_ratio=0, 
                        n_layers=12, 
                        n_heads=12, 
                        n_kv_heads=12)

bert = BrainBert(bert_config, vq_vae)


BertConfig(window_size=128, n_electrodes=256, mask_ratio=0, tokenizer_downsample=8, n_layers=12, dim=64, hidden_dim=1024, head_dim=32, n_heads=12, n_kv_heads=12)
Encoder: number of parameters: 20.74M


In [61]:
@dataclass
class CombinerConfig(Serializable):
    # data params
    n_registers: int = 2
    n_layers: int = 4
    dim: int = 64
    hidden_dim: int = 1024

    head_dim: int = 32
    n_heads: int = 16
    n_kv_heads: int = 16 

@dataclass
class CausalModelConfig(Serializable):

    block_size: int = 0
    rope_theta: float = 1000.0

    # data params
    n_layers: int = 4
    dim: int = 64
    hidden_dim: int = 1024
    dropout: float = 0.0
    

    head_dim: int = 32
    n_heads: int = 16
    n_kv_heads: int = 16 
    calculate_loss: bool = False




In [62]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt2 = GPT2LMHeadModel.from_pretrained('gpt2', add_cross_attention=True)

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.3.crossattention.c_proj.bias', 'h.1.crossattention.c_proj.bias', 'h.0.crossattention.c_attn.bias', 'h.1.crossattention.c_attn.bias', 'h.0.crossattention.q_attn.bias', 'h.7.crossattention.c_attn.bias', 'h.10.crossattention.c_attn.weight', 'h.3.crossattention.q_attn.weight', 'h.2.crossattention.q_attn.weight', 'h.8.crossattention.c_attn.weight', 'h.11.crossattention.c_proj.weight', 'h.6.crossattention.q_attn.weight', 'h.10.ln_cross_attn.weight', 'h.5.crossattention.c_proj.weight', 'h.2.ln_cross_attn.bias', 'h.4.crossattention.q_attn.bias', 'h.1.crossattention.q_attn.weight', 'h.1.ln_cross_attn.bias', 'h.10.crossattention.c_proj.weight', 'h.11.crossattention.c_attn.bias', 'h.11.crossattention.q_attn.bias', 'h.9.crossattention.c_attn.weight', 'h.6.crossattention.c_proj.bias', 'h.9.ln_cross_attn.bias', 'h.0.ln_cross_attn.bias', 'h.3.ln_cross_attn.bias', 'h.7.ln_cross_attn.we

In [63]:
combiner_config = CombinerConfig()
causal_config = CausalModelConfig()


model = Franky(combiner_config, causal_config, bert, gpt2, tokenizer)



Shape of the rope cache:  torch.Size([32, 16])
Full Franky: number of parameters: 176.26M


In [64]:
x = torch.randn(1, 128, 256)
targers = torch.arange(25).unsqueeze(0)

loss, features = model(x, targers)

print(features.shape)

torch.Size([1, 16, 256, 64])
brain features shape torch.Size([1, 32, 768])
is_padded torch.Size([1, 32])
torch.Size([1, 24, 50257])


In [299]:


window_size = 32
n_electrodes = 256

train_transform = A.Compose([
    
    # A.CoarseDropout(fill_value=0, p=0.5),
    # A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=0.5),
    # A.GaussNoise(var_limit=0.005, mean=0, p=0.5),

    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    # A.RandomCrop(height=window_size, width=n_electrodes, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True),

])

test_transform = A.Compose([
    A.PadIfNeeded(min_height=window_size, min_width=n_electrodes, position='top_left', 
                  border_mode=0, value=0, always_apply=True),
    A.Crop(x_min=0, x_max=n_electrodes, y_min=0, y_max=window_size, always_apply=True)
])