# KD on Text-generation models

In [None]:
%%capture
%pip install -U bitsandbytes
%pip install -U accelerate
%pip install -U peft
%pip install -U transformers==4.48.0 #4.46.3
%pip install -U datasets
%pip install -U wandb

In [2]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
training_bf16 = torch.cuda.is_bf16_supported()

In [3]:
# loading model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if training_bf16 else torch.float16,
)

In [4]:
teacher_model_path = "JC-24/gemma-7b-mediqa-final"
student_model_path = "/kaggle/input/gemma/transformers/2b/2"

teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_model_path,
    quantization_config=bnb_config
)

tokenizer = AutoTokenizer.from_pretrained(teacher_model_path)
tokenizer.pad_token = tokenizer.eos_token

`low_cpu_mem_usage` was None, now default to True since model is quantized.


In [5]:
student_model = AutoModelForCausalLM.from_pretrained(
    student_model_path,
    quantization_config=bnb_config
)

`low_cpu_mem_usage` was None, now default to True since model is quantized.


In [6]:
# Wrap both models with DataParallel to use multiple GPUs
student_model = torch.nn.DataParallel(student_model)
teacher_model = torch.nn.DataParallel(teacher_model)

# Move both models to GPU
student_model = student_model.to(device)
teacher_model = teacher_model.to(device)

In [7]:
from torch.utils.data import DataLoader
from torch.optim import AdamW

batch_size = 1
epochs = 1

In [8]:
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

from datasets import load_dataset

data = load_dataset("wikitext", "wikitext-2-raw-v1",split="train[:500]")

split_dataset = data.train_test_split(test_size=0.2)

train_data = split_dataset['train']
eval_data = split_dataset['test']

train_dataset = MyDataset(train_data)
eval_dataset = MyDataset(eval_data)

train_dataloader = DataLoader(train_dataset, batch_size = batch_size) 
eval_dataloader = DataLoader(eval_dataset, batch_size = batch_size) 

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

In [None]:
# from torchmetrics.text import Perplexity
# from tqdm import tqdm
# import torch.nn as nn

# device = "cuda"
# def calculate_perplexity_batched(model, dataloader, device):
#     global logits,input_ids
#     model.eval()
#     # print(tokenizer.pad_token_id)
#     perplexity_metric = Perplexity(ignore_index=tokenizer.pad_token_id).to(device)
#     perp_list = []
#     with torch.no_grad():
#         for batch in tqdm(dataloader):
#             tokenized_data = tokenizer(batch["text"], padding="max_length", max_length=512, truncation=True, return_tensors="pt")
#             input_ids = tokenized_data["input_ids"].to(device)
#             attention_mask = tokenized_data["attention_mask"].to(device)
            
#             outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
#             logits = outputs.logits  # model outputs
#             print(logits[0][0][:10])
#             # Calculate perplexity batch-wise
#             # logits = nn.functional.log_softmax(outputs.logits,dim=-1)
#             perplexity_metric.update(logits[:,:-1],input_ids[:,1:])
#             perp = perplexity_metric.compute().item()
#             perp_list.append(perp)
#     # Compute final perplexity
#     final_perplexity = sum(perp_list)/len(perp_list)
#     return final_perplexity

# student_perplexity = calculate_perplexity_batched(student_model, eval_dataloader, device) # Student Perplexity: 151353.71875
# print("Student Perplexity:", student_perplexity)
# teacher_perplexity = calculate_perplexity_batched(teacher_model, eval_dataloader, device) # Teacher Perplexity: 151384.03125
# print("Teacher Perplexity:", teacher_perplexity)

In [None]:
import torch.nn as nn

"""

why we calculate log_softmax for student logits and only softmax for teacher logits:
https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html

"""

