In [1]:
from huggingface_hub import login

login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
import os

# os.environ["CUDA_VISIBLE_DEVICES"]="9"

token = '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'

In [3]:
# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

# Configure 8-bit quantization
quantization_config_8bit = BitsAndBytesConfig(
    load_in_8bit=True,
    # bnb_8bit_compute_dtype=torch.float32
)

# Clear GPU Cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()


# Force garbage collection
import gc
gc.collect()

# Move models to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Reload models with same configuration and move to GPU
model_name1 = "meta-llama/Llama-3.2-1B-Instruct"
student_model = AutoModelForCausalLM.from_pretrained(
    model_name1,
    quantization_config=quantization_config,
    # device_map="cuda",
    token=token,
    # attn_implementation="sdpa"
)

model_name =  "meta-llama/Llama-3.1-8B-Instruct"
teacher_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="cuda",
    token=token,
    attn_implementation="sdpa"
)

teacher_model.eval()  
student_model.train()

from peft import get_peft_model, LoraConfig, TaskType

# Configure LoRA for Seq2Seq
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

# Attach LoRA to the model
student_model = get_peft_model(student_model, peft_config)
student_model.to(device)


teacher_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", trust_remote_code=True)
teacher_tokenizer.pad_token = teacher_tokenizer.eos_token


# Get model sizes in GB
def get_model_size_gb(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_gb = (param_size + buffer_size) / 1024**3
    return size_gb

student_size_gb = get_model_size_gb(student_model)
teacher_size_gb = get_model_size_gb(teacher_model)

print(f"Student model size: {student_size_gb:.2f} GB")
print(f"Teacher model size: {teacher_size_gb:.2f} GB")



Unused kwargs: ['bnb_8bit_compute_dtype']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Student model size: 0.95 GB
Teacher model size: 5.21 GB


In [4]:
from transformers import DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import load_from_disk
import torch
from torch.nn import functional as F
from peft import get_peft_model, LoraConfig, TaskType

# Load and preprocess dataset
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split

tokenized_datasets = load_from_disk('tokenized_dataset')
train_test_split = tokenized_datasets.train_test_split(test_size=0.2)



def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, temperature=2.0):
    """
    Compute distillation loss for text generation tasks.
    """
 
    # Apply mask before flattening
    masked_student_logits = student_logits 
    masked_teacher_logits = teacher_logits
    masked_labels = labels 

    # Compute KL divergence
    loss_distill = F.kl_div(
        input=F.log_softmax(masked_student_logits / temperature, dim=-1),
        target=F.softmax(masked_teacher_logits / temperature, dim=-1),
        reduction='batchmean'
    )

    # Compute standard cross-entropy loss 
    loss_student = F.cross_entropy(
        masked_student_logits.view(-1, student_logits.size(-1)), 
        masked_labels.view(-1),
        ignore_index=-100
    )

    # Combine losses
    return alpha * loss_student + (1 - alpha) * loss_distill


class DistillationSeq2SeqTrainer(Seq2SeqTrainer):
    def __init__(self, teacher_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs_student = model(**inputs)
        logits_student = outputs_student.logits

        with torch.no_grad():
            outputs_teacher = self.teacher_model(**inputs)
            logits_teacher = outputs_teacher.logits

        loss = distillation_loss(logits_student, logits_teacher, labels)
        return (loss, outputs_student) if return_outputs else loss


# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="no",
    learning_rate=5e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="epoch",
    predict_with_generate=True,
    logging_dir="./logs",
    logging_steps=500,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=1, 
    per_device_eval_batch_size=1, 
)

# Create a data collator for sequence-to-sequence models
# data_collator = DataCollatorForSeq2Seq(tokenizer=teacher_tokenizer, model=student_model)


# Initialize DistillationSeq2SeqTrainer
trainer = DistillationSeq2SeqTrainer(
    teacher_model=teacher_model,
    model=student_model,
    args=training_args,
    train_dataset=train_test_split['train'], 
    eval_dataset=train_test_split['test'],
    # data_collator=data_collator,
    processing_class=teacher_tokenizer,
)

# Train the model
trainer.train()

# Evaluate the model
results = trainer.evaluate()
print(results)

# Save the student model
student_model.save_pretrained("LLama1b")
teacher_tokenizer.save_pretrained("LLama1b")




  0%|          | 0/1600 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 115.0771, 'grad_norm': 161.4300079345703, 'learning_rate': 3.4375e-05, 'epoch': 0.31}
{'loss': 99.9986, 'grad_norm': 138.9449462890625, 'learning_rate': 1.8750000000000002e-05, 'epoch': 0.62}
{'loss': 98.205, 'grad_norm': 143.67135620117188, 'learning_rate': 3.125e-06, 'epoch': 0.94}
{'train_runtime': 431.2981, 'train_samples_per_second': 3.71, 'train_steps_per_second': 3.71, 'train_loss': 103.96357421875, 'epoch': 1.0}


  0%|          | 0/400 [00:00<?, ?it/s]

{'eval_loss': 98.5, 'eval_runtime': 86.7349, 'eval_samples_per_second': 4.612, 'eval_steps_per_second': 4.612, 'epoch': 1.0}


('LLama1b\\tokenizer_config.json',
 'LLama1b\\special_tokens_map.json',
 'LLama1b\\tokenizer.json')

In [5]:
# Save the full model by merging the adapters with the base model
merged_model = student_model.merge_and_unload()
merged_model.save_pretrained("LLama1b_full")
teacher_tokenizer.save_pretrained("LLama1b_full") 



('LLama1b_full\\tokenizer_config.json',
 'LLama1b_full\\special_tokens_map.json',
 'LLama1b_full\\tokenizer.json')