# 🎨 MASUKA V2 - Complete Training & Generation Platform

**Train Flux LoRAs and Generate Images with GPU in Google Colab**

---

## ⚡ Quick Start Guide

1. **Enable GPU**: Runtime → Change runtime type → **T4 GPU** (or A100 for faster)
2. **Run Cell 1** ▶️ - Complete setup (~10 minutes)
3. **Run Cell 2** ▶️ - Upload images & train (~30-60 minutes)
4. **Run Cell 3** ▶️ - Generate images with your trained LoRA!

---

## 📋 What You'll Get

- ✅ **Training**: Train custom Flux LoRAs with 15-25 images
- ✅ **Generation**: Generate images using trained LoRAs
- ✅ **Storage**: Models saved to S3 automatically
- ✅ **API**: Full REST API with documentation

---

**⚠️ Important Notes:**
- Keep this notebook running during training
- Training: 45min (T4) or 20min (A100)
- Generation: ~1 min per image
- All models saved to S3 bucket

---

In [None]:
#@title 🚀 CELL 1: Complete Setup (Run Once)
#@markdown This will take ~10 minutes. Watch for your API URL at the end!

import os
import sys
import time
import subprocess
from IPython.display import clear_output, HTML

def print_step(emoji, title, status=""):
    print(f"\n{emoji} {title}")
    print("="*60)
    if status:
        print(status)

def run_command(cmd, description, silent=True):
    """Run a command and handle output"""
    print(f"  {description}...", end="", flush=True)
    try:
        if silent:
            result = subprocess.run(cmd, shell=True, capture_output=True, text=True, check=True)
        else:
            result = subprocess.run(cmd, shell=True, check=True)
        print(" ✅")
        return True
    except subprocess.CalledProcessError as e:
        print(f" ❌\n{e}")
        return False

print("")
print("="*70)
print("     🎨 MASUKA V2 - Complete Setup Starting")
print("="*70)

# ============================================================
# STEP 1: Check GPU
# ============================================================
print_step("1️⃣", "Checking GPU")
run_command("nvidia-smi -L", "Detecting GPU", silent=True)

import torch
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    vram = torch.cuda.get_device_properties(0).total_memory / (1024**3)
    print(f"  GPU: {gpu_name}")
    print(f"  VRAM: {vram:.1f} GB")
    print("  ✅ GPU Ready!")
else:
    print("  ❌ No GPU detected!")
    print("  Go to: Runtime → Change runtime type → GPU")
    sys.exit(1)

# ============================================================
# STEP 2: Clone Repository
# ============================================================
print_step("2️⃣", "Cloning MASUKA V2 Repository")
if not os.path.exists('/content/masuka-v2'):
    run_command(
        "git clone https://github.com/SamuelD27/masuka.git /content/masuka-v2",
        "Cloning from GitHub",
        silent=False
    )
else:
    print("  ✅ Repository exists")
    os.chdir('/content/masuka-v2')
    run_command("git pull origin main", "Updating repository", silent=False)

os.chdir('/content/masuka-v2')

# ============================================================
# STEP 3: Install Dependencies
# ============================================================
print_step("3️⃣", "Installing Python Dependencies (Phase 3 includes ML libs)")
print("  This may take 3-5 minutes...")

# Install from requirements.txt
subprocess.run(
    [sys.executable, "-m", "pip", "install", "-q", "-r", "backend/requirements.txt"],
    stdout=subprocess.DEVNULL,
    stderr=subprocess.DEVNULL
)

# Install pyngrok for tunnel
subprocess.run(
    [sys.executable, "-m", "pip", "install", "-q", "pyngrok"],
    stdout=subprocess.DEVNULL,
    stderr=subprocess.DEVNULL
)

print("  ✅ All dependencies installed")

# Verify key packages
import fastapi, celery, pydantic, diffusers
print(f"  FastAPI: {fastapi.__version__}")
print(f"  Celery: {celery.__version__}")
print(f"  Pydantic: {pydantic.__version__}")
print(f"  Diffusers: {diffusers.__version__}")

