In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from datasets import Dataset
import random
from peft import get_peft_model, LoraConfig, TaskType
import torch.nn.functional as F

# Add device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize teacher model normally in full precision
model_name = "meta-llama/Llama-2-7b-hf"
teacher_tokenizer = AutoTokenizer.from_pretrained(model_name)
teacher_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto"  # This will handle CUDA allocation efficiently
)

# Set padding token for the tokenizer
if teacher_tokenizer.pad_token is None:
    teacher_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    # Resize token embeddings for the model to account for the new token
    teacher_model.resize_token_embeddings(len(teacher_tokenizer))

# Configure LoRA to only train the adapters
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    modules_to_save=None  # Don't save any full modules
)

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.07s/it]
We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. This can affect the multi-gpu inference when using accelerate device_map.Please make sure to update your driver to the latest version which resolves this.
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [2]:
def generate_response(model, tokenizer, prompt, max_length=512):
    inputs = tokenizer(
        prompt, 
        padding_side="left",
        return_tensors="pt", 
        padding='max_length',
        truncation=True, 
        max_length=max_length // 2  # Reduce input length to leave room for generation
    ).to(device)  # Move inputs to GPU
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=max_length // 4,  # Allow generation of new tokens up to half max_length
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            #output_scores=True,
            output_logits=True,
            return_dict_in_generate=True
        )
    # Convert tuple of tensors into a single tensor
    logits_tensor = torch.cat([t.unsqueeze(1) for t in outputs.logits], dim=1)
    new_logits = logits_tensor  # Replace tuple with tensor. This is the tensor of logits for the new tokens, shape: (batch_size, num_new_tokens, vocab_size)
    new_tokens = outputs.sequences[:, inputs.input_ids.shape[-1]:]
    old_tokens = inputs.input_ids
    #print shapes
    print(new_logits.shape, new_tokens.shape, old_tokens.shape, outputs.sequences.shape, len(outputs.logits))
    #return tokenizer.decode(new_tokens[0], skip_special_tokens=True)
    #print (outputs)
    decoded=[tokenizer.decode(seq, skip_special_tokens=True) for seq in new_tokens]
    return decoded, new_logits

In [3]:
#return full text = False is an option in pipeline but not in generate.

In [4]:
decoded,logits=generate_response(teacher_model, teacher_tokenizer, ["What is the capital of France?", "Cual es la capital de Francia?"])

From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.


torch.Size([2, 128, 32001]) torch.Size([2, 128]) torch.Size([2, 256]) torch.Size([2, 384]) 128


In [5]:
decoded

['\nWhat is the capital of France? Paris\nWhat is the currency of France? Euro\nWhat is the official language of France? French\nWhat is the country code for France? +33\nWhat is the time in France? GMT +1 (GMT +2 from last Sunday in March to last Sunday in October)\nWhat is the internet TLD for France? .fr\nWhat is the population of France? 64,550,000 (est. 2006)\nWhat is the phone calling code for France? +33\nWhat are the major airports',
 '\nCual es la capital de Chile?\nCual es la capital de Cuba?\nCual es la capital de Canadá?\nCual es la capital de Colombia?\nCual es la capital de Corea del Norte?\nCual es la capital de Corea del Sur?\nCual es la capital de Costa Rica?\nCual es la capital de China?\nCual es la capital de Chile?\nCual es la capital de Colombia?\nCual es la capital de Canadá?\nCual es la capital de Costa Rica?\nCual es la capital de Corea del']

In [6]:
# Create student model using LoRA - this only creates adapter weights
student_model = get_peft_model(teacher_model, lora_config)

# Print only the trainable parameters (should be much smaller)
print("Trainable parameters for LoRA adapters:")
student_model.print_trainable_parameters()

Trainable parameters for LoRA adapters:
trainable params: 4,194,304 || all params: 6,742,618,112 || trainable%: 0.0622


In [7]:
generate_response(student_model, teacher_tokenizer, "What is the capital of France?\n Answer:")

torch.Size([1, 128, 32001]) torch.Size([1, 128]) torch.Size([1, 256]) torch.Size([1, 384]) 128


(['Paris.\nWhen did the French Revolution take place?\nAnswer: 1789.\nWhat is the national animal of France?\nAnswer: The rooster.\nWhat is the national flower of France?\nAnswer: The lily of the valley.\nWhat is the currency of France?\nAnswer: The euro.\nWhat is the national sport of France?\nAnswer: Tennis.\nWhat is the national food of France?\nAnswer: The croissant.\nWhat is the national drink of France?\nAnswer: The wine.\nWhat is the national dance of France?\nAnswer: The can'],
 tensor([[[-4.7355, -3.4162, 11.0670,  ..., -4.2172, -0.6887,  1.2191],
          [ 1.7003,  2.7376, 15.6574,  ...,  0.0611, -0.3284,  2.7565],
          [-2.1786, -2.3958, 14.4948,  ..., -2.5721, -0.9039,  1.0302],
          ...,
          [-2.5831, -2.4949,  9.4353,  ..., -2.4685, -0.0593,  2.1532],
          [-4.1165, -5.1592,  8.1207,  ..., -2.7060, -0.2856, -0.0532],
          [-5.0095, -8.4241,  4.4959,  ..., -2.9793, -1.6683, -0.7302]]],
        device='cuda:0'))

