# LEMA: Layer-wise Efficient Memory Abstraction
This notebook is a self-contained LEMA workspace. It includes the library, examples, tasks, and documentation.

--- 
# LEMA: Layer-wise Efficient Memory Abstraction

**Virtualize GPU VRAM for LLM Fine-Tuning**

LEMA is a specialized framework designed to facilitate the fine-tuning of Large Language Models (LLMs) on hardware where model size exceeds available VRAM. By treating model weights as addressable binary segments and implementing a **Triple-Buffer Strategy**, LEMA allows training 7B+ models on GPUs with as little as 16GB VRAM.

## Key Performance (Tesla P100 - 16GB)

| Model | Standard PEFT | LEMA | Status |
| :--- | :--- | :--- | :--- |
| **Llama-2 7B** | **OOM (Crash)** | **5.90 GB VRAM** | **Stable** |
| **SmolLM2 1.7B**| 3.88 GB | 3.20 GB | Stable |
| **TinyLlama 1.1B**| 2.67 GB | 2.12 GB | Stable |

## Core Features

- **Global Binary Index (GBI)**: Zero-copy mapping of `.safetensors` files using `mmap`.
- **Triple-Buffer Pipeline**: Pipelined data movement (Disk -> RAM -> VRAM) to hide PCIe latency.
- **High-Level API**: Simplified `LemaModel` and `LemaTrainer` interfaces for fast integration.
- **Automatic Checkpointing**: Built-in interval-based saving of LoRA adapters and optimizer states.

## Installation

```bash
git clone https://github.com/Pomilon/LEMA.git
cd LEMA
pip install -e .
```

## Quick Start

```python
import torch
from lema import LemaConfig, LemaModel, MemoryStrategy

# 1. Configuration
config = LemaConfig(
    model_name_or_path="NousResearch/Llama-2-7b-hf",
    gbi_path="llama2_7b.safetensors", # Single monolithic safetensors file
    strategy=MemoryStrategy.STREAMING,
    lora_rank=16,
    gradient_checkpointing=True
)

# 2. Initialize Model & Trainer
model = LemaModel(config)
model.initialize_lora() # Pre-initialize adapters

optimizer = torch.optim.AdamW(model.get_trainable_parameters(), lr=1e-4)
trainer = model.get_trainer(optimizer)

# 3. Train
input_ids = torch.randint(0, 32000, (1, 512)).cuda()
logits, loss = trainer.train_step(input_ids, labels=input_ids)
```

## Documentation

- [**User Guide**](docs/USER_GUIDE.md): Model preparation, conversion, and tips.
- [**API Reference**](docs/API_REFERENCE.md): Detailed class and method specifications.
- [**Architecture**](docs/ARCHITECTURE.md): Deep dive into the memory pipeline and LEMA-loop.

## Kaggle Benchmark

You can run the latest verification suite on Kaggle using the provided notebook:
[**LEMA Benchmark Notebook**](https://www.kaggle.com/code/kloyford/lema-benchmark-notebook)

## License
MIT License - Copyright (c) 2026 Pomilon


In [None]:
!mkdir -p src/lema/core src/lema/engine src/lema/models src/lema/utils tasks examples/kaggle docs output

In [None]:
%%writefile src/lema/config.py
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any
from enum import Enum

class MemoryStrategy(Enum):
    STREAMING = "streaming" # Disk -> RAM -> VRAM
    RESIDENT = "resident"   # RAM -> VRAM (No Disk offload for weights)

@dataclass
class LemaConfig:
    """
    Central Configuration for LEMA Training/Inference.
    """
    # Model Settings
    model_name_or_path: str
    model_type: Optional[str] = None # 'llama' or 'gpt2', auto-detected if None
    gbi_path: Optional[str] = None # Path to converted safetensors for GBI
    
    # Hardware / Memory Settings
    device: str = "cuda"
    strategy: MemoryStrategy = MemoryStrategy.STREAMING
    ram_buffer_size: int = 2 # Number of layers to keep in RAM
    vram_buffer_size: int = 1 # Number of layers to keep in VRAM
    
    # LoRA Settings
    use_lora: bool = True
    lora_rank: int = 16
    lora_alpha: int = 32
    lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
    
    # Training Settings
    learning_rate: float = 1e-4
    batch_size: int = 1
    gradient_accumulation_steps: int = 1
    max_seq_length: int = 512
    gradient_checkpointing: bool = False
    
    # Checkpointing Settings
    save_steps: int = 500
    output_dir: str = "output"
    
    # Advanced
    dtype: str = "float16" # float16, bfloat16, float32
    attn_implementation: str = "eager" # eager, sdpa, flash_attention_2

    def __post_init__(self):
        if self.gbi_path is None:
            # Default to expecting a local safetensors file named after the model or a standard name
            if self.model_name_or_path.endswith(".safetensors"):
                self.gbi_path = self.model_name_or_path
            else:
                self.gbi_path = "model.safetensors"
        
        if isinstance(self.strategy, str):
            self.strategy = MemoryStrategy(self.strategy.lower())

    def to_dict(self) -> Dict[str, Any]:
        return {
            k: v.value if isinstance(v, Enum) else v 
            for k, v in self.__dict__.items()
        }

    def save_pretrained(self, save_directory: str):
        import os
        import json
        os.makedirs(save_directory, exist_ok=True)
        config_file = os.path.join(save_directory, "lema_config.json")
        with open(config_file, "w") as f:
            json.dump(self.to_dict(), f, indent=4)

    @classmethod
    def from_pretrained(cls, load_directory: str, **kwargs):
        import os
        import json
        config_file = os.path.join(load_directory, "lema_config.json")
        if not os.path.exists(config_file):
            raise FileNotFoundError(f"Config file not found in {load_directory}")
        
        with open(config_file, "r") as f:
            config_dict = json.load(f)
        
        # Override with kwargs
        config_dict.update(kwargs)
        
        # Handle enum conversion
        if "strategy" in config_dict and isinstance(config_dict["strategy"], str):
            config_dict["strategy"] = MemoryStrategy(config_dict["strategy"])
            
        return cls(**config_dict)


In [None]:
%%writefile src/lema/__init__.py
from .config import LemaConfig, MemoryStrategy
from .core.model import LemaModel
from .engine.trainer import LemaTrainer

__version__ = "0.1.1"


In [None]:
%%writefile src/lema/utils/model_utils.py
import torch
import os
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM, AutoConfig

def break_shared_weights(model: torch.nn.Module):
    """
    Ensures that shared weights (like lm_head and embed_tokens) are distinct copies.
    Required for safetensors compatibility.
    """
    if hasattr(model, "lm_head") and hasattr(model, "model") and hasattr(model.model, "embed_tokens"):
        if model.lm_head.weight.data_ptr() == model.model.embed_tokens.weight.data_ptr():
            # Only clone the specific shared tensor, not the whole model
            model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())
    return model

def prepare_monolithic_safetensors(model_name_or_path: str, output_path: str, device: str = "auto"):
    """
    Downloads a model and saves it as a single, framework-compatible safetensors file.
    Uses 'auto' device map to offload to GPU and save System RAM during conversion.
    """
    print(f"Loading {model_name_or_path} for monolithic conversion...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map=device
    )
    model = break_shared_weights(model)
    
    print(f"Saving monolithic safetensors to {output_path}...")
    # Pass state_dict directly to save_file to avoid memory doubling
    sd = model.state_dict()
    save_file(sd, output_path)
    
    # Cleanup
    del sd
    del model
    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
%%writefile src/lema/engine/trainer.py
import torch
import torch.nn.functional as F
import threading
import os
from typing import Any, Optional, List, Union
from ..core.memory import TripleBufferManager
from ..models.base import LemaModelAdapter
from ..config import LemaConfig, MemoryStrategy

class LemaTrainer:
    def __init__(self, 
                 config: LemaConfig,
                 model_adapter: LemaModelAdapter, 
                 gbi: Any, 
                 lora_manager: Any = None, 
                 optimizer: Optional[torch.optim.Optimizer] = None,
                 memory_manager: Optional[TripleBufferManager] = None):
        
        self.config = config
        self.adapter = model_adapter
        self.gbi = gbi
        self.device = config.device
        self.strategy = config.strategy
        
        # Use provided memory manager or create a new one
        if memory_manager is not None:
            self.memory = memory_manager
        else:
            self.memory = TripleBufferManager(gbi, model_adapter, self.device, strategy=self.strategy)
        
        self.layers = self.adapter.get_layer_metadata()
        self.lora_manager = lora_manager
        self.optimizer = optimizer
        self.global_step = 0

    def save_checkpoint(self, save_directory: str):
        """Saves the model state (config + LoRA) and optionally optimizer state."""
        self.config.save_pretrained(save_directory)
        if self.lora_manager:
            self.lora_manager.save_pretrained(save_directory)
        
        if self.optimizer:
            torch.save(self.optimizer.state_dict(), os.path.join(save_directory, "optimizer.bin"))

    def train_step(self, inputs: Any, labels: Optional[torch.Tensor] = None):
        """
        Executes one forward pass and one backward pass.
        If labels are provided, computes CrossEntropyLoss.
        """
        boundary_activations: List[torch.Tensor] = []
        is_streaming = (self.strategy == MemoryStrategy.STREAMING)
        
        # --- FORWARD PASS ---
        if is_streaming:
            self.memory.prefetch_to_ram(self.layers[0]['id'], 0)
            self.memory.async_transfer_to_vram(self.layers[0]['id'], 0, ram_slot=0)
            if len(self.layers) > 1:
                self.memory.prefetch_to_ram(self.layers[1]['id'], 1)
        else:
            self.memory.async_transfer_to_vram(self.layers[0]['id'], 0)

        hidden_states = inputs
        
        for i, layer_meta in enumerate(self.layers):
            slot = i % 2
            next_slot = (i + 1) % 2
            
            flat_vram = self.memory.get_vram_flat_buffer(slot)
            
            disk_thread = None
            if i + 1 < len(self.layers):
                if is_streaming:
                    self.memory.async_transfer_to_vram(self.layers[i+1]['id'], next_slot, ram_slot=next_slot)
                    if i + 2 < len(self.layers):
                        disk_thread = threading.Thread(target=self.memory.prefetch_to_ram, args=(self.layers[i+2]['id'], slot))
                        disk_thread.start()
                else:
                    self.memory.async_transfer_to_vram(self.layers[i+1]['id'], next_slot)
            
            layer_module = self.adapter.construct_layer_module(layer_meta['id'], flat_vram, self.lora_manager)
            
            # Store input for backward
            if isinstance(hidden_states, tuple): 
                 current_input = hidden_states[0].detach()
            else:
                current_input = hidden_states.detach()
            boundary_activations.append(current_input)
            
            with torch.no_grad():
                hidden_states = self.adapter.forward_layer(layer_module, hidden_states, gradient_checkpointing=False)

            if disk_thread: disk_thread.join()
            if hasattr(self.adapter, "release_layer_module"):
                self.adapter.release_layer_module(layer_module)
            del layer_module

        # Final Logits
        logits = hidden_states
        loss_val = None

        # --- BACKWARD PASS ---
        if not torch.is_grad_enabled():
            return logits, None

        last_idx = len(self.layers) - 1
        if is_streaming:
            self.memory.prefetch_to_ram(self.layers[last_idx]['id'], 0)
            self.memory.async_transfer_to_vram(self.layers[last_idx]['id'], 0, ram_slot=0)
            if last_idx > 0:
                self.memory.prefetch_to_ram(self.layers[last_idx-1]['id'], 1)
        else:
            self.memory.async_transfer_to_vram(self.layers[last_idx]['id'], 0)
        
        grad_output = None
        
        for i in range(last_idx, -1, -1):
            slot = (last_idx - i) % 2
            next_slot = (last_idx - i + 1) % 2
            
            flat_vram = self.memory.get_vram_flat_buffer(slot)
            
            disk_thread = None
            if i - 1 >= 0:
                if is_streaming:
                    self.memory.async_transfer_to_vram(self.layers[i-1]['id'], next_slot, ram_slot=next_slot)
                    if i - 2 >= 0:
                        disk_thread = threading.Thread(target=self.memory.prefetch_to_ram, args=(self.layers[i-2]['id'], slot))
                        disk_thread.start()
                else:
                    self.memory.async_transfer_to_vram(self.layers[i-1]['id'], next_slot)
            
            layer_module = self.adapter.construct_layer_module(self.layers[i]['id'], flat_vram, self.lora_manager)
            layer_input = boundary_activations[i]
            if layer_input.dtype.is_floating_point:
                layer_input.requires_grad_(True)
            
            output = self.adapter.forward_layer(layer_module, layer_input, gradient_checkpointing=self.config.gradient_checkpointing)
            
            if i == last_idx:
                if labels is not None:
                    # Real Causal LM Loss
                    # Shift so that tokens < n predict n
                    shift_logits = output[..., :-1, :].contiguous()
                    shift_labels = labels[..., 1:].contiguous()
                    loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                    loss_val = loss.item()
                else:
                    loss = output.mean() # Dummy
                
                loss.backward()
                grad_output = layer_input.grad
            else:
                if isinstance(output, tuple):
                    output[0].backward(grad_output)
                else:
                    output.backward(grad_output)
                grad_output = layer_input.grad
            
            if disk_thread: disk_thread.join()
            if hasattr(self.adapter, "release_layer_module"):
                self.adapter.release_layer_module(layer_module)
            del layer_module

        if self.optimizer:
            self.optimizer.step()
            self.optimizer.zero_grad()
            
        self.global_step += 1
        
        # Automatic checkpointing
        if self.config.save_steps > 0 and self.global_step % self.config.save_steps == 0:
            checkpoint_path = os.path.join(self.config.output_dir, f"checkpoint-{self.global_step}")
            self.save_checkpoint(checkpoint_path)

        return logits, loss_val


