# LEMA Benchmark Notebook
This notebook was auto-generated to verify the LEMA framework on Kaggle GPUs.
It compares Standard Fine-Tuning (PEFT) vs LEMA (Virtual Memory) VRAM usage.

In [None]:
!mkdir -p src/lema/core src/lema/engine src/lema/models src/lema/utils tasks

In [None]:
%%writefile src/lema/config.py
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any
from enum import Enum

class MemoryStrategy(Enum):
    STREAMING = "streaming" # Disk -> RAM -> VRAM
    RESIDENT = "resident"   # RAM -> VRAM (No Disk offload for weights)

@dataclass
class LemaConfig:
    """
    Central Configuration for LEMA Training/Inference.
    """
    # Model Settings
    model_name_or_path: str
    gbi_path: Optional[str] = None # Path to converted safetensors for GBI
    
    # Hardware / Memory Settings
    device: str = "cuda"
    strategy: MemoryStrategy = MemoryStrategy.STREAMING
    ram_buffer_size: int = 2 # Number of layers to keep in RAM
    vram_buffer_size: int = 1 # Number of layers to keep in VRAM
    
    # LoRA Settings
    use_lora: bool = True
    lora_rank: int = 16
    lora_alpha: int = 32
    lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
    
    # Training Settings
    learning_rate: float = 1e-4
    batch_size: int = 1
    gradient_accumulation_steps: int = 1
    max_seq_length: int = 512
    gradient_checkpointing: bool = False
    
    # Advanced
    dtype: str = "float16" # float16, bfloat16, float32
    attn_implementation: str = "eager" # eager, sdpa, flash_attention_2

    def __post_init__(self):
        if self.gbi_path is None:
            # Default to expecting a local safetensors file named after the model or a standard name
            if self.model_name_or_path.endswith(".safetensors"):
                self.gbi_path = self.model_name_or_path
            else:
                self.gbi_path = "model.safetensors"
        
        if isinstance(self.strategy, str):
            self.strategy = MemoryStrategy(self.strategy.lower())

    def to_dict(self) -> Dict[str, Any]:
        return {
            k: v.value if isinstance(v, Enum) else v 
            for k, v in self.__dict__.items()
        }


In [None]:
%%writefile src/lema/engine/trainer.py
import torch
import torch.nn.functional as F
import threading
from typing import Any, Optional, List, Union
from ..core.memory import TripleBufferManager
from ..models.base import LemaModelAdapter
from ..config import LemaConfig, MemoryStrategy

class LemaTrainer:
    def __init__(self, 
                 config: LemaConfig,
                 model_adapter: LemaModelAdapter, 
                 gbi, 
                 lora_manager=None, 
                 optimizer=None):
        
        self.config = config
        self.adapter = model_adapter
        self.gbi = gbi
        self.device = config.device
        self.strategy = config.strategy
        
        # Initialize Memory Manager with config strategy
        # Note: TripleBufferManager might need updates to accept config too, 
        # but for now we map config values to its expected args
        self.memory = TripleBufferManager(gbi, model_adapter, self.device, strategy=self.strategy)
        
        self.layers = self.adapter.get_layer_metadata()
        self.lora_manager = lora_manager
        self.optimizer = optimizer
        
    def train_step(self, inputs: Any, labels: Optional[torch.Tensor] = None):
        """
        Executes one forward pass and one backward pass.
        If labels are provided, computes CrossEntropyLoss.
        """
        boundary_activations: List[torch.Tensor] = []
        is_streaming = (self.strategy == MemoryStrategy.STREAMING)
        
        # --- FORWARD PASS ---
        if is_streaming:
            self.memory.prefetch_to_ram(self.layers[0]['id'], 0)
            self.memory.async_transfer_to_vram(self.layers[0]['id'], 0, ram_slot=0)
            if len(self.layers) > 1:
                self.memory.prefetch_to_ram(self.layers[1]['id'], 1)
        else:
            self.memory.async_transfer_to_vram(self.layers[0]['id'], 0)

        hidden_states = inputs
        
        for i, layer_meta in enumerate(self.layers):
            slot = i % 2
            next_slot = (i + 1) % 2
            
            flat_vram = self.memory.get_vram_flat_buffer(slot)
            
            disk_thread = None
            if i + 1 < len(self.layers):
                if is_streaming:
                    self.memory.async_transfer_to_vram(self.layers[i+1]['id'], next_slot, ram_slot=next_slot)
                    if i + 2 < len(self.layers):
                        disk_thread = threading.Thread(target=self.memory.prefetch_to_ram, args=(self.layers[i+2]['id'], slot))
                        disk_thread.start()
                else:
                    self.memory.async_transfer_to_vram(self.layers[i+1]['id'], next_slot)
            
            layer_module = self.adapter.construct_layer_module(layer_meta['id'], flat_vram, self.lora_manager)
            
            # Store input for backward
            if isinstance(hidden_states, tuple): 
                 current_input = hidden_states[0].detach()
            else:
                current_input = hidden_states.detach()
            boundary_activations.append(current_input)
            
            with torch.no_grad():
                hidden_states = self.adapter.forward_layer(layer_module, hidden_states, gradient_checkpointing=False)

            if disk_thread: disk_thread.join()
            if hasattr(self.adapter, "release_layer_module"):
                self.adapter.release_layer_module(layer_module)
            del layer_module

        # Final Logits
        logits = hidden_states
        loss_val = None

        # --- BACKWARD PASS ---
        if not torch.is_grad_enabled():
            return logits, None

        last_idx = len(self.layers) - 1
        if is_streaming:
            self.memory.prefetch_to_ram(self.layers[last_idx]['id'], 0)
            self.memory.async_transfer_to_vram(self.layers[last_idx]['id'], 0, ram_slot=0)
            if last_idx > 0:
                self.memory.prefetch_to_ram(self.layers[last_idx-1]['id'], 1)
        else:
            self.memory.async_transfer_to_vram(self.layers[last_idx]['id'], 0)
        
        grad_output = None
        
        for i in range(last_idx, -1, -1):
            slot = (last_idx - i) % 2
            next_slot = (last_idx - i + 1) % 2
            
            flat_vram = self.memory.get_vram_flat_buffer(slot)
            
            disk_thread = None
            if i - 1 >= 0:
                if is_streaming:
                    self.memory.async_transfer_to_vram(self.layers[i-1]['id'], next_slot, ram_slot=next_slot)
                    if i - 2 >= 0:
                        disk_thread = threading.Thread(target=self.memory.prefetch_to_ram, args=(self.layers[i-2]['id'], slot))
                        disk_thread.start()
                else:
                    self.memory.async_transfer_to_vram(self.layers[i-1]['id'], next_slot)
            
            layer_module = self.adapter.construct_layer_module(self.layers[i]['id'], flat_vram, self.lora_manager)
            layer_input = boundary_activations[i]
            if layer_input.dtype.is_floating_point:
                layer_input.requires_grad_(True)
            
            output = self.adapter.forward_layer(layer_module, layer_input, gradient_checkpointing=self.config.gradient_checkpointing)
            
            if i == last_idx:
                if labels is not None:
                    # Real Causal LM Loss
                    # Shift so that tokens < n predict n
                    shift_logits = output[..., :-1, :].contiguous()
                    shift_labels = labels[..., 1:].contiguous()
                    loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                    loss_val = loss.item()
                else:
                    loss = output.mean() # Dummy
                
                loss.backward()
                grad_output = layer_input.grad
            else:
                if isinstance(output, tuple):
                    output[0].backward(grad_output)
                else:
                    output.backward(grad_output)
                grad_output = layer_input.grad
            
            if disk_thread: disk_thread.join()
            if hasattr(self.adapter, "release_layer_module"):
                self.adapter.release_layer_module(layer_module)
            del layer_module

        if self.optimizer:
            self.optimizer.step()
            self.optimizer.zero_grad()
            
        return logits, loss_val