In [8]:
# Define system prompt and tasks
SYSTEM_PROMPT = """You are a helpful AI assistant that provides clear, accurate, and concise answers.
Always format code properly and explain technical concepts clearly."""

tasks = [
    "Explain how a binary search works.",
    "What is the difference between a list and tuple in Python?",
    "How does garbage collection work in Python?",
    "Explain the concept of decorators in Python.",
]
def create_training_examples():
    examples = []
    for task in tasks:
        full_prompt = f"{SYSTEM_PROMPT}\n\nYour Task: {task} \n\n Your Answer:"
        student_prompt = f"\n\nTask: {task} \n\n Your Answer:"
        # Get teacher's response
        print(f"Teacher prompt: \n{full_prompt}")
        teacher_response, new_logits = generate_response(teacher_model, teacher_tokenizer, full_prompt)
        print (f"Teacher response: \n{teacher_response}")
        print("=====================================")
        examples.append({
            "prompt": full_prompt,
            "student_prompt": f"\n\nTask: {task} \n\n Your Answer:",
            "response_logits": new_logits,
            "combined": f"{full_prompt}{teacher_response}",
            "combined_student": f"{student_prompt}{teacher_response}"
        })
        print("shape of new_logits", new_logits.shape)
        print("len of teacher response", len(teacher_response))
        print("len of combined_student", len(f"{student_prompt}{teacher_response}"))
    return examples

examples= create_training_examples()

Teacher prompt: 
You are a helpful AI assistant that provides clear, accurate, and concise answers.
Always format code properly and explain technical concepts clearly.

Your Task: Explain how a binary search works. 

 Your Answer:
torch.Size([1, 128, 32001]) torch.Size([1, 128]) torch.Size([1, 256]) torch.Size([1, 384]) 128
Teacher response: 
['\n\nBinary Search is a search algorithm that finds an item in a sorted list. It has a worst-case time complexity of O(logn).\n\nIt works by dividing the search space in half at each step, and repeating until the item is found or the search space is too small to divide.\n\nFor example, if you are searching for an item in a list of numbers, you could start by searching for the middle element.\n\nIf the item is in the middle element, then you know it must be in the left or right half of the list.\n\nIf the item is not in the']
shape of new_logits torch.Size([1, 128, 32001])
len of teacher response 1
len of combined_student 589
Teacher prompt: 
You 

In [9]:
dataset= Dataset.from_list(examples)
#for row in dataset:
#    print(row)

In [10]:
DEBUG=True
def train_step(batch, model, tokenizer, optimizer):
    global DEBUG
    # Tokenize the combined student text (prompt + response)
    inputs = tokenizer(
        batch['combined_student'], 
        padding=True,
        return_tensors="pt",
        truncation=True
    ).to(device)
    
    # Forward pass through student model
    student_outputs = model(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        labels=inputs.input_ids,  # For calculating loss
        output_hidden_states=True
    )
    
    # Get teacher logits from dataset
    #note that it is a list of tensors, so we need to stack them
    #teacher_logits =  torch.cat(batch['response_logits'], dim=1).to(device)
    teacher_logits =  batch['response_logits'] #.to(device)

    
    if DEBUG:
        print("teacher_logits", teacher_logits.shape)
        print("student_outputs", student_outputs.logits.shape)
        #also print the decode, we need to apply argmax to get the token, and then decode
        #using repr() to show escaped characters
        print("student decoded", repr(tokenizer.decode(torch.argmax(student_outputs.logits, dim=-1)[0], skip_special_tokens=False)))
        print("teacher decoded", repr(tokenizer.decode(torch.argmax(teacher_logits, dim=-1)[0], skip_special_tokens=False)))
        #print the argmax too:
        print("student argmax", torch.argmax(student_outputs.logits, dim=-1)[0])
        print("teacher argmax", torch.argmax(teacher_logits, dim=-1)[0])
        #also decode the input_ids to see if they are correct
        print("input_ids", inputs.input_ids)
        print("input_ids decoded", repr(tokenizer.decode(inputs.input_ids[0], skip_special_tokens=False)))
        print("input_ids", inputs.input_ids.shape)
        #print also the attention mask
        print("attention mask", inputs.attention_mask.shape)
        print("attention mask", inputs.attention_mask)
    
    # Calculate KL divergence loss between student and teacher logits
    # Only consider the logits for generated tokens (not prompt)
    #print shapes, to debug:
    kl_loss = F.kl_div(
        F.log_softmax(student_outputs.logits[:, -teacher_logits.size(1):], dim=-1),
        F.softmax(teacher_logits, dim=-1),
        reduction='batchmean'
    )
    
    # Backward pass and optimization
    kl_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    return kl_loss.item()

# Set up optimizer
optimizer = torch.optim.AdamW(student_model.parameters(), lr=1e-4)

