In [None]:
!pip install torch datasets numpy matplotlib seaborn
!pip uninstall -y datasets fsspec huggingface_hub transformers tokenizers
!rm -rf ~/.cache/huggingface/datasets
!pip install datasets==2.14.7 fsspec==2023.10.0 huggingface_hub==0.17.3 transformers==4.35.2 tokenizers==0.15.0
!pip install torch-geometric

In [None]:
#print(torch.__version__)
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.6.0+cu124.html

<h1>Imports and Helper Functions</h1>


In [None]:
# --- Cell 1: Imports and Helper Functions ---
import torch
import torch.nn.functional as F
import json
import os
import random
import numpy as np
from transformers import AutoTokenizer
from typing import Optional, Dict, Any

# Assuming the following files are in the same directory as the notebook:
# gnn_moe_config.py
# gnn_moe_architecture.py

from gnn_moe_config import GNNMoEConfig
from gnn_moe_architecture import GNNMoEModel

def set_seed(seed: int):
    """Set random seeds for reproducibility."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    print(f"🎲 Random seed set to: {seed}")

def detect_device() -> torch.device:
    """Automatically detect the best available device."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"🚀 Using CUDA GPU: {torch.cuda.get_device_name()}")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        print("🚀 Using Apple MPS")
    else:
        device = torch.device("cpu")
        print("🚀 Using CPU")
    return device

def load_config_from_json(config_path: str) -> GNNMoEConfig:
    """Load GNNMoEConfig from a JSON file."""
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found: {config_path}")
    
    print(f"📋 Loading config from {config_path}")
    with open(config_path, 'r') as f:
        config_dict = json.load(f)
    
    config = GNNMoEConfig()
    for key, value in config_dict.items():
        if hasattr(config, key):
            setattr(config, key, value)
    
    config.__post_init__()
    print(f"✅ Config loaded successfully.")
    return config

def load_model_and_checkpoint(config: GNNMoEConfig, checkpoint_path: str, device: torch.device) -> GNNMoEModel:
    """Load the model and its trained weights from checkpoint."""
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
    
    print(f"🧠 Creating GNNMoEModel...")
    model = GNNMoEModel(config)
    
    print(f"🔄 Loading weights from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    state_dict = checkpoint.get('model_state_dict', checkpoint)
    model.load_state_dict(state_dict)
    
    model.to(device)
    model.eval()
    
    print("✅ Model loaded and in evaluation mode.")
    return model

def load_tokenizer(tokenizer_name: Optional[str] = 'gpt2') -> AutoTokenizer:
    """Load the tokenizer."""
    print(f"🔤 Loading tokenizer: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

def apply_sampling(logits: torch.Tensor, temperature: float = 1.0, top_k: int = 0, top_p: float = 0.0) -> torch.Tensor:
    """Apply temperature, top-k, and top-p sampling to logits."""
    if temperature != 1.0:
        logits = logits / temperature
    
    if top_k > 0:
        top_k = min(top_k, logits.size(-1))
        top_k_logits, top_k_indices = torch.topk(logits, top_k)
        logits_filtered = torch.full_like(logits, float('-inf'))
        logits_filtered.scatter_(-1, top_k_indices, top_k_logits)
        logits = logits_filtered
    
    if top_p > 0.0 and top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = float('-inf')
    
    return logits

def generate_text(
    model: GNNMoEModel,
    tokenizer: AutoTokenizer,
    prompt: str,
    device: torch.device,
    max_new_tokens: int = 100,
    temperature: float = 0.7,
    top_k: int = 50,
    top_p: float = 0.9
) -> str:
    """Generate text using autoregressive decoding."""
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    attention_mask = torch.ones_like(input_ids)
    generated_ids = input_ids
    
    with torch.no_grad():
        for _ in range(max_new_tokens):
            if generated_ids.shape[1] >= model.config.max_seq_length:
                break
            
            outputs = model(generated_ids, attention_mask=attention_mask)
            next_token_logits = outputs['logits'][:, -1, :]
            filtered_logits = apply_sampling(next_token_logits, temperature, top_k, top_p)
            
            next_token_id = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones_like(next_token_id)], dim=1)
            
            if next_token_id.item() == tokenizer.eos_token_id:
                break
                
    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

print("✅ All imports and helper functions are defined.")


<h1>Configuration and Model Loading</h1>

In [None]:
# --- Cell 2: Configuration and Model Loading ---

# --- Parameters to set ---
CHECKPOINT_PATH = "checkpoints/your_model_checkpoint.pth.tar"  # <-- ❗ UPDATE THIS PATH
CONFIG_PATH = "checkpoints/your_config.json"              # <-- ❗ UPDATE THIS PATH
SEED = 42
# -------------------------

set_seed(SEED)
device = detect_device()

try:
    # Load configuration and model
    config = load_config_from_json(CONFIG_PATH)
    model = load_model_and_checkpoint(config, CHECKPOINT_PATH, device)
    tokenizer = load_tokenizer()

    # Display model summary
    total_params = sum(p.numel() for p in model.parameters())
    print(f"\n📊 Model Summary:")
    print(f"   - Experts: {config.num_experts}")
    print(f"   - Embedding Dim: {config.embed_dim}")
    print(f"   - Model Layers: {config.num_layers}")
    print(f"   - Coupler Type: {getattr(config, 'coupler_type', 'GNN')}")
    print(f"   - Total Parameters: {total_params:,}")

except FileNotFoundError as e:
    print(f"❌ {e}")
    print("   Please update the CHECKPOINT_PATH and CONFIG_PATH variables in this cell.")



<h1>Text Generation</h1>

# --- Cell 3: Text Generation ---

# --- Generation Parameters ---
prompt_text = "The field of artificial intelligence is" # <-- ✍️ Your prompt here
max_new_tokens = 150
temperature = 0.75
top_k = 50
top_p = 0.95
# ---------------------------

if 'model' in locals() and 'tokenizer' in locals():
    print(f"💬 Generating text from prompt: '{prompt_text}'")
    print("-" * 50)
    
    generated_output = generate_text(
        model=model,
        tokenizer=tokenizer,
        prompt=prompt_text,
        device=device,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p
    )
    
    print("\n✨ Generated Text ✨")
    print("-" * 50)
    print(generated_output)
else:
    print("Model not loaded. Please run Cell 2 successfully first.")

