In [1]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
%%shell
pip install bitsandbytes
pip install datasets
pip install peft
pip install optuna

Collecting bitsandbytes
  Downloading bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Downloading bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl (137.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.5/137.5 MB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.43.3
Collecting datasets
  Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.0-py3-no



In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from tqdm.notebook import tqdm
from safetensors.torch import safe_open
from peft import get_peft_model, LoraConfig
import optuna
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 2
batch_size = 8
print(device)

cuda


In [4]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", load_in_4bit=True)
tokenizer.add_special_tokens({"pad_token":"<pad>"})
model.generation_config.pad_token_id = tokenizer.pad_token_id

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct.
401 Client Error. (Request ID: Root=1-66e5c6d2-3e3563780e5f7bc90817dc7a;4cc01da4-bad4-46f4-aef8-ef8ad27dc7c1)

Cannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/resolve/main/config.json.
Access to model meta-llama/Meta-Llama-3-8B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in.

In [None]:
lora_config = LoraConfig(
    r=4,  # rank of the low-rank matrix
    lora_alpha=8,  # scaling factor for the LoRA updates
    lora_dropout=0.2,  # dropout to apply after LoRA
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]  # attention layers to train on
)

In [None]:
# prepare model for LoRA training
model = get_peft_model(model, lora_config)

In [None]:
# load and select train data
data = load_dataset("openai/gsm8k", "main", split="train")
data = data.select(range(1000))

In [None]:
# load eval data
val_data = load_dataset("openai/gsm8k", "main", split="test")

In [None]:
# load teacher logits
def load_list_of_logits_safetensor(file_path):
    with safe_open(file_path, framework="pt") as f:
        logits_list = []
        for key in f.keys():
            logits_list.append(f.get_tensor(key))

    return logits_list

In [None]:
# made loss into an object for better implementation
class KnowledgeDistillationLoss(nn.Module):
    def __init__(self, temperature=1.0, alpha=0.5):
        super(KnowledgeDistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        # hard Loss: cross-entropy between student predictions and true labels
        loss_hard = self.criterion(student_logits, labels)

        # soft loss: reverse KL-divergence between soft targets from teacher and student
        teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=1)
        student_probs = F.softmax(student_logits / self.temperature, dim=1)
        loss_soft = F.kl_div(teacher_log_probs, student_probs, reduction='batchmean', log_target=False) * (self.temperature ** 2)

        # Combine the losses
        loss = self.alpha * loss_hard + (1.0 - self.alpha) * loss_soft
        return loss

In [None]:
def evaluate_model(model, validation_data, tokenizer, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for example in validation_data:
            # Tokenize the input (question) and label (answer)
            inputs = tokenizer(example['question'], truncation=True, max_length=256, return_tensors="pt").to(device)
            labels = tokenizer(example['answer'], truncation=True, max_length=256, return_tensors="pt")['input_ids'].to(device)

            # Forward pass through the model
            outputs = model(**inputs)
            print(len(outputs))
            print(outputs[0].shape)
            print(tokenizer.decode(outputs[0]))
            student_logits = outputs.logits  # Shape [batch_size, sequence_length, vocab_size]

            # Adjust sequence lengths to match
            seq_len = min(student_logits.size(1), labels.size(1))
            student_logits = student_logits[:, :seq_len, :]
            labels = labels[:, :seq_len]

            # Flatten logits and labels for loss computation
            student_logits = student_logits.view(-1, student_logits.size(-1))  # Shape [total_tokens, vocab_size]
            labels = labels.view(-1)  # Shape [total_tokens]

            # Compute the loss (CrossEntropyLoss in this case)
            loss = F.cross_entropy(student_logits, labels)
            total_loss += loss.item()

    # Return the average loss over the validation set
    return total_loss / len(validation_data)

In [None]:
teacher_logits_L = load_list_of_logits_safetensor('/content/drive/MyDrive/llama-3.1-8b-gsm8k-base-tensors.safetensors')

In [None]:
no_decay = ["bias", "LayerNorm.weight"]
lora_params = []
base_params = []

for n, p in model.named_parameters():
    if "lora" in n:
        lora_params.append(p)
    else:
        base_params.append(p)

# Create parameter groups
optimizer_grouped_parameters = [
    {"params": base_params, "weight_decay": 0.0},  # No weight decay for base model params
    {"params": lora_params, "weight_decay": 1e-2},  # Apply weight decay to LoRA params
]

In [None]:
optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=5e-6, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(teacher_logits_L))
kd_loss = KnowledgeDistillationLoss(temperature=5.942267335064758, alpha=0.6093348343631224)

