In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load Teacher (Large Model)
teacher_model_name = "ZySec-AI/SecurityLLM"
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name)

# Load Student (Smaller Model, e.g., DistilGPT2)
student_model_name = "distilgpt2"
student_model = AutoModelForCausalLM.from_pretrained(student_model_name)


In [None]:
import torch
import torch.nn.functional as F

def distillation_loss(student_outputs, teacher_outputs, T=2.0, alpha=0.5):
    # Soft Target Loss (KL Divergence)
    soft_loss = F.kl_div(F.log_softmax(student_outputs.logits / T, dim=-1),
                         F.softmax(teacher_outputs.logits / T, dim=-1),
                         reduction='batchmean')
    
    # Hard Target Loss (Cross-Entropy)
    hard_loss = F.cross_entropy(student_outputs.logits.view(-1, student_outputs.logits.size(-1)),
                                teacher_outputs.input_ids.view(-1))
    
    return alpha * soft_loss + (1 - alpha) * hard_loss


In [None]:
from torch.utils.data import DataLoader
from datasets import load_dataset

# Load your dataset (using a small dataset for illustration)
dataset = load_dataset("wikitext", split="train[:1%]")  # Example dataset, use your own data

# Tokenize the data
def tokenize_function(examples):
    return tokenizer(examples["text"], return_tensors="pt", truncation=True, padding="max_length", max_length=512)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataloader = DataLoader(tokenized_datasets, batch_size=8, shuffle=True)

# Optimizer for Student Model
optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)

# Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model.to(device)
teacher_model.to(device)

for epoch in range(3):  # Number of epochs
    student_model.train()
    teacher_model.eval()  # Teacher model is not updated, only used for generating soft targets
    
    total_loss = 0
    for batch in train_dataloader:
        inputs = {key: value.to(device) for key, value in batch.items()}
        
        # Forward pass through Teacher Model (get logits)
        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs)
        
        # Forward pass through Student Model
        student_outputs = student_model(**inputs)
        
        # Compute the distillation loss
        loss = distillation_loss(student_outputs, teacher_outputs, T=2.0, alpha=0.5)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_dataloader)}")


In [None]:
# Save the trained student model
student_model.save_pretrained("distilled_SecurityLLM")
tokenizer.save_pretrained("distilled_SecurityLLM")
