In [1]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

from trl import GKDTrainer, GKDConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("Rowan/hellaswag", split="train")
raw_eval_dataset = load_dataset("Rowan/hellaswag", split="validation")

def preprocess_sft(example):
    context = example["ctx"]
    endings = example["endings"]
    correct_ending = endings[int(example["label"])]

    input_text = f"Context: {context.strip()} Ending:"
    target_text = correct_ending.strip()

    return {"input_text": input_text, "target_text": target_text}

sft_dataset = dataset.map(preprocess_sft)
sft_eval_dataset = raw_eval_dataset.map(preprocess_sft)

print(sft_dataset[0])

{'ind': 4, 'activity_label': 'Removing ice from car', 'ctx_a': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles.', 'ctx_b': 'then', 'ctx': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then', 'endings': [', the man adds wax to the windshield and cuts it.', ', a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled.', ', the man puts on a christmas coat, knitted with netting.', ', the man continues removing the snow on his car.'], 'source_id': 'activitynet~v_-1IBHYS3L-Y', 'split': 'train', 'split_type': 'indomain', 'label': '3', 'input_text': 'Context: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then Ending:', 'target_text': ', the man continues removing the snow on his car.'}


In [3]:
sft_eval_dataset[0]

{'ind': 24,
 'activity_label': 'Roof shingle removal',
 'ctx_a': 'A man is sitting on a roof.',
 'ctx_b': 'he',
 'ctx': 'A man is sitting on a roof. he',
 'endings': ['is using wrap to wrap a pair of skis.',
  'is ripping level tiles off.',
  "is holding a rubik's cube.",
  'starts pulling up roofing on a roof.'],
 'source_id': 'activitynet~v_-JhWjGDPHMY',
 'split': 'val',
 'split_type': 'indomain',
 'label': '3',
 'input_text': 'Context: A man is sitting on a roof. he Ending:',
 'target_text': 'starts pulling up roofing on a roof.'}

In [5]:
model_name = "/home/jovyan/layer-skip/model-checkpoint/layer_skip_1b_pruned_5_layers" 
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    inference_mode=False,
)

training_args = SFTConfig(
    output_dir="/home/jovyan/layer-skip/model-checkpoint/layer_skip_1b_pruned_5_layers_hellaswag_tuned",
    dataset_text_field="input_text",
    num_train_epochs=1,
    eval_steps=500,
    evaluation_strategy="steps",         
    warmup_steps=150,
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=sft_dataset,
    args=training_args,
    peft_config=lora_config,
    eval_dataset=sft_eval_dataset
)

trainer.train()


Map: 100%|██████████| 39905/39905 [00:02<00:00, 18853.47 examples/s]
Map: 100%|██████████| 10042/10042 [00:00<00:00, 16302.25 examples/s]


Step,Training Loss,Validation Loss
500,3.3052,2.913358
1000,2.71,2.819783
1500,2.6755,2.781629
2000,2.6298,2.75543
2500,2.6176,2.741166
3000,2.6149,2.729652
3500,2.5919,2.724007
4000,2.5797,2.719765
4500,2.5848,2.715374


TrainOutput(global_step=4989, training_loss=2.6894206071717153, metrics={'train_runtime': 4441.0323, 'train_samples_per_second': 8.986, 'train_steps_per_second': 1.123, 'total_flos': 1.1915801621434368e+16, 'train_loss': 2.6894206071717153, 'epoch': 1.0})