In [None]:
def train_model(model, teacher_logits_L, data, tokenizer, optimizer, scheduler, kd_loss, num_epochs, device, batch_size=1):
    model.train()
    num_batches = len(data) // batch_size + int(len(data) % batch_size != 0)  # Calculate number of batches

    for epoch in range(num_epochs):
        # Use tqdm to create a progress bar for the entire dataset
        progress_bar = tqdm(range(num_batches), desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch')

        for batch_idx in progress_bar:
            # Determine the start and end indices for this batch
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, len(data))

            # Get the current batch of examples
            batch = data[start_idx:end_idx]
            # questions = [example['question'] for example in batch]
            # answers = [example['answer'] for example in batch]
            questions = batch['question']
            answers = batch['answer']

            # Tokenize the input and label on the fly
            inputs = tokenizer(questions, truncation=True, padding=True, max_length=256, return_tensors="pt").to(device)
            labels = tokenizer(answers, truncation=True, padding=True, max_length=256, return_tensors="pt")['input_ids'].to(device)

            # Forward pass for student model
            outputs = model(**inputs)
            student_logits = outputs.logits  # Shape should be [batch_size, sequence_length, vocab_size]

            '''
            # Fetch corresponding teacher logits for this batch
            batch_teacher_logits = teacher_logits_L[start_idx:end_idx].to(device)

            # Ensure logits and labels have matching sequence lengths
            seq_len = min(student_logits.size(1), labels.size(1), batch_teacher_logits.size(1))

            student_logits = student_logits[:, :seq_len, :]
            labels = labels[:, :seq_len]
            batch_teacher_logits = batch_teacher_logits[:, :seq_len, :]

            # Flatten logits and labels for loss computation
            student_logits = student_logits.view(-1, student_logits.size(-1))  # Shape [batch_size * sequence_length, vocab_size]
            labels = labels.view(-1)  # Shape [batch_size * sequence_length]
            batch_teacher_logits = batch_teacher_logits.view(-1, student_logits.size(-1))  # Shape [batch_size * sequence_length, vocab_size]

            # Compute the KD loss
            loss = kd_loss(student_logits, batch_teacher_logits, labels)

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Step the scheduler
            scheduler.step()

            # Update the progress bar with the current loss
            progress_bar.set_postfix(loss=loss.item())
            '''
            # Fetch corresponding teacher logits for this batch
            # batch_teacher_logits = teacher_logits_L[start_idx:end_idx].to(device)
            batch_teacher_logits = teacher_logits_L[start_idx:end_idx]

            # Ensure logits and labels have matching sequence lengths
            # rint(batch_teacher_logits)
            # batch_teacher_logits = torch.tensor(batch_teacher_logits)
            student_logits.size(1)
            labels.size(1)
            batch_teacher_logits_seq_len = 100000
            for i in range(len(batch_teacher_logits)):
              batch_teacher_logits_seq_len = min(batch_teacher_logits_seq_len, batch_teacher_logits[i].shape[0])
            # seq_len = min(student_logits.size(1), labels.size(1), batch_teacher_logits.size(1))
            seq_len = min(student_logits.size(1), labels.size(1), batch_teacher_logits_seq_len)

            student_logits = student_logits[:, :seq_len, :]
            labels = labels[:, :seq_len]
            # batch_teacher_logits = batch_teacher_logits[:, :seq_len, :]
            batch_teacher_logits_truncated = []
            for i in range(len(batch_teacher_logits)):
              batch_teacher_logits_truncated.append(batch_teacher_logits[i][:seq_len])
            # for tensor in batch_teacher_logits_truncated:
            #   print(tensor.shape)
            batch_teacher_logits_truncated = torch.stack(batch_teacher_logits_truncated, dim=0)
            print(f"Teacher logits shape: {batch_teacher_logits_truncated.shape}")
            print(f"Student logits shape: {student_logits.shape}")

            # Flatten logits and labels for loss computation
            student_logits = student_logits.view(-1, student_logits.size(-1))  # Shape [batch_size * sequence_length, vocab_size]
            labels = labels.view(-1)  # Shape [batch_size * sequence_length]
            batch_teacher_logits_truncated = batch_teacher_logits_truncated.view(-1, student_logits.size(-1))  # Shape [batch_size * sequence_length, vocab_size]

            # Compute the KD loss
            loss = kd_loss(student_logits, batch_teacher_logits_truncated.to(device), labels)

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Step the scheduler
            scheduler.step()

            # Update the progress bar with the current loss
            progress_bar.set_postfix(loss=loss.item())

    return model

In [None]:
teacher_logits_L_reshaped = []
for i in range(len(teacher_logits_L)):
  x, y, z = teacher_logits_L[i].shape
  teacher_logits_L_reshaped.append(teacher_logits_L[i].reshape((y, z)))

In [None]:

model = train_model(model, teacher_logits_L, data, tokenizer, optimizer, scheduler, kd_loss, num_epochs, device)

In [None]:
teacher_logits_L[0].shape

In [None]:
data[0:1]

In [None]:

val = evaluate_model(model, val_data, tokenizer, device)
print(val)

In [None]:
from transformers import TrainingArguments, Trainer

In [None]:
model.save_pretrained("/content/drive/MyDrive/finetuned_llama/model.pt")

In [None]:
inputs =