In [1]:
import os
import torch
import torch.nn.functional as F
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
import yaml

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATASET_CONTEXT = """
PubMedQA is a dataset and a task that involves Question Answering (QA) using scientific literature from PubMed, which is a free resource that contains millions of articles related to life sciences and biomedical research. PubMedQA specifically focuses on using abstracts and passages from PubMed articles to answer medical and scientific questions.
"""

In [3]:
config = {
    "dataset" : {
        "name": "MothMalone/SLMS-KD-Benchmarks"
    }, 
    "models": {
        "teacher": "meta-llama/Llama-2-13b",
        "student": "meta-llama/Llama-3.2-1B"
    },
    "tokenizer": {
    "max_length": 4096,
    "chat_template" : """
            {% for message in messages %}
            {% if loop.first and messages[0]['role'] != 'system' %}
                {{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}
            {% endif %}
            {{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}
            {% endfor %}
            {% if add_generation_prompt %}
            {{ '<|im_start|>assistant\n' }}
            {% endif %}
    """,
    },
    "training": {
        "output_dir": "./results",
        "num_train_epochs": 3,
        "per_device_train_batch_size": 1,
        "gradient_accumulation_steps": 8,
        "save_steps": 1000,
        "logging_steps": 1,
        "learning_rate": 2e-5,
        "weight_decay": 0.05,
        "warmup_ratio": 0.1,
        "lr_scheduler_type": "cosine",
        "resume_from_checkpoint": None,
        "fp16": False,
        "bf16": True
    },
    "model_config": {
        "use_flash_attention": False
    }

}

In [4]:
dataset = load_dataset(config['dataset']['name'], 'pubmedqa')
dataset = dataset['train'].select(range(700))
dataset

Dataset({
    features: ['pubid', 'question', 'context', 'long_answer', 'final_decision'],
    num_rows: 700
})

In [5]:
student_tokenizer = AutoTokenizer.from_pretrained(config["models"]["student"])

In [6]:
from transformers import LlamaTokenizer

# Set the path to where the model is stored after download
model_path = "/home/lexuanan/.cache/huggingface/hub/models--meta-llama--Llama-2-13b/snapshots/5a3ad81c857aaf765c7a229a449490745a9004c9"

# Load the tokenizer
teacher_tokenizer = LlamaTokenizer.from_pretrained(model_path)

# Verify that it loads correctly
print(teacher_tokenizer)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message


LlamaTokenizer(name_or_path='/home/lexuanan/.cache/huggingface/hub/models--meta-llama--Llama-2-13b/snapshots/5a3ad81c857aaf765c7a229a449490745a9004c9', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)


In [7]:
def apply_template(row):
    question =  f"""
    {DATASET_CONTEXT}
    Given the following question from the PubMedQA dataset: 
    question: {row['question']}, 
    with these data:
    context: {row['context']}, 
    long_answer: {row['long_answer']}, 
    Please provide the final decision based on these data and provide the final_decision. It is either
    "yes" or "no" or "maybe".
    """
    return {
        "question" : question 
    }
    

In [8]:
dataset = dataset.map(apply_template)

In [9]:
dataset = dataset.remove_columns(['pubid','context', 'long_answer'])

In [10]:
dataset
student_tokenizer.pad_token = student_tokenizer.eos_token

In [11]:
def tokenize_function(row):
    return student_tokenizer(row["question"], truncation=True, max_length=config["tokenizer"]["max_length"], padding="max_length")

In [12]:
tokenized_dataset = dataset.map(tokenize_function, batched=True, num_proc=8, remove_columns=["question"])
tokenized_dataset

Dataset({
    features: ['final_decision', 'input_ids', 'attention_mask'],
    num_rows: 700
})

In [13]:
model_kwargs = {"torch_dtype": torch.bfloat16}
if config["model_config"]["use_flash_attention"]:
    model_kwargs["attn_implementation"] = "flash_attention_2"

In [14]:
def pad_logits(student_logits, teacher_logits):
    student_size, teacher_size = student_logits.size(-1), teacher_logits.size(-1)
    if student_size != teacher_size:
        pad_size = abs(student_size - teacher_size)
        pad_tensor = torch.zeros((*teacher_logits.shape[:-1], pad_size), dtype=teacher_logits.dtype, device=teacher_logits.device)
        return (torch.cat([student_logits, pad_tensor], dim=-1), teacher_logits) if student_size < teacher_size else (student_logits, torch.cat([teacher_logits, pad_tensor], dim=-1))
    return student_logits, teacher_logits

In [15]:
model_kwargs = {"torch_dtype": torch.bfloat16}
if config["model_config"]["use_flash_attention"]:
    model_kwargs["attn_implementation"] = "flash_attention_2"

In [16]:
from transformers import LlamaForCausalLM, LlamaTokenizer

# Set the model path
model_path = "/home/lexuanan/.cache/huggingface/hub/models--meta-llama--Llama-2-13b/snapshots/5a3ad81c857aaf765c7a229a449490745a9004c9"

# Load the tokenizer
teacher_tokenizer = LlamaTokenizer.from_pretrained(model_path)

# Load the entire model from the directory (not just one .pth file)
teacher_model = LlamaForCausalLM.from_pretrained(model_path, **model_kwargs)

# Check if the model loaded successfully
print(teacher_model)


OSError: Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory /home/lexuanan/.cache/huggingface/hub/models--meta-llama--Llama-2-13b/snapshots/5a3ad81c857aaf765c7a229a449490745a9004c9.