# ============================================================
# STEP 4: Install SimpleTuner
# ============================================================
print_step("4️⃣", "Installing SimpleTuner")
if not os.path.exists('/content/SimpleTuner'):
    run_command(
        "git clone -q https://github.com/bghira/SimpleTuner /content/SimpleTuner",
        "Cloning SimpleTuner",
        silent=False
    )
    
    print("  Installing SimpleTuner package...", end="", flush=True)
    os.chdir('/content/SimpleTuner')
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-q", "-e", "."],
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL
    )
    print(" ✅")
    
    os.chdir('/content/masuka-v2')
    print("  ✅ SimpleTuner ready at /content/SimpleTuner")
else:
    print("  ✅ SimpleTuner already installed")

# ============================================================
# STEP 5: Setup PostgreSQL
# ============================================================
print_step("5️⃣", "Setting up PostgreSQL")
run_command("apt-get update -qq && apt-get install -y -qq postgresql postgresql-contrib", 
            "Installing PostgreSQL")
run_command("service postgresql start", "Starting PostgreSQL")
time.sleep(2)

commands = [
    "DROP DATABASE IF EXISTS masuka;",
    "DROP USER IF EXISTS masuka;",
    "CREATE USER masuka WITH PASSWORD 'password123';",
    "CREATE DATABASE masuka OWNER masuka;",
    "GRANT ALL PRIVILEGES ON DATABASE masuka TO masuka;"
]
for cmd in commands:
    subprocess.run(["sudo", "-u", "postgres", "psql", "-c", cmd], 
                   capture_output=True)
print("  ✅ Database 'masuka' created")

# ============================================================
# STEP 6: Setup Redis
# ============================================================
print_step("6️⃣", "Setting up Redis")
run_command("apt-get install -y -qq redis-server", "Installing Redis")
run_command("redis-server --daemonize yes", "Starting Redis")
time.sleep(1)
run_command("redis-cli ping", "Testing Redis")

# ============================================================
# STEP 7: Configure Environment
# ============================================================
print_step("7️⃣", "Configuring Environment")

env_content = """# MASUKA V2 Colab Environment
APP_NAME=MASUKA V2
DEBUG=true

# Database
DATABASE_URL=postgresql://masuka:password123@localhost:5432/masuka

# Redis
REDIS_URL=redis://localhost:6379/0
CELERY_BROKER_URL=redis://localhost:6379/0
CELERY_RESULT_BACKEND=redis://localhost:6379/0

# JWT
SECRET_KEY=colab-secret-key-change-in-production
ALGORITHM=HS256
ACCESS_TOKEN_EXPIRE_MINUTES=1440

# Storage (AWS S3)
S3_BUCKET=masuka-v2
AWS_ACCESS_KEY_ID=AKIAWMUZGDEJYY6UYN5A
AWS_SECRET_ACCESS_KEY=AZxb199Vcy/aI5CGyvefWy1MhLGq1x4tKcCmq0NG
S3_REGION=ap-southeast-1

# Hugging Face (for Flux model download)
HF_TOKEN=hf_your_token_here

# Paths
SIMPLETUNER_PATH=/content/SimpleTuner
MODELS_PATH=/content/models
TEMP_PATH=/tmp/masuka

# CORS
CORS_ORIGINS=*
"""

os.makedirs('/content/masuka-v2/backend', exist_ok=True)
with open('/content/masuka-v2/backend/.env', 'w') as f:
    f.write(env_content)
print("  ✅ Environment configured")

# ============================================================
# STEP 8: Initialize Database & Directories
# ============================================================
print_step("8️⃣", "Initializing Database & Directories")
os.makedirs('/tmp/masuka/uploads', exist_ok=True)
os.makedirs('/tmp/masuka/training', exist_ok=True)
os.makedirs('/tmp/masuka/generated', exist_ok=True)
os.makedirs('/tmp/masuka/model_cache', exist_ok=True)
os.makedirs('/content/models', exist_ok=True)
print("  ✅ Directories created")

# ============================================================
# STEP 9: Start FastAPI Backend
# ============================================================
print_step("9️⃣", "Starting FastAPI Backend")
os.chdir('/content/masuka-v2/backend')

