# LEMA Inference & Merging Demonstration
This notebook loads the fine-tuned LEMA model (adapter weights) and runs inference to verify the custom chat format.
It also demonstrates how to merge the adapter into the base model if needed.


In [None]:
!pip install -q transformers safetensors accelerate
# Clone LEMA repository
!git clone https://github.com/Pomilon/LEMA.git
!pip install -q -e LEMA/


In [None]:
import os

os.makedirs('inference/framework', exist_ok=True)
os.makedirs('inference/engines', exist_ok=True)
os.makedirs('inference', exist_ok=True)
os.makedirs('checkpoints/final', exist_ok=True)

# Create __init__.py
with open('inference/__init__.py', 'w') as f: pass
with open('inference/framework/__init__.py', 'w') as f: pass
with open('inference/engines/__init__.py', 'w') as f: pass


In [None]:
import os; os.makedirs('inference/framework', exist_ok=True)


In [None]:
%%writefile inference/framework/model_handler.py
"""
Handle loading and interaction with LEMA fine-tuned models.
"""

from lema import LemaModel, MemoryStrategy
from transformers import AutoTokenizer
import torch
import threading
from typing import Optional, List
import time

class LemaModelHandler:
    """Manages LEMA model loading and inference."""
    
    def __init__(self, checkpoint_path: str, device: str = "cuda"):
        """
        Load a fine-tuned LEMA model.
        
        Args:
            checkpoint_path: Path to LEMA checkpoint
            device: Device to run on (cuda/cpu)
        """
        self.checkpoint_path = checkpoint_path
        self.device = device
        
        print(f"Loading LEMA model from {checkpoint_path}...")
        # Load model using LEMA's API
        self.model = LemaModel.from_pretrained(checkpoint_path, device=device)
        # Ensure model components are on correct device
        self.model.to(device)
        self.model.initialize_lora()
        
        # Access internal components for manual forward pass
        self.memory = self.model.memory
        self.adapter = self.model.adapter
        self.layers = self.adapter.get_layer_metadata()
        self.lora_manager = self.model.lora_manager
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model.config.model_name_or_path
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        # Debug: Check LoRA weights
        lora_params = self.model.get_trainable_parameters()
        if lora_params:
            total_norm = sum(p.norm().item() for p in lora_params)
            print(f"Debug: Total LoRA weight norm: {total_norm:.4f}")
            if total_norm == 0:
                print("⚠️ Warning: LoRA weights are all zeros!")
        else:
            print("⚠️ Warning: No LoRA parameters found in model!")

        print("Model loaded successfully.")
    
    def generate(
        self, 
        prompt: str, 
        max_new_tokens: int = 128,
        temperature: float = 0.7,
        top_p: float = 0.9,
        do_sample: bool = True,
        stop_sequences: List[str] = ["[/LEMA_REPLY]"]
    ) -> str:
        """
        Generate text from prompt using LEMA's streaming architecture.
        """
        # Tokenize
        inputs = self.tokenizer(
            prompt, 
            return_tensors="pt", 
            truncation=True,
            max_length=512 # limit input context
        )
        input_ids = inputs["input_ids"].to(self.device)
        
        # Generate loop
        current_input_ids = input_ids
        
        print("Generating...", end="", flush=True)
        
        for i in range(max_new_tokens):
            with torch.no_grad():
                logits = self._forward_pass(current_input_ids)
            
            # Get last token logits
            next_token_logits = logits[0, -1, :]
            
            # Apply temperature
            if temperature > 0:
                next_token_logits = next_token_logits / temperature
            
            # Sample or greedy
            if do_sample and temperature > 0:
                # Top-p (nucleus) sampling
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
                    
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                    
                    indices_to_remove = sorted_indices[sorted_indices_to_remove]
                    next_token_logits[indices_to_remove] = -float('inf')
                
                probs = torch.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
            
            # Append
            current_input_ids = torch.cat([current_input_ids, next_token.unsqueeze(0)], dim=1)
            
            # Check EOS
            if next_token.item() == self.tokenizer.eos_token_id:
                break
                
            print(".", end="", flush=True)

            # Check stop sequences
            if stop_sequences:
                decoded_so_far = self.tokenizer.decode(current_input_ids[0, -20:], skip_special_tokens=False)
                if any(stop in decoded_so_far for stop in stop_sequences):
                    break
        
        print(" Done!")
        
        # Decode
        output_text = self.tokenizer.decode(current_input_ids[0], skip_special_tokens=False)
        return output_text
    
    def _forward_pass(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Custom forward pass logic for LEMA models.
        Replicates LemaTrainer.train_step logic but inference only.
        """
        is_streaming = (self.model.config.strategy == MemoryStrategy.STREAMING)
        hidden_states = input_ids # Start with embeddings usually handled by first layer or similar?
        # Wait, LemaModelAdapter usually handles embeddings in the first layer or separately.
        # In LemaTrainer: hidden_states = inputs (which are input_ids)
        # So the adapter handles input_ids -> embeddings.
        
        # Prefetch first layer
        if is_streaming:
            self.memory.prefetch_to_ram(self.layers[0]['id'], 0)
            self.memory.async_transfer_to_vram(self.layers[0]['id'], 0, ram_slot=0)
            if len(self.layers) > 1:
                self.memory.prefetch_to_ram(self.layers[1]['id'], 1)
        else:
            self.memory.async_transfer_to_vram(self.layers[0]['id'], 0)
            
        for i, layer_meta in enumerate(self.layers):
            slot = i % 2
            next_slot = (i + 1) % 2
            
            flat_vram = self.memory.get_vram_flat_buffer(slot)
            
            # Prefetch next layers
            disk_thread = None
            if i + 1 < len(self.layers):
                if is_streaming:
                    self.memory.async_transfer_to_vram(self.layers[i+1]['id'], next_slot, ram_slot=next_slot)
                    if i + 2 < len(self.layers):
                        disk_thread = threading.Thread(target=self.memory.prefetch_to_ram, args=(self.layers[i+2]['id'], slot))
                        disk_thread.start()
                else:
                    self.memory.async_transfer_to_vram(self.layers[i+1]['id'], next_slot)
            
            # Construct layer
            layer_module = self.adapter.construct_layer_module(layer_meta['id'], flat_vram, self.lora_manager)
            
            # Forward
            # Note: We disable gradient checkpointing for inference
            hidden_states = self.adapter.forward_layer(layer_module, hidden_states, gradient_checkpointing=False)
            
            if disk_thread: disk_thread.join()
            
            # Release layer (move to CPU/Disk if needed, or just free VRAM pointer)
            # LemaModelAdapter.release_layer_module handles cleanup
            if hasattr(self.adapter, "release_layer_module"):
                self.adapter.release_layer_module(layer_module)
            del layer_module
            
        return hidden_states



In [None]:
import os; os.makedirs('inference/framework', exist_ok=True)


In [None]:
%%writefile inference/framework/chat_parser.py
"""
Parse and validate LEMA custom chat format.
"""

import re
from typing import Optional, Dict
from dataclasses import dataclass

@dataclass
class LemaResponse:
    """Parsed LEMA response."""
    answer: str
    explanation: str
    confidence: str
    raw_text: str
    is_valid: bool

class ChatParser:
    """Parse LEMA custom chat format."""
    
    # Regex patterns for extracting fields
    LEMA_REPLY_PATTERN = r'\\[LEMA_REPLY\\](.*?)\\[/LEMA_REPLY\\]'
    ANSWER_PATTERN = r'Answer:\s*(.+?)(?=\n|Explanation:|Confidence:|$)'
    EXPLANATION_PATTERN = r'Explanation:\s*(.+?)(?=\n|Confidence:|$)'
    CONFIDENCE_PATTERN = r'Confidence:\s*(High|Medium|Low)'
    
    @classmethod
    def parse_response(cls, text: str) -> LemaResponse:
        """
        Parse a LEMA-formatted response.
        
        Args:
            text: Generated text that should contain [LEMA_REPLY] block
        
        Returns:
            LemaResponse with parsed fields and validation status
        """
        # Extract LEMA_REPLY block
        reply_match = re.search(cls.LEMA_REPLY_PATTERN, text, re.DOTALL)
        
        if not reply_match:
            return LemaResponse(
                answer="",
                explanation="",
                confidence="",
                raw_text=text,
                is_valid=False
            )
        
        reply_content = reply_match.group(1)
        
        # Extract fields
        answer_match = re.search(cls.ANSWER_PATTERN, reply_content, re.DOTALL)
        explanation_match = re.search(cls.EXPLANATION_PATTERN, reply_content, re.DOTALL)
        confidence_match = re.search(cls.CONFIDENCE_PATTERN, reply_content)
        
        answer = answer_match.group(1).strip() if answer_match else ""
        explanation = explanation_match.group(1).strip() if explanation_match else ""
        confidence = confidence_match.group(1).strip() if confidence_match else ""
        
        is_valid = bool(answer and explanation and confidence)
        
        return LemaResponse(
            answer=answer,
            explanation=explanation,
            confidence=confidence,
            raw_text=text,
            is_valid=is_valid
        )
    
    @classmethod
    def format_prompt(cls, user_message: str) -> str:
        """
        Format a user message into the LEMA chat template.
        
        Args:
            user_message: User's question/input
        
        Returns:
            Properly formatted prompt for the model
        """
        return f"""<|system|>
You are a precise assistant trained using LEMA.

<|user|>
{user_message}

<|assistant|>
[LEMA_REPLY]
Answer:"""



In [None]:
import os; os.makedirs('inference/framework', exist_ok=True)


In [None]:
%%writefile inference/framework/conversation.py
"""
Manage conversation state and history.
"""

from typing import List, Dict
from dataclasses import dataclass, field

@dataclass
class Message:
    """Single message in conversation."""
    role: str  # 'user' or 'assistant'
    content: str
    metadata: Dict = field(default_factory=dict)

class ConversationManager:
    """Manage conversation history and context."""
    
    def __init__(self, max_history: int = 10):
        """
        Initialize conversation manager.
        
        Args:
            max_history: Maximum number of turns to keep
        """
        self.max_history = max_history
        self.messages: List[Message] = []
    
    def add_user_message(self, content: str):
        """Add user message to history."""
        self.messages.append(Message(role='user', content=content))
        self._trim_history()
    
    def add_assistant_message(self, content: str, **metadata):
        """Add assistant message to history."""
        self.messages.append(Message(
            role='assistant', 
            content=content,
            metadata=metadata
        ))
        self._trim_history()
    
    def get_context(self, include_current: bool = True) -> str:
        """
        Build context string from conversation history.
        
        Args:
            include_current: Whether to include the most recent exchange
        
        Returns:
            Formatted conversation context
        """
        # For now, we just use the immediate question
        # You could extend this to include conversation history
        if not self.messages:
            return ""
        
        # Get last user message
        for msg in reversed(self.messages):
            if msg.role == 'user':
                return msg.content
        
        return ""
    
    def clear(self):
        """Clear conversation history."""
        self.messages.clear()
    
    def _trim_history(self):
        """Keep only max_history recent messages."""
        if len(self.messages) > self.max_history:
            self.messages = self.messages[-self.max_history:]



In [None]:
import os; os.makedirs('inference/engines', exist_ok=True)


In [None]:
%%writefile inference/engines/cli_engine.py
"""
Command-line interface for LEMA chatbot.
"""

import sys
import os
from pathlib import Path

# Add framework to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from inference.framework.model_handler import LemaModelHandler
from inference.framework.chat_parser import ChatParser
from inference.framework.conversation import ConversationManager

class CLIChatEngine:
    """Interactive CLI chat interface."""
    
    def __init__(self, checkpoint_path: str, device: str = "cuda"):
        """
        Initialize CLI chat engine.
        
        Args:
            checkpoint_path: Path to LEMA checkpoint
            device: Device to run on
        """
        print("Loading model...")
        self.model_handler = LemaModelHandler(checkpoint_path, device)
        self.conversation = ConversationManager()
        self.parser = ChatParser()
        print("✅ Model loaded!\n")
    
    def run(self):
        """Run interactive chat loop."""
        print("=" * 60)
        print("LEMA Chatbot - Custom Format Demonstration")
        print("=" * 60)
        print("\nCommands:")
        print("  'quit' or 'exit' - Exit the chat")
        print("  'clear' - Clear conversation history")
        print("  'debug' - Toggle debug mode")
        print("\n" + "=" * 60 + "\n")
        
        debug_mode = False
        
        while True:
            try:
                user_input = input("You: ").strip()
                
                if not user_input:
                    continue
                
                if user_input.lower() in ['quit', 'exit']:
                    print("\nGoodbye!")
                    break
                
                if user_input.lower() == 'clear':
                    self.conversation.clear()
                    print("\n🔄 Conversation cleared\n")
                    continue
                
                if user_input.lower() == 'debug':
                    debug_mode = not debug_mode
                    print(f"\n🐛 Debug mode: {'ON' if debug_mode else 'OFF'}\n")
                    continue
                
                # Add to conversation
                self.conversation.add_user_message(user_input)
                
                # Format prompt
                prompt = self.parser.format_prompt(user_input)
                
                if debug_mode:
                    print(f"\n[DEBUG] Prompt:\n{prompt}\n")
                
                # Generate response
                response_text = self.model_handler.generate(prompt)
                
                if debug_mode:
                    print(f"\n[DEBUG] Raw response:\n{response_text}\n")
                
                # Parse response
                parsed = self.parser.parse_response(response_text)
                
                if parsed.is_valid:
                    # Display parsed response
                    print(f"\nAssistant: {parsed.answer}")
                    print(f"💡 {parsed.explanation}")
                    print(f"📊 Confidence: {parsed.confidence}\n")
                    
                    # Add to conversation
                    self.conversation.add_assistant_message(
                        parsed.answer,
                        explanation=parsed.explanation,
                        confidence=parsed.confidence
                    )
                else:
                    # Model didn't follow format
                    print(f"\n⚠️  Model response didn't follow LEMA format:")
                    print(f"{response_text}\n")
                    print("This might indicate the model needs more training.\n")
            
            except KeyboardInterrupt:
                print("\n\nGoodbye!")
                break
            except Exception as e:
                print(f"\n❌ Error: {e}\n")
                if debug_mode:
                    import traceback
                    traceback.print_exc()

def main():
    """Main entry point."""
    import argparse
    
    parser = argparse.ArgumentParser(description="LEMA CLI Chat Interface")
    parser.add_argument(
        "checkpoint",
        type=str,
        help="Path to LEMA checkpoint directory"
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        help="Device to run on (cuda/cpu)"
    )
    
    args = parser.parse_args()
    
    engine = CLIChatEngine(args.checkpoint, args.device)
    engine.run()

if __name__ == "__main__":
    main()



In [None]:
import os; os.makedirs('tools', exist_ok=True)


In [None]:
%%writefile tools/merge_adapter.py
import os
import torch
import argparse
import gc
import psutil
import json
import shutil
from safetensors.torch import save_file
from transformers import AutoConfig, AutoTokenizer
from huggingface_hub import HfApi

# Adjust path to import lema
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from lema import LemaConfig, LemaModel

def get_ram_usage():
    return psutil.virtual_memory().used / 1e9

def get_disk_usage(path="."):
    total, used, free = shutil.disk_usage(path)
    return free / 1e9

def merge_adapter(checkpoint_dir: str, output_dir: str, base_model_path: str, repo_id: str = None, token: str = None):
    """
    Merges LEMA LoRA adapter into base model.
    If repo_id is provided, performs STREAMING UPLOAD:
    - Saves a shard
    - Uploads to HF
    - Deletes local shard
    This bypasses local disk limits.
    """
    print(f"[{get_ram_usage():.2f}GB RAM | {get_disk_usage():.2f}GB Disk] Loading LEMA config...")
    
    api = None
    if repo_id:
        if not token:
            print("❌ Repo ID provided but no token found.")
            return
        api = HfApi(token=token)
        print(f"🚀 Streaming Upload Enabled: Target -> {repo_id}")
    
    try:
        config = LemaConfig.from_pretrained(checkpoint_dir)
    except Exception as e:
        print(f"Error loading config: {e}")
        return

    if not os.path.exists(config.gbi_path) and not os.path.exists(base_model_path):
        if os.path.exists(base_model_path):
            config.gbi_path = base_model_path
        else:
            print("Base model not found.")
            return

    print(f"[{get_ram_usage():.2f}GB] Initializing LEMA model...")
    config.device = "cpu"
    model = LemaModel(config)
    model.adapter._max_pool_size = 1
    
    print(f"[{get_ram_usage():.2f}GB] Loading adapter weights...")
    model.lora_manager.load_pretrained(checkpoint_dir)
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Metadata for index.json
    weight_map = {}
    
    # Define Shards (4 layers per shard)
    layers = model.adapter.get_layer_metadata()
    block_layers = [l for l in layers if l['type'] == 'block']
    shard_size = 4
    
    # Calculate Total Shards for Naming
    # Embeddings (1) + Layers (32) + Head/Norm (1) = 34 "units"
    # Embeddings is processed alone -> Shard 1
    # Layers (32) / 4 = 8 Shards
    # Head/Norm -> Final Shard
    # Total ~10 shards? Let's keep it dynamic but we need total count for proper naming "00001-of-XXXXX"
    # Actually, safetensors naming convention "model-00001-of-00005.safetensors" assumes we know total at start.
    # Let's pre-calculate.
    # 1 (Emb) + 8 (Layers) + 1 (Head) = 10 shards.
    
    total_shards = 1 + (len(block_layers) // shard_size) + 1
    if len(block_layers) % shard_size != 0: total_shards += 1 # Remainder
    
    current_shard_idx = 1
    current_shard_weights = {}
    
    def save_and_upload_shard():
        nonlocal current_shard_idx, current_shard_weights
        if not current_shard_weights: return
        
        # Proper naming from the start
        filename = f"model-{current_shard_idx:05d}-of-{total_shards:05d}.safetensors"
        filepath = os.path.join(output_dir, filename)
        
        print(f"[{get_ram_usage():.2f}GB RAM | {get_disk_usage():.2f}GB Disk] Saving {filename}...")
        save_file(current_shard_weights, filepath)
        
        # Update map
        for k in current_shard_weights.keys():
            weight_map[k] = filename
            
        # Clear memory
        current_shard_weights.clear()
        current_shard_idx += 1
        gc.collect()
        
        # UPLOAD AND DELETE
        if api:
            print(f"⬆️ Uploading {filename}...")
            try:
                api.upload_file(
                    path_or_fileobj=filepath,
                    path_in_repo=filename,
                    repo_id=repo_id,
                    repo_type="model",
                    commit_message=f"Upload shard {current_shard_idx-1}/{total_shards}"
                )
                print(f"✅ Uploaded. Deleting local file to save space.")
                os.remove(filepath)
            except Exception as e:
                print(f"❌ Upload failed for {filename}: {e}")
                # Don't delete if upload failed, so user can manually recover if space allows
        else:
            print(f"💾 Saved locally.")

    # --- 1. Embeddings ---
    print(f"Processing Embeddings (Shard {current_shard_idx})...")
    emb_name = model.adapter.get_param_names_for_layer(0)[0]
    current_shard_weights["model.embed_tokens.weight"] = model.memory.gbi.handle.get_tensor(emb_name).clone().to(dtype=torch.float16)
    save_and_upload_shard() # Save embeddings as shard 1

    # --- 2. Transformer Layers ---
    for i, layer_meta in enumerate(block_layers):
        idx = layer_meta['block_index']
        
        if idx % 5 == 0:
            print(f"[{get_ram_usage():.2f}GB] Merging Layer {idx}...")
        
        # Load & Merge
        model.memory.prefetch_to_ram(layer_meta['id'], 0)
        flat_buffer = model.memory.ram_buffers[0]
        module = model.adapter.construct_layer_module(layer_meta['id'], flat_buffer, model.lora_manager)
        
        for _, child in module.named_modules():
            if hasattr(child, "lora_A") and hasattr(child, "base_layer"):
                scale = child.scaling
                delta = (child.lora_B.data @ child.lora_A.data) * scale
                child.base_layer.weight.data += delta.to(child.base_layer.weight.dtype)
        
        # Extract
        prefix = f"model.layers.{idx}."
        names = model.adapter.get_param_names_for_layer(layer_meta['id'])
        module_params = dict(module.named_parameters())
        
        for full_name in names:
            clean_k = full_name[len(prefix):]
            if clean_k not in module_params:
                clean_k = clean_k.replace(".weight", ".base_layer.weight")
            
            # Store in state dict (clone to detach from LEMA's reusable buffer)
            # CAST TO FP16 to save space (Standard Llama is FP16/BF16)
            current_shard_weights[full_name] = module_params[clean_k].data.clone().to(dtype=torch.float16).cpu()
            
        del module
        model.adapter.layer_pool.clear()
        gc.collect()
        
        # Check if shard is full
        if (i + 1) % shard_size == 0:
            save_and_upload_shard()

    # Save any remaining layers in buffer
    if current_shard_weights:
        save_and_upload_shard()

    # --- 3. Head / Norm ---
    print(f"Processing Head & Norm (Shard {current_shard_idx})...")
    last_layer_id = layers[-1]['id']
    model.memory.prefetch_to_ram(last_layer_id, 0)
    head_buffer = model.memory.ram_buffers[0]
    head_module = model.adapter.construct_layer_module(last_layer_id, head_buffer, model.lora_manager)
    
    current_shard_weights["model.norm.weight"] = head_module.norm.weight.data.clone().to(dtype=torch.float16)
    current_shard_weights["lm_head.weight"] = head_module.lm_head.weight.data.clone().to(dtype=torch.float16)
    
    del head_module
    gc.collect()
    
    # Save final shard
    save_and_upload_shard()
    
    # --- 4. Save Index & Configs ---
    print("Saving index and configs...")
    
    index_data = {"metadata": {}, "weight_map": weight_map}
    index_path = os.path.join(output_dir, "model.safetensors.index.json")
    
    with open(index_path, "w") as f:
        json.dump(index_data, f, indent=2)
    
    # Auxiliary files
    config_path = os.path.join(output_dir, "config.json")
    AutoConfig.from_pretrained(config.model_name_or_path).save_pretrained(output_dir)
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
        tokenizer.save_pretrained(output_dir)
    except Exception as e:
        print(f"Warning: {e}")

    if api:
        print("⬆️ Uploading index and configs...")
        files_to_upload = [
            "model.safetensors.index.json", "config.json", "generation_config.json",
            "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "tokenizer.model"
        ]
        
        for fname in files_to_upload:
            fpath = os.path.join(output_dir, fname)
            if os.path.exists(fpath):
                try:
                    api.upload_file(
                        path_or_fileobj=fpath,
                        path_in_repo=fname,
                        repo_id=repo_id,
                        repo_type="model",
                        commit_message="Upload config/index"
                    )
                except Exception as e:
                    print(f"Failed to upload {fname}: {e}")
        print("✅ Streaming Upload Complete!")
    else:
        print("✅ Local Merge Complete.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", type=str, required=True)
    parser.add_argument("--output", type=str, required=True)
    parser.add_argument("--base_model", type=str, default="llama2_7b.safetensors")
    parser.add_argument("--repo_id", type=str, default=None, help="HF Repo ID for streaming upload")
    parser.add_argument("--token", type=str, default=None, help="HF Token")
    
    args = parser.parse_args()
    merge_adapter(args.checkpoint, args.output, args.base_model, args.repo_id, args.token)



## Upload Your Weights
1. Create a Kaggle Dataset containing your `adapter_model.bin` and `lema_config.json`.
2. Add the dataset to this notebook.
3. Copy the files to `checkpoints/final/` below.

For this demo, we assume the dataset is mounted at `/kaggle/input/lema-finetuned-weights/`.


In [None]:
# Example copy command (adjust path to your dataset)
# !cp /kaggle/input/lema-finetuned-weights/* checkpoints/final/

# Verify files
!ls -l checkpoints/final/


## Prepare Base Model
We need the monolithic `.safetensors` file for LEMA to function.


In [None]:
import sys
import os
sys.path.append(os.path.abspath('LEMA/src'))

from lema.utils.model_utils import prepare_monolithic_safetensors

MODEL_NAME = 'NousResearch/Llama-2-7b-hf'
MODEL_PATH = 'llama2_7b.safetensors'

if not os.path.exists(MODEL_PATH):
    print(f'Preparing {MODEL_PATH}...')
    prepare_monolithic_safetensors(MODEL_NAME, MODEL_PATH, device='auto')


## Run Inference


In [None]:

import sys
import torch
from inference.framework.model_handler import LemaModelHandler
from inference.framework.chat_parser import ChatParser

# Setup
checkpoint_path = "checkpoints/final"
handler = LemaModelHandler(checkpoint_path, device="cuda")
parser = ChatParser()

# Test Prompts
questions = [
    "What is LEMA?",
    "Who invented the telephone?",
    "What is photosynthesis?"
]

print("-" * 60)
for q in questions:
    print(f"
User: {q}")
    prompt = parser.format_prompt(q)
    
    # Generate
    response_text = handler.generate(prompt, max_new_tokens=128)
    
    print(f"
Raw Output:
{response_text}")
    
    # Parse
    parsed = parser.parse_response(response_text)
    if parsed.is_valid:
        print(f"
✅ Valid LEMA Format!")
        print(f"Answer: {parsed.answer}")
        print(f"Confidence: {parsed.confidence}")
    else:
        print(f"
❌ Invalid Format")

print("-" * 60)



## Merge Adapter (Export)
Convert the LEMA adapter + Base Model into a standard HuggingFace model.


In [None]:

# Merge Adapter and Stream Upload to Hugging Face
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login

# Authentication
try:
    user_secrets = UserSecretsClient()
    hf_token = user_secrets.get_secret("HF_TOKEN")
    login(token=hf_token)
    print("✅ Logged in via Kaggle Secrets")
except:
    hf_token = input("Enter HF Token (Write):")
    login(token=hf_token)

REPO_ID = "YOUR-HF-USERNAME/LEMA-llama-2-7b" # Change this to your repo

!python tools/merge_adapter.py \
    --checkpoint checkpoints/final \
    --output merged_model \
    --base_model llama2_7b.safetensors \
    --repo_id {REPO_ID} \
    --token {hf_token}

print("\n✅ Streaming Merge & Upload Complete!")

