In [1]:
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
import pickle
from datasets import Dataset
from transformers import default_data_collator
from tqdm import tqdm
from transformers import BitsAndBytesConfig
import torch
from huggingface_hub import login

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
login(token = "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")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [4]:
def load_data_from_pickle(filepath):
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data['texts'],data['A'],data['Q']

def custom_data_collator(features):
    batch = default_data_collator(features)
    input_ids = [f['input_ids'] for f in features]
    input_ids = tokenizer(input_ids, padding=True, return_tensors="pt")
    batch['input_ids'] = input_ids['input_ids']
    batch['labels'] =input_ids['input_ids'] # labels same as input_ids for LM
    #batch['labels'][:,:20] = -100
    return batch

def prepare_dataset(texts,A,Q):
    train_size = int(0.8 * len(texts))
    train_set = Dataset.from_dict({
        'input_ids': texts[:train_size],
        'A': A[:train_size],
        'Q': Q[:train_size]
    })
    test_set = Dataset.from_dict({
        'input_ids': texts[train_size:],
        'A': A[train_size:],
        'Q': Q[train_size:]
    })
    return train_set, test_set

In [5]:
import xml.etree.ElementTree as ET
import re

def prsanswer(text):
        # Using regex to find the content inside the <ans> tags
    match = re.search(r'<ans>\s*(.*?)\s*</ans>', text)
    if match:
        return match.group(1).strip()

    return None  # In case there is no match

def calculate_accuracy(dataset, model, n=200):
    corr = 0
    tot = 0
    for i in tqdm(range(0, min(len(dataset),n))):
        tot+=1
        input_ids = [dataset[i]['Q']+' <ans>']
        input_ids = tokenizer(input_ids, return_tensors="pt")
        predictions = model.generate(input_ids=input_ids['input_ids'].to('cuda:0'), attention_mask=input_ids['attention_mask'].to('cuda:0'), max_length=128)
        output = tokenizer.decode(predictions[0].tolist(), skip_special_tokens=True)
        #print(output)
        output = prsanswer(output)
        
        corr += (output == dataset[i]['A'])
        #print(output,dataset[i]['A'],output == dataset[i]['A'])
    print(corr/tot)

In [6]:
def train_save_push(model,lr, epoch, tokenizer, train, test, ood, target_modules, output_dir, hh):
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epoch,
        per_device_train_batch_size=8,
        gradient_accumulation_steps = 2,
        evaluation_strategy="epoch",
        logging_dir='./logs',
        logging_steps=100,
        learning_rate=lr,
        do_eval=True,
    )
    lora_config =  LoraConfig(
            r=8,
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules=target_modules,
            task_type="CAUSAL_LM",
            bias="none"
    )
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        peft_config=lora_config,
        train_dataset=train_dataset,
        dataset_text_field="input_ids",
        data_collator= custom_data_collator,
        eval_dataset=test_dataset,
        packing=True,
        max_seq_length = 128
    )
    trainer.train()
    trainer.save_model(f'{output_dir}/final')
    model = PeftModel.from_pretrained(model, f'{output_dir}/final')
    model.push_to_hub(f'saaduddinM/{hh}')

In [7]:
TMA = [["x_proj", "embeddings", "in_proj", "out_proj"], ["q_proj", "up_pro_proj", "k_proj", "down_proj", "v_proj"],
      ["q_proj", "o_proj", "k_proj","v_proj"]]

output_dir = "Mamba1.4B/add_small"
model_name = "state-spaces/mamba-1.4b-hf"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Mamba1.4B_add_small'
lr =2e-3
epoch = 7
TM = 0


output_dir = "RGemma2B/add_small"
model_name = "google/recurrentgemma-2b"
DS = 'dataset/addition_dataset_text_Small'
hh = 'RGemma2B_add_small'
lr =2e-4
epoch = 7
TM = 2


output_dir = "Zephyr3B/add_small"
model_name = "stabilityai/stablelm-zephyr-3b"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Zephyr3B_add_small'
lr =2e-4
epoch = 7
TM = 1

