In [18]:
from transformers import BigBirdForCausalLM, BigBirdConfig
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from datasets import Dataset
import torch
import h5py
import pickle
import numpy as np
from datagen.tokens import GLOBAL_TOKENS_MAP, GLOBAL_TOKENS_RMAP
from transformers import Trainer, TrainingArguments
from torch.utils.data import DataLoader
from transformers import default_data_collator
import torch
from tqdm import tqdm
from transformers import Trainer, TrainingArguments

  torch.utils._pytree._register_pytree_node(


In [19]:
# Define a custom model configuration with fewer parameters
config =  BigBirdConfig(
    vocab_size=len(GLOBAL_TOKENS_MAP),  # Based on the number of unique tokens
    hidden_size = 256,
    intermediate_size = 256,
    num_hidden_layers = 22,
    num_attention_heads = 8,
    max_position_embeddings = 128,
    use_cache = True,
    pad_token_id = 31,
    bos_token_id = 30,
    eos_token_id = 29,
    rope_theta=20.0,
    
)


In [20]:
model = BigBirdForCausalLM(config)

# Calculate the number of parameters
print(f"Total parameters in the model: {model.num_parameters()}")

If you want to use `BigBirdForCausalLM` as a standalone, add `is_decoder=True.`


Total parameters in the model: 8881184


In [21]:
def custom_data_collator(features):
    batch = default_data_collator(features)
    input_ids = [f['input_ids'] for f in features]

    batch['input_ids'] = torch.tensor(input_ids, dtype=torch.long)
    batch['labels'] = torch.tensor(input_ids, dtype=torch.long)  # labels same as input_ids for LM
    #batch['labels'][:,:20] = -100
    return batch

# Use this custom collator in the DataLoader
data_collator = custom_data_collator

In [22]:
def load_data_from_pickle(filepath='dataset/addition_dataset.pkl'):
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data['texts'],data['info']


def prepare_dataset(texts,info):
    train_size = int(0.8 * len(texts))
    train_set = Dataset.from_dict({
        'input_ids': texts[:train_size],
        'info': info[:train_size]
    })
    test_set = Dataset.from_dict({
        'input_ids': texts[train_size:],
        'info': info[train_size:]
    })
    return train_set, test_set


In [23]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=30,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    evaluation_strategy='epoch',
    learning_rate=1e-4
)
texts,info = load_data_from_pickle()
train_dataset, test_dataset = prepare_dataset(texts,info)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator
)


In [24]:
# Decoding function to transform tokens back to text using GLOBAL_TOKENS_MAP
def decode(tokens):
    return ' '.join(GLOBAL_TOKENS_MAP.get(token, "") for token in tokens if token in GLOBAL_TOKENS_MAP)


def calculate_accuracy(dataset,model,boundary = 20):
    print(decode(dataset[0]['input_ids'][:boundary]))
    corr = 0
    batch_size = 64
    for i in tqdm(range(0, len(dataset), batch_size)):
        #print(i)
        bs = min(batch_size,len(dataset)-i)
        predictions = model.generate(torch.tensor([dataset[j]['input_ids'][:boundary] for j in range(i,i+bs)]).to('cuda:0'), max_length=128, )
        
        for k in range(bs):    
            pred = []
            add = False
            for j in predictions[k][boundary:].tolist():
                if j == 15:
                    break
                if add:
                    pred.append(j)
                if j == 14:
                    add = True
                
            if len(dataset[i+k]['info']) != len(pred):
                continue
            if dataset[i+k]['info'] == pred:
                corr+=1
    print("Accuracy: ",corr/len(dataset))


In [25]:
trainer.train()
eval_results = trainer.evaluate()
print(f"Evaluation Results: {eval_results}")

Attention type 'block_sparse' is not possible if sequence_length: 73 <= num global tokens: 2 * config.block_size + min. num sliding tokens: 3 * config.block_size + config.num_random_blocks * config.block_size + additional buffer: config.num_random_blocks * config.block_size = 704 with config.block_size = 64, config.num_random_blocks = 3. Changing attention type to 'original_full'...


Epoch,Training Loss,Validation Loss
1,No log,0.74471
2,No log,0.446038
3,No log,0.394573
4,0.730800,0.158895
5,0.730800,0.007448
6,0.730800,0.004403
7,0.730800,0.003119
8,0.013200,0.002374
9,0.013200,0.001897
10,0.013200,0.001575


wandb: Network error (ReadTimeout), entering retry loop.


Evaluation Results: {'eval_loss': 0.00037760258419439197, 'eval_runtime': 3.1142, 'eval_samples_per_second': 642.214, 'eval_steps_per_second': 10.275, 'epoch': 30.0}


In [26]:
calculate_accuracy(train_dataset,model, 22*2+4)
calculate_accuracy(test_dataset,model, 22*2+4)

BOS <prompt> 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 3 0 + 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 7 4 6 </prompt>



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125/125 [04:01<00:00,  1.93s/it]


Accuracy:  0.0
BOS <prompt> 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9 6 8 8 + 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9 3 5 8 </prompt>



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:59<00:00,  1.86s/it]

Accuracy:  0.0





In [27]:
for i in range(6,21):
    print("Size: ",i)
    texts, info = load_data_from_pickle(f'addition_dataset{i}.pkl')
    train_dataset, test_dataset = prepare_dataset(texts, info)
    calculate_accuracy(test_dataset,model, 22*2+4)
    

Size:  6


FileNotFoundError: [Errno 2] No such file or directory: 'addition_dataset6.pkl'

In [None]:
def encode(texts):
    return [GLOBAL_TOKENS_RMAP.get(words, "") for words in texts.split() if words in GLOBAL_TOKENS_RMAP]


In [None]:
model_weights_path = 'LLAMA_512_512_44_16_80M.pt'
torch.save(model.state_dict(), model_weights_path)
model = LlamaForCausalLM(config)  # Make sure you use the correct model class
model.load_state_dict(torch.load(model_weights_path))

In [13]:
model.eval()
outputs = model(torch.tensor([test_dataset[0]['input_ids'][:48]]).to('cuda:0'))

In [16]:
torch.argmax(outputs.logits[0], axis = 1).tolist()

[30,
 10,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 9,
 6,
 8,
 8,
 19,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 9,
 3,
 5,
 8,
 11]