def kd_loss(student_logits,teacher_logits,actual_labels, alpha=0.5, temperature=2.0):
    
    teacher_soft = nn.functional.softmax(teacher_logits / temperature,dim=-1)
    student_log_soft = nn.functional.log_softmax(student_logits / temperature,dim=-1)
    
    kl_loss = nn.KLDivLoss(reduction = "batchmean")
    kl_div_loss = kl_loss(student_log_soft,teacher_soft) * (temperature ** 2)  
    
    hard_loss = nn.CrossEntropyLoss()(student_logits.view(-1, student_logits.size(-1)), actual_labels.view(-1)) 
    
    total_loss = alpha*kl_div_loss + (1-alpha)*hard_loss
    return kl_div_loss, hard_loss, total_loss

In [24]:
# """
# perplexity =  exp ( sum of negative log-likelihood of the tokens in the sequence / total tokens)
# perplexity is equivalent to the exponentiation of the cross-entropy between the actual data and model predictions
# """
# from tqdm import tqdm
# # def get_perplexity(logits, target, pad_token_id):
# #     # Convert logits to log probabilities
# #     log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
# #     target_log_probs = log_probs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze(-1)
# #     # Calculate the negative log likelihood
# #     mask = (target != pad_token_id)
# #     target_log_probs = target_log_probs[mask]
# #     negative_log_likelihood = -target_log_probs
    
# #     # Calculate the mean negative log likelihood over all tokens
# #     # print("neg_log: ",negative_log_likelihood[0][0][:5])
# #     mean_nll = negative_log_likelihood.mean()
# #     # print("---->",mean_nll.item())
# #     # Calculate perplexity as exp(mean negative log likelihood)
# #     perplexity = torch.exp(mean_nll)
# #     # print(target[0][0], pad_token_id)
# #     print("##p:",target[0][0], pad_token_id, perplexity.item())
# #     return perplexity.item()
    
# def calculate_perplexity(model, dataloader, device):
#     losses = []
#     with torch.no_grad():
#         # p_list = []
#         for data in tqdm(dataloader):
#             tokenized_data = tokenizer(data["text"],padding="max_length",max_length = 512,truncation=True,return_tensors = "pt")
#             input_ids = tokenized_data["input_ids"].to(device)
#             # if input_ids[0][0] == tokenizer.pad_token_id:
#             #     continue
#             attention_mask = tokenized_data["attention_mask"].to(device)
            
#             outputs = model(input_ids,attention_mask=attention_mask, labels=input_ids)
#         #     p = get_perplexity(outputs.logits,input_ids, tokenizer.pad_token_id)
#         #     p_list.append(p)
#         # return sum(p_list)/len(p_list)
#             loss = outputs.loss.item()
#             losses.append(loss)
#     return torch.exp(torch.tensor(losses)).mean().item()


# student_perplexity = calculate_perplexity(student_model,eval_dataloader,device)
# print("student_perplexity: ",student_perplexity)
# teacher_perplexity = calculate_perplexity(teacher_model,eval_dataloader,device)
# print("teacher_perplexity: ",teacher_perplexity)
# # student_perplexity:  9043377.0
# # teacher_perplexity:  22564.099609375

In [None]:
student_model = AutoModelForCausalLM.from_pretrained(
    student_model_path,
    quantization_config=bnb_config
)

student_model = prepare_model_for_kbit_training(student_model)

# Define LoRA configuration for fine-tuning
peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)

# Wrap student model with LoRA adapters
student_model = get_peft_model(student_model, peft_config)

In [None]:
optimizer = AdamW(student_model.parameters(),lr = 1.4e-5)