In [None]:
%%writefile src/lema/core/lora.py
import torch
import torch.nn as nn
import math
from typing import Dict, Tuple, Optional, List
try:
    from transformers.pytorch_utils import Conv1D
except ImportError:
    Conv1D = None

class LoRAWrapper(nn.Module):
    """
    Wraps a Linear or Conv1D layer with LoRA adapters.
    """
    def __init__(self, base_layer: nn.Module, rank: int, alpha: float, lora_A: nn.Parameter, lora_B: nn.Parameter):
        super().__init__()
        self.base_layer = base_layer
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        self.lora_A = lora_A
        self.lora_B = lora_B
        
    def forward(self, x):
        # Base forward
        result = self.base_layer(x)
        
        # LoRA forward
        # Calculation: (x @ A.T @ B.T) * scaling
        lora_out = (x @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
        return result + lora_out

class LoRAManager:
    """
    Manages the lifecycle and storage of LoRA parameters.
    """
    def __init__(self, config: Dict, device="cuda"):
        self.rank = config.get("r", 8)
        self.alpha = config.get("alpha", 16)
        self.target_modules = config.get("target_modules", ["c_attn", "c_proj", "c_fc"])
        self.device = device
        
        # Store parameters: key -> {'A': Param, 'B': Param}
        self.params: Dict[str, Dict[str, nn.Parameter]] = {}
        
    def get_or_init_params(self, layer_id: int, module_name: str, in_features: int, out_features: int) -> Dict[str, nn.Parameter]:
        key = f"{layer_id}.{module_name}"
        
        if key not in self.params:
            lora_A = torch.zeros((self.rank, in_features), device=self.device)
            nn.init.kaiming_uniform_(lora_A, a=math.sqrt(5))
            
            lora_B = torch.zeros((out_features, self.rank), device=self.device)
            nn.init.zeros_(lora_B)
            
            self.params[key] = {
                'A': nn.Parameter(lora_A, requires_grad=True),
                'B': nn.Parameter(lora_B, requires_grad=True)
            }
            
        return self.params[key]

    def apply_lora(self, layer_id: int, module: nn.Module, module_name_prefix: str = ""):
        """
        Recursively replaces Linear/Conv1D layers with LoRAWrapper if they match target_modules.
        """
        for name, child in module.named_children():
            full_name = f"{module_name_prefix}.{name}" if module_name_prefix else name
            
            # Check if this is a target module
            is_target = any(name == target or name.endswith(target) for target in self.target_modules)
            
            if isinstance(child, LoRAWrapper) and is_target:
                # Already wrapped, just swap parameters for the new layer
                if isinstance(child.base_layer, nn.Linear):
                    in_features = child.base_layer.in_features
                    out_features = child.base_layer.out_features
                elif Conv1D is not None and isinstance(child.base_layer, Conv1D):
                    in_features = child.base_layer.weight.shape[0]
                    out_features = child.base_layer.weight.shape[1]
                else:
                    # Generic fallback if weight exists
                    in_features = child.base_layer.weight.shape[1] if hasattr(child.base_layer, "weight") else 0
                    out_features = child.base_layer.weight.shape[0] if hasattr(child.base_layer, "weight") else 0

                params = self.get_or_init_params(layer_id, full_name, in_features, out_features)
                child.lora_A = params['A']
                child.lora_B = params['B']
                continue

            in_features = None
            out_features = None
            
            if isinstance(child, nn.Linear) and is_target:
                in_features = child.in_features
                out_features = child.out_features
            elif Conv1D is not None and isinstance(child, Conv1D) and is_target:
                in_features = child.weight.shape[0]
                out_features = child.weight.shape[1]
                
            if in_features is not None and out_features is not None:
                params = self.get_or_init_params(layer_id, full_name, in_features, out_features)
                
                lora_layer = LoRAWrapper(
                    base_layer=child,
                    rank=self.rank,
                    alpha=self.alpha,
                    lora_A=params['A'],
                    lora_B=params['B']
                )
                setattr(module, name, lora_layer)
            else:
                self.apply_lora(layer_id, child, full_name)

    def update_lora_params(self, layer_id: int, module: nn.Module):
        """
        Efficiently updates LoRA parameters for a reused module.
        Uses cached wrapper list if available, otherwise traverses and builds cache.
        """
        if not hasattr(module, "_lora_cache"):
            module._lora_cache = []
            # First time: Traverse and collect wrappers
            # We reuse apply_lora logic but adapted for collection
            self._collect_and_update_wrappers(layer_id, module, "", module._lora_cache)
        else:
            # Fast path: Update parameters from cache
            for wrapper, name, in_f, out_f in module._lora_cache:
                params = self.get_or_init_params(layer_id, name, in_f, out_f)
                wrapper.lora_A = params['A']
                wrapper.lora_B = params['B']

    def _collect_and_update_wrappers(self, layer_id: int, module: nn.Module, prefix: str, cache: List):
        for name, child in module.named_children():
            full_name = f"{prefix}.{name}" if prefix else name
            
            if isinstance(child, LoRAWrapper):
                # Already wrapped (from previous usage or just now)
                in_f = child.base_layer.in_features if hasattr(child.base_layer, "in_features") else child.base_layer.weight.shape[1]
                out_f = child.base_layer.out_features if hasattr(child.base_layer, "out_features") else child.base_layer.weight.shape[0]
                
                params = self.get_or_init_params(layer_id, full_name, in_f, out_f)
                child.lora_A = params['A']
                child.lora_B = params['B']
                
                cache.append((child, full_name, in_f, out_f))
                continue
            
            # Check if this is a target module to wrap
            is_target = any(name == target or name.endswith(target) for target in self.target_modules)
            
            if is_target and (isinstance(child, nn.Linear) or (Conv1D is not None and isinstance(child, Conv1D))):
                in_features = child.in_features if isinstance(child, nn.Linear) else child.weight.shape[0]
                out_features = child.out_features if isinstance(child, nn.Linear) else child.weight.shape[1]
                
                params = self.get_or_init_params(layer_id, full_name, in_features, out_features)
                
                lora_layer = LoRAWrapper(
                    base_layer=child,
                    rank=self.rank,
                    alpha=self.alpha,
                    lora_A=params['A'],
                    lora_B=params['B']
                )
                setattr(module, name, lora_layer)
                cache.append((lora_layer, full_name, in_features, out_features))
            else:
                self._collect_and_update_wrappers(layer_id, child, full_name, cache)

    def get_trainable_parameters(self) -> List[torch.nn.Parameter]:
        """
        Returns a list of all nn.Parameter objects managed by this manager.
        """
        all_params = []
        for p_dict in self.params.values():
            all_params.append(p_dict['A'])
            all_params.append(p_dict['B'])
        return all_params

In [None]:
%%writefile src/lema/core/gbi.py
import torch
from safetensors import safe_open
from typing import Dict, List, Any, Optional
import os

class GlobalBinaryIndex:
    """
    GBI v0.4: Contiguous Block Access.
    Allows fetching a whole layer as a single byte-range if possible,
    but here we focus on providing tensors for contiguous packing.
    """

    def __init__(self, model_path: str):
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")
        
        self.model_path = model_path
        self.handle = safe_open(self.model_path, framework="pt", device="cpu")
        self.keys = list(self.handle.keys())

    def load_tensors(self, param_names: List[str], device: str = "cpu") -> Dict[str, torch.Tensor]:
        tensors = {}
        for name in param_names:
            tensors[name] = self.handle.get_tensor(name)
        return tensors

In [None]:
%%writefile src/lema/core/memory.py
import torch
import threading
from typing import Dict, Optional, List, Tuple
from enum import Enum
import gc
from ..config import MemoryStrategy

class TripleBufferManager:
    """
    Unified Memory Manager supporting both Disk-Streaming and RAM-Residency.
    """
    def __init__(self, gbi, adapter, device="cuda", strategy=MemoryStrategy.STREAMING):
        self.gbi = gbi
        self.adapter = adapter
        self.device = device
        self.strategy = strategy
        
        self.is_cuda = self.device.startswith("cuda")
        self.transfer_streams = [torch.cuda.Stream() for _ in range(2)] if self.is_cuda else None
        
        self.layers_meta = self.adapter.get_layer_metadata()
        
        # Calculate max layer size for pre-allocating buffers
        self.max_params = self._calculate_max_params()
        
        # Pre-allocated VRAM slots (Double buffering)
        self.vram_flat_buffers = [
            torch.empty(self.max_params, device=self.device, dtype=torch.float32)
            for _ in range(2)
        ]
        
        # RAM Buffers
        if self.strategy == MemoryStrategy.RESIDENT:
            print(f"LEMA: Initializing RESIDENT strategy (Caching model in RAM)...")
            self.ram_flat_buffers: Dict[int, torch.Tensor] = {}
            self._initialize_full_ram_cache()
        else:
            print(f"LEMA: Initializing STREAMING strategy (Default)...")
            # In streaming mode, we only need 2 RAM slots for the pipeline
            self.ram_flat_buffers: List[torch.Tensor] = [
                torch.empty(self.max_params, device="cpu", dtype=torch.float32).pin_memory() if self.is_cuda else torch.empty(self.max_params, device="cpu", dtype=torch.float32)
                for _ in range(2)
            ]
            self.ram_layer_ids = [-1, -1]

    def _calculate_max_params(self) -> int:
        max_p = 0
        for layer in self.layers_meta:
            names = self.adapter.get_param_names_for_layer(layer['id'])
            current_p = 0
            for name in names:
                meta = self.gbi.handle.get_slice(name)
                current_p += meta.get_shape().numel() if hasattr(meta.get_shape(), 'numel') else torch.Size(meta.get_shape()).numel()
            max_p = max(max_p, current_p)
        return max_p

    def _initialize_full_ram_cache(self):
        """Pre-packs the entire model into pinned RAM."""
        for layer in self.layers_meta:
            layer_id = layer['id']
            self._pack_layer_to_ram(layer_id, is_resident=True)

    def _pack_layer_to_ram(self, layer_id: int, slot: int = 0, is_resident: bool = False):
        """Helper to load a layer from disk and pack it into a flat RAM buffer."""
        param_names = self.adapter.get_param_names_for_layer(layer_id)
        weights = self.gbi.load_tensors(param_names, device="cpu")
        
        if is_resident:
            total_el = sum(w.numel() for w in weights.values())
            buf = torch.empty(total_el, device="cpu", dtype=torch.float32).pin_memory()
        else:
            buf = self.ram_flat_buffers[slot]
            
        offset = 0
        for name in param_names:
            w = weights[name]
            numel = w.numel()
            buf[offset : offset + numel].copy_(w.view(-1))
            offset += numel
            
        if is_resident:
            self.ram_flat_buffers[layer_id] = buf
        else:
            self.ram_layer_ids[slot] = layer_id

    def prefetch_to_ram(self, layer_id: int, ram_slot: int):
        """Stage 1 (Streaming only): Load from Disk to RAM Slot."""
        if self.strategy == MemoryStrategy.RESIDENT:
            return # No-op for resident mode
            
        if self.ram_layer_ids[ram_slot] == layer_id:
            return
        
        self._pack_layer_to_ram(layer_id, ram_slot, is_resident=False)

    def async_transfer_to_vram(self, layer_id: int, vram_slot: int, ram_slot: Optional[int] = None):
        """Stage 2: Async transfer to GPU."""
        if self.strategy == MemoryStrategy.RESIDENT:
            src_buf = self.ram_flat_buffers[layer_id]
        else:
            if ram_slot is None:
                raise ValueError("ram_slot must be provided in streaming mode")
            src_buf = self.ram_flat_buffers[ram_slot]
            
        vram_dest = self.vram_flat_buffers[vram_slot]
        
        if self.is_cuda and self.transfer_streams:
            stream = self.transfer_streams[vram_slot]
            with torch.cuda.stream(stream):
                vram_dest[:src_buf.numel()].copy_(src_buf, non_blocking=True)
        else:
            # CPU or Synchronous copy
            vram_dest[:src_buf.numel()].copy_(src_buf)

    def get_vram_flat_buffer(self, vram_slot: int) -> torch.Tensor:
        """Stage 3: Usage."""
        if self.is_cuda and self.transfer_streams:
            self.transfer_streams[vram_slot].synchronize()
        return self.vram_flat_buffers[vram_slot]

    def clear_vram_slot(self, vram_slot: int):
        pass

In [None]:
%%writefile src/lema/models/base.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Tuple, Optional
import torch
import torch.nn as nn

class LemaModelAdapter(ABC):
    """
    Abstract Base Class for LEMA Model Adapters.
    
    This class defines the interface that any model architecture must implement
    to be compatible with the LEMA (Layer-wise Efficient Memory Abstraction) framework.
    It bridges the gap between the raw binary weights managed by LEMA and the 
    PyTorch execution semantics.
    """

    def __init__(self, config: Dict[str, Any]):
        self.config = config

    @abstractmethod
    def get_layer_metadata(self) -> List[Dict[str, Any]]:
        """
        Returns a list of dictionaries, where each dictionary describes a logical "layer"
        or "block" in the model that LEMA should manage as a unit.
        
        Returns:
            List[Dict]: e.g. [{'id': 0, 'name': 'transformer.h.0', 'inputs': [...], 'outputs': [...]}, ...]
        """
        pass

    @abstractmethod
    def construct_layer_module(self, layer_id: int, weights: Dict[str, torch.Tensor], lora_manager: Optional[Any] = None) -> nn.Module:
        """
        Constructs a PyTorch nn.Module for the specified layer using the provided weights.
        The weights will be on the target device (VRAM) when passed here.
        
        Args:
            layer_id (int): The index of the layer to construct.
            weights (Dict[str, torch.Tensor]): A dictionary mapping parameter names to tensors.
            lora_manager (Optional[Any]): The LoRAManager instance to apply adapters.
            
        Returns:
            nn.Module: The executable layer module.
        """
        pass

    @abstractmethod
    def forward_layer(self, layer_module: nn.Module, inputs: Any, **kwargs) -> Any:
        """
        Executes the forward pass for a single layer.
        
        Args:
            layer_module (nn.Module): The module constructed by construct_layer_module.
            inputs (Any): The input activations (tensor or tuple of tensors).
            **kwargs: Additional arguments (e.g., attention masks, rotary embeddings).
            
        Returns:
            Any: The output activations.
        """
        pass

    @abstractmethod
    def get_param_names_for_layer(self, layer_id: int) -> List[str]:
        """
        Returns the list of parameter names (as found in the safetensors file) 
        required for the specified layer.
        
        Args:
            layer_id (int): Layer index.
            
        Returns:
            List[str]: List of parameter keys.
        """
        pass

    @property
    @abstractmethod
    def hidden_size(self) -> int:
        """Returns the model's hidden size for buffer allocation."""
        pass


In [None]:
%%writefile src/lema/models/gpt2.py
import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Config
from typing import List, Dict, Any, Optional

from .base import LemaModelAdapter

class GPT2Adapter(LemaModelAdapter):
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.hf_config = GPT2Config(**config)
        if getattr(self.hf_config, "_attn_implementation", None) is None:
            self.hf_config._attn_implementation = config.get("attn_implementation", "eager")
        self.layer_pool: List[nn.Module] = []
        self.param_mappings: Dict[int, List[tuple]] = {}
        
    def get_layer_metadata(self) -> List[Dict[str, Any]]:
        layers = []
        layers.append({'id': 0, 'name': 'embeddings', 'type': 'embedding'})
        for i in range(self.hf_config.n_layer):
            layers.append({'id': i + 1, 'name': f'h.{i}', 'type': 'block', 'block_index': i})
        layers.append({'id': self.hf_config.n_layer + 1, 'name': 'head', 'type': 'head'})
        return layers

    def get_param_names_for_layer(self, layer_id: int) -> List[str]:
        if layer_id == 0:
            return ['transformer.wte.weight', 'transformer.wpe.weight']
        elif 1 <= layer_id <= self.hf_config.n_layer:
            idx = layer_id - 1
            prefix = f"transformer.h.{idx}"
            return [
                f"{prefix}.attn.c_attn.weight", f"{prefix}.attn.c_attn.bias",
                f"{prefix}.attn.c_proj.weight", f"{prefix}.attn.c_proj.bias",
                f"{prefix}.ln_1.weight", f"{prefix}.ln_1.bias",
                f"{prefix}.ln_2.weight", f"{prefix}.ln_2.bias",
                f"{prefix}.mlp.c_fc.weight", f"{prefix}.mlp.c_fc.bias",
                f"{prefix}.mlp.c_proj.weight", f"{prefix}.mlp.c_proj.bias",
            ]
        elif layer_id == self.hf_config.n_layer + 1:
            return ['transformer.ln_f.weight', 'transformer.ln_f.bias', 'lm_head.weight']
        return []

    def construct_layer_module(self, layer_id: int, flat_buffer: Optional[torch.Tensor] = None, lora_manager: Optional[Any] = None) -> nn.Module:
        device = flat_buffer.device if flat_buffer is not None else torch.device("cpu")
        module = None
        for i, m in enumerate(self.layer_pool):
            if layer_id == 0 and isinstance(m, GPT2EmbeddingsLayer):
                module = self.layer_pool.pop(i); break
            elif layer_id == self.hf_config.n_layer + 1 and isinstance(m, GPT2HeadLayer):
                module = self.layer_pool.pop(i); break
            elif 1 <= layer_id <= self.hf_config.n_layer and isinstance(m, GPT2Block):
                module = self.layer_pool.pop(i); break
        
        if module is None:
            if layer_id == 0: module = GPT2EmbeddingsLayer(self.hf_config)
            elif layer_id == self.hf_config.n_layer + 1: module = GPT2HeadLayer(self.hf_config)
            else:
                module = GPT2Block(self.hf_config)
            
            # Initialization only
            module.to(device)

        if lora_manager and 1 <= layer_id <= self.hf_config.n_layer:
            lora_manager.update_lora_params(layer_id, module)

        if id(module) not in self.param_mappings:
            self.param_mappings[id(module)] = self._create_mapping(layer_id, module)

        if flat_buffer is not None:
            mapping = self.param_mappings[id(module)]
            offset = 0
            with torch.no_grad():
                for param, numel, shape in mapping:
                    param.data.copy_(flat_buffer[offset : offset + numel].view(shape), non_blocking=True)
                    offset += numel
            
        return module

    def _create_mapping(self, layer_id: int, module: nn.Module) -> List[tuple]:
        names = self.get_param_names_for_layer(layer_id)
        idx = layer_id - 1
        module_params = dict(module.named_parameters())
        mapping = []
        for full_name in names:
            if layer_id == 0:
                clean_k = "wte.weight" if "wte" in full_name else "wpe.weight"
            elif layer_id == self.hf_config.n_layer + 1:
                if "ln_f" in full_name: clean_k = "ln_f.weight" if "weight" in full_name else "ln_f.bias"
                else: clean_k = "head.weight"
            else:
                prefix = f"transformer.h.{idx}."
                clean_k = full_name[len(prefix):]
                if clean_k not in module_params: clean_k = clean_k.replace(".weight", ".base_layer.weight").replace(".bias", ".base_layer.bias")
            param = module_params[clean_k]
            mapping.append((param, param.numel(), param.shape))
        return mapping

    def release_layer_module(self, module: nn.Module):
        if len(self.layer_pool) < 5:
            self.layer_pool.append(module)

    def forward_layer(self, layer_module: nn.Module, inputs: Any, **kwargs) -> Any:
        hidden_states = inputs[0] if isinstance(inputs, tuple) else inputs
        if isinstance(layer_module, GPT2Block):
            return layer_module(hidden_states)[0]
        return layer_module(hidden_states)

    @property
    def hidden_size(self) -> int:
        return self.hf_config.n_embd

class GPT2EmbeddingsLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
    def forward(self, input_ids):
        position_ids = torch.arange(0, input_ids.size(-1), dtype=torch.long, device=input_ids.device).unsqueeze(0)
        return self.wte(input_ids) + self.wpe(position_ids)

class GPT2HeadLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
    def forward(self, x): return self.head(self.ln_f(x))

In [None]:
%%writefile src/lema/models/llama.py
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaConfig, LlamaRotaryEmbedding
from typing import List, Dict, Any, Optional

from .base import LemaModelAdapter

class LlamaAdapter(LemaModelAdapter):
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.hf_config = LlamaConfig(**config)
        if getattr(self.hf_config, "_attn_implementation", None) is None:
            self.hf_config._attn_implementation = config.get("attn_implementation", "eager")
        self.rotary_emb = LlamaRotaryEmbedding(self.hf_config)
        self.layer_pool: List[nn.Module] = []
        self.param_mappings: Dict[int, List[tuple]] = {}
        self._max_pool_size = 8
        
    def get_layer_metadata(self) -> List[Dict[str, Any]]:
        layers = []
        layers.append({'id': 0, 'name': 'embeddings', 'type': 'embedding'})
        for i in range(self.hf_config.num_hidden_layers):
            layers.append({'id': i + 1, 'name': f'layers.{i}', 'type': 'block', 'block_index': i})
        layers.append({'id': self.hf_config.num_hidden_layers + 1, 'name': 'head', 'type': 'head'})
        return layers

    def get_param_names_for_layer(self, layer_id: int) -> List[str]:
        if layer_id == 0:
            return ['model.embed_tokens.weight']
        elif 1 <= layer_id <= self.hf_config.num_hidden_layers:
            idx = layer_id - 1
            prefix = f"model.layers.{idx}"
            return [
                f"{prefix}.input_layernorm.weight",
                f"{prefix}.self_attn.q_proj.weight", f"{prefix}.self_attn.k_proj.weight",
                f"{prefix}.self_attn.v_proj.weight", f"{prefix}.self_attn.o_proj.weight",
                f"{prefix}.post_attention_layernorm.weight",
                f"{prefix}.mlp.gate_proj.weight", f"{prefix}.mlp.up_proj.weight",
                f"{prefix}.mlp.down_proj.weight",
            ]
        elif layer_id == self.hf_config.num_hidden_layers + 1:
            return ['model.norm.weight', 'lm_head.weight']
        return []

    def construct_layer_module(self, layer_id: int, flat_buffer: Optional[torch.Tensor] = None, lora_manager: Optional[Any] = None) -> nn.Module:
        if flat_buffer is not None:
            device = flat_buffer.device
        else:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        module = None
        for i, m in enumerate(self.layer_pool):
            if layer_id == 0 and isinstance(m, LlamaEmbeddingsLayer):
                module = self.layer_pool.pop(i); break
            elif layer_id == self.hf_config.num_hidden_layers + 1 and isinstance(m, LlamaHeadLayer):
                module = self.layer_pool.pop(i); break
            elif 1 <= layer_id <= self.hf_config.num_hidden_layers and isinstance(m, LlamaDecoderLayer):
                module = self.layer_pool.pop(i); break
        
        if module is None:
            if layer_id == 0: module = LlamaEmbeddingsLayer(self.hf_config, None)
            elif layer_id == self.hf_config.num_hidden_layers + 1: module = LlamaHeadLayer(self.hf_config, None)
            else:
                module = LlamaDecoderLayer(self.hf_config, layer_idx=0)
            
            # Initialization only: move to target device
            module.to(device)

        if lora_manager and 1 <= layer_id <= self.hf_config.num_hidden_layers:
            lora_manager.update_lora_params(layer_id, module)

        if id(module) not in self.param_mappings:
            self.param_mappings[id(module)] = self._create_mapping(layer_id, module)

        if flat_buffer is not None:
            mapping = self.param_mappings[id(module)]
            offset = 0
            with torch.no_grad():
                for param, numel, shape in mapping:
                    param.data.copy_(flat_buffer[offset : offset + numel].view(shape), non_blocking=True)
                    offset += numel
            
        if hasattr(module, "layer_idx") and 1 <= layer_id <= self.hf_config.num_hidden_layers:
            module.layer_idx = layer_id - 1
        return module

    def _create_mapping(self, layer_id: int, module: nn.Module) -> List[tuple]:
        names = self.get_param_names_for_layer(layer_id)
        idx = layer_id - 1
        module_params = dict(module.named_parameters())
        mapping = []
        for full_name in names:
            if layer_id == 0: clean_k = "embed_tokens.weight"
            elif layer_id == self.hf_config.num_hidden_layers + 1:
                clean_k = "norm.weight" if "model.norm" in full_name else "lm_head.weight"
            else:
                prefix = f"model.layers.{idx}."
                clean_k = full_name[len(prefix):]
                if clean_k not in module_params: clean_k = clean_k.replace(".weight", ".base_layer.weight")
            param = module_params[clean_k]
            mapping.append((param, param.numel(), param.shape))
        return mapping

    def release_layer_module(self, module: nn.Module):
        if len(self.layer_pool) < self._max_pool_size:
            self.layer_pool.append(module)
        else:
            del module
            if torch.cuda.is_available(): torch.cuda.empty_cache()

    def forward_layer(self, layer_module: nn.Module, inputs: Any, **kwargs) -> Any:
        hidden_states = inputs[0] if isinstance(inputs, tuple) else inputs
        
        if isinstance(layer_module, LlamaDecoderLayer):
            batch_size, seq_len = hidden_states.shape[:2]
            device = hidden_states.device
            position_ids = kwargs.get("position_ids")
            if position_ids is None:
                position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
            
            # Compute RoPE (Should be [bs, seq, dim])
            attn = layer_module.self_attn
            try:
                if hasattr(attn, "rotary_emb") and attn.rotary_emb is not None:
                    try: cos, sin = attn.rotary_emb(hidden_states, position_ids)
                    except: cos, sin = attn.rotary_emb(position_ids)
                else:
                    cos, sin = self.rotary_emb(hidden_states, position_ids)
            except Exception:
                head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
                dummy = torch.zeros(batch_size, seq_len, head_dim, device=device)
                cos, sin = self.rotary_emb(dummy, position_ids)

            # Ensure 3D [bs, seq, dim] for transformers broadcasting
            if cos.ndim == 2:
                cos = cos.unsqueeze(0)
                sin = sin.unsqueeze(0)
            elif cos.ndim == 4:
                cos = cos.squeeze(1)
                sin = sin.squeeze(1)
            
            # If batch size mismatch (e.g. rotary_emb returned [1, seq, dim] but bs > 1)
            if cos.shape[0] != batch_size and cos.shape[0] == 1:
                cos = cos.expand(batch_size, -1, -1)
                sin = sin.expand(batch_size, -1, -1)

            attention_mask = kwargs.get("attention_mask")
            if attention_mask is None:
                mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
                mask = torch.triu(mask, diagonal=1)
                attention_mask = mask.view(1, 1, seq_len, seq_len).expand(batch_size, 1, seq_len, seq_len)

            # 2. Manual Forward with Checkpointing
            residual = hidden_states
            hidden_states = layer_module.input_layernorm(hidden_states)
            
            # Wrap Self-Attention for memory efficiency
            def attn_block(x, mask, pids, cos_sin):
                return layer_module.self_attn(
                    hidden_states=x,
                    attention_mask=mask,
                    position_ids=pids,
                    position_embeddings=cos_sin,
                    past_key_value=kwargs.get("past_key_value"),
                    output_attentions=kwargs.get("output_attentions", False),
                    use_cache=kwargs.get("use_cache", False),
                    cache_position=kwargs.get("cache_position")
                )[0]

            if torch.is_grad_enabled() and kwargs.get("gradient_checkpointing", False):
                from torch.utils.checkpoint import checkpoint
                attn_output = checkpoint(attn_block, hidden_states, attention_mask, position_ids, (cos, sin), use_reentrant=False)
            else:
                attn_output = attn_block(hidden_states, attention_mask, position_ids, (cos, sin))
            
            hidden_states = residual + attn_output
            
            # MLP block
            residual = hidden_states
            hidden_states = layer_module.post_attention_layernorm(hidden_states)
            hidden_states = layer_module.mlp(hidden_states)
            hidden_states = residual + hidden_states
            
            return hidden_states
                
        return layer_module(hidden_states)

    @property
    def hidden_size(self) -> int:
        return self.hf_config.hidden_size

class LlamaEmbeddingsLayer(nn.Module):
    def __init__(self, config, weights):
        super().__init__()
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
    def forward(self, x): return self.embed_tokens(x)

class LlamaHeadLayer(nn.Module):
    def __init__(self, config, weights):
        super().__init__()
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
    def forward(self, x): return self.lm_head(self.norm(x))


In [None]:
%%writefile tasks/fine_tune_llama_7b.py
import torch
import os
import time
from transformers import AutoTokenizer, AutoConfig
from src.lema.core.gbi import GlobalBinaryIndex
from src.lema.models.llama import LlamaAdapter
from src.lema.engine.trainer import LemaTrainer
from src.lema.core.lora import LoRAManager
from src.lema.config import LemaConfig, MemoryStrategy

MODEL_NAME = "NousResearch/Llama-2-7b-hf"
MODEL_PATH = "llama2_7b.safetensors"

# 1. Realistic Dataset: "Concise Assistant"
TRAINING_DATA = [
    "What is photosynthesis? Photosynthesis is the process by which plants use sunlight to synthesize nutrients from carbon dioxide and water.",
    "Who was Albert Einstein? Albert Einstein was a theoretical physicist who developed the theory of relativity.",
    "What is the capital of France? The capital of France is Paris.",
    "Explain gravity. Gravity is a natural phenomenon by which all things with mass or energy are brought toward one another.",
    "What is LEMA? LEMA is a framework that virtualizes GPU memory to enable training large models on limited hardware.",
    "How does a CPU work? A CPU executes instructions of a computer program by performing basic arithmetic, logic, and I/O operations.",
    "What is the speed of light? The speed of light in a vacuum is approximately 299,792,458 meters per second.",
    "Define machine learning. Machine learning is a field of artificial intelligence focused on building systems that learn from data.",
    "What is DNA? DNA is a molecule that carries the genetic instructions used in the growth, development, and functioning of all living organisms.",
    "What is the ocean? The ocean is a continuous body of salt water that covers more than 70 percent of Earth's surface.",
] * 5 # 50 examples to keep it relatively fast for a demo

def fine_tune_llama_7b_task():
    print("--- STARTING LEMA 7B FINE-TUNING ---")
    
    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Config
    config = AutoConfig.from_pretrained(MODEL_NAME)
    hf_config_dict = config.to_dict()
    hf_config_dict["attn_implementation"] = "eager"
    hf_config_dict["torch_dtype"] = "float16" # Ensure we use float16 config
    
    adapter = LlamaAdapter(hf_config_dict)
    gbi = GlobalBinaryIndex(MODEL_PATH)
    
    # LEMA Config
    lema_config = LemaConfig(
        model_name_or_path=MODEL_PATH,
        device="cuda",
        strategy=MemoryStrategy.STREAMING
    )
    
    # LoRA Config
    lora_config = {
        "r": 16, 
        "alpha": 32, 
        "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    }
    lora_manager = LoRAManager(lora_config, device="cuda")
    
    print("Initializing LoRA parameters...")
    # Trigger param creation for all layers & release to avoid leak
    for layer in adapter.get_layer_metadata():
        if layer['type'] == 'block':
            module = adapter.construct_layer_module(layer['id'], None, lora_manager)
            adapter.release_layer_module(module)
            
    torch.cuda.empty_cache()
            
    trainable_params = lora_manager.get_trainable_parameters()
    print(f"Number of trainable LoRA parameters: {len(trainable_params)}")
    
    optimizer = torch.optim.AdamW(trainable_params, lr=5e-5)
    
    trainer = LemaTrainer(
        config=lema_config,
        model_adapter=adapter, 
        gbi=gbi, 
        lora_manager=lora_manager, 
        optimizer=optimizer
    )
    
    # 3. Training Loop
    print(f"\nTraining on {len(TRAINING_DATA)} examples...")
    start_time = time.time()
    for epoch in range(1): # 1 Epoch for demonstration
        total_loss = 0
        for i, text in enumerate(TRAINING_DATA):
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
            input_ids = inputs["input_ids"].to("cuda")
            
            logits, loss = trainer.train_step(input_ids, labels=input_ids)
            total_loss += loss
            
            if (i+1) % 10 == 0:
                print(f"Step {i+1}/{len(TRAINING_DATA)} - Current Loss: {loss:.4f}")
                
        avg_loss = total_loss / len(TRAINING_DATA)
        print(f"Epoch {epoch+1} - Avg Loss: {avg_loss:.4f}")
    
    end_time = time.time()
    print(f"\nTraining completed in {end_time - start_time:.2f} seconds.")
        
    # 4. Validation
    print("\n--- TESTING BEHAVIOR (Concise Assistant) ---")
    test_prompts = [
        "What is the moon?",
        "Who was Isaac Newton?"
    ]
    
    for prompt in test_prompts:
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to("cuda")
        
        generated = input_ids
        for _ in range(25):
            with torch.no_grad():
                logits, _ = trainer.train_step(generated)
            next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
            generated = torch.cat([generated, next_token_id], dim=-1)
            if next_token_id.item() == tokenizer.eos_token_id:
                break
                
        print(f"Prompt: {prompt}")
        print(f"Response: {tokenizer.decode(generated[0], skip_special_tokens=True)}")

    # Cleanup
    if os.path.exists(MODEL_PATH):
        os.remove(MODEL_PATH)


In [None]:
%%writefile tasks/fine_tune_llama_7b_config.py
import torch
import os
import time
from transformers import AutoTokenizer, AutoConfig
from src.lema.core.gbi import GlobalBinaryIndex
from src.lema.models.llama import LlamaAdapter
from src.lema.engine.trainer import LemaTrainer
from src.lema.core.lora import LoRAManager
from src.lema.config import LemaConfig, MemoryStrategy

MODEL_NAME = "NousResearch/Llama-2-7b-hf"
MODEL_PATH = "llama2_7b.safetensors"

TRAINING_DATA = [
    "What is photosynthesis? Photosynthesis is the process by which plants use sunlight to synthesize nutrients from carbon dioxide and water.",
    "Who was Albert Einstein? Albert Einstein was a theoretical physicist who developed the theory of relativity.",
    "What is the capital of France? The capital of France is Paris.",
    "Explain gravity. Gravity is a natural phenomenon by which all things with mass or energy are brought toward one another.",
    "What is LEMA? LEMA is a framework that virtualizes GPU memory to enable training large models on limited hardware.",
] * 10

def fine_tune_llama_7b_with_config():
    print("--- STARTING LEMA 7B FINE-TUNING (Config Object) ---")
    
    # 1. Setup Configuration
    config = LemaConfig(
        model_name_or_path=MODEL_PATH,
        device="cuda",
        strategy=MemoryStrategy.STREAMING,
        lora_rank=16,
        lora_alpha=32,
        learning_rate=5e-5,
        dtype="float16"
    )
    
    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    
    # HF Config (for model architecture)
    hf_config = AutoConfig.from_pretrained(MODEL_NAME)
    model_config = hf_config.to_dict()
    model_config["attn_implementation"] = config.attn_implementation
    model_config["torch_dtype"] = config.dtype
    
    # Components
    adapter = LlamaAdapter(model_config)
    gbi = GlobalBinaryIndex(config.gbi_path)
    
    # LoRA
    lora_config_dict = {
        "r": config.lora_rank, 
        "alpha": config.lora_alpha, 
        "target_modules": config.lora_target_modules
    }
    lora_manager = LoRAManager(lora_config_dict, device=config.device)
    
    print("Initializing LoRA parameters...")
    for layer in adapter.get_layer_metadata():
        if layer['type'] == 'block':
            module = adapter.construct_layer_module(layer['id'], None, lora_manager)
            adapter.release_layer_module(module)
    torch.cuda.empty_cache()
    
    trainable_params = lora_manager.get_trainable_parameters()
    optimizer = torch.optim.AdamW(trainable_params, lr=config.learning_rate)
    
    # Trainer
    trainer = LemaTrainer(
        config=config,
        model_adapter=adapter, 
        gbi=gbi, 
        lora_manager=lora_manager, 
        optimizer=optimizer
    )
    
    # Training Loop
    print(f"\nTraining on {len(TRAINING_DATA)} examples...")
    for epoch in range(1):
        total_loss = 0
        for i, text in enumerate(TRAINING_DATA):
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
            input_ids = inputs["input_ids"].to(config.device)
            
            logits, loss = trainer.train_step(input_ids, labels=input_ids)
            total_loss += loss
            
            if (i+1) % 10 == 0:
                print(f"Step {i+1}/{len(TRAINING_DATA)} - Current Loss: {loss:.4f}")
                
        print(f"Epoch {epoch+1} - Avg Loss: {total_loss / len(TRAINING_DATA):.4f}")
    
    print("Fine-tuning with Config Object completed successfully.")

    # Cleanup
    if os.path.exists(MODEL_PATH):
        os.remove(MODEL_PATH)


In [None]:
%%writefile tasks/fine_tune_smollm.py
import torch
import os
import time
from transformers import AutoTokenizer, AutoConfig
from src.lema.core.gbi import GlobalBinaryIndex
from src.lema.models.llama import LlamaAdapter
from src.lema.engine.trainer import LemaTrainer
from src.lema.core.lora import LoRAManager
from src.lema.config import LemaConfig, MemoryStrategy

MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B"
MODEL_PATH = "smollm2_1.7b.safetensors"

# 1. Realistic Dataset: "Concise Assistant"
# The model should learn to answer everything in one short, professional sentence.
TRAINING_DATA = [
    "What is photosynthesis? Photosynthesis is the process by which plants use sunlight to synthesize nutrients from carbon dioxide and water.",
    "Who was Albert Einstein? Albert Einstein was a theoretical physicist who developed the theory of relativity.",
    "What is the capital of France? The capital of France is Paris.",
    "Explain gravity. Gravity is a natural phenomenon by which all things with mass or energy are brought toward one another.",
    "What is LEMA? LEMA is a framework that virtualizes GPU memory to enable training large models on limited hardware.",
    "How does a CPU work? A CPU executes instructions of a computer program by performing basic arithmetic, logic, and I/O operations.",
    "What is the speed of light? The speed of light in a vacuum is approximately 299,792,458 meters per second.",
    "Define machine learning. Machine learning is a field of artificial intelligence focused on building systems that learn from data.",
    "What is DNA? DNA is a molecule that carries the genetic instructions used in the growth, development, and functioning of all living organisms.",
    "What is the ocean? The ocean is a continuous body of salt water that covers more than 70 percent of Earth's surface.",
] * 10 # 100 examples for more "realistic" weight updates

def fine_tune_realistic():
    print("--- STARTING REALISTIC LEMA FINE-TUNING (v0.6) ---")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    
    config = AutoConfig.from_pretrained(MODEL_NAME)
    hf_config_dict = config.to_dict()
    hf_config_dict["attn_implementation"] = "eager"
    
    adapter = LlamaAdapter(hf_config_dict)
    gbi = GlobalBinaryIndex(MODEL_PATH)
    
    # LEMA Config
    lema_config = LemaConfig(
        model_name_or_path=MODEL_PATH,
        device="cuda",
        strategy=MemoryStrategy.STREAMING
    )
    
    # 2. HEAVY LoRA Config (All major linear layers)
    # This increases the number of parameters the optimizer has to manage.
    lora_config = {
        "r": 16, 
        "alpha": 32, 
        "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    }
    lora_manager = LoRAManager(lora_config, device="cuda")
    
    print("Initializing heavy LoRA parameters...")
    # Trigger param creation for all layers
    for layer in adapter.get_layer_metadata():
        if layer['type'] == 'block':
            module = adapter.construct_layer_module(layer['id'], None, lora_manager)
            adapter.release_layer_module(module)
    
    # Clear cache after initialization
    torch.cuda.empty_cache()
            
    trainable_params = lora_manager.get_trainable_parameters()
    print(f"Number of trainable LoRA parameters: {len(trainable_params)}")
    
    optimizer = torch.optim.AdamW(trainable_params, lr=5e-5)
    
    trainer = LemaTrainer(
        config=lema_config,
        model_adapter=adapter, 
        gbi=gbi, 
        lora_manager=lora_manager, 
        optimizer=optimizer
    )
    
    # 3. Training Loop
    print("\nTraining on 100 examples...")
    start_time = time.time()
    for epoch in range(3): # Fewer epochs but more data per epoch
        total_loss = 0
        for i, text in enumerate(TRAINING_DATA):
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
            input_ids = inputs["input_ids"].to("cuda")
            
            logits, loss = trainer.train_step(input_ids, labels=input_ids)
            total_loss += loss
            
            if (i+1) % 20 == 0:
                print(f"Step {i+1}/{len(TRAINING_DATA)} - Current Loss: {loss:.4f}")
                
        avg_loss = total_loss / len(TRAINING_DATA)
        print(f"Epoch {epoch+1}/3 - Avg Loss: {avg_loss:.4f}")
    
    end_time = time.time()
    print(f"\nTraining completed in {end_time - start_time:.2f} seconds.")
        
    # 4. Validation
    print("\n--- TESTING BEHAVIOR (Concise Assistant) ---")
    test_prompts = [
        "What is the moon?",
        "Who was Isaac Newton?"
    ]
    
    for prompt in test_prompts:
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to("cuda")
        
        generated = input_ids
        for _ in range(25):
            with torch.no_grad():
                logits, _ = trainer.train_step(generated)
            next_token_id = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
            generated = torch.cat([generated, next_token_id], dim=-1)
            if next_token_id.item() == tokenizer.eos_token_id:
                break
                
        print(f"Prompt: {prompt}")
        print(f"Response: {tokenizer.decode(generated[0], skip_special_tokens=True)}")

    # Final Cleanup
    if os.path.exists(MODEL_PATH):
        os.remove(MODEL_PATH)

if __name__ == "__main__":
    # Expecting benchmark script to have downloaded the model already
    if os.path.exists(MODEL_PATH):
        fine_tune_realistic()
    else:
        print(f"Error: {MODEL_PATH} not found.")

In [None]:
import sys
import os
sys.path.append(os.path.abspath('src'))
print('LEMA library loaded.')

In [None]:
!pip install -q safetensors accelerate peft transformers

In [None]:
import torch
import gc
import os
import time
import transformers
from transformers import AutoModelForCausalLM, AutoConfig
from peft import get_peft_model, LoraConfig
from safetensors.torch import save_file

# LEMA Imports
from src.lema.core.gbi import GlobalBinaryIndex
from src.lema.models.llama import LlamaAdapter
from src.lema.engine.trainer import LemaTrainer
from src.lema.core.lora import LoRAManager
from src.lema.config import LemaConfig, MemoryStrategy

print(f"Using Transformers version: {transformers.__version__}")

# --- MODELS TO TEST ---
# Focusing on Llama architectures for speed comparison
MODELS = [
    {
        "name": "TinyLlama 1.1B",
        "hf_id": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        "path": "tinyllama_1b.safetensors",
        "type": "llama"
    },
    {
        "name": "Llama-2 7B",
        "hf_id": "NousResearch/Llama-2-7b-hf",
        "path": "llama2_7b.safetensors",
        "type": "llama"
    }
]

NUM_STEPS = 20 # Enough to stabilize avg speed

def download_and_convert(model_info):
    print(f"\n--- Preparing {model_info['name']} ---")
    if os.path.exists(model_info['path']):
        print(f"{model_info['path']} already exists.")
        return

    print(f"Downloading {model_info['hf_id']}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_info['hf_id'], 
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="cpu"
    )
    
    # Break shared weights if necessary
    if hasattr(model, "lm_head") and hasattr(model, "model") and hasattr(model.model, "embed_tokens"):
         if model.lm_head.weight.data_ptr() == model.model.embed_tokens.weight.data_ptr():
             model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())

    print(f"Saving to {model_info['path']}...")
    save_file(model.state_dict(), model_info['path'])
    del model
    gc.collect()

def benchmark_peft_speed(model_info):
    print(f"\n>>> BENCHMARKING PEFT SPEED: {model_info['name']} <<<")
    torch.cuda.empty_cache()
    
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_info['hf_id'],
            torch_dtype=torch.float16,
            device_map="cuda"
        )
        
        # Enable Gradient Checkpointing to save memory
        model.gradient_checkpointing_enable()
        
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        peft_config = LoraConfig(
            r=16, lora_alpha=32, target_modules=target_modules,
            lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
        )
        model = get_peft_model(model, peft_config)
        
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
        input_ids = torch.randint(0, 1000, (1, 512)).cuda()
        
        # Warmup
        model(input_ids, labels=input_ids).loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        torch.cuda.synchronize()
        
        print(f"Running {NUM_STEPS} steps...")
        start_time = time.time()
        for _ in range(NUM_STEPS):
            loss = model(input_ids, labels=input_ids).loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        torch.cuda.synchronize()
        end_time = time.time()
        
        avg_time = (end_time - start_time) / NUM_STEPS
        print(f"PEFT Avg Time/Step: {avg_time:.4f}s")
        
        del model
        del optimizer
        torch.cuda.empty_cache()
        return avg_time
        
    except Exception as e:
        print(f"PEFT Benchmark Failed: {e}")
        return float('inf')

def benchmark_lema_speed(model_info):
    print(f"\n>>> BENCHMARKING LEMA SPEED: {model_info['name']} <<<")
    torch.cuda.empty_cache()
    download_and_convert(model_info)
    
    # Enable checkpointing only for large models (e.g. 7B)
    use_gc = "7b" in model_info['hf_id'].lower()
    print(f"Gradient Checkpointing: {use_gc}")
    
    try:
        hf_config = AutoConfig.from_pretrained(model_info['hf_id'])
        hf_config_dict = hf_config.to_dict()
        hf_config_dict["attn_implementation"] = "eager"
        hf_config_dict["torch_dtype"] = "float16"
        
        adapter = LlamaAdapter(hf_config_dict)
        gbi = GlobalBinaryIndex(model_info['path'])
        
        lema_config = LemaConfig(
            model_name_or_path=model_info['path'],
            device="cuda",
            strategy=MemoryStrategy.STREAMING,
            learning_rate=1e-4,
            gradient_checkpointing=use_gc
        )
        
        lora_config = {
            "r": 16, "alpha": 32, 
            "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        }
        lora_manager = LoRAManager(lora_config, device="cuda")
        
        # Init params
        for layer in adapter.get_layer_metadata():
            if layer['type'] == 'block':
                module = adapter.construct_layer_module(layer['id'], None, lora_manager)
                adapter.release_layer_module(module)
        
        torch.cuda.empty_cache()
        
        trainable_params = lora_manager.get_trainable_parameters()
        optimizer = torch.optim.AdamW(trainable_params, lr=lema_config.learning_rate)
        
        trainer = LemaTrainer(
            config=lema_config,
            model_adapter=adapter, 
            gbi=gbi, 
            lora_manager=lora_manager, 
            optimizer=optimizer
        )
        
        input_ids = torch.randint(0, 1000, (1, 512)).cuda()
        
        # Warmup
        trainer.train_step(input_ids, labels=input_ids)
        torch.cuda.synchronize()
        
        print(f"Running {NUM_STEPS} steps...")
        start_time = time.time()
        for _ in range(NUM_STEPS):
            trainer.train_step(input_ids, labels=input_ids)
        torch.cuda.synchronize()
        end_time = time.time()
        
        avg_time = (end_time - start_time) / NUM_STEPS
        print(f"LEMA Avg Time/Step: {avg_time:.4f}s")
        
        return avg_time
        
    except Exception as e:
        print(f"LEMA Benchmark Failed: {e}")
        import traceback
        traceback.print_exc()
        return float('inf')
    finally:
        if os.path.exists(model_info['path']):
            os.remove(model_info['path'])

if __name__ == "__main__":
    results = {}
    for model in MODELS:
        peft_time = benchmark_peft_speed(model)
        lema_time = benchmark_lema_speed(model)
        results[model['name']] = {"PEFT": peft_time, "LEMA": lema_time}
    
    print("\n=== SPEED BENCHMARK RESULTS (Time per Step) ===")
    print(f"{ 'Model':<20} | { 'PEFT (s)':<10} | { 'LEMA (s)':<10} | { 'Overhead':<10}")
    print("-" * 60)
    for name, data in results.items():
        peft = data["PEFT"]
        lema = data["LEMA"]
        overhead = (lema / peft) if peft > 0 else float('inf')
        print(f"{name:<20} | {peft:<10.4f} | {lema:<10.4f} | {overhead:<10.2f}x")