In [None]:
%%writefile src/lema/core/lora.py
import torch
import torch.nn as nn
import math
from typing import Dict, Tuple, Optional, List
try:
    from transformers.pytorch_utils import Conv1D
except ImportError:
    Conv1D = None

class LoRAWrapper(nn.Module):
    """
    Wraps a Linear or Conv1D layer with LoRA adapters.
    """
    def __init__(self, base_layer: nn.Module, rank: int, alpha: float, lora_A: nn.Parameter, lora_B: nn.Parameter):
        super().__init__()
        self.base_layer = base_layer
        self.rank = rank
        self.alpha = alpha
        self.scaling = alpha / rank
        
        self.lora_A = lora_A
        self.lora_B = lora_B
        
    def forward(self, x):
        # Base forward
        result = self.base_layer(x)
        
        # LoRA forward
        # Calculation: (x @ A.T @ B.T) * scaling
        lora_out = (x @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
        return result + lora_out

class LoRAManager:
    """
    Manages the lifecycle and storage of LoRA parameters.
    """
    def __init__(self, config: Dict, device="cuda"):
        self.rank = config.get("r", 8)
        self.alpha = config.get("alpha", 16)
        self.target_modules = config.get("target_modules", ["c_attn", "c_proj", "c_fc"])
        self.device = device
        
        # Store parameters: key -> {'A': Param, 'B': Param}
        self.params: Dict[str, Dict[str, nn.Parameter]] = {}
        
    def get_or_init_params(self, layer_id: int, module_name: str, in_features: int, out_features: int) -> Dict[str, nn.Parameter]:
        key = f"{layer_id}.{module_name}"
        
        if key not in self.params:
            lora_A = torch.zeros((self.rank, in_features), device=self.device)
            nn.init.kaiming_uniform_(lora_A, a=math.sqrt(5))
            
            lora_B = torch.zeros((out_features, self.rank), device=self.device)
            nn.init.zeros_(lora_B)
            
            self.params[key] = {
                'A': nn.Parameter(lora_A, requires_grad=True),
                'B': nn.Parameter(lora_B, requires_grad=True)
            }
            
        return self.params[key]

    def apply_lora(self, layer_id: int, module: nn.Module, module_name_prefix: str = ""):
        """
        Recursively replaces Linear/Conv1D layers with LoRAWrapper if they match target_modules.
        """
        for name, child in module.named_children():
            full_name = f"{module_name_prefix}.{name}" if module_name_prefix else name
            
            # Check if this is a target module
            is_target = any(name == target or name.endswith(target) for target in self.target_modules)
            
            if isinstance(child, LoRAWrapper) and is_target:
                # Already wrapped, just swap parameters for the new layer
                if isinstance(child.base_layer, nn.Linear):
                    in_features = child.base_layer.in_features
                    out_features = child.base_layer.out_features
                elif Conv1D is not None and isinstance(child.base_layer, Conv1D):
                    in_features = child.base_layer.weight.shape[0]
                    out_features = child.base_layer.weight.shape[1]
                else:
                    # Generic fallback if weight exists
                    in_features = child.base_layer.weight.shape[1] if hasattr(child.base_layer, "weight") else 0
                    out_features = child.base_layer.weight.shape[0] if hasattr(child.base_layer, "weight") else 0

                params = self.get_or_init_params(layer_id, full_name, in_features, out_features)
                child.lora_A = params['A']
                child.lora_B = params['B']
                continue

            in_features = None
            out_features = None
            
            if isinstance(child, nn.Linear) and is_target:
                in_features = child.in_features
                out_features = child.out_features
            elif Conv1D is not None and isinstance(child, Conv1D) and is_target:
                in_features = child.weight.shape[0]
                out_features = child.weight.shape[1]
                
            if in_features is not None and out_features is not None:
                params = self.get_or_init_params(layer_id, full_name, in_features, out_features)
                
                lora_layer = LoRAWrapper(
                    base_layer=child,
                    rank=self.rank,
                    alpha=self.alpha,
                    lora_A=params['A'],
                    lora_B=params['B']
                )
                setattr(module, name, lora_layer)
            else:
                self.apply_lora(layer_id, child, full_name)

    def update_lora_params(self, layer_id: int, module: nn.Module):
        """
        Efficiently updates LoRA parameters for a reused module.
        Uses cached wrapper list if available, otherwise traverses and builds cache.
        """
        if not hasattr(module, "_lora_cache"):
            module._lora_cache = []
            # First time: Traverse and collect wrappers
            # We reuse apply_lora logic but adapted for collection
            self._collect_and_update_wrappers(layer_id, module, "", module._lora_cache)
        else:
            # Fast path: Update parameters from cache
            for wrapper, name, in_f, out_f in module._lora_cache:
                params = self.get_or_init_params(layer_id, name, in_f, out_f)
                wrapper.lora_A = params['A']
                wrapper.lora_B = params['B']

    def _collect_and_update_wrappers(self, layer_id: int, module: nn.Module, prefix: str, cache: List):
        for name, child in module.named_children():
            full_name = f"{prefix}.{name}" if prefix else name
            
            if isinstance(child, LoRAWrapper):
                # Already wrapped (from previous usage or just now)
                in_f = child.base_layer.in_features if hasattr(child.base_layer, "in_features") else child.base_layer.weight.shape[1]
                out_f = child.base_layer.out_features if hasattr(child.base_layer, "out_features") else child.base_layer.weight.shape[0]
                
                params = self.get_or_init_params(layer_id, full_name, in_f, out_f)
                child.lora_A = params['A']
                child.lora_B = params['B']
                
                cache.append((child, full_name, in_f, out_f))
                continue
            
            # Check if this is a target module to wrap
            is_target = any(name == target or name.endswith(target) for target in self.target_modules)
            
            if is_target and (isinstance(child, nn.Linear) or (Conv1D is not None and isinstance(child, Conv1D))):
                in_features = child.in_features if isinstance(child, nn.Linear) else child.weight.shape[0]
                out_features = child.out_features if isinstance(child, nn.Linear) else child.weight.shape[1]
                
                params = self.get_or_init_params(layer_id, full_name, in_features, out_features)
                
                lora_layer = LoRAWrapper(
                    base_layer=child,
                    rank=self.rank,
                    alpha=self.alpha,
                    lora_A=params['A'],
                    lora_B=params['B']
                )
                setattr(module, name, lora_layer)
                cache.append((lora_layer, full_name, in_features, out_features))
            else:
                self._collect_and_update_wrappers(layer_id, child, full_name, cache)

    def get_trainable_parameters(self) -> List[torch.nn.Parameter]:
        """
        Returns a list of all nn.Parameter objects managed by this manager.
        """
        all_params = []
        for p_dict in self.params.values():
            all_params.append(p_dict['A'])
            all_params.append(p_dict['B'])
        return all_params

    def save_pretrained(self, save_directory: str):
        import os
        os.makedirs(save_directory, exist_ok=True)
        # Filter for only LoRA weights
        state_dict = {}
        for key, p_dict in self.params.items():
            state_dict[f"{key}.lora_A"] = p_dict['A'].data.cpu()
            state_dict[f"{key}.lora_B"] = p_dict['B'].data.cpu()
        
        torch.save(state_dict, os.path.join(save_directory, "adapter_model.bin"))

    def load_pretrained(self, load_directory: str):
        import os
        weight_path = os.path.join(load_directory, "adapter_model.bin")
        if not os.path.exists(weight_path):
            raise FileNotFoundError(f"Adapter weights not found in {load_directory}")
        
        state_dict = torch.load(weight_path, map_location="cpu")
        for full_key, tensor in state_dict.items():
            # full_key is e.g. "1.self_attn.q_proj.lora_A"
            parts = full_key.split(".")
            param_type = parts[-1] # lora_A or lora_B
            key = ".".join(parts[:-1]) # e.g. "1.self_attn.q_proj"
            
            if key not in self.params:
                self.params[key] = {}
            
            p_dict = self.params[key]
            if param_type == "lora_A":
                if 'A' not in p_dict:
                    p_dict['A'] = nn.Parameter(tensor.to(self.device), requires_grad=True)
                else:
                    p_dict['A'].data.copy_(tensor.to(self.device))
            elif param_type == "lora_B":
                if 'B' not in p_dict:
                    p_dict['B'] = nn.Parameter(tensor.to(self.device), requires_grad=True)
                else:
                    p_dict['B'].data.copy_(tensor.to(self.device))

In [None]:
%%writefile src/lema/core/model.py
import torch
import os
from typing import Optional, Dict, Any, Union
from transformers import AutoConfig

from ..config import LemaConfig
from ..models import get_adapter
from .gbi import GlobalBinaryIndex
from .lora import LoRAManager
from .memory import TripleBufferManager

class LemaModel:
    """
    High-level interface for LEMA Models.
    Wraps all low-level components into a single object.
    """
    def __init__(self, config: LemaConfig):
        self.config = config
        
        # 1. Initialize GBI
        self.gbi = GlobalBinaryIndex(config.gbi_path)
        
        # 2. Get HF config for the adapter
        # Try to load from model path or fallback to a default if not found
        try:
            hf_config_obj = AutoConfig.from_pretrained(config.model_name_or_path)
            hf_config_dict = hf_config_obj.to_dict()
        except:
            # Fallback to config dict if AutoConfig fails
            hf_config_dict = config.to_dict()

        # 3. Initialize Adapter
        model_type = config.model_type
        if model_type is None:
            # Auto-detect from path
            path_lower = config.model_name_or_path.lower()
            if "llama" in path_lower or "smollm" in path_lower:
                model_type = "llama"
            elif "gpt2" in path_lower:
                model_type = "gpt2"
            else:
                # Default to llama if unknown but looks like it
                model_type = "llama"
        
        self.adapter = get_adapter(model_type, hf_config_dict)
        
        # 4. Initialize LoRA Manager
        self.lora_manager = LoRAManager({
            "r": config.lora_rank,
            "alpha": config.lora_alpha,
            "target_modules": config.lora_target_modules
        }, device=config.device)
        
        # 5. Initialize Memory Manager
        self.memory = TripleBufferManager(
            self.gbi, 
            self.adapter, 
            device=config.device, 
            strategy=config.strategy
        )

    def get_trainer(self, optimizer: torch.optim.Optimizer):
        """Returns a LemaTrainer instance pre-configured with this model's components."""
        from ..engine.trainer import LemaTrainer
        return LemaTrainer(
            config=self.config,
            model_adapter=self.adapter,
            gbi=self.gbi,
            lora_manager=self.lora_manager,
            optimizer=optimizer,
            memory_manager=self.memory
        )

    @classmethod
    def from_pretrained(cls, path: str, **kwargs):
        """Loads a LEMA model and its adapters from a directory."""
        config = LemaConfig.from_pretrained(path, **kwargs)
        model = cls(config)
        
        # Load adapters if they exist
        if os.path.exists(os.path.join(path, "adapter_model.bin")):
            model.lora_manager.load_pretrained(path)
            
        return model

    def save_pretrained(self, save_directory: str):
        """Saves the configuration and LoRA adapters."""
        self.config.save_pretrained(save_directory)
        self.lora_manager.save_pretrained(save_directory)

    def initialize_lora(self):
        """Pre-initializes all LoRA adapters by constructing and releasing each layer once."""
        for layer in self.adapter.get_layer_metadata():
            if layer['type'] == 'block':
                module = self.adapter.construct_layer_module(layer['id'], None, self.lora_manager)
                if hasattr(self.adapter, "release_layer_module"):
                    self.adapter.release_layer_module(module)

    def get_trainable_parameters(self):
        return self.lora_manager.get_trainable_parameters()

    def to(self, device: str):
        self.config.device = device
        self.lora_manager.device = device
        self.memory.device = device
        return self

In [None]:
%%writefile src/lema/core/gbi.py
import torch
from safetensors import safe_open
from typing import Dict, List, Any, Optional
import os

class GlobalBinaryIndex:
    """
    GBI v0.4: Contiguous Block Access.
    Allows fetching a whole layer as a single byte-range if possible,
    but here we focus on providing tensors for contiguous packing.
    """

    def __init__(self, model_path: str):
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found: {model_path}")
        
        self.model_path = model_path
        self.handle = safe_open(self.model_path, framework="pt", device="cpu")
        self.keys = list(self.handle.keys())

    def load_tensors(self, param_names: List[str], device: str = "cpu") -> Dict[str, torch.Tensor]:
        tensors = {}
        for name in param_names:
            tensors[name] = self.handle.get_tensor(name)
        return tensors

In [None]:
%%writefile src/lema/core/memory.py
import torch
import threading
from typing import Dict, Optional, List, Tuple
from enum import Enum
import gc
from ..config import MemoryStrategy

class TripleBufferManager:
    """
    Unified Memory Manager supporting both Disk-Streaming and RAM-Residency.
    """
    def __init__(self, gbi, adapter, device="cuda", strategy=MemoryStrategy.STREAMING):
        self.gbi = gbi
        self.adapter = adapter
        self.device = device
        self.strategy = strategy
        
        self.is_cuda = self.device.startswith("cuda")
        self.transfer_streams = [torch.cuda.Stream() for _ in range(2)] if self.is_cuda else None
        
        self.layers_meta = self.adapter.get_layer_metadata()
        
        # Calculate max layer size for pre-allocating buffers
        self.max_params = self._calculate_max_params()
        
        # Pre-allocated VRAM slots (Double buffering)
        self.vram_flat_buffers = [
            torch.empty(self.max_params, device=self.device, dtype=torch.float32)
            for _ in range(2)
        ]
        
        # RAM Buffers
        if self.strategy == MemoryStrategy.RESIDENT:
            print(f"LEMA: Initializing RESIDENT strategy (Caching model in RAM)...")
            self.ram_flat_buffers: Dict[int, torch.Tensor] = {}
            self._initialize_full_ram_cache()
        else:
            print(f"LEMA: Initializing STREAMING strategy (Default)...")
            # In streaming mode, we only need 2 RAM slots for the pipeline
            self.ram_flat_buffers: List[torch.Tensor] = [
                torch.empty(self.max_params, device="cpu", dtype=torch.float32).pin_memory() if self.is_cuda else torch.empty(self.max_params, device="cpu", dtype=torch.float32)
                for _ in range(2)
            ]
            self.ram_layer_ids = [-1, -1]

    def _calculate_max_params(self) -> int:
        max_p = 0
        for layer in self.layers_meta:
            names = self.adapter.get_param_names_for_layer(layer['id'])
            current_p = 0
            for name in names:
                meta = self.gbi.handle.get_slice(name)
                current_p += meta.get_shape().numel() if hasattr(meta.get_shape(), 'numel') else torch.Size(meta.get_shape()).numel()
            max_p = max(max_p, current_p)
        return max_p

    def _initialize_full_ram_cache(self):
        """Pre-packs the entire model into pinned RAM."""
        for layer in self.layers_meta:
            layer_id = layer['id']
            self._pack_layer_to_ram(layer_id, is_resident=True)

    def _pack_layer_to_ram(self, layer_id: int, slot: int = 0, is_resident: bool = False):
        """Helper to load a layer from disk and pack it into a flat RAM buffer."""
        param_names = self.adapter.get_param_names_for_layer(layer_id)
        weights = self.gbi.load_tensors(param_names, device="cpu")
        
        if is_resident:
            total_el = sum(w.numel() for w in weights.values())
            buf = torch.empty(total_el, device="cpu", dtype=torch.float32).pin_memory()
        else:
            buf = self.ram_flat_buffers[slot]
            
        offset = 0
        for name in param_names:
            w = weights[name]
            numel = w.numel()
            buf[offset : offset + numel].copy_(w.view(-1))
            offset += numel
            
        if is_resident:
            self.ram_flat_buffers[layer_id] = buf
        else:
            self.ram_layer_ids[slot] = layer_id

    def prefetch_to_ram(self, layer_id: int, ram_slot: int):
        """Stage 1 (Streaming only): Load from Disk to RAM Slot."""
        if self.strategy == MemoryStrategy.RESIDENT:
            return # No-op for resident mode
            
        if self.ram_layer_ids[ram_slot] == layer_id:
            return
        
        self._pack_layer_to_ram(layer_id, ram_slot, is_resident=False)

    def async_transfer_to_vram(self, layer_id: int, vram_slot: int, ram_slot: Optional[int] = None):
        """Stage 2: Async transfer to GPU."""
        if self.strategy == MemoryStrategy.RESIDENT:
            src_buf = self.ram_flat_buffers[layer_id]
        else:
            if ram_slot is None:
                raise ValueError("ram_slot must be provided in streaming mode")
            src_buf = self.ram_flat_buffers[ram_slot]
            
        vram_dest = self.vram_flat_buffers[vram_slot]
        
        if self.is_cuda and self.transfer_streams:
            stream = self.transfer_streams[vram_slot]
            with torch.cuda.stream(stream):
                vram_dest[:src_buf.numel()].copy_(src_buf, non_blocking=True)
        else:
            # CPU or Synchronous copy
            vram_dest[:src_buf.numel()].copy_(src_buf)

    def get_vram_flat_buffer(self, vram_slot: int) -> torch.Tensor:
        """Stage 3: Usage."""
        if self.is_cuda and self.transfer_streams:
            self.transfer_streams[vram_slot].synchronize()
        return self.vram_flat_buffers[vram_slot]

    def clear_vram_slot(self, vram_slot: int):
        pass

In [None]:
%%writefile src/lema/models/base.py
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Tuple, Optional
import torch
import torch.nn as nn

class LemaModelAdapter(ABC):
    """
    Abstract Base Class for LEMA Model Adapters.
    
    This class defines the interface that any model architecture must implement
    to be compatible with the LEMA (Layer-wise Efficient Memory Abstraction) framework.
    It bridges the gap between the raw binary weights managed by LEMA and the 
    PyTorch execution semantics.
    """

    def __init__(self, config: Dict[str, Any]):
        self.config = config

    @abstractmethod
    def get_layer_metadata(self) -> List[Dict[str, Any]]:
        """
        Returns a list of dictionaries, where each dictionary describes a logical "layer"
        or "block" in the model that LEMA should manage as a unit.
        
        Returns:
            List[Dict]: e.g. [{'id': 0, 'name': 'transformer.h.0', 'inputs': [...], 'outputs': [...]}, ...]
        """
        pass

    @abstractmethod
    def construct_layer_module(self, layer_id: int, weights: Dict[str, torch.Tensor], lora_manager: Optional[Any] = None) -> nn.Module:
        """
        Constructs a PyTorch nn.Module for the specified layer using the provided weights.
        The weights will be on the target device (VRAM) when passed here.
        
        Args:
            layer_id (int): The index of the layer to construct.
            weights (Dict[str, torch.Tensor]): A dictionary mapping parameter names to tensors.
            lora_manager (Optional[Any]): The LoRAManager instance to apply adapters.
            
        Returns:
            nn.Module: The executable layer module.
        """
        pass

    @abstractmethod
    def forward_layer(self, layer_module: nn.Module, inputs: Any, **kwargs) -> Any:
        """
        Executes the forward pass for a single layer.
        
        Args:
            layer_module (nn.Module): The module constructed by construct_layer_module.
            inputs (Any): The input activations (tensor or tuple of tensors).
            **kwargs: Additional arguments (e.g., attention masks, rotary embeddings).
            
        Returns:
            Any: The output activations.
        """
        pass

    @abstractmethod
    def get_param_names_for_layer(self, layer_id: int) -> List[str]:
        """
        Returns the list of parameter names (as found in the safetensors file) 
        required for the specified layer.
        
        Args:
            layer_id (int): Layer index.
            
        Returns:
            List[str]: List of parameter keys.
        """
        pass

    @property
    @abstractmethod
    def hidden_size(self) -> int:
        """Returns the model's hidden size for buffer allocation."""
        pass


In [None]:
%%writefile src/lema/models/__init__.py
from .base import LemaModelAdapter
from .llama import LlamaAdapter
from .gpt2 import GPT2Adapter

_ADAPTER_REGISTRY = {
    "llama": LlamaAdapter,
    "gpt2": GPT2Adapter
}

def get_adapter(model_type: str, config: dict) -> LemaModelAdapter:
    if model_type not in _ADAPTER_REGISTRY:
        raise ValueError(f"Unknown model type: {model_type}. Available: {list(_ADAPTER_REGISTRY.keys())}")
    return _ADAPTER_REGISTRY[model_type](config)

def register_adapter(model_type: str, adapter_class: type):
    _ADAPTER_REGISTRY[model_type] = adapter_class


In [None]:
%%writefile src/lema/models/gpt2.py
import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Config
from typing import List, Dict, Any, Optional

from .base import LemaModelAdapter

class GPT2Adapter(LemaModelAdapter):
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.hf_config = GPT2Config(**config)
        if getattr(self.hf_config, "_attn_implementation", None) is None:
            self.hf_config._attn_implementation = config.get("attn_implementation", "eager")
        self.layer_pool: List[nn.Module] = []
        self.param_mappings: Dict[int, List[tuple]] = {}
        
    def get_layer_metadata(self) -> List[Dict[str, Any]]:
        layers = []
        layers.append({'id': 0, 'name': 'embeddings', 'type': 'embedding'})
        for i in range(self.hf_config.n_layer):
            layers.append({'id': i + 1, 'name': f'h.{i}', 'type': 'block', 'block_index': i})
        layers.append({'id': self.hf_config.n_layer + 1, 'name': 'head', 'type': 'head'})
        return layers

    def get_param_names_for_layer(self, layer_id: int) -> List[str]:
        if layer_id == 0:
            return ['transformer.wte.weight', 'transformer.wpe.weight']
        elif 1 <= layer_id <= self.hf_config.n_layer:
            idx = layer_id - 1
            prefix = f"transformer.h.{idx}"
            return [
                f"{prefix}.attn.c_attn.weight", f"{prefix}.attn.c_attn.bias",
                f"{prefix}.attn.c_proj.weight", f"{prefix}.attn.c_proj.bias",
                f"{prefix}.ln_1.weight", f"{prefix}.ln_1.bias",
                f"{prefix}.ln_2.weight", f"{prefix}.ln_2.bias",
                f"{prefix}.mlp.c_fc.weight", f"{prefix}.mlp.c_fc.bias",
                f"{prefix}.mlp.c_proj.weight", f"{prefix}.mlp.c_proj.bias",
            ]
        elif layer_id == self.hf_config.n_layer + 1:
            return ['transformer.ln_f.weight', 'transformer.ln_f.bias', 'lm_head.weight']
        return []

    def construct_layer_module(self, layer_id: int, flat_buffer: Optional[torch.Tensor] = None, lora_manager: Optional[Any] = None) -> nn.Module:
        device = flat_buffer.device if flat_buffer is not None else torch.device("cpu")
        module = None
        for i, m in enumerate(self.layer_pool):
            if layer_id == 0 and isinstance(m, GPT2EmbeddingsLayer):
                module = self.layer_pool.pop(i); break
            elif layer_id == self.hf_config.n_layer + 1 and isinstance(m, GPT2HeadLayer):
                module = self.layer_pool.pop(i); break
            elif 1 <= layer_id <= self.hf_config.n_layer and isinstance(m, GPT2Block):
                module = self.layer_pool.pop(i); break
        
        if module is None:
            if layer_id == 0: module = GPT2EmbeddingsLayer(self.hf_config)
            elif layer_id == self.hf_config.n_layer + 1: module = GPT2HeadLayer(self.hf_config)
            else:
                module = GPT2Block(self.hf_config)
            
            # Initialization only
            module.to(device)

        if lora_manager and 1 <= layer_id <= self.hf_config.n_layer:
            lora_manager.update_lora_params(layer_id, module)

        if id(module) not in self.param_mappings:
            self.param_mappings[id(module)] = self._create_mapping(layer_id, module)

        if flat_buffer is not None:
            mapping = self.param_mappings[id(module)]
            offset = 0
            with torch.no_grad():
                for param, numel, shape in mapping:
                    param.data.copy_(flat_buffer[offset : offset + numel].view(shape), non_blocking=True)
                    offset += numel
            
        return module

    def _create_mapping(self, layer_id: int, module: nn.Module) -> List[tuple]:
        names = self.get_param_names_for_layer(layer_id)
        idx = layer_id - 1
        module_params = dict(module.named_parameters())
        mapping = []
        for full_name in names:
            if layer_id == 0:
                clean_k = "wte.weight" if "wte" in full_name else "wpe.weight"
            elif layer_id == self.hf_config.n_layer + 1:
                if "ln_f" in full_name: clean_k = "ln_f.weight" if "weight" in full_name else "ln_f.bias"
                else: clean_k = "head.weight"
            else:
                prefix = f"transformer.h.{idx}."
                clean_k = full_name[len(prefix):]
                if clean_k not in module_params: clean_k = clean_k.replace(".weight", ".base_layer.weight").replace(".bias", ".base_layer.bias")
            param = module_params[clean_k]
            mapping.append((param, param.numel(), param.shape))
        return mapping

    def release_layer_module(self, module: nn.Module):
        if len(self.layer_pool) < 5:
            self.layer_pool.append(module)

    def forward_layer(self, layer_module: nn.Module, inputs: Any, **kwargs) -> Any:
        hidden_states = inputs[0] if isinstance(inputs, tuple) else inputs
        if isinstance(layer_module, GPT2Block):
            return layer_module(hidden_states)[0]
        return layer_module(hidden_states)

    @property
    def hidden_size(self) -> int:
        return self.hf_config.n_embd

class GPT2EmbeddingsLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)
    def forward(self, input_ids):
        position_ids = torch.arange(0, input_ids.size(-1), dtype=torch.long, device=input_ids.device).unsqueeze(0)
        return self.wte(input_ids) + self.wpe(position_ids)

