In [1]:
# miStudio Complete Workflow: Phi-4 SAE Training and Analysis
# 
# This notebook demonstrates the complete miStudio interpretability workflow:
# 1. Train a Sparse Autoencoder (SAE) on Microsoft Phi-4 at layer 30
# 2. Find and analyze features using miStudioFind
# 3. Generate explanations using miStudioExplain  
# 4. Score feature importance using miStudioScore
#
# Prerequisites:
# - All miStudio services running on their designated ports
# - GPU with sufficient memory (>= 16GB recommended for Phi-4)
# - HuggingFace token for accessing gated models (if needed)

import requests
import json
import time
import os
import sys
import pandas as pd
from pathlib import Path
from typing import Dict, Any, Optional
import zipfile
import urllib.request
from datetime import datetime

# Configuration
SERVICE_PORTS = {
    'train': 8001,
    'find': 8002, 
    'explain': 8003,
    'score': 8004
}

BASE_URL = "http://localhost"
SERVICE_URLS = {
    service: f"{BASE_URL}:{port}" 
    for service, port in SERVICE_PORTS.items()
}

# Model and training configuration
MODEL_NAME = "microsoft/Phi-4"
TARGET_LAYER = 30  # Layer 30 for Phi-4
WEBTEXT_CORPUS_URL = "https://huggingface.co/datasets/stas/openwebtext-10k/resolve/main/plain_text/train-00000-of-00001.parquet"
CORPUS_FILENAME = "webtext_corpus.txt"

# Output directories
OUTPUT_DIR = Path("./mistudio_phi4_results")
OUTPUT_DIR.mkdir(exist_ok=True)

# Global variables to store results between steps
training_job_id = None
find_job_id = None
explain_job_id = None
score_job_id = None

print("🚀 miStudio Complete Workflow - Phi-4 SAE Training and Analysis")
print("=" * 70)
print(f"Model: {MODEL_NAME}")
print(f"Target Layer: {TARGET_LAYER}")
print(f"Output Directory: {OUTPUT_DIR.absolute()}")
print(f"Service URLs configured:")
for service, url in SERVICE_URLS.items():
    print(f"  - {service.capitalize()}: {url}")
print("=" * 70)

# Utility functions
def check_service_health(service_name: str, url: str) -> bool:
    """Check if a service is healthy and responsive."""
    try:
        response = requests.get(f"{url}/health", timeout=10)
        if response.status_code == 200:
            health_data = response.json()
            print(f"✅ {service_name.capitalize()} service: {health_data.get('status', 'healthy')}")
            return True
        else:
            print(f"❌ {service_name.capitalize()} service: HTTP {response.status_code}")
            return False
    except requests.RequestException as e:
        print(f"❌ {service_name.capitalize()} service: Connection failed - {e}")
        return False

def wait_for_job_completion(service_url: str, job_id: str, service_name: str, 
                          max_wait_minutes: int = 120) -> Dict[str, Any]:
    """Wait for a job to complete and return the final status."""
    start_time = time.time()
    max_wait_seconds = max_wait_minutes * 60
    
    print(f"⏳ Waiting for {service_name} job {job_id} to complete...")
    
    while time.time() - start_time < max_wait_seconds:
        try:
            response = requests.get(f"{service_url}/api/v1/{service_name.lower()}/{job_id}/status")
            if response.status_code == 200:
                status_data = response.json()
                status = status_data.get('status')
                
                if status == 'completed':
                    print(f"✅ {service_name} job completed successfully!")
                    return status_data
                elif status == 'failed':
                    print(f"❌ {service_name} job failed!")
                    print(f"Error: {status_data.get('error', 'Unknown error')}")
                    return status_data
                elif status in ['running', 'queued']:
                    progress = status_data.get('progress', {})
                    if 'percentage' in progress:
                        print(f"🔄 {service_name} progress: {progress['percentage']:.1f}% - {progress.get('message', '')}")
                    else:
                        print(f"🔄 {service_name} status: {status}")
                    time.sleep(30)  # Check every 30 seconds
                else:
                    print(f"🔄 {service_name} status: {status}")
                    time.sleep(30)
            else:
                print(f"⚠️ Failed to get {service_name} status: HTTP {response.status_code}")
                time.sleep(30)
        except requests.RequestException as e:
            print(f"⚠️ Error checking {service_name} status: {e}")
            time.sleep(30)
    
    print(f"⏰ Timeout waiting for {service_name} job to complete after {max_wait_minutes} minutes")
    return {'status': 'timeout'}