subprocess.run(['pkill', '-f', 'uvicorn'], capture_output=True)
time.sleep(2)

backend_process = subprocess.Popen(
    ['uvicorn', 'app.main:app', '--host', '0.0.0.0', '--port', '8000', '--log-level', 'info'],
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True
)

print("  ⏳ Waiting for backend to start...")
time.sleep(15)

import requests
try:
    response = requests.get('http://localhost:8000/health', timeout=5)
    health = response.json()
    print(f"  ✅ Backend running: {health['app']} v{health['version']}")
    
    # Verify new Phase 3 endpoints
    response = requests.get('http://localhost:8000/api/models/', timeout=5)
    if response.status_code == 200:
        print(f"  ✅ Models API loaded (Phase 3)")
        
except Exception as e:
    print(f"  ⚠️  Backend may not be ready: {e}")

# ============================================================
# STEP 10: Start Celery Worker
# ============================================================
print_step("🔟", "Starting Celery Worker")

subprocess.run(['pkill', '-f', 'celery'], capture_output=True)
time.sleep(2)

celery_process = subprocess.Popen(
    ['celery', '-A', 'app.tasks.celery_app', 'worker', 
     '--loglevel=info', '--concurrency=1', '-Q', 'training,generation'],
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    text=True
)

print("  ⏳ Waiting for Celery to start...")
time.sleep(5)
print(f"  ✅ Celery worker running with training+generation queues (GPU: {gpu_name})")

# ============================================================
# STEP 11: Expose via ngrok
# ============================================================
print_step("1️⃣1️⃣", "Creating Public URL")

from pyngrok import ngrok

ngrok.set_auth_token("33u4PSfJRAAdkBVl0lmMTo7LebK_815Q5PcJK6h68hM5PUAyM")
ngrok.kill()
time.sleep(2)

public_url = ngrok.connect(8000)
print("  ✅ Tunnel created")

# ============================================================
# FINAL OUTPUT
# ============================================================
print("\n" + "="*70)
print("🎉 MASUKA V2 Setup Complete!")
print("="*70)
print(f"\n📡 Your Public API URL:")
print(f"   {public_url}")
print(f"\n📚 Interactive Docs:")
print(f"   {public_url}/docs")
print(f"\n🔍 Health Check:")
print(f"   {public_url}/health")
print("\n" + "="*70)
print("\n✅ Services Running:")
print("   • PostgreSQL Database")
print("   • Redis Cache & Queue")
print("   • FastAPI Backend")
print(f"   • Celery Worker (GPU: {gpu_name})")
print("   • SimpleTuner Ready")
print("   • Flux Generator Ready (Phase 3)")
print("\n⚠️  Keep this notebook running during training & generation!")
print("\n📖 Next Steps:")
print("   → Run Cell 2 to train a LoRA")
print("   → Run Cell 3 to generate images")
print("="*70)

# Store for other cells
API_URL = str(public_url).replace('NgrokTunnel: "', '').split('"')[0]
if not API_URL.startswith('http'):
    API_URL = str(public_url)
print(f"\n✅ API_URL stored: {API_URL}")

BACKEND_PROCESS = backend_process
CELERY_PROCESS = celery_process

In [None]:
#@title 📤 CELL 2: Train LoRA - Upload Images & Start Training
#@markdown Upload 15-25 images of your subject (JPG/PNG). Training starts automatically and monitors progress.

from google.colab import files
import requests
import io
import time
from IPython.display import clear_output

# Ensure API_URL is set
if not isinstance(API_URL, str):
    API_URL = str(API_URL).replace('NgrokTunnel: "', '').split('"')[0]
print(f"Using API: {API_URL}")

# Get parameters
dataset_name = "My LoRA Dataset" #@param {type:"string"}
trigger_word = "TOK" #@param {type:"string"}
learning_rate = 0.0001 #@param {type:"number"}
training_steps = 2000 #@param {type:"slider", min:500, max:5000, step:100}

print("📤 Upload Your Training Images")
print("="*60)
print("Click 'Choose Files' and select 15-25 images...\n")

# Upload files
uploaded = files.upload()
print(f"\n✅ Received {len(uploaded)} files")

