# 🎨 MASUKA V2 - Complete Setup (Single Cell)

**Train Flux LoRAs with GPU in Google Colab**

---

## ⚡ Quick Start

1. **Enable GPU**: Runtime → Change runtime type → T4 GPU
2. **Run the cell below** ▶️
3. **Wait ~10 minutes** for complete setup
4. **Get your API URL** from the output
5. **Use the helper cells below** to upload images and train!

---

**⚠️ Important:**
- Keep this notebook running during training
- Training takes ~45min (T4) or ~20min (A100)
- Models saved to S3 automatically

In [None]:
#@title 🚀 Run Complete MASUKA V2 Setup
#@markdown This will take ~10 minutes. Watch the output for your API URL!

import os
import sys
import time
import subprocess
from IPython.display import clear_output, HTML

def print_step(emoji, title, status=""):
    print(f"\n{emoji} {title}")
    print("="*60)
    if status:
        print(status)

def run_command(cmd, description, silent=True):
    """Run a command and handle output"""
    print(f"  {description}...", end="", flush=True)
    try:
        if silent:
            result = subprocess.run(cmd, shell=True, capture_output=True, text=True, check=True)
        else:
            result = subprocess.run(cmd, shell=True, check=True)
        print(" ✅")
        return True
    except subprocess.CalledProcessError as e:
        print(f" ❌\n{e}")
        return False

print("")
print("="*60)
print("     🎨 MASUKA V2 - Complete Setup Starting")
print("="*60)

# ============================================================
# STEP 1: Check GPU
# ============================================================
print_step("1️⃣", "Checking GPU")
run_command("nvidia-smi -L", "Detecting GPU", silent=True)

import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"  GPU: {gpu_name}")
    print(f"  VRAM: {vram:.1f} GB")
    print("  ✅ GPU Ready!")
else:
    print("  ❌ No GPU detected!")
    print("  Go to: Runtime → Change runtime type → GPU")
    sys.exit(1)

# ============================================================
# STEP 2: Clone Repository
# ============================================================
print_step("2️⃣", "Cloning MASUKA V2 Repository")
if not os.path.exists('/content/masuka-v2'):
    run_command(
        "git clone https://github.com/SamuelD27/masuka.git /content/masuka-v2",
        "Cloning from GitHub",
        silent=False
    )
else:
    print("  ✅ Repository exists")

os.chdir('/content/masuka-v2')

# ============================================================
# STEP 3: Install Dependencies
# ============================================================
print_step("3️⃣", "Installing Python Dependencies")
print("  This may take 2-3 minutes...")

# Install core dependencies
deps = [
    "fastapi>=0.115.0",
    "uvicorn[standard]>=0.31.1",
    "pydantic>=2.11.0",
    "pydantic-settings>=2.5.2",
    "python-multipart>=0.0.18",
    "sqlalchemy==2.0.25",
    "psycopg2-binary",
    "alembic",
    "redis",
    "celery",
    "flower",
    "boto3",
    "python-dotenv",
    "pyyaml",
    "pyngrok"
]

for dep in deps:
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", dep], 
                   stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

print("  ✅ Core dependencies installed")

# Verify
import fastapi, celery, pydantic
print(f"  FastAPI: {fastapi.__version__}")
print(f"  Celery: {celery.__version__}")
print(f"  Pydantic: {pydantic.__version__}")

# ============================================================
# STEP 4: Install SimpleTuner
# ============================================================
print_step("4️⃣", "Installing SimpleTuner")
if not os.path.exists('/content/SimpleTuner'):
    run_command(
        "git clone -q https://github.com/bghira/SimpleTuner /content/SimpleTuner",
        "Cloning SimpleTuner",
        silent=False
    )
    
    # Install SimpleTuner (uses setup.py)
    print("  Installing SimpleTuner package...", end="", flush=True)
    os.chdir('/content/SimpleTuner')
    result = subprocess.run(
        [sys.executable, "setup.py", "install"],
        capture_output=True,
        text=True
    )
    if result.returncode == 0:
        print(" ✅")
    else:
        # Try pip install
        print(" Using pip...", end="", flush=True)
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "-q", "-e", "."],
            stdout=subprocess.DEVNULL
        )
        print(" ✅")
    
    os.chdir('/content/masuka-v2')
    print("  ✅ SimpleTuner ready at /content/SimpleTuner")
