In [1]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

from trl import GKDTrainer, GKDConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("Rowan/hellaswag", split="train")
raw_eval_dataset = load_dataset("Rowan/hellaswag", split="validation")

def preprocess_sft(example):
    context = example["ctx"]
    endings = example["endings"]
    correct_ending = endings[int(example["label"])]

    input_text = f"Context: {context.strip()} Ending:"
    target_text = correct_ending.strip()

    return {"input_text": input_text, "target_text": target_text}

sft_dataset = dataset.map(preprocess_sft)
sft_eval_dataset = raw_eval_dataset.map(preprocess_sft)

print(sft_dataset[0])

{'ind': 4, 'activity_label': 'Removing ice from car', 'ctx_a': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles.', 'ctx_b': 'then', 'ctx': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then', 'endings': [', the man adds wax to the windshield and cuts it.', ', a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled.', ', the man puts on a christmas coat, knitted with netting.', ', the man continues removing the snow on his car.'], 'source_id': 'activitynet~v_-1IBHYS3L-Y', 'split': 'train', 'split_type': 'indomain', 'label': '3', 'input_text': 'Context: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then Ending:', 'target_text': ', the man continues removing the snow on his car.'}


In [3]:
sft_eval_dataset[0]

{'ind': 24,
 'activity_label': 'Roof shingle removal',
 'ctx_a': 'A man is sitting on a roof.',
 'ctx_b': 'he',
 'ctx': 'A man is sitting on a roof. he',
 'endings': ['is using wrap to wrap a pair of skis.',
  'is ripping level tiles off.',
  "is holding a rubik's cube.",
  'starts pulling up roofing on a roof.'],
 'source_id': 'activitynet~v_-JhWjGDPHMY',
 'split': 'val',
 'split_type': 'indomain',
 'label': '3',
 'input_text': 'Context: A man is sitting on a roof. he Ending:',
 'target_text': 'starts pulling up roofing on a roof.'}

In [5]:
model_name = "/home/jovyan/layer-skip/model-checkpoint/layer_skip_1b_pruned_3_layers" 
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    inference_mode=False,
)

training_args = SFTConfig(
    output_dir="/home/jovyan/layer-skip/model-checkpoint/layer_skip_1b_pruned_3_layers_hellaswag_tuned_new",
    dataset_text_field="input_text",
    num_train_epochs=1,
    eval_steps=500,
    evaluation_strategy="steps",
    warmup_steps=150
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=sft_dataset,
    args=training_args,
    peft_config=lora_config,
    eval_dataset=sft_eval_dataset
)

trainer.train()

Map: 100%|██████████| 39905/39905 [00:01<00:00, 20561.15 examples/s]
Map: 100%|██████████| 10042/10042 [00:00<00:00, 21589.59 examples/s]


Step,Training Loss,Validation Loss
500,2.9699,2.668448
1000,2.4856,2.600171
1500,2.4602,2.571829
2000,2.4253,2.557512
2500,2.4181,2.54681
3000,2.4186,2.535969
3500,2.3955,2.530508
4000,2.3867,2.525476
4500,2.3875,2.520706


TrainOutput(global_step=4989, training_loss=2.4734934923237373, metrics={'train_runtime': 9071.3149, 'train_samples_per_second': 4.399, 'train_steps_per_second': 0.55, 'total_flos': 1.4298953928118272e+16, 'train_loss': 2.4734934923237373, 'epoch': 1.0})

# GKD Fine Tune

In [13]:
teacher_model_name = "facebook/layerskip-llama3.2-1B"
student_model_name = "/home/jovyan/layer-skip/model-checkpoint/layer_skip_1b_pruned"

teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name)

teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)

if teacher_tokenizer.pad_token is None:
    teacher_tokenizer.pad_token = teacher_tokenizer.eos_token
if student_tokenizer.pad_token is None:
    student_tokenizer.pad_token = student_tokenizer.eos_token


dataset = load_dataset("Rowan/hellaswag", split="train")
eval_dataset = load_dataset("Rowan/hellaswag", split="validation")

chat_template = """
<|system|>{% for message in conversation if message.role == 'system' %}
{{ message.content }}
{% endfor %}
<|user|>{% for message in conversation if message.role == 'user' %}
{{ message.content }}
{% endfor %}
<|assistant|>{% for message in conversation if message.role == 'assistant' %}
{{ message.content }}
{% endfor %}
"""
student_tokenizer.chat_template = chat_template

def preprocess_dataset(example):
    context = example["ctx"]
    endings = example["endings"]
    correct_ending = endings[int(example["label"])]
    
    return {
        "messages": [
            {"role": "user", "content": f"Context: {context.strip()} Ending:"},
            {"role": "assistant", "content": correct_ending.strip()}
        ]
    }

preprocessed_dataset = dataset.map(preprocess_dataset, remove_columns=dataset.column_names)
gkd_eval_dataset = eval_dataset.map(preprocess_dataset, remove_columns=dataset.column_names)
preprocessed_dataset[0]

{'messages': [{'content': 'Context: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then Ending:',
   'role': 'user'},
  {'content': ', the man continues removing the snow on his car.',
   'role': 'assistant'}]}

In [14]:
gkd_eval_dataset[0]

{'messages': [{'content': 'Context: A man is sitting on a roof. he Ending:',
   'role': 'user'},
  {'content': 'starts pulling up roofing on a roof.', 'role': 'assistant'}]}

In [15]:

gkd_config = GKDConfig(
    output_dir="/home/jovyan/layer-skip/model-checkpoint/layer_skip_1b_tuned_gkd",
    num_train_epochs=1,
    eval_steps=500,
    evaluation_strategy="steps"
)

trainer = GKDTrainer(
    model=student_model,
    teacher_model=teacher_model,
    processing_class=student_tokenizer,
    train_dataset=preprocessed_dataset,
    args=gkd_config,
    eval_dataset=gkd_eval_dataset
)

trainer.train()


Map: 100%|██████████| 39905/39905 [00:02<00:00, 14700.93 examples/s]
Map: 100%|██████████| 10042/10042 [00:00<00:00, 16347.02 examples/s]


Step,Training Loss,Validation Loss
500,0.0475,
1000,0.0396,
1500,0.0381,
2000,0.0337,
2500,0.0349,
3000,0.0318,
3500,0.0302,
4000,0.0291,
4500,0.0277,


From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


TrainOutput(global_step=4989, training_loss=0.03391021304915247, metrics={'train_runtime': 10856.0003, 'train_samples_per_second': 3.676, 'train_steps_per_second': 0.46, 'total_flos': 1.3513660379799552e+16, 'train_loss': 0.03391021304915247, 'epoch': 1.0})