class GPT2HeadLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
    def forward(self, x): return self.head(self.ln_f(x))

In [None]:
%%writefile src/lema/models/llama.py
import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaConfig, LlamaRotaryEmbedding
from typing import List, Dict, Any, Optional

from .base import LemaModelAdapter

class LlamaAdapter(LemaModelAdapter):
    def __init__(self, config: Dict[str, Any]):
        super().__init__(config)
        self.hf_config = LlamaConfig(**config)
        if getattr(self.hf_config, "_attn_implementation", None) is None:
            self.hf_config._attn_implementation = config.get("attn_implementation", "eager")
        self.rotary_emb = LlamaRotaryEmbedding(self.hf_config)
        self.layer_pool: List[nn.Module] = []
        self.param_mappings: Dict[int, List[tuple]] = {}
        self._max_pool_size = 8
        
    def get_layer_metadata(self) -> List[Dict[str, Any]]:
        layers = []
        layers.append({'id': 0, 'name': 'embeddings', 'type': 'embedding'})
        for i in range(self.hf_config.num_hidden_layers):
            layers.append({'id': i + 1, 'name': f'layers.{i}', 'type': 'block', 'block_index': i})
        layers.append({'id': self.hf_config.num_hidden_layers + 1, 'name': 'head', 'type': 'head'})
        return layers

    def get_param_names_for_layer(self, layer_id: int) -> List[str]:
        if layer_id == 0:
            return ['model.embed_tokens.weight']
        elif 1 <= layer_id <= self.hf_config.num_hidden_layers:
            idx = layer_id - 1
            prefix = f"model.layers.{idx}"
            return [
                f"{prefix}.input_layernorm.weight",
                f"{prefix}.self_attn.q_proj.weight", f"{prefix}.self_attn.k_proj.weight",
                f"{prefix}.self_attn.v_proj.weight", f"{prefix}.self_attn.o_proj.weight",
                f"{prefix}.post_attention_layernorm.weight",
                f"{prefix}.mlp.gate_proj.weight", f"{prefix}.mlp.up_proj.weight",
                f"{prefix}.mlp.down_proj.weight",
            ]
        elif layer_id == self.hf_config.num_hidden_layers + 1:
            return ['model.norm.weight', 'lm_head.weight']
        return []

    def construct_layer_module(self, layer_id: int, flat_buffer: Optional[torch.Tensor] = None, lora_manager: Optional[Any] = None) -> nn.Module:
        if flat_buffer is not None:
            device = flat_buffer.device
        else:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        module = None
        for i, m in enumerate(self.layer_pool):
            if layer_id == 0 and isinstance(m, LlamaEmbeddingsLayer):
                module = self.layer_pool.pop(i); break
            elif layer_id == self.hf_config.num_hidden_layers + 1 and isinstance(m, LlamaHeadLayer):
                module = self.layer_pool.pop(i); break
            elif 1 <= layer_id <= self.hf_config.num_hidden_layers and isinstance(m, LlamaDecoderLayer):
                module = self.layer_pool.pop(i); break
        
        if module is None:
            if layer_id == 0: module = LlamaEmbeddingsLayer(self.hf_config, None)
            elif layer_id == self.hf_config.num_hidden_layers + 1: module = LlamaHeadLayer(self.hf_config, None)
            else:
                module = LlamaDecoderLayer(self.hf_config, layer_idx=0)
            
            # Initialization only: move to target device
            module.to(device)

        if lora_manager and 1 <= layer_id <= self.hf_config.num_hidden_layers:
            lora_manager.update_lora_params(layer_id, module)

        if id(module) not in self.param_mappings:
            self.param_mappings[id(module)] = self._create_mapping(layer_id, module)

        if flat_buffer is not None:
            mapping = self.param_mappings[id(module)]
            offset = 0
            with torch.no_grad():
                for param, numel, shape in mapping:
                    param.data.copy_(flat_buffer[offset : offset + numel].view(shape), non_blocking=True)
                    offset += numel
            
        if hasattr(module, "layer_idx") and 1 <= layer_id <= self.hf_config.num_hidden_layers:
            module.layer_idx = layer_id - 1
        return module

    def _create_mapping(self, layer_id: int, module: nn.Module) -> List[tuple]:
        names = self.get_param_names_for_layer(layer_id)
        idx = layer_id - 1
        module_params = dict(module.named_parameters())
        mapping = []
        for full_name in names:
            if layer_id == 0: clean_k = "embed_tokens.weight"
            elif layer_id == self.hf_config.num_hidden_layers + 1:
                clean_k = "norm.weight" if "model.norm" in full_name else "lm_head.weight"
            else:
                prefix = f"model.layers.{idx}."
                clean_k = full_name[len(prefix):]
                if clean_k not in module_params: clean_k = clean_k.replace(".weight", ".base_layer.weight")
            param = module_params[clean_k]
            mapping.append((param, param.numel(), param.shape))
        return mapping

    def release_layer_module(self, module: nn.Module):
        if len(self.layer_pool) < self._max_pool_size:
            self.layer_pool.append(module)
        else:
            del module
            if torch.cuda.is_available(): torch.cuda.empty_cache()

    def forward_layer(self, layer_module: nn.Module, inputs: Any, **kwargs) -> Any:
        hidden_states = inputs[0] if isinstance(inputs, tuple) else inputs
        
        if isinstance(layer_module, LlamaDecoderLayer):
            batch_size, seq_len = hidden_states.shape[:2]
            device = hidden_states.device
            position_ids = kwargs.get("position_ids")
            if position_ids is None:
                position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
            
            # Compute RoPE (Should be [bs, seq, dim])
            attn = layer_module.self_attn
            try:
                if hasattr(attn, "rotary_emb") and attn.rotary_emb is not None:
                    try: cos, sin = attn.rotary_emb(hidden_states, position_ids)
                    except: cos, sin = attn.rotary_emb(position_ids)
                else:
                    cos, sin = self.rotary_emb(hidden_states, position_ids)
            except Exception:
                head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
                dummy = torch.zeros(batch_size, seq_len, head_dim, device=device)
                cos, sin = self.rotary_emb(dummy, position_ids)

            # Ensure 3D [bs, seq, dim] for transformers broadcasting
            if cos.ndim == 2:
                cos = cos.unsqueeze(0)
                sin = sin.unsqueeze(0)
            elif cos.ndim == 4:
                cos = cos.squeeze(1)
                sin = sin.squeeze(1)
            
            # If batch size mismatch (e.g. rotary_emb returned [1, seq, dim] but bs > 1)
            if cos.shape[0] != batch_size and cos.shape[0] == 1:
                cos = cos.expand(batch_size, -1, -1)
                sin = sin.expand(batch_size, -1, -1)

            attention_mask = kwargs.get("attention_mask")
            if attention_mask is None:
                mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
                mask = torch.triu(mask, diagonal=1)
                attention_mask = mask.view(1, 1, seq_len, seq_len).expand(batch_size, 1, seq_len, seq_len)

            # 2. Manual Forward with Checkpointing
            residual = hidden_states
            hidden_states = layer_module.input_layernorm(hidden_states)
            
            # Wrap Self-Attention for memory efficiency
            def attn_block(x, mask, pids, cos_sin):
                return layer_module.self_attn(
                    hidden_states=x,
                    attention_mask=mask,
                    position_ids=pids,
                    position_embeddings=cos_sin,
                    past_key_value=kwargs.get("past_key_value"),
                    output_attentions=kwargs.get("output_attentions", False),
                    use_cache=kwargs.get("use_cache", False),
                    cache_position=kwargs.get("cache_position")
                )[0]

            if torch.is_grad_enabled() and kwargs.get("gradient_checkpointing", False):
                from torch.utils.checkpoint import checkpoint
                attn_output = checkpoint(attn_block, hidden_states, attention_mask, position_ids, (cos, sin), use_reentrant=False)
            else:
                attn_output = attn_block(hidden_states, attention_mask, position_ids, (cos, sin))
            
            hidden_states = residual + attn_output
            
            # MLP block
            residual = hidden_states
            hidden_states = layer_module.post_attention_layernorm(hidden_states)
            hidden_states = layer_module.mlp(hidden_states)
            hidden_states = residual + hidden_states
            
            return hidden_states
                
        return layer_module(hidden_states)

    @property
    def hidden_size(self) -> int:
        return self.hf_config.hidden_size