else:
    print("  ✅ SimpleTuner already installed")

# ============================================================
# STEP 5: Setup PostgreSQL
# ============================================================
print_step("5️⃣", "Setting up PostgreSQL")
run_command("apt-get update -qq && apt-get install -y -qq postgresql postgresql-contrib", 
            "Installing PostgreSQL")
run_command("service postgresql start", "Starting PostgreSQL")
time.sleep(2)

# Create database
commands = [
    "DROP DATABASE IF EXISTS masuka;",
    "DROP USER IF EXISTS masuka;",
    "CREATE USER masuka WITH PASSWORD 'password123';",
    "CREATE DATABASE masuka OWNER masuka;",
    "GRANT ALL PRIVILEGES ON DATABASE masuka TO masuka;"
]
for cmd in commands:
    subprocess.run(["sudo", "-u", "postgres", "psql", "-c", cmd], 
                   capture_output=True)
print("  ✅ Database 'masuka' created")

# ============================================================
# STEP 6: Setup Redis
# ============================================================
print_step("6️⃣", "Setting up Redis")
run_command("apt-get install -y -qq redis-server", "Installing Redis")
run_command("redis-server --daemonize yes", "Starting Redis")
time.sleep(1)
run_command("redis-cli ping", "Testing Redis")

# ============================================================
# STEP 7: Configure Environment
# ============================================================
print_step("7️⃣", "Configuring Environment")

env_content = """# MASUKA V2 Colab Environment
APP_NAME=MASUKA V2
DEBUG=true

# Database
DATABASE_URL=postgresql://masuka:password123@localhost:5432/masuka

# Redis
REDIS_URL=redis://localhost:6379/0
CELERY_BROKER_URL=redis://localhost:6379/0
CELERY_RESULT_BACKEND=redis://localhost:6379/0

# JWT
SECRET_KEY=colab-secret-key-change-in-production
ALGORITHM=HS256
ACCESS_TOKEN_EXPIRE_MINUTES=1440

# Storage (AWS S3)
S3_BUCKET=masuka-v2
AWS_ACCESS_KEY_ID=AKIAWMUZGDEJYY6UYN5A
AWS_SECRET_ACCESS_KEY=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
S3_ENDPOINT_URL=https://s3.ap-southeast-1.amazonaws.com
S3_REGION=ap-southeast-1

# Paths
SIMPLETUNER_PATH=/content/SimpleTuner
MODELS_PATH=/content/models
TEMP_PATH=/tmp/masuka

# CORS
CORS_ORIGINS=*
"""

os.makedirs('/content/masuka-v2/backend', exist_ok=True)
with open('/content/masuka-v2/backend/.env', 'w') as f:
    f.write(env_content)
print("  ✅ Environment configured")

# ============================================================
# STEP 8: Initialize Database
# ============================================================
print_step("8️⃣", "Initializing Database")
os.makedirs('/tmp/masuka/uploads', exist_ok=True)
os.makedirs('/tmp/masuka/training', exist_ok=True)
os.makedirs('/content/models', exist_ok=True)
print("  ✅ Directories created")

# Note: Database tables will be created when backend starts

# ============================================================
# STEP 9: Start FastAPI Backend
# ============================================================
print_step("9️⃣", "Starting FastAPI Backend")
os.chdir('/content/masuka-v2/backend')

# Kill any existing uvicorn processes
subprocess.run(['pkill', '-f', 'uvicorn'], capture_output=True)
time.sleep(2)