# Prepare files
file_list = []
for filename, data in uploaded.items():
    if filename.lower().endswith(('.jpg', '.jpeg')):
        content_type = 'image/jpeg'
    elif filename.lower().endswith('.png'):
        content_type = 'image/png'
    elif filename.lower().endswith('.webp'):
        content_type = 'image/webp'
    else:
        continue
    file_list.append(('files', (filename, io.BytesIO(data), content_type)))

# Create dataset
print("\n📦 Creating dataset...")
response = requests.post(
    f"{API_URL}/api/datasets/",
    json={"name": dataset_name, "trigger_word": trigger_word}
)

if response.status_code != 201:
    print(f"❌ Error: {response.status_code} - {response.text}")
    raise Exception(f"Failed to create dataset")

dataset = response.json()
dataset_id = dataset['id']
print(f"✅ Dataset: {dataset_id}")

# Upload images
print("\n📤 Uploading to server...")
response = requests.post(
    f"{API_URL}/api/datasets/{dataset_id}/upload",
    files=file_list
)

if response.status_code != 200:
    print(f"❌ Error: {response.status_code} - {response.text}")
    raise Exception(f"Failed to upload files")

result = response.json()
print(f"✅ Uploaded {result['uploaded_count']} images")

# Start training
print("\n🚀 Starting training...")
response = requests.post(
    f"{API_URL}/api/training/flux",
    json={
        "name": f"{dataset_name} - v1",
        "model_type": "flux_image",
        "dataset_id": dataset_id,
        "learning_rate": learning_rate,
        "steps": training_steps,
        "network_dim": 32,
        "network_alpha": 16,
        "resolution": 1024,
        "trigger_word": trigger_word
    }
)

if response.status_code not in [200, 201]:
    print(f"❌ Error: {response.status_code} - {response.text}")
    raise Exception(f"Failed to start training")

training = response.json()
SESSION_ID = training['session_id']

print("\n" + "="*60)
print("✅ Training Started!")
print("="*60)
print(f"Session ID: {SESSION_ID}")
print(f"Status: {training['status']}")
print(f"Task ID: {training['task_id']}")
print(f"\n⏱️  Estimated Time: 45-60 minutes (T4) or 20-30 minutes (A100)")
print("="*60)

time.sleep(5)

# Monitor progress
print("\n\n📊 Monitoring Training Progress...\n")

start_time = time.time()
last_step = 0

while True:
    try:
        response = requests.get(f"{API_URL}/api/training/{SESSION_ID}")
        status = response.json()
        
        clear_output(wait=True)
        
        print("="*70)
        print(f"🎨 {status['name']}")
        print("="*70)
        
        current_status = status['status']
        status_emoji = {
            'pending': '⏳', 'training': '🔄',
            'completed': '✅', 'failed': '❌', 'cancelled': '🛑'
        }
        print(f"\nStatus: {status_emoji.get(current_status, '❓')} {current_status.upper()}")
        
        if status.get('current_step') and status.get('total_steps'):
            current = status['current_step']
            total = status['total_steps']
            progress = (current / total) * 100
            
            print(f"\nProgress: {progress:.1f}% ({current:,}/{total:,} steps)")
            
            bar_length = 50
            filled = int(bar_length * progress / 100)
            bar = '█' * filled + '░' * (bar_length - filled)
            print(f"[{bar}] {progress:.1f}%")
            
            if status.get('current_loss'):
                print(f"\nLoss: {status['current_loss']:.6f}")
            
            if current > last_step and last_step > 0:
                elapsed = time.time() - start_time
                time_per_step = elapsed / current
                eta_seconds = time_per_step * (total - current)
                print(f"\nElapsed: {elapsed/60:.1f}m | ETA: {eta_seconds/60:.1f}m")
            
            last_step = current
        
        print("\n" + "="*70)
        
        if current_status in ['completed', 'failed', 'cancelled']:
            if current_status == 'completed':
                print("\n🎉 Training Completed!")
                print("✅ Model uploaded to S3")
                print(f"\n📦 Session ID: {SESSION_ID}")
                print("\n💡 Next: Run Cell 3 to generate images with this LoRA!")
            elif current_status == 'failed':
                print(f"\n❌ Training Failed: {status.get('error_message', 'Unknown')}")
            else:
                print("\n🛑 Training Cancelled")
            break
        
        print("\n⏳ Updating in 10 seconds...")
        time.sleep(10)
        
    except KeyboardInterrupt:
        print("\n\n⚠️  Monitoring stopped. Training continues in background.")
        print(f"📋 Session ID: {SESSION_ID}")
        break
    except Exception as e:
        print(f"\n❌ Error: {e}")
        time.sleep(10)

