In [2]:
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git

[0m

In [3]:
!nvidia-smi -L

GPU 0: NVIDIA A100 80GB PCIe (UUID: GPU-1b05da60-e4d7-2900-e930-248e1f2c401a)


In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "EleutherAI/gpt-j-6b",
    load_in_8bit=True,
    device_map='auto',
)

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
tokenizer.pad_token = tokenizer.eos_token

In [2]:
for param in model.parameters():
  param.requires_grad = False  # freeze the model - train adapters later
  if param.ndim == 1:
    # cast the small parameters (e.g. layernorm) to fp32 for stability
    param.data = param.data.to(torch.float32)

model.gradient_checkpointing_enable()  # reduce number of stored activations
model.enable_input_require_grads()

class CastOutputToFloat(nn.Sequential):
  def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)

In [3]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [4]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16, #attention heads
    lora_alpha=32, #alpha scaling
    # target_modules=["q_proj", "v_proj"], #if you know the
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM" # set this for CLM or Seq2Seq
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

trainable params: 7340032 || all params: 6058222816 || trainable%: 0.12115817167725645


In [5]:
from transformers import TextDataset, DataCollatorForLanguageModeling

# Load the datasets
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="../data/preprocessed/train_without_reasoning.csv",
    block_size=256,
)

val_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="../data/preprocessed/val_without_reasoning.csv",
    block_size=256,
)

# Define data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False,
)



In [6]:
from transformers import Trainer, TrainingArguments

# Define the parameters for fine-tuning
lr = 1e-5
end_lr = 2e-6
num_train_epochs = 1
warmup_steps = 100

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir='gptj-without-reasoning-results',
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        gradient_accumulation_steps=4,
        warmup_steps=warmup_steps,
        num_train_epochs=num_train_epochs,
        learning_rate=lr,
        weight_decay=0.1,
        fp16=True,
        logging_steps=10,
        logging_dir='./logs',
        evaluation_strategy="steps",
        eval_steps=50,
        
    ),
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)
model.config.use_cache = False

In [7]:
trainer.train()
trainer.save_model()

[34m[1mwandb[0m: Currently logged in as: [33mharsha-surampudi1997[0m ([33mharshasurampudi[0m). Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss,Validation Loss
50,2.2569,2.246367
100,2.2081,2.158312
150,1.9833,2.018227
200,1.9459,1.9439
250,1.8991,1.89666
300,1.8292,1.868124
350,1.8576,1.850322
400,1.825,1.838317
450,1.8339,1.829886
500,1.8023,1.824124


wandb: Network error (ConnectTimeout), entering retry loop.
wandb: Network error (ConnectTimeout), entering retry loop.
wandb: Network error (ConnectTimeout), entering retry loop.
wandb: Network error (ConnectTimeout), entering retry loop.
wandb: Network error (ConnectTimeout), entering retry loop.
wandb: Network error (ConnectTimeout), entering retry loop.
wandb: Network error (ConnectTimeout), entering retry loop.
wandb: Network error (ConnectTimeout), entering retry loop.


In [9]:
model.push_to_hub("harsha28/gptj-lfqa-without-reasoning",
                  use_auth_token=True,
                  commit_message="lr 1e-5, 1 epoch",
                  private=True)

adapter_model.bin:   0%|          | 0.00/29.4M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/harsha28/gptj-lfqa-without-reasoning/commit/c28f83383947410c1f81275f8c3d8502af82dfd4', commit_message='lr 1e-5, 1 epoch', commit_description='', oid='c28f83383947410c1f81275f8c3d8502af82dfd4', pr_url=None, pr_revision=None, pr_num=None)