class LlamaEmbeddingsLayer(nn.Module):
    def __init__(self, config, weights):
        super().__init__()
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
    def forward(self, x): return self.embed_tokens(x)

class LlamaHeadLayer(nn.Module):
    def __init__(self, config, weights):
        super().__init__()
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
    def forward(self, x): return self.lm_head(self.norm(x))


In [None]:
%%writefile tasks/fine_tune_llama_7b.py
import torch
import os
import time
from transformers import AutoTokenizer
from lema import LemaConfig, LemaModel, MemoryStrategy
from lema.utils.model_utils import prepare_monolithic_safetensors

MODEL_NAME = "NousResearch/Llama-2-7b-hf"
MODEL_PATH = "llama2_7b.safetensors"

TRAINING_DATA = [
    "What is photosynthesis? Photosynthesis is the process by which plants use sunlight to synthesize nutrients from carbon dioxide and water.",
    "Who was Albert Einstein? Albert Einstein was a theoretical physicist who developed the theory of relativity.",
    "What is the capital of France? The capital of France is Paris.",
    "Explain gravity. Gravity is a natural phenomenon by which all things with mass or energy are brought toward one another.",
    "What is LEMA? LEMA is a framework that virtualizes GPU memory to enable training large models on limited hardware.",
] * 5