In [None]:
#@title 🎨 CELL 3: Generate Images with Trained LoRA
#@markdown Generate images using your trained model. First time takes 10-15 min (Flux model download), then ~1 min per image.

import requests
import time
from IPython.display import clear_output, display, Image as IPImage
import io

# Ensure API_URL is set
if not isinstance(API_URL, str):
    API_URL = str(API_URL).replace('NgrokTunnel: "', '').split('"')[0]
print(f"Using API: {API_URL}\n")

# Generation parameters
prompt = "photo of TOK person standing in a field, professional photography, golden hour lighting, 8k, highly detailed" #@param {type:"string"}
negative_prompt = "blurry, low quality, distorted, ugly, deformed" #@param {type:"string"}
use_lora = True #@param {type:"boolean"}
lora_weight = 0.8 #@param {type:"slider", min:0.0, max:1.0, step:0.1}
num_images = 2 #@param {type:"slider", min:1, max:4, step:1}
num_inference_steps = 30 #@param {type:"slider", min:10, max:100, step:5}
guidance_scale = 3.5 #@param {type:"slider", min:1.0, max:20.0, step:0.5}
width = 1024 #@param {type:"slider", min:512, max:2048, step:64}
height = 1024 #@param {type:"slider", min:512, max:2048, step:64}
seed = None #@param {type:"integer"}

# ============================================================
# Step 1: Get available models
# ============================================================
print("📋 Fetching available models...\n")
response = requests.get(f"{API_URL}/api/models/")

if response.status_code != 200:
    print(f"❌ Error fetching models: {response.status_code}")
    print(f"Response: {response.text}")
    raise Exception("Failed to fetch models")

models = response.json()

if not models:
    print("⚠️  No trained models found!")
    print("\n💡 Please run Cell 2 first to train a LoRA.")
    raise Exception("No models available")

print(f"✅ Found {len(models)} model(s):\n")
for i, model in enumerate(models):
    print(f"{i+1}. {model['name']} (v{model['version']})")
    print(f"   ID: {model['id']}")
    print(f"   Type: {model['model_type']}")
    print(f"   Trigger: {model.get('trigger_word', 'N/A')}")
    print(f"   Size: {model.get('file_size_mb', 'N/A')} MB")
    print()

# Use the most recent model
selected_model = models[0]
model_id = selected_model['id'] if use_lora else None

if use_lora:
    print(f"🎯 Using model: {selected_model['name']}\n")
else:
    print(f"🎯 Using base Flux (no LoRA)\n")

# ============================================================
# Step 2: Start generation
# ============================================================
print("="*70)
print("🎨 Starting Image Generation")
print("="*70)
print(f"\n📝 Prompt: {prompt}")
print(f"🚫 Negative: {negative_prompt}")
print(f"\n⚙️  Settings:")
print(f"   Images: {num_images}")
print(f"   Steps: {num_inference_steps}")
print(f"   Guidance: {guidance_scale}")
print(f"   Size: {width}x{height}")
if use_lora:
    print(f"   LoRA Weight: {lora_weight}")
print("\n🚀 Submitting generation job...")

generation_request = {
    "prompt": prompt,
    "negative_prompt": negative_prompt,
    "num_images": num_images,
    "num_inference_steps": num_inference_steps,
    "guidance_scale": guidance_scale,
    "width": width,
    "height": height
}

if use_lora and model_id:
    generation_request["model_id"] = model_id
    generation_request["lora_weight"] = lora_weight

if seed is not None:
    generation_request["seed"] = seed

response = requests.post(
    f"{API_URL}/api/generate/image",
    json=generation_request
)

