In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from tqdm.notebook import tqdm
from safetensors.torch import safe_open
from torch.utils.data import DataLoader
import os
from safetensors.torch import save_file
from torch.nn.utils.rnn import pad_sequence
from peft import get_peft_model, LoraConfig
import optuna
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 1
batch_size = 8
print(device)

In [None]:
# tokenizer = AutoTokenizer.from_pretrained("../Meta-Llama-3.1-8B")
# model = AutoModelForCausalLM.from_pretrained("../Meta-Llama-3.1-8B").to(device)
tokenizer = AutoTokenizer.from_pretrained("../Meta-Llama-3.1-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("../Meta-Llama-3.1-8B-Instruct").to(device)
model.to(torch.bfloat16) # VERY IMPORTANT: ENSURE USAGE OF BF16 ON ALL TRAINING TASKS TO REDUCE VRAM USAGE
student_tokenizer.add_special_tokens({"pad_token":"<pad>"})
student_base_model.generation_config.pad_token_id = student_tokenizer.pad_token_id

In [None]:
lora_config = LoraConfig(
    r=4,  # Rank of the low-rank matrix
    lora_alpha=8,  # Scaling factor for the LoRA updates
    lora_dropout=0.2,  # Dropout to apply after LoRA
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]  # The modules you want to apply LoRA to
)

In [None]:
model = get_peft_model(model, lora_config)

In [None]:
data = load_dataset("openai/gsm8k", "socratic", split="train")
# data = data[:2638]

In [None]:
val_data = load_dataset("openai/gsm8k", "socratic", split="test")

In [None]:
def load_list_of_logits_safetensor(file_path):
    with safe_open(file_path, framework="pt") as f:
        logits_list = []
        for key in f.keys():
            logits_list.append(f.get_tensor(key))
    
    return logits_list

In [None]:
class KnowledgeDistillationLoss(nn.Module):
    """
    Object used to calculate KD loss using a mix of hard loss (cross-entropy) and soft loss (KL-Divergence).
    """
    def __init__(self, temperature=1.0, alpha=0.5):
        """
        Parameters:
        - temperature (float): Temperature for softening logits before KL-Divergence.
        - alpha (float): Weight for combining hard and soft losses.
        """
        super(KnowledgeDistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        # Hard Loss: Cross-Entropy between student predictions and true labels
        loss_hard = self.criterion(student_logits, labels)

        # Soft Loss: Reverse KL-Divergence between soft targets from teacher and student
        teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=1)
        student_probs = F.softmax(student_logits / self.temperature, dim=1)

        # NOTE: swap the probs that get log-softmaxed to be the first one passed if switching from KL-Divergence to Reverse KL-Divergence and vice-versa
        loss_soft = F.kl_div(teacher_log_probs, student_probs, reduction='batchmean', log_target=False) * (self.temperature ** 2)

        # Combine the losses
        loss = self.alpha * loss_hard + (1.0 - self.alpha) * loss_soft
        return loss

In [None]:
def evaluate_model(model, validation_data, tokenizer, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for example in validation_data:
            # Tokenize the input (question) and label (answer)
            inputs = tokenizer(example['question'], truncation=True, max_length=256, return_tensors="pt").to(device)
            labels = tokenizer(example['answer'], truncation=True, max_length=256, return_tensors="pt")['input_ids'].to(device)

            # Forward pass through the model
            outputs = model(**inputs)
            student_logits = outputs.logits  # Shape [batch_size, sequence_length, vocab_size]

            # Adjust sequence lengths to match
            seq_len = min(student_logits.size(1), labels.size(1))
            student_logits = student_logits[:, :seq_len, :]
            labels = labels[:, :seq_len]

            # Flatten logits and labels for loss computation
            student_logits = student_logits.view(-1, student_logits.size(-1))  # Shape [total_tokens, vocab_size]
            labels = labels.view(-1)  # Shape [total_tokens]

            # Compute the loss (CrossEntropyLoss in this case)
            loss = F.cross_entropy(student_logits, labels)
            total_loss += loss.item()

    # Return the average loss over the validation set
    return total_loss / len(validation_data)

In [None]:
teacher_logits_L = load_list_of_logits_safetensor('../llama-3.1-405b-gsm8k-base-tensors.safetensors')

In [None]:
no_decay = ["bias", "LayerNorm.weight"]
lora_params = []
base_params = []

for n, p in model.named_parameters():
    if "lora" in n:
        lora_params.append(p)
    else:
        base_params.append(p)

# Create parameter groups
optimizer_grouped_parameters = [
    {"params": base_params, "weight_decay": 0.0},  # No weight decay for base model params
    {"params": lora_params, "weight_decay": 1e-2},  # Apply weight decay to LoRA params
]

In [None]:
optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=5e-6, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(teacher_logits_L))
kd_loss = KnowledgeDistillationLoss(temperature=5.942267335064758, alpha=0.6093348343631224)