def fine_tune_llama_7b_task():
    print("--- STARTING LEMA 7B FINE-TUNING ---")
    
    if not os.path.exists(MODEL_PATH):
        print(f"Preparing {MODEL_PATH}...")
        prepare_monolithic_safetensors(MODEL_NAME, MODEL_PATH)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    
    config = LemaConfig(
        model_name_or_path=MODEL_NAME,
        gbi_path=MODEL_PATH,
        device="cuda",
        strategy=MemoryStrategy.STREAMING,
        learning_rate=5e-5,
        lora_rank=16,
        gradient_checkpointing=True
    )
    
    model = LemaModel(config)
    model.initialize_lora()
    
    optimizer = torch.optim.AdamW(model.get_trainable_parameters(), lr=config.learning_rate)
    trainer = model.get_trainer(optimizer)
    
    print(f"\nTraining on {len(TRAINING_DATA)} examples...")
    start_time = time.time()
    for text in TRAINING_DATA:
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
        input_ids = inputs["input_ids"].to("cuda")
        logits, loss = trainer.train_step(input_ids, labels=input_ids)
            
    print(f"Training completed in {time.time() - start_time:.2f} seconds.")
    trainer.save_checkpoint("output/llama-7b-final")

if __name__ == "__main__":
    fine_tune_llama_7b_task()

In [None]:
%%writefile tasks/fine_tune_llama_7b_config.py
import torch
import os
import time
from transformers import AutoTokenizer, AutoConfig
from src.lema.core.gbi import GlobalBinaryIndex
from src.lema.models.llama import LlamaAdapter
from src.lema.engine.trainer import LemaTrainer
from src.lema.core.lora import LoRAManager
from src.lema.config import LemaConfig, MemoryStrategy

MODEL_NAME = "NousResearch/Llama-2-7b-hf"
MODEL_PATH = "llama2_7b.safetensors"

TRAINING_DATA = [
    "What is photosynthesis? Photosynthesis is the process by which plants use sunlight to synthesize nutrients from carbon dioxide and water.",
    "Who was Albert Einstein? Albert Einstein was a theoretical physicist who developed the theory of relativity.",
    "What is the capital of France? The capital of France is Paris.",
    "Explain gravity. Gravity is a natural phenomenon by which all things with mass or energy are brought toward one another.",
    "What is LEMA? LEMA is a framework that virtualizes GPU memory to enable training large models on limited hardware.",
] * 10

def fine_tune_llama_7b_with_config():
    print("--- STARTING LEMA 7B FINE-TUNING (Config Object) ---")
    
    # 1. Setup Configuration
    config = LemaConfig(
        model_name_or_path=MODEL_PATH,
        device="cuda",
        strategy=MemoryStrategy.STREAMING,
        lora_rank=16,
        lora_alpha=32,
        learning_rate=5e-5,
        dtype="float16"
    )
    
    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    
    # HF Config (for model architecture)
    hf_config = AutoConfig.from_pretrained(MODEL_NAME)
    model_config = hf_config.to_dict()
    model_config["attn_implementation"] = config.attn_implementation
    model_config["torch_dtype"] = config.dtype
    
    # Components
    adapter = LlamaAdapter(model_config)
    gbi = GlobalBinaryIndex(config.gbi_path)
    
    # LoRA
    lora_config_dict = {
        "r": config.lora_rank, 
        "alpha": config.lora_alpha, 
        "target_modules": config.lora_target_modules
    }
    lora_manager = LoRAManager(lora_config_dict, device=config.device)
    
    print("Initializing LoRA parameters...")
    for layer in adapter.get_layer_metadata():
        if layer['type'] == 'block':
            module = adapter.construct_layer_module(layer['id'], None, lora_manager)
            adapter.release_layer_module(module)
    torch.cuda.empty_cache()
    
    trainable_params = lora_manager.get_trainable_parameters()
    optimizer = torch.optim.AdamW(trainable_params, lr=config.learning_rate)
    
    # Trainer
    trainer = LemaTrainer(
        config=config,
        model_adapter=adapter, 
        gbi=gbi, 
        lora_manager=lora_manager, 
        optimizer=optimizer
    )
    
    # Training Loop
    print(f"\nTraining on {len(TRAINING_DATA)} examples...")
    for epoch in range(1):
        total_loss = 0
        for i, text in enumerate(TRAINING_DATA):
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
            input_ids = inputs["input_ids"].to(config.device)
            
            logits, loss = trainer.train_step(input_ids, labels=input_ids)
            total_loss += loss
            
            if (i+1) % 10 == 0:
                print(f"Step {i+1}/{len(TRAINING_DATA)} - Current Loss: {loss:.4f}")
                
        print(f"Epoch {epoch+1} - Avg Loss: {total_loss / len(TRAINING_DATA):.4f}")
    
    print("Fine-tuning with Config Object completed successfully.")

    # Cleanup
    if os.path.exists(MODEL_PATH):
        os.remove(MODEL_PATH)


In [None]:
%%writefile tasks/fine_tune_smollm.py
import torch
import os
import time
from transformers import AutoTokenizer
from lema import LemaConfig, LemaModel, MemoryStrategy
from lema.utils.model_utils import prepare_monolithic_safetensors

MODEL_NAME = "HuggingFaceTB/SmolLM2-1.7B"
MODEL_PATH = "smollm2_1.7b.safetensors"

TRAINING_DATA = [
    "What is photosynthesis? Photosynthesis is the process by which plants use sunlight to synthesize nutrients from carbon dioxide and water.",
    "Who was Albert Einstein? Albert Einstein was a theoretical physicist who developed the theory of relativity.",
    "What is the capital of France? The capital of France is Paris.",
] * 10

def fine_tune_smollm():
    print("--- STARTING SMOL-LM FINE-TUNING ---")
    
    if not os.path.exists(MODEL_PATH):
        print(f"Preparing {MODEL_PATH}...")
        prepare_monolithic_safetensors(MODEL_NAME, MODEL_PATH)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    
    config = LemaConfig(
        model_name_or_path=MODEL_NAME,
        gbi_path=MODEL_PATH,
        device="cuda",
        strategy=MemoryStrategy.STREAMING,
        learning_rate=5e-5,
        lora_rank=16
    )
    
    model = LemaModel(config)
    model.initialize_lora()
    
    optimizer = torch.optim.AdamW(model.get_trainable_parameters(), lr=config.learning_rate)
    trainer = model.get_trainer(optimizer)
    
    start_time = time.time()
    for text in TRAINING_DATA:
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
        input_ids = inputs["input_ids"].to("cuda")
        trainer.train_step(input_ids, labels=input_ids)
            
    print(f"Training completed in {time.time() - start_time:.2f} seconds.")
    trainer.save_checkpoint("output/smollm-final")

if __name__ == "__main__":
    fine_tune_smollm()


In [None]:
%%writefile examples/demo_lema.py
import torch
import os
from lema import LemaConfig, LemaModel, MemoryStrategy
from lema.utils.model_utils import break_shared_weights
from transformers import GPT2Config, GPT2LMHeadModel
from safetensors.torch import save_file

