In [1]:
import sys
from pathlib import Path
import logging

# --- Setup Logging and Paths ---
logging.basicConfig(level=logging.INFO, filename='notebook.log', filemode='w')
logger = logging.getLogger(__name__)

project_root = Path('.').resolve()
src_path = project_root / 'layered-context-graph' / 'src'
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))
logger.info(f"Project root set to: {project_root}")

In [2]:
from models.qwq_model import QwQModel
import torch

In [4]:
# --- Cell 3: Model Loading ---
MODEL_PATH = './QwQ_LCoT_7B_Instruct'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

preloaded_qwq_model = None
try:
    preloaded_qwq_model = QwQModel(MODEL_PATH, device)
    logger.info("QwQModel pre-loaded successfully.")
    print("QwQModel pre-loaded successfully.")
except Exception as e:
    logger.error(f"Error pre-loading QwQModel: {e}", exc_info=True)
    print(f"Error pre-loading QwQModel: {e}")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

: 

In [1]:
from accelerate import init_empty_weights
import torch

# This is what init_empty_weights() does internally:
# def init_empty_weights():
#     old_register_parameter = torch.nn.Module.register_parameter
    
#     def register_empty_parameter(module, name, param):
#         # Create parameter on meta device instead of default device
#         if param is not None:
#             param = param.to(torch.device("meta"))
#         old_register_parameter(module, name, param)
    
#     # Temporarily replace the registration function
#     torch.nn.Module.register_parameter = register_empty_parameter
    
#     # Also set default device to meta
#     with torch.device("meta"):
#         yield

from pathlib import Path
from transformers import AutoConfig, AutoModelForCausalLM
from safetensors import safe_open
import json

class ShardedSafeTensorsLoader:
    def __init__(self, model_dir):
        self.model_dir = Path(model_dir)
        with open(self.model_dir / "model.safetensors.index.json") as f:
            self.index = json.load(f)
    
    def load_to_gpu(self, model):
        """Load sharded SafeTensors directly to GPU"""
        weight_map = self.index["weight_map"]
        
        # Group by shard
        shards = {}
        for param_name, shard_file in weight_map.items():
            if shard_file not in shards:
                shards[shard_file] = []
            shards[shard_file].append(param_name)
        
        # Load each shard
        for shard_file, param_names in shards.items():
            shard_path = self.model_dir / shard_file
            
            # Open SafeTensors file with direct GPU loading
            with safe_open(shard_path, framework="pt", device="cuda:0") as f:
                for param_name in param_names:
                    if param_name in f.keys():
                        # Direct GPU load - no RAM used!
                        tensor = f.get_tensor(param_name)
                        self._assign_param(model, param_name, tensor)
            # Clean up CUDA cache between shards
            torch.cuda.empty_cache()
        return model
            

    
    def _assign_param(self, model, param_name, tensor):
        """Assign parameter to model"""
        keys = param_name.split('.')
        ptr = model
        for key in keys[:-1]:
            ptr = getattr(ptr, key)
        setattr(ptr, keys[-1], torch.nn.Parameter(tensor))
            
def load_sharded_model(model_dir: str):
    # Step 1: Load the config
    config = AutoConfig.from_pretrained(model_dir)
    
    # Step 2: Create model on meta device (no memory used)
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)
    
    # Verify it's on meta device
    print(f"Model device: {next(model.parameters()).device}")  # Should print "meta"
    
    # Step 3: Load weights directly to GPU
    loader = ShardedSafeTensorsLoader(model_dir)
    model = loader.load_to_gpu(model)
    
    # Now model is on GPU
    print(f"Model device after loading: {next(model.parameters()).device}")  # Should print "cuda:0"
    
    return model

# Usage example:
model = load_sharded_model("/workspaces/layer_context_seg/QwQ_LCoT_7B_Instruct")

Model device: meta
Model device after loading: cuda:0


In [15]:
# --- Cell 4: Test Text Generation ---
if preloaded_qwq_model:
    try:
        prompt = "Once upon a time,"
        generated_text = preloaded_qwq_model.generate(prompt)
        logger.info(f"Generated text: {generated_text}")
        print(f"Generated text: {generated_text}")
    except Exception as e:
        logger.error(f"Error during text generation test: {e}", exc_info=True)
        print(f"Error during text generation test: {e}")