backend_process = subprocess.Popen(
    ['uvicorn', 'app.main:app', '--host', '0.0.0.0', '--port', '8000', '--log-level', 'info'],
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True
)

print("  ⏳ Waiting for backend to start...")
time.sleep(15)

# Test backend and verify routes
import requests
try:
    response = requests.get('http://localhost:8000/health', timeout=5)
    health = response.json()
    print(f"  ✅ Backend running: {health['app']} v{health['version']}")
    
    # Verify datasets endpoint
    response = requests.get('http://localhost:8000/api/datasets/', timeout=5)
    if response.status_code == 200:
        print(f"  ✅ Datasets API loaded")
    else:
        print(f"  ⚠️  Datasets API issue: {response.status_code}")
        
except Exception as e:
    print(f"  ⚠️  Backend may not be ready: {e}")
    print(f"  Check backend logs if issues persist")

# ============================================================
# STEP 10: Start Celery Worker
# ============================================================
print_step("🔟", "Starting Celery Worker")

# Kill any existing celery processes
subprocess.run(['pkill', '-f', 'celery'], capture_output=True)
time.sleep(2)

celery_process = subprocess.Popen(
    ['celery', '-A', 'app.tasks.celery_app', 'worker', 
     '--loglevel=info', '--concurrency=1', '-Q', 'training,generation'],
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    text=True
)

print("  ⏳ Waiting for Celery to start...")
time.sleep(5)
print(f"  ✅ Celery worker running (GPU: {gpu_name})")

# ============================================================
# STEP 11: Expose via ngrok
# ============================================================
print_step("1️⃣1️⃣", "Creating Public URL")

from pyngrok import ngrok

# Set ngrok auth token
ngrok.set_auth_token("33u4PSfJRAAdkBVl0lmMTo7LebK_815Q5PcJK6h68hM5PUAyM")

# Kill any existing tunnels
ngrok.kill()
time.sleep(2)

# Create new tunnel
public_url = ngrok.connect(8000)
print("  ✅ Tunnel created")

# ============================================================
# FINAL OUTPUT
# ============================================================
print("\n" + "="*70)
print("🎉 MASUKA V2 Setup Complete!")
print("="*70)
print(f"\n📡 Your Public API URL:")
print(f"   {public_url}")
print(f"\n📚 Interactive Docs:")
print(f"   {public_url}/docs")
print(f"\n🔍 Health Check:")
print(f"   {public_url}/health")
print("\n" + "="*70)
print("\n✅ Services Running:")
print("   • PostgreSQL Database")
print("   • Redis Cache & Queue")
print("   • FastAPI Backend")
print(f"   • Celery Worker (GPU: {gpu_name})")
print("   • SimpleTuner Ready")
print("\n⚠️  Keep this notebook running during training!")
print("\n📖 Use the helper cells below to:")
print("   1. Upload training images & start training")
print("   2. Monitor progress (automatic)")
print("="*70)

# Store for other cells (convert to string!)
API_URL = str(public_url).replace('NgrokTunnel: "', '').split('"')[0]
if not API_URL.startswith('http'):
    API_URL = str(public_url)
print(f"\n✅ API_URL stored: {API_URL}")

BACKEND_PROCESS = backend_process
CELERY_PROCESS = celery_process

---

# 📖 Usage - Upload & Train

Run the cells below after setup completes

In [None]:
#@title 📤 Upload Images, Start Training & Monitor Progress
#@markdown Upload 15-25 images of your subject (JPG/PNG). Training will start and be monitored automatically.

from google.colab import files
import requests
import io
import time
from IPython.display import clear_output

# Ensure API_URL is a string
if not isinstance(API_URL, str):
    API_URL = str(API_URL).replace('NgrokTunnel: "', '').split('"')[0]
print(f"Using API: {API_URL}")

# Get parameters
dataset_name = "My LoRA Dataset" #@param {type:"string"}
trigger_word = "myface" #@param {type:"string"}
learning_rate = 0.0001 #@param {type:"number"}
training_steps = 2000 #@param {type:"slider", min:500, max:5000, step:100}