In [None]:
def train_model(model, teacher_logits_L, data, tokenizer, optimizer, scheduler, kd_loss, num_epochs, device, batch_size=1):
    """
    Trains the model using the provided Knowledge Distillation loss with manual batch processing.

    Parameters:
    - model: The student model to train.
    - teacher_logits_L: Precomputed teacher logits for the dataset.
    - data: The dataset containing questions and answers.
    - tokenizer: Tokenizer for the model.
    - optimizer: The optimizer for the model.
    - scheduler: Learning rate scheduler.
    - kd_loss: The Knowledge Distillation loss function.
    - num_epochs: Number of epochs to train.
    - device: The device (CPU or GPU) to use for training.
    - batch_size: Number of samples per batch.

    Returns:
    - Trained model.
    """
    model.train()
    num_batches = len(data) // batch_size + int(len(data) % batch_size != 0)  # Calculate number of batches

    for epoch in range(num_epochs):
        # Use tqdm to create a progress bar for the entire dataset
        progress_bar = tqdm(range(num_batches), desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')

        for batch_idx in progress_bar:
            # Determine the start and end indices for this batch
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(data))

            # Get the current batch of examples
            batch = data[start_idx:end_idx]
            questions = [example['question'] for example in batch]
            answers = [example['answer'] for example in batch]

            # Tokenize the input and label on the fly
            inputs = tokenizer(questions, truncation=True, padding=True, max_length=256, return_tensors="pt").to(device)
            labels = tokenizer(answers, truncation=True, padding=True, max_length=256, return_tensors="pt")['input_ids'].to(device)

            # Forward pass for student model
            outputs = model(**inputs)
            student_logits = outputs.logits  # Shape should be [batch_size, sequence_length, vocab_size]

            # Fetch corresponding teacher logits for this batch
            batch_teacher_logits = teacher_logits_L[start_idx:end_idx].to(device)

            # Ensure logits and labels have matching sequence lengths
            seq_len = min(student_logits.size(1), labels.size(1), batch_teacher_logits.size(1))

            student_logits = student_logits[:, :seq_len, :]
            labels = labels[:, :seq_len]
            batch_teacher_logits = batch_teacher_logits[:, :seq_len, :]

            # Flatten logits and labels for loss computation
            student_logits = student_logits.view(-1, student_logits.size(-1))  # Shape [batch_size * sequence_length, vocab_size]
            labels = labels.view(-1)  # Shape [batch_size * sequence_length]
            batch_teacher_logits = batch_teacher_logits.view(-1, student_logits.size(-1))  # Shape [batch_size * sequence_length, vocab_size]

            # Compute the KD loss
            loss = kd_loss(student_logits, batch_teacher_logits, labels)

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Step the scheduler
            scheduler.step()

            # Update the progress bar with the current loss
            progress_bar.set_postfix(loss=loss.item())

    return model

In [None]:
model = train_model(model, teacher_logits_L, data, tokenizer, optimizer, scheduler, kd_loss, num_epochs, device, batch_size)
val = evaluate_model(model, val_data, tokenizer, device)
print(val)

In [None]:
# Save an initial copy of the model's state_dict
initial_model_state = copy.deepcopy(model.state_dict())

def objective(trial):
    # Reset the model to its initial state
    model.load_state_dict(initial_model_state)

    # Suggest values for alpha and temperature
    alpha = trial.suggest_float("alpha", 0.0, 1.0)
    temperature = 5.942267335064758

    # Create the Knowledge Distillation loss function with the suggested parameters
    kd_loss_fn = KnowledgeDistillationLoss(temperature=temperature, alpha=alpha).to(device)

    # Reinitialize the optimizer and scheduler
    optimizer = torch.optim.SGD(model.parameters(), lr=5e-6, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(teacher_logits_L))

    # Train the model using the `train_model` function
    trained_model = train_model(
        model=model,
        teacher_logits_L=teacher_logits_L,
        data=data,
        tokenizer=tokenizer,
        optimizer=optimizer,
        scheduler=scheduler,
        kd_loss=kd_loss_fn,
        num_epochs=num_epochs,
        device=device
    )

    # Evaluate the model on the validation set
    val_loss = evaluate_model(trained_model, val_data, tokenizer, device)

    return val_loss

# Define the number of trials
n_trials = 35

# Create the Optuna study
study = optuna.create_study(direction="minimize")

# Add a progress bar for the study
with tqdm(total=n_trials, desc="Temp/Alpha Trials") as pbar:
    def update_pbar(study, trial):
        pbar.update(1)
    
    study.optimize(objective, n_trials=n_trials, callbacks=[update_pbar])

with open("output2.txt", "w") as file:
    file.write(f"Best alpha with temperature of {5.942267335064758}: {study.best_params['alpha']}\n")
    

In [None]:
print(f"For 1 ep, Best alpha: {study.best_params['alpha']}, Best temperature: {study.best_params['temperature']}")

In [None]:
inputs = tokenizer("What is proof by induction?", return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=3000, do_sample=True, top_p=0.9).to(device)
output_answer = tokenizer.batch_decode(outputs)
output_answer
# with open("quadratic.txt", "w") as file:
#     file.write(output_answer)

In [None]:
model.save_pretrained('llama8b-LoRA-IS', max_shard_size="5GB")