def run_demo():
    print("--- LEMA Unified API Demo ---")

    model_dir = "./demo_model"
    gbi_path = os.path.join(model_dir, "model.safetensors")
    
    # 1. Configuration
    config = LemaConfig(
        model_name_or_path=model_dir, 
        model_type="gpt2",
        gbi_path=gbi_path, 
        device="cpu",
        strategy=MemoryStrategy.STREAMING,
        lora_rank=8,
        lora_target_modules=["c_attn"],
        output_dir="./lema_checkpoints",
        save_steps=10
    )

    # 2. Initialize Model & Trainer
    model = LemaModel(config)
    model.initialize_lora()
    
    optimizer = torch.optim.AdamW(model.get_trainable_parameters(), lr=1e-4)
    trainer = model.get_trainer(optimizer)

    # 3. Execution
    print("Executing training step...")
    input_ids = torch.randint(0, 1000, (1, 16))
    logits, loss = trainer.train_step(input_ids, labels=input_ids)
    
    print(f"Step complete. Loss: {loss:.4f}")
    trainer.save_checkpoint("./lema_checkpoints/final_demo")

if __name__ == "__main__":
    model_dir = "./demo_model"
    if not os.path.exists(os.path.join(model_dir, "model.safetensors")):
        print("Generating dummy model...")
        os.makedirs(model_dir, exist_ok=True)
        dummy_config = GPT2Config(n_layer=2, n_embd=128, n_head=4, vocab_size=1000)
        dummy_model = GPT2LMHeadModel(dummy_config)
        dummy_model = break_shared_weights(dummy_model)
        
        state_dict = {k: v.clone().detach() for k, v in dummy_model.state_dict().items()}
        save_file(state_dict, os.path.join(model_dir, "model.safetensors"))
        dummy_config.save_pretrained(model_dir)
    
    run_demo()

In [None]:
%%writefile examples/benchmark_runner.py
import sys
import os
import resource
import torch
import time
from transformers import GPT2Config, GPT2LMHeadModel

# Add src to path if not installed
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src")))

from lema.core.gbi import GlobalBinaryIndex
from lema.models.gpt2 import GPT2Adapter
from lema.engine.trainer import LemaTrainer
from lema.core.lora import LoRAManager
from lema.config import LemaConfig, MemoryStrategy

def get_peak_rss_mb():
    # ru_maxrss is in kilobytes on Linux
    return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024

def run_baseline(model_path):
    print("--- Running BASELINE ---")
    start_ram = get_peak_rss_mb()
    
    # 1. Config (Matching dummy_gpt2 if path matches)
    if "dummy" in model_path:
        config = GPT2Config(vocab_size=1000, n_positions=128, n_embd=64, n_layer=4, n_head=4)
    else:
        config = GPT2Config(vocab_size=50257, n_positions=1024, n_embd=768, n_layer=12, n_head=12)
    
    # 2. Instantiate Model
    print("Instantiating Model...")
    model = GPT2LMHeadModel(config)
    
    # 3. Load Weights
    print("Loading Weights...")
    from safetensors.torch import load_file
    state_dict = load_file(model_path)
    model.load_state_dict(state_dict, strict=False)
    
    print(f"Model loaded. RAM: {get_peak_rss_mb():.2f} MB")
    
    # 4. Forward Pass
    print("Forward Pass...")
    input_ids = torch.randint(0, config.vocab_size, (1, 64))
    output = model(input_ids)
    
    # 5. Backward Pass
    print("Backward Pass...")
    loss = output.logits.mean()
    loss.backward()
    
    peak_ram = get_peak_rss_mb()
    print(f"Baseline Peak RSS: {peak_ram:.2f} MB")
    return peak_ram

def run_lema(model_path):
    print("--- Running LEMA ---")
    start_ram = get_peak_rss_mb()
    
    # 1. Config
    if "dummy" in model_path:
        hf_config = {"vocab_size": 1000, "n_positions": 128, "n_embd": 64, "n_layer": 4, "n_head": 4, "attn_implementation": "eager"}
    else:
        hf_config = {"vocab_size": 50257, "n_positions": 1024, "n_embd": 768, "n_layer": 12, "n_head": 12, "attn_implementation": "eager"}
    
    lema_config = LemaConfig(model_name_or_path=model_path, device="cpu", strategy=MemoryStrategy.STREAMING)
    
    # 2. Components
    print("Initializing Components...")
    adapter = GPT2Adapter(hf_config)
    gbi = GlobalBinaryIndex(model_path)
    
    # LoRA
    lora_config = {"r": 8, "alpha": 16, "target_modules": ["c_attn"]}
    lora_manager = LoRAManager(lora_config, device="cpu")
    
    # Init LoRA params
    for layer in adapter.get_layer_metadata():
        if layer['type'] == 'block':
            module = adapter.construct_layer_module(layer['id'], None, lora_manager)
            adapter.release_layer_module(module)
            
    optimizer = torch.optim.AdamW(lora_manager.get_trainable_parameters(), lr=1e-4)
    trainer = LemaTrainer(lema_config, adapter, gbi, lora_manager=lora_manager, optimizer=optimizer)
    
    print(f"Components Ready. RAM: {get_peak_rss_mb():.2f} MB")
    
    # 3. Train Step
    print("Training Step...")
    input_ids = torch.randint(0, hf_config["vocab_size"], (1, 64))
    trainer.train_step(input_ids, labels=input_ids)
    
    peak_ram = get_peak_rss_mb()
    print(f"LEMA Peak RSS: {peak_ram:.2f} MB")
    return peak_ram

if __name__ == "__main__":
    if len(sys.argv) < 3:
        print("Usage: python examples/benchmark_runner.py [mode] [model_path]")
        sys.exit(1)
        
    mode = sys.argv[1]
    path = sys.argv[2]
    
    if mode == "baseline":
        run_baseline(path)
    elif mode == "lema":
        run_lema(path)
    else:
        print("Unknown mode")

In [None]:
%%writefile examples/generate_medium_gpt2.py
import torch
from transformers import GPT2Config, GPT2LMHeadModel
from safetensors.torch import save_file
import os

config = GPT2Config(
    vocab_size=50257,
    n_positions=1024,
    n_embd=768,
    n_layer=12,
    n_head=12
)

print("Creating GPT-2 Small model...")
model = GPT2LMHeadModel(config)

# Break shared weights
model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())

print("Saving to medium_gpt2.safetensors...")
save_file(model.state_dict(), "medium_gpt2.safetensors")

size_bytes = os.path.getsize("medium_gpt2.safetensors")
print(f"Created medium_gpt2.safetensors: {size_bytes / 1024 / 1024:.2f} MB")


In [None]:
%%writefile examples/generate_dummy_gpt2.py
import torch
from transformers import GPT2Config, GPT2LMHeadModel
from safetensors.torch import save_file

config = GPT2Config(
    vocab_size=1000,
    n_positions=128,
    n_embd=64,
    n_layer=4,
    n_head=4
)
model = GPT2LMHeadModel(config)

# Break shared weights for safetensors
model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())

# Save to safetensors
save_file(model.state_dict(), "dummy_gpt2.safetensors")
print("Created dummy_gpt2.safetensors")


In [None]:
%%writefile examples/kaggle/benchmark_logic.py
import torch
import gc
import os
import transformers
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
from safetensors.torch import save_file

# LEMA Imports
from src.lema.core.gbi import GlobalBinaryIndex
from src.lema.models.llama import LlamaAdapter
from src.lema.models.gpt2 import GPT2Adapter
from src.lema.engine.trainer import LemaTrainer
from src.lema.core.lora import LoRAManager
from src.lema.config import LemaConfig, MemoryStrategy

print(f"Using Transformers version: {transformers.__version__}")

# --- MODELS TO TEST ---
MODELS = [
    {
        "name": "GPT2 (Small)",
        "hf_id": "gpt2",
        "path": "gpt2.safetensors",
        "type": "gpt2"
    },
    {
        "name": "TinyLlama 1.1B",
        "hf_id": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        "path": "tinyllama_1b.safetensors",
        "type": "llama"
    },
    {
        "name": "SmolLM2 1.7B",
        "hf_id": "HuggingFaceTB/SmolLM2-1.7B",
        "path": "smollm2_1.7b.safetensors",
        "type": "llama"
    },
    {
        "name": "Llama-2 7B",
        "hf_id": "NousResearch/Llama-2-7b-hf",
        "path": "llama2_7b.safetensors",
        "type": "llama"
    }
]

def download_and_convert(model_info):
    print(f"\n--- Preparing {model_info['name']} ---")
    if os.path.exists(model_info['path']):
        print(f"{model_info['path']} already exists.")
        return

    print(f"Downloading {model_info['hf_id']}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_info['hf_id'], 
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="cpu"
    )
    
    # Break shared weights if necessary
    if hasattr(model, "lm_head") and hasattr(model, "model") and hasattr(model.model, "embed_tokens"):
         if model.lm_head.weight.data_ptr() == model.model.embed_tokens.weight.data_ptr():
             model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())
    elif hasattr(model, "lm_head") and hasattr(model, "transformer") and hasattr(model.transformer, "wte"):
        # GPT2 shared weights
        if model.lm_head.weight.data_ptr() == model.transformer.wte.weight.data_ptr():
             model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())

    print(f"Saving to {model_info['path']}...")
    save_file(model.state_dict(), model_info['path'])
    del model
    gc.collect()

from peft import get_peft_model, LoraConfig

def run_peft_baseline(model_info):
    print(f"\n>>> TESTING STANDARD PEFT ON: {model_info['name']} <<<")
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    download_and_convert(model_info)
    
    try:
        # Load Model in FP16
        model = AutoModelForCausalLM.from_pretrained(
            model_info['hf_id'],
            torch_dtype=torch.float16,
            device_map="cuda"
        )
        
        # Configure LoRA
        target_modules = ["c_attn", "c_proj", "c_fc"] if model_info['type'] == "gpt2" else \
                         ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
                         
        peft_config = LoraConfig(
            r=16, lora_alpha=32,
            target_modules=target_modules,
            lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
        )
        
        model = get_peft_model(model, peft_config)
        print(f"PEFT Trainable params: {model.print_trainable_parameters()}")
        
        # Dummy Train Step
        input_ids = torch.randint(0, 1000, (1, 128)).cuda()
        output = model(input_ids, labels=input_ids)
        loss = output.loss
        loss.backward()
        
        peak_vram = torch.cuda.max_memory_allocated() / 1024**3
        print(f"Standard PEFT Peak VRAM: {peak_vram:.2f} GB")
        
        del model
        torch.cuda.empty_cache()
        return peak_vram
        
    except Exception as e:
        print(f"PEFT Baseline Failed: {e}")
        return float('inf')

