In [1]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, TrainerCallback
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
import torch
from transformers import Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
from peft.tuners.lora import LoraModel
from peft import PeftModel, PeftModelForSequenceClassification

from trl import GKDTrainer, GKDConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
boolq_dataset = load_dataset("google/boolq")

In [3]:
model_name = "/home/jovyan/layer-skip/model-checkpoint/layer_skip_1b_pruned_1_layers"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at /home/jovyan/layer-skip/model-checkpoint/layer_skip_1b_pruned_1_layers and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
def preprocess_boolq_function(examples):
    # Tokenize the questions and passages
    inputs = [q + " " + p for q, p in zip(examples["question"], examples["passage"])]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")

    # Encode the labels (True -> 1, False -> 0)
    labels = [1 if ans else 0 for ans in examples["answer"]]
    model_inputs["labels"] = torch.tensor(labels, dtype=torch.long)
    return model_inputs
tokenized_boolq_datasets = boolq_dataset.map(preprocess_boolq_function, batched=True, remove_columns=["question", "passage", "answer"])

In [5]:
model.base_model

LlamaModel(
  (embed_tokens): Embedding(128256, 2048)
  (layers): ModuleList(
    (0-13): 14 x LlamaDecoderLayer(
      (self_attn): LlamaSdpaAttention(
        (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2048, out_features=512, bias=False)
        (v_proj): Linear(in_features=2048, out_features=512, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    )
  )
  (norm): LlamaRMSNorm((2048,), eps=1e-05)
  (rotary_emb): LlamaRotaryEmbedding(

In [6]:
# lora_config = LoraConfig(
#     r=8,
#     lora_alpha=32,
#     lora_dropout=0.1,
#     task_type="SEQ_CLS",
#     target_modules=["q_proj", "v_proj"]
# )
# model = get_peft_model(model, lora_config)

In [10]:
from transformers import EarlyStoppingCallback

training_args = TrainingArguments(
    output_dir="/home/jovyan/layer-skip/model-checkpoint/layer_skip_1b_1_layers_boolq_tuned_regular",
    num_train_epochs=4,                   # Number of epochs
    per_device_train_batch_size=4,        # Training batch size
    per_device_eval_batch_size=4,         # Evaluation batch size
    eval_steps=200,                       # Evaluate every 500 steps
    evaluation_strategy="steps",
    save_steps=200,                       # Save checkpoints every 500 steps
    learning_rate=5e-5,                   # Starting learning rate
    warmup_steps=150,                    # Warmup steps
    weight_decay=0.3,                    # Weight decay
    metric_for_best_model="eval_loss",
    lr_scheduler_type="polynomial",
    max_grad_norm=1.0,
    label_smoothing_factor=0.3,
    logging_dir="/home/jovyan/layer-skip/logs/layer_skip_1b_1_layers_boolq_tuned", 
    logging_steps=50,                     # Log every 50 steps
    save_total_limit=3,                   # Save a maximum of 3 checkpoints
    fp16=True,                            # Mixed precision
    gradient_accumulation_steps=10,        # Accumulate gradients
    load_best_model_at_end=True,          # Load best model at end of training
    remove_unused_columns=False,
    save_strategy="steps"
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=tokenized_boolq_datasets['train'],
    args=training_args,
    eval_dataset=tokenized_boolq_datasets['validation'],
    callbacks=[EarlyStoppingCallback(early_stopping_patience=1)]
)

trainer.train()

  trainer = Trainer(


Step,Training Loss,Validation Loss
200,0.7723,0.661144
400,0.6707,0.714333


TrainOutput(global_step=400, training_loss=0.798938627243042, metrics={'train_runtime': 1114.3301, 'train_samples_per_second': 33.839, 'train_steps_per_second': 0.844, 'total_flos': 4.18506660642816e+16, 'train_loss': 0.798938627243042, 'epoch': 1.697072549851506})