<a href="https://colab.research.google.com/github/SattamAltwaim/SaSOKE/blob/main/notebooks/7_simple_web_ui.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🤟 SOKE Simple Web UI

**Super Simple Setup - Just Upload HTML and Go!**

This notebook:
1. Loads the SOKE model
2. Starts an API server
3. Hosts your HTML file
4. Gives you a shareable URL

### Requirements
- **GPU Runtime**: `Runtime → Change runtime type → GPU`
- **HTML File**: Upload `soke_ui.html` when prompted

## Step 1: Setup Environment

In [None]:
# Clone repo and mount Drive
import os
if not os.path.exists('/content/SaSOKE'):
    !git clone https://github.com/SattamAltwaim/SaSOKE.git
%cd /content/SaSOKE

from google.colab import drive
drive.mount('/content/drive')

drive_data = '/content/drive/MyDrive/GraduationProject/CodeFiles/SaSOKE'
print("✓ Ready!")

In [None]:
# Install dependencies
!pip install -q pytorch_lightning torchmetrics omegaconf shortuuid transformers diffusers einops wandb rich matplotlib
!pip install -q smplx h5py scikit-image spacy ftfy more-itertools natsort tensorboard sentencepiece
!pip install -q fastapi uvicorn python-multipart nest_asyncio
print("✓ Dependencies installed!")

## Step 2: Load Model

In [None]:
# Setup paths
import sys
import yaml
from omegaconf import OmegaConf
from mGPT.config import parse_args

# Create symbolic links
deps_links = {
    'deps/smpl_models': f'{drive_data}/deps/smpl_models',
    'deps/mbart-h2s-csl-phoenix': f'{drive_data}/deps/mbart-h2s-csl-phoenix',
}

for expected_path, actual_path in deps_links.items():
    if not os.path.exists(expected_path):
        os.makedirs(os.path.dirname(expected_path), exist_ok=True)
        os.symlink(actual_path, expected_path)

# Configure
with open('configs/soke.yaml', 'r') as f:
    config = yaml.safe_load(f)

config['ACCELERATOR'] = 'gpu'
config['DEVICE'] = [0]
config['DATASET']['H2S']['ROOT'] = f'{drive_data}/data/How2Sign'
config['DATASET']['H2S']['MEAN_PATH'] = f'{drive_data}/smpl-x/mean.pt'
config['DATASET']['H2S']['STD_PATH'] = f'{drive_data}/smpl-x/std.pt'
config['TRAIN']['PRETRAINED_VAE'] = f'{drive_data}/checkpoints/vae/tokenizer.ckpt'

with open('configs/web_inference.yaml', 'w') as f:
    yaml.dump(config, f)

with open('configs/assets.yaml', 'r') as f:
    assets = yaml.safe_load(f)

assets['RENDER']['SMPL_MODEL_PATH'] = 'deps/smpl_models/smpl'
assets['RENDER']['MODEL_PATH'] = 'deps/smpl_models'
assets['METRIC']['TM2T']['t2m_path'] = f'{drive_data}/deps/deps/t2m/t2m/'

with open('configs/assets_web.yaml', 'w') as f:
    yaml.dump(assets, f)

sys.argv = ['', '--cfg', 'configs/web_inference.yaml', '--cfg_assets', 'configs/assets_web.yaml']
cfg = parse_args(phase="test")
cfg.FOLDER = cfg.TEST.FOLDER

print("✓ Configuration ready!")

In [None]:
# Load model
import torch
import pytorch_lightning as pl
from mGPT.models.build_model import build_model
from mGPT.data.build_data import build_data
from mGPT.utils.load_checkpoint import load_pretrained_vae, load_pretrained
from mGPT.utils.logger import create_logger

pl.seed_everything(cfg.SEED_VALUE)
cfg.DATASET.WORD_VERTILIZER_PATH = f'{drive_data}/deps/deps/t2m/glove/'