print("📤 Upload Your Training Images")
print("="*60)
print("Click 'Choose Files' and select 15-25 images...\n")

# Upload files
uploaded = files.upload()
print(f"\n✅ Received {len(uploaded)} files")

# Prepare files
file_list = []
for filename, data in uploaded.items():
    if filename.lower().endswith(('.jpg', '.jpeg')):
        content_type = 'image/jpeg'
    elif filename.lower().endswith('.png'):
        content_type = 'image/png'
    elif filename.lower().endswith('.webp'):
        content_type = 'image/webp'
    else:
        continue
    file_list.append(('files', (filename, io.BytesIO(data), content_type)))

# Create dataset
print("\n📦 Creating dataset...")
response = requests.post(
    f"{API_URL}/api/datasets/",
    json={"name": dataset_name, "trigger_word": trigger_word}
)

# Check for errors
if response.status_code != 201:
    print(f"❌ Error creating dataset:")
    print(f"Status: {response.status_code}")
    print(f"Response: {response.text}")
    raise Exception(f"Failed to create dataset: {response.text}")

dataset = response.json()
dataset_id = dataset['id']
print(f"✅ Dataset: {dataset_id}")

# Upload images
print("\n📤 Uploading to server...")
response = requests.post(
    f"{API_URL}/api/datasets/{dataset_id}/upload",
    files=file_list
)

if response.status_code != 200:
    print(f"❌ Error uploading files:")
    print(f"Status: {response.status_code}")
    print(f"Response: {response.text}")
    raise Exception(f"Failed to upload files: {response.text}")

result = response.json()
print(f"✅ Uploaded {result['uploaded_count']} images")

# Start training
print("\n🚀 Starting training...")
response = requests.post(
    f"{API_URL}/api/training/flux",
    json={
        "name": f"{dataset_name} - v1",
        "model_type": "flux_image",
        "dataset_id": dataset_id,
        "learning_rate": learning_rate,
        "steps": training_steps,
        "network_dim": 32,
        "network_alpha": 16,
        "resolution": 1024,
        "trigger_word": trigger_word
    }
)

if response.status_code != 200:
    print(f"❌ Error starting training:")
    print(f"Status: {response.status_code}")
    print(f"Response: {response.text}")
    raise Exception(f"Failed to start training: {response.text}")

training = response.json()
SESSION_ID = training['session_id']

print("\n" + "="*60)
print("✅ Training Started!")
print("="*60)
print(f"Session ID: {SESSION_ID}")
print(f"Status: {training['status']}")
print(f"\n⏱️  Estimated Time: 45-60 minutes (T4) or 20-30 minutes (A100)")
print("="*60)

# Wait a moment for training to initialize
time.sleep(5)

# ============================================================
# NOW AUTOMATICALLY START MONITORING
# ============================================================
print("\n\n📊 Monitoring Training Progress...\n")

start_time = time.time()
last_step = 0