if response.status_code != 201:
    print(f"\n❌ Error starting generation: {response.status_code}")
    print(f"Response: {response.text}")
    raise Exception("Failed to start generation")

generation = response.json()
job_id = generation['job_id']

print(f"\n✅ Generation job created!")
print(f"📋 Job ID: {job_id}")
print(f"🔄 Task ID: {generation['task_id']}")

print("\n" + "="*70)
print("⏳ Generating images...")
print("="*70)
print("\n💡 First generation takes 10-15 min (Flux model download)")
print("   Subsequent generations: ~1 min per image")
print("\n🔄 Polling for completion...\n")

# ============================================================
# Step 3: Monitor generation progress
# ============================================================
start_time = time.time()
check_count = 0

while True:
    try:
        response = requests.get(f"{API_URL}/api/generate/{job_id}")
        
        if response.status_code != 200:
            print(f"❌ Error checking status: {response.status_code}")
            time.sleep(5)
            continue
        
        job_status = response.json()
        status = job_status['status']
        
        elapsed = time.time() - start_time
        check_count += 1
        
        print(f"[{elapsed:.0f}s] Check #{check_count}: {status}...", end="\r")
        
        if status == 'completed':
            clear_output(wait=True)
            print("\n" + "="*70)
            print("🎉 Generation Complete!")
            print("="*70)
            print(f"\n⏱️  Total time: {elapsed:.1f} seconds ({elapsed/60:.1f} minutes)")
            
            output_paths = job_status.get('output_paths', [])
            print(f"\n✅ Generated {len(output_paths)} image(s):\n")
            
            # Display images
            for i, url in enumerate(output_paths):
                print(f"Image {i+1}:")
                print(f"  URL: {url[:80]}...")
                
                # Download and display
                try:
                    img_response = requests.get(url)
                    if img_response.status_code == 200:
                        display(IPImage(data=img_response.content))
                        print()
                except Exception as e:
                    print(f"  ⚠️  Could not display image: {e}\n")
            
            print("="*70)
            print("\n💡 Images saved to S3 bucket: masuka-v2/generated/")
            print(f"\n📋 Job ID: {job_id}")
            print("\n✨ Generate more images by running this cell again!")
            print("="*70)
            break
            
        elif status == 'failed':
            print(f"\n\n❌ Generation Failed!")
            error = job_status.get('error_message', 'Unknown error')
            print(f"Error: {error}")
            break
            
        elif status == 'processing':
            # Still processing, continue polling
            time.sleep(10)
        else:
            # Unknown status
            time.sleep(5)
            
    except KeyboardInterrupt:
        print("\n\n⚠️  Monitoring stopped. Generation continues in background.")
        print(f"📋 Job ID: {job_id}")
        print(f"💡 Check status: {API_URL}/api/generate/{job_id}")
        break
    except Exception as e:
        print(f"\n❌ Error during monitoring: {e}")
        time.sleep(10)

---

## 🎉 You're All Set!

### What You Can Do Now:

1. **Train More LoRAs** - Run Cell 2 with different images
2. **Generate More Images** - Run Cell 3 with different prompts
3. **Experiment** - Try different parameters (steps, guidance, LoRA weight)

### Your Models Are Saved:

- **Location:** S3 Bucket `masuka-v2`
- **Models:** `models/{session_id}/flux_lora.safetensors`
- **Generations:** `generated/{job_id}/image_*.png`

### API Endpoints:

- **Health:** `{API_URL}/health`
- **Docs:** `{API_URL}/docs`
- **Models:** `{API_URL}/api/models/`
- **Generate:** `{API_URL}/api/generate/image`
- **Training:** `{API_URL}/api/training/`

### Tips:

- **First generation is slow** - Flux model download (10GB)
- **Use trigger word** - Include your trigger word in prompts (e.g., "TOK")
- **Adjust LoRA weight** - 0.6-0.9 works best for most cases
- **Higher steps = better quality** - But takes longer (try 30-50)

---

**Created by MASUKA V2** | Phase 3 Complete ✅

**Questions?** Check the [documentation](https://github.com/SamuelD27/masuka)

---