In [None]:
#train_loop
teacher_model.eval()
logging_step = 50
for epoch in range(epochs):
    for step, data in enumerate(train_dataloader):
        
        tokenized_data = tokenizer(data["text"],padding="max_length",max_length = 512,truncation=True,return_tensors = "pt")
        input_ids = tokenized_data["input_ids"].to(device)
        attention_mask = tokenized_data["attention_mask"].to(device)
        with torch.no_grad():
            teacher_out = teacher_model(input_ids,attention_mask=attention_mask)
            
        student_out = student_model(input_ids,attention_mask=attention_mask)
        student_logits = student_out.logits
        kl_div_loss, hard_loss, total_loss = kd_loss(student_out.logits,teacher_out.logits, input_ids)

        if step%logging_step==0:
            print("step: ",step," kl_div_loss:",kl_div_loss.item(), " hard_loss:",hard_loss.item(), " total_loss:", total_loss.item())
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

    student_model.eval()
    student_perplexity = calculate_perplexity(student_model,eval_dataloader,device)
    teacher_perplexity = calculate_perplexity(teacher_model,eval_dataloader,device)
    print("Perplexity of student model: ", student_perplexity)
    print("Perplexity of teacher model: ", teacher_perplexity)
    #evaluate the model
    student_model.train()

# KD using Tranformers Trainer

In [None]:
%%capture
%pip install -U bitsandbytes
%pip install -U accelerate
%pip install -U peft
%pip install -U transformers==4.48.0 #4.46.3
%pip install -U datasets
%pip install -U wandb

In [None]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

hf_token = user_secrets.get_secret("HUGGINGFACE_TOKEN")
login(token = hf_token)

In [None]:
import torch
training_bf16 = torch.cuda.is_bf16_supported()

In [None]:
# loading model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if training_bf16 else torch.float16,
)

In [None]:
teacher_model_path = "/kaggle/input/qwen2/transformers/1.5b/1"
student_model_path = "/kaggle/input/qwen2/transformers/0.5b/1"

teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_model_path,
    quantization_config=bnb_config
)

tokenizer = AutoTokenizer.from_pretrained(teacher_model_path)
tokenizer.pad_token = tokenizer.eos_token

student_model = AutoModelForCausalLM.from_pretrained(
    student_model_path,
    quantization_config=bnb_config
)

student_model = prepare_model_for_kbit_training(student_model)

# Define LoRA configuration for fine-tuning
peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)

# Wrap student model with LoRA adapters
student_model = get_peft_model(student_model, peft_config)

In [None]:
from datasets import load_dataset
data = load_dataset("wikitext", "wikitext-2-raw-v1",split="train[:500]")

split_dataset = data.train_test_split(test_size=0.2)

train_data = split_dataset['train']
eval_data = split_dataset['test']

In [None]:
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer


# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",
    eval_steps = 0.2,
    max_steps=1,
    learning_rate=5e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    # num_train_epochs=1,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=1,
    save_strategy="steps",
    remove_unused_columns=False
)

# Custom DataCollator for Distillation
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer, model=student_model)

# Trainer for KD
class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, temperature=2.0, alpha=0.5, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        self.temperature = temperature
        self.alpha = alpha

    def compute_loss(self, model, inputs, return_outputs=False):
        # Forward pass for student
        print("---------------------------")
        inputs = tokenizer(inputs["text"],padding="max_length",max_length = 512,truncation=True,return_tensors = "pt")
        labels = inputs.pop("labels")
        inputs.pop("text")
        outputs_student = model(**inputs)
        logits_student = outputs_student.logits

        # Forward pass for teacher
        with torch.no_grad():
            outputs_teacher = self.teacher_model(**inputs)
            logits_teacher = outputs_teacher.logits

        # Compute KD loss
        loss_kl = F.kl_div(
            F.log_softmax(logits_student / self.temperature, dim=-1),
            F.softmax(logits_teacher / self.temperature, dim=-1),
            reduction="batchmean"
        ) * (self.temperature ** 2)

        # Compute hard label loss
        loss_ce = F.cross_entropy(logits_student.view(-1, logits_student.size(-1)), labels.view(-1))

        # Total loss
        loss = self.alpha * loss_kl + (1 - self.alpha) * loss_ce
        print(loss,"----------------------------")
        return (loss, outputs_student) if return_outputs else loss

# Trainer instance
trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    tokenizer=tokenizer,
    data_collator=data_collator,
    teacher_model=teacher_model,
    temperature=2.0,
    alpha=0.5
)

