# Setup Environment and Dependencies
Install required packages including torch, transformers, and unsloth. Setup GPU environment on Colab T4.

In [None]:
# Install required packages
!pip install torch transformers unsloth

# Check if GPU is available
import torch
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

# Import Required Libraries and Code
Import necessary libraries and copy over the TwoModelTrainer code from TSL-Training-Single-Dynamic.py.

In [None]:
# Import Required Libraries and Code

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from typing import List, Dict, Tuple
import logging
import wandb
from trl import SFTTrainer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class PromptResponseDataset(Dataset):
    """Dataset for prompt-response pairs"""
    def __init__(self, data: List[Dict[str, str]], tokenizer, max_length: int = 512):
        self.tokenizer = tokenizer
        self.data = data
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        item = self.data[idx]
        
        # Tokenize input prompt
        prompt_encoding = self.tokenizer(
            item['prompt'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize target response
        target_encoding = self.tokenizer(
            item['target'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'prompt_ids': prompt_encoding['input_ids'].squeeze(),
            'prompt_mask': prompt_encoding['attention_mask'].squeeze(),
            'target_ids': target_encoding['input_ids'].squeeze(),
            'target_mask': target_encoding['attention_mask'].squeeze()
        }

class TwoModelTrainer:
    def __init__(
        self,
        static_model_name: str,
        dynamic_model_name: str,
        tokenizer_name: str,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    ):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        
        # Initialize models
        logger.info(f"Loading static model {static_model_name}")
        self.static_model = AutoModelForCausalLM.from_pretrained(static_model_name).to(device)
        self.static_model.eval()  # Set to evaluation mode
        
        logger.info(f"Loading dynamic model {dynamic_model_name}")
        self.dynamic_model = AutoModelForCausalLM.from_pretrained(dynamic_model_name).to(device)
        
        # Freeze the static model
        for param in self.static_model.parameters():
            param.requires_grad = False

    def compute_loss(
        self,
        dynamic_output: torch.Tensor,
        static_output: torch.Tensor,
        target_ids: torch.Tensor,
        target_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute the combined loss between dynamic model output -> static model output -> target
        """
        # Loss between static model output and target
        static_loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)(
            static_output.view(-1, static_output.size(-1)),
            target_ids.view(-1)
        )
        
        # Additional loss terms could be added here, such as:
        # - KL divergence between dynamic and static outputs
        # - Auxiliary objectives for the dynamic model
        # - Regularization terms
        
        return static_loss

    def train(
        self,
        train_dataset: Dataset,
        val_dataset: Dataset,
        batch_size: int = 8,
        num_epochs: int = 3,
        learning_rate: float = 5e-5,
        warmup_steps: int = 100,
        gradient_accumulation_steps: int = 1,
        max_grad_norm: float = 1.0,
        log_wandb: bool = True
    ):
        """Train the dynamic model"""
        if log_wandb:
            wandb.init(project="two-model-finetuning")
        
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True
        )
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=batch_size
        )
        
        # Initialize optimizer and scheduler
        optimizer = optim.AdamW(
            self.dynamic_model.parameters(),
            lr=learning_rate
        )
        
        scheduler = optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=learning_rate,
            epochs=num_epochs,
            steps_per_epoch=len(train_dataloader)
        )
        
        for epoch in range(num_epochs):
            self.dynamic_model.train()
            total_loss = 0
            
            for step, batch in enumerate(train_dataloader):
                # Move batch to device
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Generate output from dynamic model
                dynamic_output = self.dynamic_model(
                    input_ids=batch['prompt_ids'],
                    attention_mask=batch['prompt_mask']
                ).logits
                
                # Pass dynamic output through static model
                with torch.no_grad():
                    static_output = self.static_model(
                        input_ids=torch.argmax(dynamic_output, dim=-1),
                        attention_mask=batch['prompt_mask']
                    ).logits
                
                # Compute loss
                loss = self.compute_loss(
                    dynamic_output,
                    static_output,
                    batch['target_ids'],
                    batch['target_mask']
                )
                
                # Scale loss for gradient accumulation
                loss = loss / gradient_accumulation_steps
                loss.backward()
                
                total_loss += loss.item()
                
                # Gradient accumulation and optimization
                if (step + 1) % gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(
                        self.dynamic_model.parameters(),
                        max_grad_norm
                    )
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                
                # Log metrics
                if log_wandb and step % 100 == 0:
                    wandb.log({
                        "loss": loss.item(),
                        "learning_rate": scheduler.get_last_lr()[0],
                    })
                
            # Validation loop
            val_loss = self.evaluate(val_dataloader)
            
            logger.info(f"Epoch {epoch + 1}/{num_epochs}")
            logger.info(f"Average training loss: {total_loss / len(train_dataloader)}")
            logger.info(f"Validation loss: {val_loss}")
            
            if log_wandb:
                wandb.log({
                    "epoch": epoch,
                    "train_loss": total_loss / len(train_dataloader),
                    "val_loss": val_loss
                })

    def evaluate(self, dataloader: DataLoader) -> float:
        """Evaluate the model on the validation set"""
        self.dynamic_model.eval()
        total_loss = 0
        
        with torch.no_grad():
            for batch in dataloader:
                batch = {k: v.to(self.device) for k, v in batch.items()}
                
                # Dynamic model generates initial output
                dynamic_output = self.dynamic_model(
                    input_ids=batch['prompt_ids'],
                    attention_mask=batch['prompt_mask']
                ).logits
                
                # Dynamic output is passed through static model
                static_output = self.static_model(
                    input_ids=torch.argmax(dynamic_output, dim=-1),
                    attention_mask=batch['prompt_mask']
                ).logits
                
                loss = self.compute_loss(
                    dynamic_output,
                    static_output,
                    batch['target_ids'],
                    batch['target_mask']
                )
                
                total_loss += loss.item()
        
        return total_loss / len(dataloader)

    def save_model(self, path: str):
        """Save the dynamic model"""
        self.dynamic_model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)

# Initialize Models and Tokenizer
Initialize the Llama 3.2 1B and 3B models using unsloth, configure tokenizer with proper chat template.

In [None]:
# Initialize Models and Tokenizer

from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template

# Initialize tokenizer with chat template
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct")
tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1")

# Initialize models
static_model_name = "unsloth/Llama-3.2-1B-Instruct"
dynamic_model_name = "unsloth/Llama-3.2-3B-Instruct"

logger.info(f"Loading static model {static_model_name}")
static_model, _ = FastLanguageModel.from_pretrained(
    model_name=static_model_name,
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True
)
static_model = static_model.to(device)
static_model.eval()  # Set to evaluation mode

logger.info(f"Loading dynamic model {dynamic_model_name}")
dynamic_model, _ = FastLanguageModel.from_pretrained(
    model_name=dynamic_model_name,
    max_seq_length=2048, 
    dtype=None,
    load_in_4bit=True
)
dynamic_model = dynamic_model.to(device)

# Freeze the static model
for param in static_model.parameters():
    param.requires_grad = False

# Create Dataset and Dataloaders
Create PromptResponseDataset class and dataloaders for training data.

In [None]:
# Create Dataset and Dataloaders

# Load and prepare FineTome dataset
from datasets import load_dataset
dataset = load_dataset("mlabonne/FineTome-100k", split="train")

# Convert ShareGPT format to HF format 
from unsloth.chat_templates import standardize_sharegpt
dataset = standardize_sharegpt(dataset)

# Apply proper chat template formatting
def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) 
             for convo in convos]
    return {"text": texts}

dataset = dataset.map(formatting_prompts_func, batched=True)

# Split into train/val
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size]
)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8)

# Configure Two-Model Training
Setup TwoModelTrainer with static and dynamic models, configure training parameters.

In [None]:
# Configure Two-Model Training

# Initialize TwoModelTrainer
trainer = TwoModelTrainer(
    static_model_name="unsloth/Llama-3.2-1B-Instruct",
    dynamic_model_name="unsloth/Llama-3.2-3B-Instruct",
    tokenizer_name="unsloth/Llama-3.2-1B-Instruct",
    device=device
)

# Configure training parameters
batch_size = 8
num_epochs = 3
learning_rate = 5e-5
warmup_steps = 100
gradient_accumulation_steps = 1
max_grad_norm = 1.0
log_wandb = True

# Train the model
trainer.train(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    learning_rate=learning_rate,
    warmup_steps=warmup_steps,
    gradient_accumulation_steps=gradient_accumulation_steps,
    max_grad_norm=max_grad_norm,
    log_wandb=log_wandb
)

# Save the trained model
trainer.save_model("path/to/save/model")

# Training Loop
Run the training loop with gradient accumulation and validation.

In [None]:
# Training Loop

# Initialize optimizer and scheduler
optimizer = optim.AdamW(
    trainer.dynamic_model.parameters(),
    lr=learning_rate
)

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=learning_rate,
    epochs=num_epochs,
    steps_per_epoch=len(train_dataloader)
)

for epoch in range(num_epochs):
    trainer.dynamic_model.train()
    total_loss = 0

    for step, batch in enumerate(train_dataloader):
        # Move batch to device
        batch = {k: v.to(trainer.device) for k, v in batch.items()}

        # Generate output from dynamic model
        dynamic_output = trainer.dynamic_model(
            input_ids=batch['prompt_ids'],
            attention_mask=batch['prompt_mask']
        ).logits

        # Pass dynamic output through static model
        with torch.no_grad():
            static_output = trainer.static_model(
                input_ids=torch.argmax(dynamic_output, dim=-1),
                attention_mask=batch['prompt_mask']
            ).logits

        # Compute loss
        loss = trainer.compute_loss(
            dynamic_output,
            static_output,
            batch['target_ids'],
            batch['target_mask']
        )

        # Scale loss for gradient accumulation
        loss = loss / gradient_accumulation_steps
        loss.backward()

        total_loss += loss.item()

        # Gradient accumulation and optimization
        if (step + 1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(
                trainer.dynamic_model.parameters(),
                max_grad_norm
            )
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        # Log metrics
        if log_wandb and step % 100 == 0:
            wandb.log({
                "loss": loss.item(),
                "learning_rate": scheduler.get_last_lr()[0],
            })

    # Validation loop
    val_loss = trainer.evaluate(val_dataloader)

    logger.info(f"Epoch {epoch + 1}/{num_epochs}")
    logger.info(f"Average training loss: {total_loss / len(train_dataloader)}")
    logger.info(f"Validation loss: {val_loss}")

    if log_wandb:
        wandb.log({
            "epoch": epoch,
            "train_loss": total_loss / len(train_dataloader),
            "val_loss": val_loss
        })

# Model Inference
Test the trained model with example prompts (not using TextStreamer for output because it would need to be configured in-line between the models)

In [None]:
# Model Inference
import torch

# Enable fast inference optimizations
FastLanguageModel.for_inference(trainer.dynamic_model)
FastLanguageModel.for_inference(trainer.static_model)

# Example prompts for testing
example_prompts = [
    {"role": "user", "content": "What is the capital of France?"}, 
    {"role": "user", "content": "Explain the theory of relativity."},
]

# Tokenize inputs with chat template
inputs = tokenizer.apply_chat_template(
    example_prompts,
    tokenize=True, 
    add_generation_prompt=True,
    return_tensors="pt"
).to(device)

# Generate with dynamic model
dynamic_outputs = trainer.dynamic_model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"], 
    max_new_tokens=64,
    use_cache=True,
    temperature=1.5,
    min_p=0.1,
)

# Pass through static model
with torch.no_grad():
    static_outputs = trainer.static_model(
        input_ids=dynamic_outputs,
        attention_mask=torch.ones_like(dynamic_outputs).bool()
    ).logits

# Get final responses
final_tokens = torch.argmax(static_outputs, dim=-1)
responses = tokenizer.batch_decode(final_tokens, skip_special_tokens=True)

# Display results
for i, response in enumerate(responses):
    print(f"\nPrompt {i+1}: {example_prompts[i]['content']}")
    print(f"Response {i+1}: {response}")

# Save and Export Models
Save the trained models in LoRA format and export to GGUF format.

In [None]:
# Save and Export Models

# Save the trained model in LoRA format
trainer.dynamic_model.save_pretrained("TSL_lora_model")
tokenizer.save_pretrained("TSL_lora_model")

# Save the model in GGUF format
if True:
    trainer.dynamic_model.save_pretrained_gguf("TSL_model_gguf", tokenizer)
    # Remember to go to https://huggingface.co/settings/tokens for a token!
    # And change hf to your username!
    trainer.dynamic_model.push_to_hub_gguf("Solshine/TSL_model_gguf", tokenizer, token=HF_TOKEN)

# Save the model in 16bit GGUF format
if True:
    trainer.dynamic_model.save_pretrained_gguf("TSL_model_gguf_16bit", tokenizer, quantization_method="f16")
    trainer.dynamic_model.push_to_hub_gguf("Solshine/TSL_model_gguf_16bit", tokenizer, quantization_method="f16", token=HF_TOKEN)

# Save the model in q4_k_m GGUF format
if False:
    trainer.dynamic_model.save_pretrained_gguf("model_gguf_q4_k_m", tokenizer, quantization_method="q4_k_m")
    trainer.dynamic_model.push_to_hub_gguf("hf/model_gguf_q4_k_m", tokenizer, quantization_method="q4_k_m", token="")

# Save the model in multiple GGUF formats
if False:
    trainer.dynamic_model.push_to_hub_gguf(
        "hf/model_gguf_multiple",  # Change hf to your username!
        tokenizer,
        quantization_method=["q4_k_m", "q8_0", "q5_k_m"],
        token="",  # Get a token at https://huggingface.co/settings/tokens
    )