datamodule = build_data(cfg)
model = build_model(cfg, datamodule)

logger = create_logger(cfg, phase="test")
if cfg.TRAIN.PRETRAINED_VAE:
    load_pretrained_vae(cfg, model, logger)

ckpt_path = f'{drive_data}/experiments/mgpt/SOKE/checkpoints/last.ckpt'
if os.path.exists(ckpt_path):
    cfg.TEST.CHECKPOINTS = ckpt_path
    load_pretrained(cfg, model, logger, phase="test")

model = model.cuda()
model.eval()

mean = datamodule.hparams.mean.cuda()
std = datamodule.hparams.std.cuda()

# Load SMPL-X
from mGPT.utils.human_models import smpl_x, get_coord

print("\n✅ Model loaded and ready!")

## Step 3: Upload Your HTML File

**Upload the `soke_ui.html` file (or any HTML file you want to host)**

In [None]:
from google.colab import files
import shutil

print("📤 Please upload your HTML file...")
uploaded = files.upload()

if uploaded:
    html_filename = list(uploaded.keys())[0]
    
    # Save to static folder
    os.makedirs('static', exist_ok=True)
    shutil.copy(html_filename, 'static/index.html')
    
    print(f"\n✅ HTML file uploaded: {html_filename}")
    print(f"   Saved as: static/index.html")
else:
    print("\n⚠️ No file uploaded. The server will start but won't have a UI.")

## Step 4: Start Server

**This will start the server and give you a URL to share!**

In [None]:
# Create FastAPI app with static file serving
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import time