In [None]:
# Train
trainer.train()

# KD on Sequence Classification Model

In [None]:
%%capture
%pip install -U transformers==4.48.0 #4.46.3
%pip install -U datasets

In [None]:
from datasets import load_dataset

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import DistilBertForSequenceClassification, DistilBertConfig

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [None]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

hf_token = user_secrets.get_secret("HUGGINGFACE_TOKEN")
login(token = hf_token)

In [None]:
device = torch.device('cuda')
device

In [None]:
# Load teacher model and tokenizer
model_path = "shawhin/bert-phishing-classifier_teacher"

tokenizer = AutoTokenizer.from_pretrained(model_path)
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    model_path).to(device)

In [None]:
# Load student model
my_config = DistilBertConfig(n_heads=8, n_layers=4) # drop 4 heads per layer and 2 layers
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased",
                                                                    config=my_config).to(device)

In [None]:
data = load_dataset("shawhin/phishing-site-classification")

In [None]:
# define text preprocessing
def preprocess_function(examples):
    return tokenizer(examples["text"], padding='max_length', truncation=True)

# tokenize all datasetse
tokenized_data = data.map(preprocess_function, batched=True)
tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

In [None]:
def evaluate_model(model, dataloader, device):
    model.eval()  # Set model to evaluation mode
    all_preds = []
    all_labels = []

    # Disable gradient calculations
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass to get logits
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # Get predictions
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')

    return accuracy, precision, recall, f1

In [None]:
# Function to compute distillation and hard-label loss
def distillation_loss(student_logits, teacher_logits, true_labels, temperature, alpha):
    # Compute soft targets from teacher logits
    soft_targets = nn.functional.softmax(teacher_logits / temperature, dim=1)
    student_soft = nn.functional.log_softmax(student_logits / temperature, dim=1)

    # KL Divergence loss for distillation
    distill_loss = nn.functional.kl_div(student_soft, soft_targets, reduction='batchmean') * (temperature ** 2)

    # Cross-entropy loss for hard labels
    hard_loss = nn.CrossEntropyLoss()(student_logits, true_labels)

    # Combine losses
    loss = alpha * distill_loss + (1.0 - alpha) * hard_loss

    return loss

In [None]:
# hyperparameters
batch_size = 32
lr = 1e-4
num_epochs = 5
temperature = 2.0
alpha = 0.5

# define optimizer
optimizer = optim.Adam(student_model.parameters(), lr=lr)

# create training data loader
dataloader = DataLoader(tokenized_data['train'], batch_size=batch_size)
# create testing data loader
test_dataloader = DataLoader(tokenized_data['test'], batch_size=batch_size)

In [None]:
# put student model in train mode
student_model.train()
teacher_model.eval()
# train model
for epoch in range(num_epochs):
    for batch in dataloader:
        # Prepare inputs
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Disable gradient calculation for teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits

        # Forward pass through the student model
        student_outputs = student_model(input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits
        # Compute the distillation loss
        loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} completed with loss: {loss.item()}")

    # Evaluate the teacher model
    teacher_accuracy, teacher_precision, teacher_recall, teacher_f1 = evaluate_model(teacher_model, test_dataloader, device)
    print(f"Teacher (test) - Accuracy: {teacher_accuracy:.4f}, Precision: {teacher_precision:.4f}, Recall: {teacher_recall:.4f}, F1 Score: {teacher_f1:.4f}")

    # Evaluate the student model
    student_accuracy, student_precision, student_recall, student_f1 = evaluate_model(student_model, test_dataloader, device)
    print(f"Student (test) - Accuracy: {student_accuracy:.4f}, Precision: {student_precision:.4f}, Recall: {student_recall:.4f}, F1 Score: {student_f1:.4f}")
    print("\n")

    # put student model back into train mode
    student_model.train()

In [None]:
student_model.push_to_hub("shawhin/bert-phishing-classifier_student")