In [1]:
import os
import sys
from pathlib import Path
sys.path.insert(1, os.path.realpath(os.path.pardir))

import torch
import torch.nn.functional as F
from torch import nn
import safetensors


import einops

from models import brainformer
from utils.data_utils import BrainDataset, get_tokenizer
from utils.train_utils import TrainConfig, run_train_model, count_parameters, simple_train_model

from models.simple_mae_abs import SimpleEncoder, SimpleMAE

from dataclasses import dataclass
from simple_parsing.helpers import Serializable


In [2]:
from peft import LoraConfig
from peft import get_peft_model

from transformers import GPT2Tokenizer
from models.gpt2_model import GPT
import tiktoken
from contextlib import nullcontext
from accelerate import notebook_launcher

import matplotlib.pyplot as plt

In [3]:
class Franky(nn.Module): 
    """This is first model which incorporate brain features into LLM"""

    def __init__(self, brain_model, llm_model, tokenizer=None):
        super().__init__()

        self.brain_model = brain_model
        self.projector = nn.Linear(self.brain_model.config.dim, llm_model.config.n_embd)
        self.llm_model= llm_model
        self.tokenizer = tokenizer

        self.date_embeddings = nn.Embedding(num_embeddings=25, embedding_dim=llm_model.config.n_embd)
        
        print("Full Franky: number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params
    
    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device

    def forward(self, x, targets=None, date_info=None):
        """
        Train model.
        """
        attn_mask = self.brain_model.get_attn_mask_padded(x)
        attn_mask = attn_mask.unsqueeze(1)

        features = self.brain_model(x, attn_mask)
        # features = features[:, :self.brain_model.n_registers]
        
        features = self.projector(features)

        date_emb = self.date_embeddings(date_info)
        date_emb = date_emb.unsqueeze(1)

        features = torch.cat([features, date_emb], axis=1)

        new_idx = targets.clone()
        new_idx[new_idx == -100] = 50256

        loss, logits = self.llm_model.forward(idx=new_idx, prefix=features, targets=targets)

        return loss, logits
    
    def generate(self, x, date_info=None, tokenizer=None):
        device = self.device
        
        x = torch.from_numpy(x[None, ]).to(device).to(self.dtype)

        ### Encoder part
        attn_mask = self.brain_model.get_attn_mask_padded(x)
        attn_mask = attn_mask.unsqueeze(1)

        features = self.brain_model(x, attn_mask)
        # features = features[:, :self.brain_model.n_registers]
        
        features = self.projector(features)

        date_emb = self.date_embeddings(date_info)
        date_emb = date_emb.unsqueeze(1)
        
        features = torch.cat([features, date_emb], axis=1)

        ### Text part
        start = '<|endoftext|>'
        input_ids = tokenizer(start,  return_tensors="pt")['input_ids'].to(self.device)
        
        max_new_tokens = 25
        temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
        top_k = 10
        
        with torch.no_grad():
            y = self.llm_model.generate(input_ids, max_new_tokens, prefix=features, temperature=temperature, top_k=top_k)

        stop_tokens_ids = (y == 50256).nonzero()
        end = len(y) if len(stop_tokens_ids)==1 else stop_tokens_ids[1]
        ids_clean = y[1:end]
        
        pred = tokenizer.decode(ids_clean, skip_special_tokens=True)
        return pred

In [4]:
@dataclass
class SimpleEncoderConfig(Serializable):
    # data params
    block_size: int = 768
    patch_size: int = 256

    n_layers: int = 8
    dim: int = 256
    hidden_dim: int = 1024
    n_registers: int = 16

    head_dim: int = 32
    n_heads: int = 16
    n_kv_heads: int = 16 # now it should be the same with n_heads.
    rope_theta: int = 10000


@dataclass
class SimpleMAEConfig(Serializable):
    # data params
    n_layers: int = 4
    dim: int = 256
    hidden_dim: int = 1024

    head_dim: int = 32
    n_heads: int = 8
    n_kv_heads: int = 8 # now it should be the same with n_heads.
    rope_theta: int = 10000

In [5]:
mae_weights = "/drive/logs/kovalev/fixed_abs_mae_11M_5M_spikes/step_34000_loss_0.0166.safetensors"

mae_model = SimpleMAE( SimpleEncoderConfig(), SimpleMAEConfig())
mae_model = torch.compile(mae_model)
safetensors.torch.load_model(mae_model, mae_weights)
mae_model = mae_model._orig_mod

brain_model = mae_model.encoder

SimpleEncoderConfig(block_size=768, patch_size=256, n_layers=8, dim=256, hidden_dim=1024, n_registers=16, head_dim=32, n_heads=16, n_kv_heads=16, rope_theta=10000)
Encoder: number of parameters: 10.76M
MAE: number of parameters: 15.29M


Process ForkProcess-8:
Process ForkProcess-7:
Process ForkProcess-2:
Process ForkProcess-3:
Process ForkProcess-5:
Process ForkProcess-6:
Process ForkProcess-4:
Process ForkProcess-1:
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/conda/envs/pytorch/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/conda/envs/pytorch/lib/python3.10/concurrent/futures/process.py", line 240, in _process_worker
    call_item = call_queu

In [6]:
device = 'cuda'
dtype = torch.float32

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
llm_model = GPT.from_pretrained('gpt2', dict(dropout=0.0)) # 

### Create Franky model
model = Franky(brain_model=brain_model, llm_model=llm_model)

# weights = '/drive/logs/kovalev/fixed_franky_v2_spikes_only_projector/step_500_loss_3.4384.safetensors'
# safetensors.torch.load_model(model, weights)

config = LoraConfig(
    r=4,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["c_proj", "c_attn", 'c_fc','c_proj'])

model = get_peft_model(model, config)

# for param in model.llm_model.parameters():
#     param.requires_grad = False
    
for param in model.projector.parameters():
    param.requires_grad = True

for param in model.date_embeddings.parameters():
    param.requires_grad = True

for param in model.brain_model.parameters():
    param.requires_grad = True

# for param in model.llm_model.parameters():
#     param.requires_grad = False


# model.print_trainable_parameters()

print('Initing of the Franky completed')
count_parameters(model)



loading weights from pretrained gpt: gpt2
forcing vocab_size=50257, block_size=1024, bias=True
overriding dropout rate to 0.0
number of parameters: 123.65M
Full Franky: number of parameters: 135.42M
Initing of the Franky completed
Total: 136.01M, Trainable: 11.57M


(136007168, 11567360)

In [7]:
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift, Clip

train_transform = Compose([
    # AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.15, p=0.1),
    Clip(0, 1, p=1)
])

test_transform = Compose([
    Clip(0, 1, p=1)
])

#### Test forward pass

In [8]:
# data_path = Path("/drive/data/competitionData")
# test_dataset = BrainDataset(data_path / 'test', tokenize_function=get_tokenizer(tokenizer), transform=train_transform)
# test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False)

# x, y, date = next(iter(test_dataloader))
# print(x.shape, y.shape, date)
# y = model(x, y, date)

### Run training pipeline

In [9]:
project_name = 'frankenstein'

train_config = TrainConfig(exp_name='all_tokens_train_all',
                           mixed_precision=True, 
                           batch_size=32, 
                           num_workers=3, 
                           pin_memory=True, 
                           warmup_iters=1000,
                           eval_interval=100,
                           grad_accum=1, 
                           weight_decay=1e-5)
# peter path
# data_path = Path(r'C:\Users\peter\alvi\brain2text\competitionData')
data_path = Path("/drive/data/competitionData")
save_folder = Path("/drive/logs/kovalev")

train_dataset = BrainDataset(data_path / 'train', tokenize_function=get_tokenizer(tokenizer), transform=train_transform)
test_dataset = BrainDataset(data_path / 'test', tokenize_function=get_tokenizer(tokenizer), transform=test_transform)

# indices = torch.arange(32).repeat(4)
# train_dataset = torch.utils.data.Subset(train_dataset, indices)
# test_dataset = torch.utils.data.Subset(test_dataset, indices)
args = (model, (train_dataset, test_dataset), train_config, project_name, save_folder)
notebook_launcher(run_train_model, args, num_processes=1)

# simple_train_model(*args)


Runed processing of the  /drive/data/competitionData/train
len: 8800
max input len 768
Runed processing of the  /drive/data/competitionData/test
len: 880
max input len 768
Launching training on one GPU.


dataloader_config = DataLoaderConfiguration(split_batches=True)
[34m[1mwandb[0m: Currently logged in as: [33mkoval_alvi[0m. Use [1m`wandb login --relogin`[0m to force relogin


Device for training:  cuda
Num devices:  1
Completed initialization of scheduler
****************************************************************************************************
overall_steps 100: 3.347933053970337
val loss: 3.3763537406921387
saved model:  step_100_loss_3.3764.safetensors
****************************************************************************************************
overall_steps 200: 3.173790454864502
val loss: 3.246567726135254
saved model:  step_200_loss_3.2466.safetensors
****************************************************************************************************
overall_steps 300: 3.3049468994140625
val loss: 3.1914420127868652
saved model:  step_300_loss_3.1914.safetensors
****************************************************************************************************
overall_steps 400: 3.086965799331665
val loss: 3.183224678039551
saved model:  step_400_loss_3.1832.safetensors
****************************************************************

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

In [None]:
max = 10

for i, sample in enumerate(train_dataset):
    
    x, gt, _ = sample
    
    pred = model.generate(x, tokenizer=tokenizer)
    
    gt[gt==-100]=50256
    gt_txt = tokenizer.decode(gt, skip_special_tokens=True)
    
    print('pred: ', pred)
    print('gt_txt: ', gt_txt)
    print('----')
    if i > max:
        break 

In [None]:
pred, pred = tokenizer.decode(ids_clean, skip_special_tokens=False)
gt

In [None]:
model.generate(x, tokenizer=tokenizer)