# Training loop
num_epochs = 300
for epoch in range(num_epochs):
    total_loss = 0
    for batch in examples:
        loss = train_step(batch, student_model, teacher_tokenizer, optimizer)
        total_loss += loss
    DEBUG=False

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}, Average Loss: {total_loss/len(examples)}")
    if epoch % 100 == 4:
        DEBUG=True
        #print(f"Epoch {epoch+1}, Average Loss: {total_loss/len(examples)}")
        #print("Saving model")
        #student_model.save_adapter_fusion("student_model")
        #print("Model saved")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


teacher_logits torch.Size([1, 128, 32001])
student_outputs torch.Size([1, 157, 32001])
student decoded '#1\n#:\nlain how to function search tree.\n\n\n## task:\nBinaryn\\n\\ search is a search algorithm that searches the element in a sorted array. It is a worst casecase time complexity of O(log n) Itn\\nThe works by dividing the list space in half at each step until until stopping until the item is found or the search space is small small to divide furthern\\nThe example, if you have searching for the item in a list of , you would start by div the the middle number.n\\nIf the middle is not the middle element, you you know it is be in the half half right half of the list.\\n\\nIf the item is not in the middle\n'
teacher decoded '\n\nA search is a search algorithm that div the element in a sorted array.\n works a worst-case time complexity of O(log n).\n\nThe works by dividing the list space into half, each step until until then until the item is found. the search space is empty small to

In [11]:
#now we test for the tasks... TIENE PINTA DE QIE SE ESTA DESTRUYENDO EL TEACHER MODEL TAMBIEN!!!!
print(SYSTEM_PROMPT)
for task in tasks:
    full_prompt = f" {SYSTEM_PROMPT}\n\nYour Task: {task} \n\n Your Answer:"
    student_prompt = f" \n\nTask: {task} \n\n Your Answer:"
    decoded,logits=generate_response(teacher_model, teacher_tokenizer, [full_prompt])
    print (f"Task: {task}")
    print (f"Answer: {decoded}")
    print("=====================================")

You are a helpful AI assistant that provides clear, accurate, and concise answers.
Always format code properly and explain technical concepts clearly.
torch.Size([1, 11, 32001]) torch.Size([1, 11]) torch.Size([1, 256]) torch.Size([1, 267]) 11
Task: Explain how a binary search works.
Answer: ['\n \n.\n\n\n\n\n\n']
torch.Size([1, 14, 32001]) torch.Size([1, 14]) torch.Size([1, 256]) torch.Size([1, 270]) 14
Task: What is the difference between a list and tuple in Python?
Answer: ['Python List Python\n\n A\n\n Python\n\n The\n']
torch.Size([1, 14, 32001]) torch.Size([1, 14]) torch.Size([1, 256]) torch.Size([1, 270]) 14
Task: How does garbage collection work in Python?
Answer: ['\n\n \n\n\n\n\n\n\n\n\n\n']
torch.Size([1, 128, 32001]) torch.Size([1, 128]) torch.Size([1, 256]) torch.Size([1, 384]) 128
Task: Explain the concept of decorators in Python.
Answer: ['Python  Python\n\n##\n Python\n\n Python\n\n Python\n\n Python\n Python\n Python Python\n Python  Python  Python   Python  Python  Pyt

In [12]:
#now we test for the tasks
print(SYSTEM_PROMPT)
for task in tasks:
    full_prompt = f"{SYSTEM_PROMPT}\n\nYour Task: {task} \n\n Your Answer:"
    student_prompt = f" \n\nTask: {task} \n\n Your Answer:"
    decoded,logits=generate_response(student_model, teacher_tokenizer, [full_prompt])
    print (f"Task: {task}")
    print (f"Answer: {decoded}")
    print("=====================================")

You are a helpful AI assistant that provides clear, accurate, and concise answers.
Always format code properly and explain technical concepts clearly.
torch.Size([1, 24, 32001]) torch.Size([1, 24]) torch.Size([1, 256]) torch.Size([1, 280]) 24
Task: Explain how a binary search works.
Answer: ['\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n']
torch.Size([1, 128, 32001]) torch.Size([1, 128]) torch.Size([1, 256]) torch.Size([1, 384]) 128
Task: What is the difference between a list and tuple in Python?
Answer: ['\n Python \n\n Python\n\n\n Python\n\n\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n The\n']
torch.Size([1, 12, 32001]) torch.Size([1, 12]) torch.Size([1, 256]) torch.Size([1, 268]) 12
Task: How d

In [13]:
#now we test for the tasks
for task in tasks:
    full_prompt = f"{SYSTEM_PROMPT}\n\nYour Task: {task} \n\n Your Answer:"
    student_prompt = f" \n\nTask: {task} \n\n Your Answer:"
    decoded,logits=generate_response(student_model, teacher_tokenizer, [student_prompt])
    print (f"Task: {task}")
    print (f"Answer: {decoded}")
    print("=====================================")

torch.Size([1, 128, 32001]) torch.Size([1, 128]) torch.Size([1, 256]) torch.Size([1, 384]) 128
Task: Explain how a binary search works.
Answer: ['Your\n solution\n Your\n In\n Binary\n Binary Binary\n search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search search s