def download_webtext_corpus() -> str:
    """Download and prepare the webtext corpus for training."""
    corpus_path = OUTPUT_DIR / CORPUS_FILENAME
    
    if corpus_path.exists():
        print(f"📁 Webtext corpus already exists: {corpus_path}")
        return str(corpus_path)
    
    print("📥 Downloading webtext corpus...")
    try:
        # Download the parquet file
        parquet_path = OUTPUT_DIR / "webtext.parquet"
        urllib.request.urlretrieve(WEBTEXT_CORPUS_URL, parquet_path)
        
        # Convert parquet to text format
        import pandas as pd
        df = pd.read_parquet(parquet_path)
        
        # Extract text content and create a plain text file
        with open(corpus_path, 'w', encoding='utf-8') as f:
            for text in df['text'].head(1000):  # Use first 1000 samples for demo
                if isinstance(text, str) and len(text.strip()) > 50:
                    # Clean and write text
                    clean_text = text.strip().replace('\n', ' ').replace('\r', ' ')
                    f.write(clean_text + '\n')
        
        # Clean up parquet file
        parquet_path.unlink()
        
        print(f"✅ Webtext corpus prepared: {corpus_path}")
        print(f"📊 Corpus size: {corpus_path.stat().st_size / 1024 / 1024:.1f} MB")
        
        return str(corpus_path)
        
    except Exception as e:
        print(f"❌ Failed to download webtext corpus: {e}")
        # Create a fallback corpus
        print("📝 Creating fallback corpus...")
        fallback_texts = [
            "The quick brown fox jumps over the lazy dog.",
            "Artificial intelligence is transforming how we work and live.",
            "Machine learning models require large datasets for training.",
            "Natural language processing helps computers understand text.",
            "Deep learning uses neural networks with multiple layers.",
            "Transformers have revolutionized natural language understanding.",
            "Sparse autoencoders help us understand what neural networks learn.",
            "Feature interpretability is crucial for AI safety and alignment.",
        ] * 100  # Repeat to create more content
        
        with open(corpus_path, 'w', encoding='utf-8') as f:
            for text in fallback_texts:
                f.write(text + '\n')
        
        print(f"✅ Fallback corpus created: {corpus_path}")
        return str(corpus_path)


🚀 miStudio Complete Workflow - Phi-4 SAE Training and Analysis
Model: microsoft/Phi-4
Target Layer: 30
Output Directory: /home/sean/app/miStudio/mistudio_phi4_results
Service URLs configured:
  - Train: http://localhost:8001
  - Find: http://localhost:8002
  - Explain: http://localhost:8003
  - Score: http://localhost:8004


In [2]:
# Step 0: Health Checks
print("\n🏥 STEP 0: SERVICE HEALTH CHECKS")
print("-" * 50)

all_services_healthy = True
for service_name, service_url in SERVICE_URLS.items():
    if not check_service_health(service_name, service_url):
        all_services_healthy = False

if not all_services_healthy:
    print("\n❌ Some services are not healthy. Please ensure all miStudio services are running.")
    print("Expected services:")
    for service, port in SERVICE_PORTS.items():
        print(f"  - miStudio{service.capitalize()}: port {port}")
    sys.exit(1)

print("\n✅ All services are healthy and ready!")