def run_test(model_info):
    print(f"\n>>> TESTING LEMA ON: {model_info['name']} <<<")
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    
    download_and_convert(model_info)
    
    try:
        # 1. Config
        hf_config = AutoConfig.from_pretrained(model_info['hf_id'])
        hf_config_dict = hf_config.to_dict()
        hf_config_dict["attn_implementation"] = "eager"
        hf_config_dict["torch_dtype"] = "float16"
        
        # 2. Components
        if model_info['type'] == "llama":
            adapter = LlamaAdapter(hf_config_dict)
        elif model_info['type'] == "gpt2":
            adapter = GPT2Adapter(hf_config_dict)
        else:
            raise ValueError(f"Unknown type: {model_info['type']}")
            
        gbi = GlobalBinaryIndex(model_info['path'])
        
        # LEMA Config
        lema_config = LemaConfig(
            model_name_or_path=model_info['path'],
            device="cuda",
            strategy=MemoryStrategy.STREAMING,
            learning_rate=1e-4
        )
        
        # LoRA - ALL LINEAR LAYERS
        target_modules = ["c_attn", "c_proj", "c_fc"] if model_info['type'] == "gpt2" else \
                         ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
                         
        lora_config = {
            "r": 16, "alpha": 32, 
            "target_modules": target_modules
        }
        lora_manager = LoRAManager(lora_config, device="cuda")
        
        # Initialize LoRA params (Fix leak)
        print("Initializing LoRA parameters...")
        for layer in adapter.get_layer_metadata():
            if layer['type'] == 'block':
                module = adapter.construct_layer_module(layer['id'], None, lora_manager)
                adapter.release_layer_module(module)
        
        torch.cuda.empty_cache()
        
        trainable_params = lora_manager.get_trainable_parameters()
        print(f"Trainable Tensors: {len(trainable_params)}")
        optimizer = torch.optim.AdamW(trainable_params, lr=lema_config.learning_rate)
        
        trainer = LemaTrainer(
            config=lema_config,
            model_adapter=adapter, 
            gbi=gbi, 
            lora_manager=lora_manager, 
            optimizer=optimizer
        )
        
        # 3. Execution
        print("Executing Train Step...")
        # Create dummy inputs based on vocab size
        vocab_size = hf_config.vocab_size
        input_ids = torch.randint(0, vocab_size, (1, 128)).cuda()
        
        logits, loss = trainer.train_step(input_ids, labels=input_ids)
        
        print(f"Loss: {loss:.4f}")
        peak_vram = torch.cuda.max_memory_allocated() / 1024**3
        print(f"LEMA Peak VRAM: {peak_vram:.2f} GB")
        print(f"RESULT: {model_info['name']} -> SUCCESS")
        return peak_vram
        
    except Exception as e:
        print(f"RESULT: {model_info['name']} -> FAILED")
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        return float('inf')
    finally:
        # Cleanup model file to save disk space on Kaggle
        if os.path.exists(model_info['path']):
            os.remove(model_info['path'])
        gc.collect()
        torch.cuda.empty_cache()

if __name__ == "__main__":
    results = {}
    for model in MODELS:
        peft_vram = run_peft_baseline(model)
        lema_vram = run_test(model)
        results[model['name']] = {"PEFT": peft_vram, "LEMA": lema_vram}
    
    print("\n=== FINAL RESULTS (VRAM in GB) ===")
    print(f"{'Model':<20} | {'PEFT (Baseline)':<15} | {'LEMA (Ours)':<15} | {'Savings':<10}")
    print("-" * 65)
    for name, data in results.items():
        peft = data["PEFT"]
        lema = data["LEMA"]
        savings = (1 - lema/peft) * 100 if peft > 0 else 0
        print(f"{name:<20} | {peft:<15.2f} | {lema:<15.2f} | {savings:<10.1f}%")

In [None]:
%%writefile examples/kaggle/speed_benchmark.py
import torch
import gc
import os
import time
import transformers
from transformers import AutoModelForCausalLM, AutoConfig
from peft import get_peft_model, LoraConfig
from safetensors.torch import save_file

# LEMA Unified Imports
from lema import LemaConfig, LemaModel, MemoryStrategy
from lema.utils.model_utils import prepare_monolithic_safetensors

print(f"Using Transformers version: {transformers.__version__}")

MODELS = [
    {"name": "TinyLlama 1.1B", "hf_id": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "path": "tinyllama_1b.safetensors", "type": "llama"},
    {"name": "Llama-2 7B", "hf_id": "NousResearch/Llama-2-7b-hf", "path": "llama2_7b.safetensors", "type": "llama"}
]

NUM_STEPS = 20 

def benchmark_peft_speed(model_info):
    print(f"\n>>> BENCHMARKING PEFT SPEED: {model_info['name']} <<<")
    torch.cuda.empty_cache()
    try:
        model = AutoModelForCausalLM.from_pretrained(model_info['hf_id'], torch_dtype=torch.float16, device_map="cuda")
        model.gradient_checkpointing_enable()
        peft_config = LoraConfig(
            r=16, lora_alpha=32, 
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
            task_type="CAUSAL_LM"
        )
        model = get_peft_model(model, peft_config)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
        input_ids = torch.randint(0, 1000, (1, 512)).cuda()
        
        # Warmup
        model(input_ids, labels=input_ids).loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        torch.cuda.synchronize()
        
        start_time = time.time()
        for _ in range(NUM_STEPS):
            model(input_ids, labels=input_ids).loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        torch.cuda.synchronize()
        
        avg_time = (time.time() - start_time) / NUM_STEPS
        print(f"PEFT Avg Time/Step: {avg_time:.4f}s")
        return avg_time
    except Exception as e:
        print(f"PEFT Benchmark Failed: {e}")
        return float('inf')

def benchmark_lema_speed(model_info):
    print(f"\n>>> BENCHMARKING LEMA SPEED: {model_info['name']} <<<")
    torch.cuda.empty_cache()
    
    if not os.path.exists(model_info['path']):
        print(f"Preparing {model_info['path']}...")
        prepare_monolithic_safetensors(model_info['hf_id'], model_info['path'])
    
    use_gc = "7b" in model_info['hf_id'].lower()
    
    try:
        config = LemaConfig(
            model_name_or_path=model_info['hf_id'],
            gbi_path=model_info['path'],
            device="cuda",
            strategy=MemoryStrategy.STREAMING,
            learning_rate=1e-4,
            gradient_checkpointing=use_gc,
            lora_rank=16
        )
        
        model = LemaModel(config)
        model.initialize_lora()
        
        optimizer = torch.optim.AdamW(model.get_trainable_parameters(), lr=config.learning_rate)
        trainer = model.get_trainer(optimizer)
        
        input_ids = torch.randint(0, 1000, (1, 512)).cuda()
        trainer.train_step(input_ids, labels=input_ids) # Warmup
        torch.cuda.synchronize()
        
        start_time = time.time()
        for _ in range(NUM_STEPS):
            trainer.train_step(input_ids, labels=input_ids)
        torch.cuda.synchronize()
        
        avg_time = (time.time() - start_time) / NUM_STEPS
        print(f"LEMA Avg Time/Step: {avg_time:.4f}s")
        return avg_time
    except Exception as e:
        print(f"LEMA Benchmark Failed: {e}")
        return float('inf')
    finally:
        if os.path.exists(model_info['path']): os.remove(model_info['path'])

if __name__ == "__main__":
    results = {}
    for model in MODELS:
        peft = benchmark_peft_speed(model)
        lema = benchmark_lema_speed(model)
        results[model['name']] = {"PEFT": peft, "LEMA": lema}
    
    print("\n=== UNIFIED SPEED BENCHMARK RESULTS ===")
    for name, data in results.items():
        p, l = data["PEFT"], data["LEMA"]
        overhead = (l / p) if p > 0 and p != float('inf') else float('inf')
        print(f"{name}: PEFT={p:.4f}s, LEMA={l:.4f}s, Overhead={overhead:.2f}x")


### docs/ARCHITECTURE.md

# LEMA Architecture

This document describes the internal mechanics of the Layer-wise Efficient Memory Abstraction (LEMA) framework.

## The Problem: The VRAM Wall
Standard fine-tuning (even with PEFT/LoRA) requires the entire model weights to be resident in VRAM. For a Llama-2 7B model in FP16, this is ~14GB. Adding optimizer states and activations quickly exceeds the capacity of consumer GPUs (e.g., 16GB).

## The LEMA Solution: Virtualization
LEMA treats GPU VRAM not as a static storage for the model, but as a **dynamic cache** for execution.

### 1. The Triple-Buffer Strategy
LEMA hides data transfer latency by pipelining movements across three memory tiers:

1.  **Storage (NVMe)**: Weights reside in `.safetensors` files. Accessed via `mmap` (Zero-copy).
2.  **System RAM (Pinned)**: Acting as a "Prefetch Buffer". Pinned memory ensures high-speed Host-to-Device (H2D) transfers.
3.  **VRAM (Execution)**: Divided into two "Slots" (Active and Prefetch).

### 2. The Execution Pipeline
While the GPU is computing Layer $N$ in Slot A, LEMA is:
-   Asynchronously transferring Layer $N+1$ from RAM to Slot B (VRAM).
-   Loading Layer $N+2$ from Disk to RAM (Staging).

When Layer $N$ finishes, the slots swap instantly.

### 3. The LEMA-Loop (Training Logic)

#### Forward Pass
-   Model is executed layer-by-layer.
-   Only "Boundary Activations" (the output of each layer) are stored in VRAM.
-   Intermediate activations are discarded.

#### Backward Pass
-   LEMA traverses the layers in reverse.
-   For each layer:
    1.  The weights are swapped back into VRAM.
    2.  The layer's forward pass is **re-executed** (Segmented Gradient Checkpointing) using the stored boundary activations.
    3.  Gradients are calculated for the LoRA adapters.
    4.  Optimizer states for those specific adapters are updated.

### 4. GBI (Global Binary Index)
LEMA uses a specialized indexer to bypass standard PyTorch/Pickle deserialization. By reading the `.safetensors` header, LEMA knows the exact byte offsets for every parameter, allowing it to "slice" the file and load only the parameters needed for the current layer module.

## Performance Trade-offs
-   **VRAM Efficiency**: ~50-70% reduction for 7B+ models.
-   **Compute Overhead**: 1.2x - 1.8x slowdown compared to fully resident training, depending on PCIe bandwidth and disk speed.
-   **System RAM**: Requires space equal to the model size (or less if using aggressive disk streaming).


### docs/LEMA Framework Proposal.md

# **LEMA: Layer-wise Efficient Memory Abstraction**

**Architectural Specification for VRAM-Efficient Model Fine-Tuning**

## **1\. Executive Summary**

LEMA is a specialized framework designed to facilitate the fine-tuning of Large Language Models (LLMs) on hardware where model size exceeds available VRAM. Unlike standard frameworks that require the full model to be resident in GPU memory, LEMA treats the model as a collection of discrete, addressable binary segments. By implementing a virtualized memory abstraction layer, LEMA performs asynchronous pre-fetching of layers into VRAM, effectively trading PCIe bandwidth for memory headroom.

## **2\. Core Concepts**

### **2.1 Global Binary Index (GBI)**

Standard model loading (e.g., PyTorch .bin or .pt) involves full deserialization into System RAM. LEMA uses a **Global Binary Index (GBI)**.

* **Zero-Copy Mapping:** Uses mmap to map the model file (preferably in .safetensors format) into the process's virtual address space.  
* **Header Indexing:** A JSON/Binary header stores the (offset, size, dtype, shape) for every tensor, allowing O(1) access to specific layer weights without scanning the file.

### **2.2 Layer-wise Execution (Patchwork)**

Instead of a monolithic model.forward(), LEMA decomposes the computational graph into a sequence of isolated layer blocks.

* **Weight Swapping:** Only the current layer ![][image1] and the next layer ![][image2] occupy VRAM.  
* **Persistence:** Model weights remain frozen in System RAM/Disk; only LoRA adapters are maintained in active memory.

## **3\. The Memory Pipeline (The Triple-Buffer Strategy)**

LEMA orchestrates data movement across three tiers to hide the latency of PCIe transfers.

| Tier | Residency | Role |
| :---- | :---- | :---- |
| **Storage (NVMe)** | Global Binary File | The source of truth. Accessed via mmap. |
| **System RAM** | Pinned Memory Buffers | The staging area for the next 2-3 layers. |
| **VRAM** | Active Slot / Prefetch Slot | The execution zone. |