output_dir = "Phi1.5B/add_small"
model_name = "microsoft/phi-1_5"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Phi1.5B_add_small'
lr =2e-4
epoch = 7
TM = 1



output_dir = "Llama7B/add_small"
model_name = "meta-llama/Llama-2-7b-hf"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Llama7B_add_small'
lr =2e-4
epoch = 7
TM = 0

output_dir = "Mistral7B/add_small"
model_name = "mistralai/Mistral-7B-v0.1"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Mistral7B_add_small'
lr =2e-4
epoch = 7
TM = 0


output_dir = "Mamba2.8B/add_small"
model_name = "state-spaces/mamba-2.8b-hf"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Mamba2.8B_add_small'
lr =2e-4
epoch = 7
TM = 0

output_dir = "Gemma2B/add_small"
model_name = "google/gemma-2b"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Gemma2B_add_small'
lr =2e-4
epoch = 7
TM = 1


output_dir = "Llama8B/add_small"
model_name = "meta-llama/Llama-3-8b-hf"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Llama8B_add_small'
lr =2e-4
epoch = 7
TM = 1

output_dir = "Mistral7B/add_small"
model_name = "mistralai/Mistral-7B-v0.1"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Mistral7B_add_small'
lr =2e-4
epoch = 7
TM = 1

output_dir = "Llama8B/add_small"
model_name = "meta-llama/Meta-Llama-3-8B"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Llama8B_add_small'
lr =2e-4
epoch = 7
TM = 1

output_dir = "Mamba7B/add_small"
model_name = "tri-ml/mamba-7b-rw"
DS = 'dataset/addition_dataset_text_Small'
hh = 'Mamba7B_add_small'
lr =2e-4
epoch = 7
TM = 0


In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer.torch_dtype = "bfloat16"
model.torch_dtype = "bfloat16"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
#model.generation_config.pad_token_id = model.generation_config.eos_token_id
texts,A,Q = load_data_from_pickle(f'{DS}.pkl')
train_dataset, test_dataset = prepare_dataset(texts,A,Q)
texts,A,Q = load_data_from_pickle(f'{DS}_OOD.pkl')
ood_test_dataset, _ = prepare_dataset(texts,A,Q)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
train_save_push(model, lr, epoch, tokenizer, train_dataset, test_dataset, ood_test_dataset, TMA[TM], output_dir, hh)

Epoch,Training Loss,Validation Loss


In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(f'saaduddinM/{hh}', torch_dtype=torch.bfloat16)
texts,A,Q = load_data_from_pickle(f'{DS}.pkl')
train_dataset, test_dataset = prepare_dataset(texts,A,Q)
texts,A,Q = load_data_from_pickle(f'{DS}_OOD.pkl')
ood_test_dataset, _ = prepare_dataset(texts,A,Q)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of the model checkpoint at tri-ml/mamba-7b-rw were not used when initializing MambaForCausalLM: ['model.lm_head.weight']
- This IS expected if you are initializing MambaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MambaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
model.to('cuda:0')

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): lora.Embedding(
      (base_layer): Embedding(50432, 4096)
      (lora_dropout): ModuleDict(
        (default): Dropout(p=0.1, inplace=False)
      )
      (lora_A): ModuleDict()
      (lora_B): ModuleDict()
      (lora_embedding_A): ParameterDict(  (default): Parameter containing: [torch.cuda.BFloat16Tensor of size 8x50432 (cuda:0)])
      (lora_embedding_B): ParameterDict(  (default): Parameter containing: [torch.cuda.BFloat16Tensor of size 4096x8 (cuda:0)])
    )
    (layers): ModuleList(
      (0-63): 64 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(8192, 8192, kernel_size=(4,), stride=(1,), padding=(3,), groups=8192)
          (act): SiLU()
          (in_proj): lora.Linear(
            (base_layer): Linear(in_features=4096, out_features=16384, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
   

In [None]:
calculate_accuracy(train_dataset, model)
calculate_accuracy(test_dataset, model)
calculate_accuracy(ood_test_dataset, model)

 24%|██▎       | 47/200 [00:27<01:26,  1.77it/s]