🏥 STEP 0: SERVICE HEALTH CHECKS
--------------------------------------------------
✅ Train service: healthy
❌ Find service: Connection failed - HTTPConnectionPool(host='localhost', port=8002): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7e8ca6191550>: Failed to establish a new connection: [Errno 111] Connection refused'))
❌ Explain service: Connection failed - HTTPConnectionPool(host='localhost', port=8003): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7e8c8fb9f920>: Failed to establish a new connection: [Errno 111] Connection refused'))
❌ Score service: Connection failed - HTTPConnectionPool(host='localhost', port=8004): Max retries exceeded with url: /health (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x7e8c8fbd01a0>: Failed to establish a new connection: [Errno 111] Connection refused'))

❌ Some services are

SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [3]:
# Step 1: Download and prepare corpus
print("\n📥 STEP 1: CORPUS PREPARATION")
print("-" * 50)

corpus_path = download_webtext_corpus()

# Upload corpus to training service
print("📤 Uploading corpus to training service...")

# First, check if we can list existing files (this will help debug permissions)
try:
    list_response = requests.get(f"{SERVICE_URLS['train']}/api/v1/files")
    if list_response.status_code == 200:
        files_info = list_response.json()
        print(f"📁 Current files in samples directory: {len(files_info.get('files', []))}")
    else:
        print(f"⚠️ Cannot list files: HTTP {list_response.status_code}")
        print("This might indicate a permission or directory issue")
except Exception as e:
    print(f"⚠️ Error checking files: {e}")

# Try to upload the corpus
try:
    with open(corpus_path, 'rb') as f:
        files = {'file': (CORPUS_FILENAME, f, 'text/plain')}
        response = requests.post(f"{SERVICE_URLS['train']}/api/v1/upload", files=files)
        
    if response.status_code == 200:
        upload_result = response.json()
        print(f"✅ Corpus uploaded successfully: {upload_result.get('filename', 'webtext_corpus.txt')}")
        
        # Handle different possible response formats
        size_bytes = upload_result.get('size_bytes') or upload_result.get('file_size_bytes')
        if size_bytes:
            print(f"📊 Size: {size_bytes / 1024 / 1024:.1f} MB")
        
        lines_count = upload_result.get('lines_count') or upload_result.get('estimated_lines')
        if lines_count:
            print(f"📄 Lines: {lines_count}")
        
        print(f"📋 Upload Response: {upload_result}")
    else:
        print(f"❌ Failed to upload corpus: HTTP {response.status_code}")
        error_response = response.json() if response.headers.get('content-type', '').startswith('application/json') else response.text
        print(f"Error details: {error_response}")
        
        # If it's a permission error, provide helpful guidance
        if "Permission denied" in str(error_response):
            print("\n🔧 PERMISSION FIX NEEDED:")
            print("The miStudioTrain service cannot write to /data/samples directory.")
            print("Please run one of these commands on your server:")
            print("  sudo mkdir -p /data/samples")
            print("  sudo chmod 777 /data/samples")
            print("OR")
            print("  sudo chown -R $USER:$USER /data")
            print("\nAlternatively, check your Docker/Kubernetes volume mounts.")
            
        sys.exit(1)
        
except Exception as e:
    print(f"❌ Error uploading corpus: {e}")
    sys.exit(1)


📥 STEP 1: CORPUS PREPARATION
--------------------------------------------------
📁 Webtext corpus already exists: mistudio_phi4_results/webtext_corpus.txt
📤 Uploading corpus to training service...
📁 Current files in samples directory: 0
✅ Corpus uploaded successfully: webtext_corpus.txt
📊 Size: 0.0 MB
📄 Lines: 801
📋 Upload Response: {'status': 'success', 'message': 'File uploaded successfully', 'filename': 'webtext_corpus.txt', 'file_path': '/data/samples/webtext_corpus.txt', 'file_size_bytes': 48300, 'estimated_lines': 801, 'ready_for_training': True}


In [4]:
# Step 2: Train SAE
print("\n🏋️ STEP 2: SAE TRAINING")
print("-" * 50)

train_request = {
    "model_name": MODEL_NAME,
    "corpus_file": CORPUS_FILENAME,
    "layer_number": TARGET_LAYER,
    "hidden_dim": 1024,
    "sparsity_coeff": 1e-3,
    "learning_rate": 1e-4,
    "batch_size": 8,  # Conservative batch size for Phi-4
    "max_epochs": 20,
    "min_loss": 0.01,
    "max_sequence_length": 512,
    "gpu_id": 0  # Use first GPU
}

print(f"🚀 Starting SAE training job...")
print(f"📋 Configuration:")
print(f"  - Model: {MODEL_NAME}")
print(f"  - Target Layer: {TARGET_LAYER}")
print(f"  - Hidden Dimensions: {train_request['hidden_dim']}")
print(f"  - Batch Size: {train_request['batch_size']}")
print(f"  - Max Epochs: {train_request['max_epochs']}")

try:
    response = requests.post(f"{SERVICE_URLS['train']}/api/v1/train", json=train_request)
    
    if response.status_code in [200, 202]:  # Accept both 200 and 202 as success
        train_result = response.json()
        training_job_id = train_result['job_id']
        print(f"✅ Training job started: {training_job_id}")
        print(f"📊 Job Details:")
        print(f"  - Status: {train_result.get('status')}")
        print(f"  - Model: {train_result.get('model_name')}")
        print(f"  - Memory Check: {train_result.get('memory_check')}")
        print(f"  - Optimizations: {train_result.get('optimizations_applied')}")
        
        # Wait for training completion
        final_status = wait_for_job_completion(
            SERVICE_URLS['train'], training_job_id, 'Train', max_wait_minutes=180
        )
        
        if final_status.get('status') == 'completed':
            print(f"🎯 Training completed successfully!")
            
            # Get training results
            result_response = requests.get(f"{SERVICE_URLS['train']}/api/v1/train/{training_job_id}/result")
            if result_response.status_code == 200:
                training_results = result_response.json()
                print(f"📊 Training Results:")
                print(f"  - Final Loss: {training_results.get('final_loss', 'N/A')}")
                print(f"  - Epochs Completed: {training_results.get('epochs_completed', 'N/A')}")
                print(f"  - Output Directory: {training_results.get('output_dir', 'N/A')}")
            else:
                print(f"⚠️ Could not retrieve training results")
        else:
            print(f"❌ Training failed or timed out")
            sys.exit(1)
            
    else:
        print(f"❌ Failed to start training: HTTP {response.status_code}")
        print(response.text)
        sys.exit(1)
        
except Exception as e:
    print(f"❌ Error during training: {e}")
    sys.exit(1)


🏋️ STEP 2: SAE TRAINING
--------------------------------------------------
🚀 Starting SAE training job...
📋 Configuration:
  - Model: microsoft/Phi-4
  - Target Layer: 30
  - Hidden Dimensions: 1024
  - Batch Size: 8
  - Max Epochs: 20
✅ Training job started: train_20250730_010840_8341
📊 Job Details:
  - Status: queued
  - Model: microsoft/Phi-4
  - Memory Check: passed
  - Optimizations: Applied optimizations for microsoft/Phi-4: True
⏳ Waiting for Train job train_20250730_010840_8341 to complete...
✅ Train job completed successfully!
🎯 Training completed successfully!
📊 Training Results:
  - Final Loss: N/A
  - Epochs Completed: N/A
  - Output Directory: N/A


In [None]:
# Step 3: Feature Analysis with miStudioFind
print("\n🔍 STEP 3: FEATURE ANALYSIS")
print("-" * 50)

find_request = {
    "source_job_id": training_job_id,
    "top_k": 20,
    "coherence_threshold": 0.5,
    "include_statistics": True
}

print(f"🔎 Starting feature analysis...")
print(f"📋 Configuration:")
print(f"  - Source Job: {training_job_id}")
print(f"  - Top K Activations: {find_request['top_k']}")
print(f"  - Coherence Threshold: {find_request['coherence_threshold']}")

try:
    response = requests.post(f"{SERVICE_URLS['find']}/api/v1/find/start", json=find_request)
    
    if response.status_code == 202:
        find_result = response.json()
        find_job_id = find_result['job_id']
        print(f"✅ Feature analysis job started: {find_job_id}")
        
        # Wait for analysis completion
        final_status = wait_for_job_completion(
            SERVICE_URLS['find'], find_job_id, 'Find', max_wait_minutes=30
        )
        
        if final_status.get('status') == 'completed':
            print(f"🎯 Feature analysis completed!")
            
            # Get analysis results
            result_response = requests.get(f"{SERVICE_URLS['find']}/api/v1/find/{find_job_id}/results")
            if result_response.status_code == 200:
                analysis_results = result_response.json()
                print(f"📊 Analysis Results:")
                print(f"  - Features Analyzed: {analysis_results.get('total_features', 'N/A')}")
                print(f"  - High Quality Features: {analysis_results.get('high_quality_count', 'N/A')}")
                print(f"  - Medium Quality Features: {analysis_results.get('medium_quality_count', 'N/A')}")
                print(f"  - Processing Time: {analysis_results.get('processing_time_seconds', 'N/A')}s")
            else:
                print(f"⚠️ Could not retrieve analysis results")
        else:
            print(f"❌ Feature analysis failed or timed out")
            sys.exit(1)
            
    else:
        print(f"❌ Failed to start feature analysis: HTTP {response.status_code}")
        print(response.text)
        sys.exit(1)
        
except Exception as e:
    print(f"❌ Error during feature analysis: {e}")
    sys.exit(1)



In [None]:
# Step 4: Generate Explanations with miStudioExplain
print("\n💡 STEP 4: EXPLANATION GENERATION")
print("-" * 50)

explain_request = {
    "request_id": f"phi4_explain_{int(time.time())}",
    "analysis_type": "complex_behavioral",
    "complexity": "medium",
    "model": "llama3.1:8b",  # Specify local LLM model
    "input_data": {
        "find_job_id": find_job_id,
        "feature_analysis": {},  # Will be populated by service
        "summary_report": f"Feature analysis results from job {find_job_id}"
    }
}

print(f"💬 Starting explanation generation...")
print(f"📋 Configuration:")
print(f"  - Source Analysis: {find_job_id}")
print(f"  - Analysis Type: {explain_request['analysis_type']}")
print(f"  - Complexity: {explain_request['complexity']}")
print(f"  - LLM Model: {explain_request['model']}")

try:
    response = requests.post(f"{SERVICE_URLS['explain']}/api/v1/explain", json=explain_request)
    
    if response.status_code == 202:
        explain_result = response.json()
        explain_job_id = explain_result['job_id']
        print(f"✅ Explanation generation job started: {explain_job_id}")
        
        # Wait for explanation completion
        final_status = wait_for_job_completion(
            SERVICE_URLS['explain'], explain_job_id, 'Explain', max_wait_minutes=60
        )
        
        if final_status.get('status') == 'completed':
            print(f"🎯 Explanation generation completed!")
            
            # Get explanation results
            result_response = requests.get(f"{SERVICE_URLS['explain']}/api/v1/explain/{explain_job_id}/results")
            if result_response.status_code == 200:
                explanation_results = result_response.json()
                print(f"📊 Explanation Results:")
                print(f"  - Explanations Generated: {explanation_results.get('total_explanations', 'N/A')}")
                print(f"  - Average Quality Score: {explanation_results.get('average_quality_score', 'N/A')}")
                print(f"  - Processing Time: {explanation_results.get('processing_time_seconds', 'N/A')}s")
            else:
                print(f"⚠️ Could not retrieve explanation results")
        else:
            print(f"❌ Explanation generation failed or timed out")
            sys.exit(1)
            
    else:
        print(f"❌ Failed to start explanation generation: HTTP {response.status_code}")
        print(response.text)
        sys.exit(1)
        
except Exception as e:
    print(f"❌ Error during explanation generation: {e}")
    sys.exit(1)



In [None]:
# Step 5: Score Features with miStudioScore
print("\n📊 STEP 5: FEATURE SCORING")
print("-" * 50)

# Create scoring configuration
scoring_config = {
    "scoring_jobs": [
        {
            "scorer": "relevance_scorer",
            "name": "ai_safety_relevance",
            "params": {
                "positive_keywords": [
                    "safety", "security", "harmful", "dangerous", "toxic",
                    "bias", "discrimination", "privacy", "ethical"
                ],
                "negative_keywords": [
                    "marketing", "advertising", "promotion", "sales"
                ]
            }
        },
        {
            "scorer": "relevance_scorer", 
            "name": "technical_relevance",
            "params": {
                "positive_keywords": [
                    "algorithm", "computation", "function", "method",
                    "processing", "analysis", "logic", "reasoning"
                ],
                "negative_keywords": [
                    "emotion", "feeling", "opinion", "preference"
                ]
            }
        }
    ]
}

score_request = {
    "features_path": f"data/output/{find_job_id}/features.json",
    "config_path": "config/scoring_config.yaml",
    "output_dir": f"data/output/{find_job_id}"
}

print(f"📈 Starting feature scoring...")
print(f"📋 Configuration:")
print(f"  - Features Source: {find_job_id}")
print(f"  - Scoring Jobs: {len(scoring_config['scoring_jobs'])}")
print(f"  - Safety Relevance: {len(scoring_config['scoring_jobs'][0]['params']['positive_keywords'])} keywords")
print(f"  - Technical Relevance: {len(scoring_config['scoring_jobs'][1]['params']['positive_keywords'])} keywords")

try:
    response = requests.post(f"{SERVICE_URLS['score']}/api/v1/score", json=score_request)
    
    if response.status_code == 200:
        score_result = response.json()
        print(f"✅ Feature scoring completed!")
        print(f"📊 Scoring Results:")
        print(f"  - Features Scored: {score_result.get('features_scored', 'N/A')}")
        print(f"  - Scores Added: {', '.join(score_result.get('scores_added', []))}")
        print(f"  - Output Path: {score_result.get('output_path', 'N/A')}")
    else:
        print(f"❌ Failed to complete feature scoring: HTTP {response.status_code}")
        print(response.text)
        sys.exit(1)
        
except Exception as e:
    print(f"❌ Error during feature scoring: {e}")
    sys.exit(1)



In [None]:
# Step 6: Results Summary and Export
print("\n📋 STEP 6: RESULTS SUMMARY AND EXPORT")
print("-" * 50)

print(f"🎉 Complete miStudio workflow finished successfully!")
print(f"")
print(f"📂 Job IDs:")
print(f"  - Training: {training_job_id}")
print(f"  - Feature Analysis: {find_job_id}")
print(f"  - Explanations: {explain_job_id}")
print(f"")
print(f"📁 Results available in:")
print(f"  - Local Output: {OUTPUT_DIR.absolute()}")
print(f"  - Service Data: /data/output/{find_job_id}/")
print(f"")

# Export comprehensive results
print(f"📦 Exporting comprehensive results...")

try:
    # Export feature analysis results in multiple formats
    export_response = requests.get(f"{SERVICE_URLS['find']}/api/v1/find/{find_job_id}/export?format=all")
    if export_response.status_code == 200:
        # Save the ZIP file
        zip_path = OUTPUT_DIR / f"phi4_layer{TARGET_LAYER}_complete_results.zip"
        with open(zip_path, 'wb') as f:
            f.write(export_response.content)
        print(f"✅ Complete results exported: {zip_path}")
        
        # Extract and display summary
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(OUTPUT_DIR / "extracted_results")
        
        print(f"📊 Exported formats:")
        for file_path in (OUTPUT_DIR / "extracted_results").iterdir():
            if file_path.is_file():
                size_mb = file_path.stat().st_size / 1024 / 1024
                print(f"  - {file_path.name}: {size_mb:.1f} MB")
    
    else:
        print(f"⚠️ Could not export results: HTTP {export_response.status_code}")

except Exception as e:
    print(f"⚠️ Error exporting results: {e}")

print(f"")
print(f"✅ WORKFLOW COMPLETE!")
print(f"")
print(f"🔍 To explore your results:")
print(f"  1. Check the extracted_results folder for detailed analysis")
print(f"  2. Review the features.json file for feature mappings")
print(f"  3. Examine explanations.json for human-readable descriptions")
print(f"  4. Analyze scores.json for feature importance rankings")
print(f"")
print(f"🔗 Access service UIs:")
for service, url in SERVICE_URLS.items():
    print(f"  - {service.capitalize()}: {url}/docs")
print(f"")
print(f"📈 This completes the full miStudio interpretability pipeline!")
print(f"   You now have a trained SAE, feature analysis, explanations, and scores")
print(f"   for the Microsoft Phi-4 model at layer {TARGET_LAYER}.")