### **Asynchronous Prefetching Logic**

1. **Compute Stream:** GPU calculates the forward pass for Layer ![][image3].  
2. **Transfer Stream:** Simultaneously, the CPU pushes Layer ![][image4] from Pinned RAM to a reserved VRAM buffer.  
3. **Synchronization:** When Layer ![][image3] finishes, the pointers are swapped. Layer ![][image3] is discarded (or moved to RAM if activations are needed), and ![][image4] begins immediate execution.

## **4\. Training Mechanics: The "LEMA-Loop"**

### **4.1 Forward Pass (Activation Management)**

To save VRAM, LEMA implements **Segmented Gradient Checkpointing**:

* Instead of storing activations for all 32 layers, LEMA stores only the "Boundary Activations" (the output of each chunk).  
* Inner-layer activations are discarded and re-computed during the backward pass.

### **4.2 Backward Pass (The Reverse Swap)**

1. Load Layer ![][image3] weights \+ LoRA adapters.  
2. Retrieve Boundary Activation for Layer ![][image5].  
3. Re-run forward pass for Layer ![][image3] to get local activations.  
4. Calculate gradients for Layer ![][image3] LoRA adapters.  
5. Offload Layer ![][image3] weights; move to Layer ![][image5].

### **4.3 Optimizer Offloading**

The **Adam Optimizer states** (Momentum and Variance) are stored in System RAM. During the weight update step, LEMA pulls only the specific optimizer slice for the current layer's adapters into VRAM, performs the update, and pushes it back.

## **5\. Technical Implementation Stack**

* **Host Language:** Python (Orchestration) / C++ (High-speed Memory Management).  
* **Backend:** CUDA / LibTorch.  
* **File Format:** safetensors (Native support for zero-copy mmap).  
* **Memory Management:** \* torch.cuda.Stream for non-blocking transfers.  
  * tensor.pin\_memory() to ensure fast Host-to-Device (H2D) throughput.

## **6\. Comparison with Existing Solutions**

| Metric | Standard LoRA | LEMA |
| :---- | :---- | :---- |
| **VRAM Requirement** | Full Model \+ Gradients | \~2 Layers \+ Buffers |
| **System RAM Usage** | Model Size | Model Size (via mmap/Page Cache) |
| **Speed** | 100% (Baseline) | 30-70% (PCIe Latency) |
| **Model Scalability** | Limited by GPU VRAM | Limited by Disk Space |

### docs/USER_GUIDE.md

# LEMA User Guide

This guide covers common workflows for fine-tuning Large Language Models using LEMA on memory-constrained hardware.

## 1. Preparing Your Model

LEMA requires model weights in `.safetensors` format. If your model is in PyTorch `.bin` format, you should convert it first.

### Shared Weights Warning
When saving to `.safetensors`, ensure that shared weights (like `lm_head.weight` and `embed_tokens.weight`) are cloned into distinct tensors, as `safetensors` does not support memory sharing.

```python
# Quick conversion snippet
from transformers import AutoModelForCausalLM
from safetensors.torch import save_file

model = AutoModelForCausalLM.from_pretrained("your-model-id")
# Break shared weights
if model.lm_head.weight.data_ptr() == model.model.embed_tokens.weight.data_ptr():
    model.lm_head.weight = torch.nn.Parameter(model.lm_head.weight.clone())

save_file(model.state_dict(), "model.safetensors")
```

## 2. Fine-Tuning Workflow

The standard workflow involves four steps: Configuration, Initialization, Training, and Saving.

### Basic Example

```python
import torch
from lema import LemaConfig, LemaModel, LemaTrainer

# 1. Setup Config
config = LemaConfig(
    model_name_or_path="NousResearch/Llama-2-7b-hf",
    gbi_path="llama2_7b.safetensors",
    lora_rank=16,
    gradient_checkpointing=True
)

# 2. Initialize
model = LemaModel(config)
model.initialize_lora() # Crucial for new models

# 3. Training
optimizer = torch.optim.AdamW(model.get_trainable_parameters(), lr=1e-4)
trainer = model.get_trainer(optimizer)

for batch in dataloader:
    logits, loss = trainer.train_step(batch['input_ids'], labels=batch['labels'])
    print(f"Loss: {loss}")

# 4. Save
trainer.save_checkpoint("checkpoints/lema-llama-7b-v1")
```

## 3. Architecture Specifics

When using LEMA, ensure your `lora_target_modules` in `LemaConfig` match your model's architecture:
- **Llama**: `["q_proj", "v_proj", ...]` (Default)
- **GPT-2**: `["c_attn"]`

## 4. Memory Strategies

LEMA supports two primary strategies in `LemaConfig`:

- **`MemoryStrategy.STREAMING` (Default)**: 
    - **Path**: Disk -> Pinned RAM -> VRAM.
    - **Pros**: Lowest VRAM usage. Can fit models much larger than System RAM if needed (via `mmap`).
    - **Cons**: Higher latency due to PCIe/Disk bottleneck.
- **`MemoryStrategy.RESIDENT`**:
    - **Path**: RAM -> VRAM.
    - **Pros**: Faster than streaming. Model weights stay in RAM.
    - **Cons**: Requires enough System RAM to hold the full model weights (~14GB for a 7B FP16 model).

## 4. Tips for Maximum Efficiency

1. **Gradient Checkpointing**: Always enable `gradient_checkpointing=True` for 7B+ models. This significantly reduces VRAM usage during the backward pass by not storing intermediate activations.
2. **Pinned Memory**: LEMA automatically uses pinned memory for transfers. Ensure your system has sufficient RAM available for the staging buffers (~2x the size of the largest layer).
3. **NVMe Storage**: When using `STREAMING` mode, placing your `.safetensors` file on an NVMe SSD will greatly reduce the "Streaming Overhead".


### docs/API_REFERENCE.md

# LEMA API Reference

This document provides detailed information about the LEMA (Layer-wise Efficient Memory Abstraction) library API.

## Core API

### `LemaModel`
The primary entry point for the framework. It orchestrates memory management, adapters, and LoRA parameters.

#### `__init__(config: LemaConfig)`
Initializes the model using a `LemaConfig` object.

#### `get_trainer(optimizer: torch.optim.Optimizer)`
Returns a `LemaTrainer` instance pre-configured with this model's components and memory manager.

#### `initialize_lora()`
Pre-initializes all LoRA adapters. Must be called before `get_trainable_parameters()` for new models.

#### `get_trainable_parameters()`
Returns a list of all trainable parameters (LoRA weights) managed by the model.

#### `save_pretrained(save_directory: str)`
Saves the configuration and LoRA adapter weights.

#### `from_pretrained(path: str, **kwargs)` (Class Method)
Loads a LEMA model from a directory containing `lema_config.json` and `adapter_model.bin`.

---

### `LemaConfig`
Configuration dataclass for LEMA.

| Parameter | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| `model_name_or_path` | `str` | Required | HuggingFace ID or path to model directory. |
| `model_type` | `str` | `None` | `llama` or `gpt2`. Auto-detected if None. |
| `gbi_path` | `str` | `None` | Path to the `.safetensors` file. |
| `device` | `str` | `"cuda"` | Execution device. |
| `strategy` | `MemoryStrategy` | `STREAMING` | `STREAMING` or `RESIDENT`. |
| `save_steps` | `int` | `500` | Steps between automatic checkpoints. |
| `output_dir` | `str` | `"output"` | Directory for automatic checkpoints. |
| `lora_rank` | `int` | `16` | LoRA rank (r). |
| `lora_alpha` | `int` | `32` | LoRA alpha. |
| `learning_rate` | `float` | `1e-4` | Learning rate. |
| `gradient_checkpointing`| `bool` | `False` | Enable to save activation VRAM. |

---

### `LemaTrainer`
Orchestrates the training loop with layer-swapping logic.

#### `__init__(config, model_adapter, gbi, lora_manager=None, optimizer=None, memory_manager=None)`
Low-level constructor. Preferred usage is via `LemaModel.get_trainer()`.

#### `train_step(inputs: torch.Tensor, labels: torch.Tensor = None)`
Executes one forward and backward pass. Tracks `global_step` and triggers auto-checkpointing.
- Returns: `(logits, loss_value)`.

#### `save_checkpoint(save_directory: str)`
Saves the model state, configuration, and optimizer state.

### docs/BENCHMARK_RESULTS.md

# LEMA Benchmark Results (v0.7 - Release Candidate)

Benchmarks were performed on **Kaggle (Tesla P100 GPU, 16GB VRAM)**.
Comparisons were made between **Standard PEFT (LoRA)** and **LEMA (Streaming Strategy)**.

## 1. VRAM Usage (Memory Efficiency)

LEMA demonstrates significant VRAM savings, particularly for larger models where the overhead of optimization states and activations usually causes OOM errors.

![VRAM Benchmark](assets/vram_benchmark.png)

### Detailed Metrics

| Model | Parameters | Standard PEFT VRAM | LEMA VRAM | Savings |
| :--- | :--- | :--- | :--- | :--- |
| **GPT-2 (Small)** | 124M | 0.44 GB | 1.05 GB | N/A* |
| **TinyLlama** | 1.1B | 2.67 GB | **2.12 GB** | **20.5%** |
| **SmolLM2** | 1.7B | 3.88 GB | **3.20 GB** | **17.6%** |
| **Llama-2** | 7B | **13.99 GB** (Load Only)** | **5.90 GB** | **57.9%** |

*\*Note on GPT-2: For extremely small models, LEMA's fixed buffering overhead exceeds the model size. LEMA is optimized for Large-scale models.*
*\**Note on Llama-2 7B: Standard PEFT can load the model (13.99GB) but fails immediately with **Out-Of-Memory (OOM)** when attempting a training step due to gradients/activations. LEMA trains comfortably with >10GB headroom.*

---

## 2. Training Speed (Throughput)

LEMA trades execution speed for memory capability. The architecture involves moving weights from system RAM to VRAM for every layer, introducing latency.

![Speed Benchmark](assets/speed_benchmark.png)

### Detailed Metrics

| Model | PEFT Speed (s/step) | LEMA Speed (s/step) | Overhead Factor | Status |
| :--- | :--- | :--- | :--- | :--- |
| **TinyLlama 1.1B** | 0.46 s | 1.45 s | **3.1x** | Usable |
| **Llama-2 7B** | **FAILED (OOM)** | **7.21 s** | **N/A** | **Enabling** |

**Analysis**:
- For models that fit in VRAM (1.1B), LEMA introduces a ~3x overhead due to Python-based stream orchestration and PCIe transfer latency.
- For models that **do not fit** (7B on 16GB cards), LEMA provides infinite speedup by enabling training where it was previously impossible.

## 3. Configuration Used

- **LoRA Targets**: All linear layers (`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj` for Llama).
- **Sequence Length**: 512.
- **Precision**: FP16.
- **Gradient Checkpointing**: Enabled for 7B, Disabled for smaller models.


In [None]:
import sys, os
sys.path.append(os.path.abspath('src'))
!pip install -q safetensors accelerate peft transformers
print('LEMA Environment Fully Loaded.')

# RUN SPEED BENCHMARK

In [None]:
%run examples/kaggle/speed_benchmark.py

# RUN SMOL-LM FINE-TUNING

In [None]:
%run tasks/fine_tune_smollm.py

# RUN LLAMA-7B FINE-TUNING

In [None]:
%run tasks/fine_tune_llama_7b.py