while True:
    try:
        response = requests.get(f"{API_URL}/api/training/{SESSION_ID}")
        status = response.json()
        
        clear_output(wait=True)
        
        print("="*70)
        print(f"🎨 {status['name']}")
        print("="*70)
        
        current_status = status['status']
        status_emoji = {
            'pending': '⏳', 'training': '🔄',
            'completed': '✅', 'failed': '❌', 'cancelled': '🛑'
        }
        print(f"\nStatus: {status_emoji.get(current_status, '❓')} {current_status.upper()}")
        
        if status.get('current_step') and status.get('total_steps'):
            current = status['current_step']
            total = status['total_steps']
            progress = (current / total) * 100
            
            print(f"\nProgress: {progress:.1f}% ({current:,}/{total:,} steps)")
            
            # Progress bar
            bar_length = 50
            filled = int(bar_length * progress / 100)
            bar = '█' * filled + '░' * (bar_length - filled)
            print(f"[{bar}] {progress:.1f}%")
            
            if status.get('current_loss'):
                print(f"\nLoss: {status['current_loss']:.6f}")
            
            # Time estimate
            if current > last_step and last_step > 0:
                elapsed = time.time() - start_time
                time_per_step = elapsed / current
                eta_seconds = time_per_step * (total - current)
                print(f"\nElapsed: {elapsed/60:.1f}m | ETA: {eta_seconds/60:.1f}m")
            
            last_step = current
        
        print("\n" + "="*70)
        
        if current_status in ['completed', 'failed', 'cancelled']:
            if current_status == 'completed':
                print("\n🎉 Training Completed!")
                print("✅ Model uploaded to S3: s3://masuka-v2/models/")
                print(f"\n📦 Session ID: {SESSION_ID}")
            elif current_status == 'failed':
                print(f"\n❌ Training Failed: {status.get('error_message', 'Unknown')}")
            else:
                print("\n🛑 Training Cancelled")
            break
        
        print("\n⏳ Updating in 10 seconds... (Press Stop to exit monitoring)")
        time.sleep(10)
        
    except KeyboardInterrupt:
        print("\n\n⚠️  Monitoring stopped. Training continues in background.")
        print(f"📋 Session ID: {SESSION_ID}")
        print(f"💡 Check status at: {API_URL}/api/training/{SESSION_ID}")
        break
    except Exception as e:
        print(f"\n❌ Error: {e}")
        time.sleep(10)

In [None]:
#@title 🔧 Debug Helper: Check Backend Status & Restart
#@markdown Run this if you get 404 errors or other API issues

import requests
import subprocess
import time

# Ensure API_URL is a string
if 'API_URL' in globals():
    if not isinstance(API_URL, str):
        API_URL = str(API_URL).replace('NgrokTunnel: "', '').split('"')[0]
    
    print(f"API URL: {API_URL}\n")
    
    # Test health endpoint
    print("Testing /health endpoint...")
    try:
        response = requests.get(f"{API_URL}/health", timeout=5)
        if response.status_code == 200:
            print(f"✅ Health check passed: {response.json()}")
        else:
            print(f"❌ Health check failed: {response.status_code}")
    except Exception as e:
        print(f"❌ Cannot reach backend: {e}")
    
    # Test root endpoint
    print("\nTesting / endpoint...")
    try:
        response = requests.get(f"{API_URL}/", timeout=5)
        if response.status_code == 200:
            print(f"✅ Root endpoint works: {response.json()}")
        else:
            print(f"❌ Root failed: {response.status_code}")
    except Exception as e:
        print(f"❌ Cannot reach root: {e}")
    
    # Test datasets endpoint
    print("\nTesting /api/datasets endpoint...")
    try:
        response = requests.get(f"{API_URL}/api/datasets/", timeout=5)
        if response.status_code == 200:
            print(f"✅ Datasets endpoint works!")
            datasets = response.json()
            print(f"   Found {len(datasets)} datasets")
        else:
            print(f"❌ Datasets failed: {response.status_code}")
            print(f"   Response: {response.text}")
    except Exception as e:
        print(f"❌ Cannot reach datasets: {e}")
    
    print("\n" + "="*70)
    print("💡 If endpoints are failing, the backend may need restart.")
    print("   Run the setup cell again to restart all services.")
    print("="*70)
    
else:
    print("❌ API_URL not found. Run the setup cell first.")

---

## 🎉 Complete!

**Your trained model is in:**
- S3 Bucket: `masuka-v2`
- Path: `models/{user_id}/{session_id}/flux_lora.safetensors`

**Download from AWS S3 Console or:**
```bash
aws s3 cp s3://masuka-v2/models/.../flux_lora.safetensors ./
```

**Next:** Use your LoRA with ComfyUI, Auto1111, or any Flux-compatible tool!

---

Created by **MASUKA V2** | Phase 2 Complete ✅