app = FastAPI(title="SOKE API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Serve static files
if os.path.exists('static'):
    app.mount("/static", StaticFiles(directory="static"), name="static")

class TextRequest(BaseModel):
    text: str
    sign_language: str = "how2sign"
    fps: int = 20
    include_mesh: bool = True

@app.get("/")
async def root():
    if os.path.exists('static/index.html'):
        return FileResponse('static/index.html')
    return HTMLResponse("<h1>SOKE Server Running</h1><p>Upload an HTML file to see the UI.</p>")

@app.get("/health")
async def health():
    return {"status": "healthy", "model_loaded": True, "gpu_available": torch.cuda.is_available()}

@app.get("/languages")
async def languages():
    return {"languages": [
        {"id": "how2sign", "name": "American Sign Language (ASL)", "input_language": "English"},
        {"id": "csl", "name": "Chinese Sign Language (CSL)", "input_language": "Chinese"},
        {"id": "phoenix", "name": "German Sign Language (DGS)", "input_language": "German"}
    ]}

def feats_to_smplx_api(features, mean_tensor, std_tensor):
    features = features * std_tensor + mean_tensor
    T = features.shape[0]
    zero_pose = torch.zeros(T, 36).to(features)
    features_full = torch.cat([zero_pose, features], dim=-1)
    return {
        'root_pose': features_full[:, 0:3].cpu().numpy().tolist(),
        'body_pose': features_full[:, 3:66].cpu().numpy().tolist(),
        'lhand_pose': features_full[:, 66:111].cpu().numpy().tolist(),
        'rhand_pose': features_full[:, 111:156].cpu().numpy().tolist(),
        'jaw_pose': features_full[:, 156:159].cpu().numpy().tolist(),
        'expression': features_full[:, 159:169].cpu().numpy().tolist(),
    }

def generate_mesh_api(smplx_params):
    num_frames = len(smplx_params['body_pose'])
    all_vertices = []
    shape_param = torch.tensor([[-0.07284723, 0.1795129, -0.27608207, 0.135155, 0.10748172,
                                 0.16037364, -0.01616933, -0.03450319, 0.01369138, 0.01108842]]).float()
    for i in range(num_frames):
        root_pose = torch.tensor([smplx_params['root_pose'][i]]).float()
        body_pose = torch.tensor([smplx_params['body_pose'][i]]).float()
        lhand_pose = torch.tensor([smplx_params['lhand_pose'][i]]).float()
        rhand_pose = torch.tensor([smplx_params['rhand_pose'][i]]).float()
        jaw_pose = torch.tensor([smplx_params['jaw_pose'][i]]).float()
        expression = torch.tensor([smplx_params['expression'][i]]).float()
        with torch.no_grad():
            vertices, _ = get_coord(root_pose=root_pose, body_pose=body_pose, lhand_pose=lhand_pose,
                                   rhand_pose=rhand_pose, jaw_pose=jaw_pose, shape=shape_param, expr=expression)
        all_vertices.append(vertices[0].cpu().numpy().tolist())
    return {"vertices": all_vertices, "faces": smpl_x.face.tolist()}

@app.post("/generate")
async def generate(request: TextRequest):
    start_time = time.time()
    try:
        batch = {'text': [request.text], 'length': [0], 'src': [request.sign_language]}
        with torch.no_grad():
            output = model.forward(batch, task="t2m")
        feats = output['feats'][0] if 'feats' in output else None
        if feats is None:
            return {"success": False, "error": "No features generated", "text": request.text, 
                    "num_frames": 0, "fps": request.fps, "generation_time": time.time() - start_time}
        smplx_params = feats_to_smplx_api(feats, mean, std)
        num_frames = len(smplx_params['body_pose'])
        mesh_data = generate_mesh_api(smplx_params) if request.include_mesh else None
        return {
            "success": True, "text": request.text, "num_frames": num_frames, "fps": request.fps,
            "smplx_params": smplx_params, "mesh_data": mesh_data,
            "generation_time": time.time() - start_time
        }
    except Exception as e:
        import traceback
        traceback.print_exc()
        return {"success": False, "error": str(e), "text": request.text, 
                "num_frames": 0, "fps": request.fps, "generation_time": time.time() - start_time}

print("✓ API configured!")

In [None]:
# Start the server
import nest_asyncio
import uvicorn
from threading import Thread
import socket

nest_asyncio.apply()

def find_free_port(start_port=8080):
    for port in range(start_port, start_port + 100):
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(('0.0.0.0', port))
                return port
        except OSError:
            continue
    return None

PORT = find_free_port(8080)

def run_server():
    uvicorn.run(app, host="0.0.0.0", port=PORT, log_level="warning")

server_thread = Thread(target=run_server, daemon=True)
server_thread.start()

import time
time.sleep(3)

# Get Colab URL
from google.colab.output import eval_js
colab_url = eval_js(f'google.colab.kernel.proxyPort({PORT})')

print("\n" + "=" * 70)
print("✅ SERVER IS RUNNING!")
print("=" * 70)
print(f"\n🌐 YOUR URL (share this with anyone):")
print(f"\n   {colab_url}")
print(f"\n" + "=" * 70)
print("\n📌 IMPORTANT:")
print("   - This URL is public and shareable")
print("   - It will stay active as long as this cell is running")
print("   - Don't stop this cell or the server will stop")
print("\n💡 TIP: Open the URL in a new tab to test it!")
print("\n⚠️  Keep this cell running. Run the KEEP ALIVE cell below to prevent timeout.")

## Step 5: Keep Server Running

**Run this cell to keep the server alive (prevents Colab timeout)**

In [None]:
# Keep-alive loop
import time
import requests

print("🔄 Keep-alive active. Press ⏹ to stop.\n")

try:
    while True:
        try:
            response = requests.get(f"http://localhost:{PORT}/health", timeout=10)
            status = response.json()
            print(f"\r[{time.strftime('%H:%M:%S')}] ✓ Running | GPU: {status['gpu_available']}", end="", flush=True)
        except:
            print(f"\r[{time.strftime('%H:%M:%S')}] ⚠ Checking...", end="", flush=True)
        time.sleep(60)
except KeyboardInterrupt:
    print("\n\n✓ Stopped.")