<a href="https://colab.research.google.com/github/SattamAltwaim/SaSOKE/blob/main/notebooks/8_gradio_web_ui.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ü§ü SOKE Gradio Web UI

**Simple Web UI that ACTUALLY WORKS in Colab!**

Gradio is designed specifically for Colab - no proxy issues!

### Features
- ‚úÖ Works 100% in Colab (no blank pages!)
- ‚úÖ Beautiful UI
- ‚úÖ Public shareable link
- ‚úÖ 3D visualization

### Requirements
- **GPU Runtime**: `Runtime ‚Üí Change runtime type ‚Üí GPU`

## Step 1: Setup

In [None]:
# Clone repo and mount Drive
import os
if not os.path.exists('/content/SaSOKE'):
    !git clone https://github.com/SattamAltwaim/SaSOKE.git
%cd /content/SaSOKE

from google.colab import drive
drive.mount('/content/drive')

drive_data = '/content/drive/MyDrive/GraduationProject/CodeFiles/SaSOKE'
print("‚úì Ready!")

In [None]:
# Install dependencies
!pip install -q pytorch_lightning torchmetrics omegaconf shortuuid transformers diffusers einops wandb rich matplotlib
!pip install -q smplx h5py scikit-image spacy ftfy more-itertools natsort tensorboard sentencepiece
!pip install -q gradio
print("‚úì Dependencies installed!")

## Step 3.5: Create Flask API Endpoint

## Step 3.5: Create Flask API Endpoint

In [None]:
# Create Flask API
from flask import Flask, request, jsonify
from flask_cors import CORS
from threading import Thread
import socket

app = Flask(__name__)
CORS(app)  # Enable CORS for frontend access

@app.route('/api/generate', methods=['POST'])
def api_generate():
    """API endpoint that takes lang_token and text, returns SMPL-X parameters"""
    try:
        data = request.get_json()
        text = data.get('text', '').strip()
        lang_token = data.get('lang_token', 'how2sign')
        
        if not text:
            return jsonify({'error': 'Text is required'}), 400
        
        # Generate SMPL-X parameters
        smplx_params, error = generate_smplx_params(text, lang_token)
        
        if error:
            return jsonify({'error': error}), 500
        
        # Return SMPL-X parameters
        return jsonify({
            'success': True,
            'smplx_params': smplx_params,
            'num_frames': len(smplx_params['body_pose']),
            'text': text,
            'lang_token': lang_token
        })
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

@app.route('/api/health', methods=['GET'])
def health():
    """Health check endpoint"""
    return jsonify({'status': 'ok', 'message': 'API is running'})

def get_free_port():
    """Get a free port"""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        s.listen(1)
        port = s.getsockname()[1]
    return port

port = get_free_port()

def run_flask():
    app.run(host='0.0.0.0', port=port, debug=False, use_reloader=False)

flask_thread = Thread(target=run_flask, daemon=True)
flask_thread.start()

# Wait a moment for Flask to start
time.sleep(2)

print(f"\\n‚úÖ Flask API started!")
print(f"üì° API URL: http://localhost:{port}/api/generate")
print(f"üíö Health check: http://localhost:{port}/api/health")
print(f"\\nüìù API Usage:")
print(f"  POST /api/generate")
print(f"  Body: {{'text': 'Hello world', 'lang_token': 'how2sign'}}")
print(f"  Response: {{'success': True, 'smplx_params': {{...}}, 'num_frames': N}}")

In [None]:
# Create Flask API
app = Flask(__name__)
CORS(app)  # Enable CORS for frontend access

@app.route('/api/generate', methods=['POST'])
def api_generate():
    """API endpoint that takes lang_token and text, returns SMPL-X parameters"""
    try:
        data = request.get_json()
        text = data.get('text', '').strip()
        lang_token = data.get('lang_token', 'how2sign')
        
        if not text:
            return jsonify({'error': 'Text is required'}), 400
        
        # Generate SMPL-X parameters
        smplx_params, error = generate_smplx_params(text, lang_token)
        
        if error:
            return jsonify({'error': error}), 500
        
        # Return SMPL-X parameters
        return jsonify({
            'success': True,
            'smplx_params': smplx_params,
            'num_frames': len(smplx_params['body_pose']),
            'text': text,
            'lang_token': lang_token
        })
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

@app.route('/api/health', methods=['GET'])
def health():
    """Health check endpoint"""
    return jsonify({'status': 'ok', 'message': 'API is running'})

# Run Flask in background thread
from threading import Thread
import socket

def get_free_port():
    """Get a free port"""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        s.listen(1)
        port = s.getsockname()[1]
    return port

port = get_free_port()

def run_flask():
    app.run(host='0.0.0.0', port=port, debug=False, use_reloader=False)

flask_thread = Thread(target=run_flask, daemon=True)
flask_thread.start()

# Wait a moment for Flask to start
import time
time.sleep(2)

print(f"\\n‚úÖ Flask API started!")
print(f"üì° API URL: http://localhost:{port}/api/generate")
print(f"üíö Health check: http://localhost:{port}/api/health")
print(f"\\nüìù API Usage:")
print(f"  POST /api/generate")
print(f"  Body: {{'text': 'Hello world', 'lang_token': 'how2sign'}}")
print(f"  Response: {{'success': True, 'smplx_params': {{...}}, 'num_frames': N}}")

## Step 2: Load Model

In [None]:
# Setup paths
import sys
import yaml
from mGPT.config import parse_args

deps_links = {
    'deps/smpl_models': f'{drive_data}/deps/smpl_models',
    'deps/mbart-h2s-csl-phoenix': f'{drive_data}/deps/mbart-h2s-csl-phoenix',
}

for expected_path, actual_path in deps_links.items():
    if not os.path.exists(expected_path):
        os.makedirs(os.path.dirname(expected_path), exist_ok=True)
        os.symlink(actual_path, expected_path)

with open('configs/soke.yaml', 'r') as f:
    config = yaml.safe_load(f)

config['ACCELERATOR'] = 'gpu'
config['DEVICE'] = [0]
config['DATASET']['H2S']['ROOT'] = f'{drive_data}/data/How2Sign'
config['DATASET']['H2S']['MEAN_PATH'] = f'{drive_data}/smpl-x/mean.pt'
config['DATASET']['H2S']['STD_PATH'] = f'{drive_data}/smpl-x/std.pt'
config['TRAIN']['PRETRAINED_VAE'] = f'{drive_data}/checkpoints/vae/tokenizer.ckpt'

with open('configs/web_inference.yaml', 'w') as f:
    yaml.dump(config, f)

with open('configs/assets.yaml', 'r') as f:
    assets = yaml.safe_load(f)

assets['RENDER']['SMPL_MODEL_PATH'] = 'deps/smpl_models/smpl'
assets['RENDER']['MODEL_PATH'] = 'deps/smpl_models'
assets['METRIC']['TM2T']['t2m_path'] = f'{drive_data}/deps/deps/t2m/t2m/'

with open('configs/assets_web.yaml', 'w') as f:
    yaml.dump(assets, f)

sys.argv = ['', '--cfg', 'configs/web_inference.yaml', '--cfg_assets', 'configs/assets_web.yaml']
cfg = parse_args(phase="test")
cfg.FOLDER = cfg.TEST.FOLDER
print("‚úì Configuration ready!")

In [None]:
# Load model
import torch
import pytorch_lightning as pl
from mGPT.models.build_model import build_model
from mGPT.data.build_data import build_data
from mGPT.utils.load_checkpoint import load_pretrained_vae, load_pretrained
from mGPT.utils.logger import create_logger
from mGPT.utils.human_models import smpl_x, get_coord

pl.seed_everything(cfg.SEED_VALUE)
cfg.DATASET.WORD_VERTILIZER_PATH = f'{drive_data}/deps/deps/t2m/glove/'

datamodule = build_data(cfg)
model = build_model(cfg, datamodule)

logger = create_logger(cfg, phase="test")
if cfg.TRAIN.PRETRAINED_VAE:
    load_pretrained_vae(cfg, model, logger)

ckpt_path = f'{drive_data}/experiments/mgpt/SOKE/checkpoints/last.ckpt'
if os.path.exists(ckpt_path):
    cfg.TEST.CHECKPOINTS = ckpt_path
    load_pretrained(cfg, model, logger, phase="test")

model = model.cuda()
model.eval()

mean = datamodule.hparams.mean.cuda()
std = datamodule.hparams.std.cuda()

print("\n‚úÖ Model loaded and ready!")

## Step 3: Create Gradio Interface

In [None]:
import gradio as gr
import numpy as np
import time
import json

def feats_to_smplx(features, mean_tensor, std_tensor):
    features = features * std_tensor + mean_tensor
    T = features.shape[0]
    zero_pose = torch.zeros(T, 36).to(features)
    return torch.cat([zero_pose, features], dim=-1)

def generate_mesh(smplx_params_full):
    num_frames = smplx_params_full.shape[0]
    all_vertices = []
    
    # Create shape_param on the SAME DEVICE as smplx_params_full (GPU)
    shape_param = torch.tensor([[-0.07284723, 0.1795129, -0.27608207, 0.135155, 0.10748172,
                                 0.16037364, -0.01616933, -0.03450319, 0.01369138, 0.01108842]],
                               device=smplx_params_full.device, dtype=torch.float32)
    
    for i in range(num_frames):
        frame_params = smplx_params_full[i:i+1]
        with torch.no_grad():
            vertices, _ = get_coord(
                root_pose=frame_params[:, 0:3],
                body_pose=frame_params[:, 3:66],
                lhand_pose=frame_params[:, 66:111],
                rhand_pose=frame_params[:, 111:156],
                jaw_pose=frame_params[:, 156:159],
                shape=shape_param,
                expr=frame_params[:, 159:169]
            )
        all_vertices.append(vertices[0].cpu().numpy())
    
    return np.array(all_vertices), smpl_x.face

def text_to_sign(text, sign_language):
    if not text.strip():
        return None, "‚ö†Ô∏è Please enter some text", ""
    
    start_time = time.time()
    
    try:
        batch = {'text': [text], 'length': [0], 'src': [sign_language]}
        
        with torch.no_grad():
            output = model.forward(batch, task="t2m")
        
        feats = output['feats'][0] if 'feats' in output else None
        
        if feats is None:
            return None, "‚ùå Generation failed - no features produced", ""
        
        smplx_params = feats_to_smplx(feats, mean, std)
        vertices, faces = generate_mesh(smplx_params)
        
        num_frames = vertices.shape[0]
        gen_time = time.time() - start_time
        
        # Create HTML viewer
        html = create_3d_viewer(vertices, faces)
        
        status = f"‚úÖ Generated {num_frames} frames in {gen_time:.2f}s"
        info = f"üìä Frames: {num_frames} | Duration: {num_frames/20:.2f}s | FPS: 20"
        
        return html, status, info
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        return None, f"‚ùå Error: {str(e)}", ""

def create_3d_viewer(vertices, faces):
    """Create an interactive 3D viewer HTML"""
    
    # Convert to JSON
    vertices_json = json.dumps(vertices.tolist())
    faces_json = json.dumps(faces.tolist())
    
    html = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
        <script src="https://cdn.jsdelivr.net/npm/three@0.128.0/examples/js/controls/OrbitControls.js"></script>
        <style>
            body {{ margin: 0; overflow: hidden; background: #1a1a24; }}
            #canvas {{ width: 100%; height: 600px; }}
            .controls {{
                position: absolute;
                bottom: 20px;
                left: 50%;
                transform: translateX(-50%);
                background: rgba(0,0,0,0.7);
                padding: 15px;
                border-radius: 10px;
                display: flex;
                gap: 10px;
                align-items: center;
            }}
            button {{
                padding: 10px 20px;
                border: none;
                border-radius: 5px;
                background: #6366f1;
                color: white;
                cursor: pointer;
                font-size: 16px;
            }}
            button:hover {{ background: #4f46e5; }}
            button:disabled {{ opacity: 0.5; cursor: not-allowed; }}
            .info {{ color: white; font-family: monospace; font-size: 14px; }}
        </style>
    </head>
    <body>
        <div id="canvas"></div>
        <div class="controls">
            <button id="play">‚ñ∂ Play</button>
            <button id="pause" disabled>‚è∏ Pause</button>
            <button id="reset">‚Ü∫ Reset</button>
            <span class="info" id="frame-info">Frame: 1 / {vertices.shape[0]}</span>
            <select id="speed">
                <option value="0.5">0.5x</option>
                <option value="1" selected>1x</option>
                <option value="2">2x</option>
            </select>
        </div>
        <script>
            const vertices = {vertices_json};
            const faces = {faces_json};
            const FPS = 20;
            let currentFrame = 0;
            let isPlaying = false;
            let playbackSpeed = 1;
            let animationId = null;
            
            // Setup Three.js
            const scene = new THREE.Scene();
            scene.background = new THREE.Color(0x1a1a24);
            
            const container = document.getElementById('canvas');
            const camera = new THREE.PerspectiveCamera(45, container.clientWidth / 600, 0.1, 100);
            camera.position.set(0, 0.5, 2.5);
            
            const renderer = new THREE.WebGLRenderer({{ antialias: true }});
            renderer.setSize(container.clientWidth, 600);
            container.appendChild(renderer.domElement);
            
            const controls = new THREE.OrbitControls(camera, renderer.domElement);
            controls.target.set(0, 0.5, 0);
            
            scene.add(new THREE.AmbientLight(0xffffff, 0.5));
            const light = new THREE.DirectionalLight(0xffffff, 0.8);
            light.position.set(2, 3, 2);
            scene.add(light);
            
            const grid = new THREE.GridHelper(4, 20, 0x333344, 0x222233);
            grid.position.y = -0.5;
            scene.add(grid);
            
            // Create mesh
            const geometry = new THREE.BufferGeometry();
            geometry.setAttribute('position', new THREE.BufferAttribute(new Float32Array(vertices[0].flat()), 3));
            geometry.setIndex(new THREE.BufferAttribute(new Uint32Array(faces.flat()), 1));
            geometry.computeVertexNormals();
            
            const material = new THREE.MeshPhongMaterial({{
                color: 0x6366f1,
                shininess: 30,
                side: THREE.DoubleSide
            }});
            
            const mesh = new THREE.Mesh(geometry, material);
            scene.add(mesh);
            
            function updateMesh(frame) {{
                geometry.setAttribute('position', new THREE.BufferAttribute(new Float32Array(vertices[frame].flat()), 3));
                geometry.computeVertexNormals();
                document.getElementById('frame-info').textContent = `Frame: ${{frame + 1}} / ${{vertices.length}}`;
            }}
            
            function animate() {{
                requestAnimationFrame(animate);
                controls.update();
                renderer.render(scene, camera);
            }}
            animate();
            
            function play() {{
                isPlaying = true;
                document.getElementById('play').disabled = true;
                document.getElementById('pause').disabled = false;
                
                let lastTime = performance.now();
                function step(time) {{
                    if (!isPlaying) return;
                    if (time - lastTime >= 1000 / (FPS * playbackSpeed)) {{
                        currentFrame = (currentFrame + 1) % vertices.length;
                        updateMesh(currentFrame);
                        lastTime = time;
                    }}
                    animationId = requestAnimationFrame(step);
                }}
                animationId = requestAnimationFrame(step);
            }}
            
            function pause() {{
                isPlaying = false;
                document.getElementById('play').disabled = false;
                document.getElementById('pause').disabled = true;
                if (animationId) cancelAnimationFrame(animationId);
            }}
            
            function reset() {{
                pause();
                currentFrame = 0;
                updateMesh(0);
            }}
            
            document.getElementById('play').onclick = play;
            document.getElementById('pause').onclick = pause;
            document.getElementById('reset').onclick = reset;
            document.getElementById('speed').onchange = (e) => {{ playbackSpeed = parseFloat(e.target.value); }};
            
            window.addEventListener('resize', () => {{
                camera.aspect = container.clientWidth / 600;
                camera.updateProjectionMatrix();
                renderer.setSize(container.clientWidth, 600);
            }});
        </script>
    </body>
    </html>
    """
    return html

print("‚úì Functions ready!")

## Step 4: Launch Interface

In [None]:
# Create Gradio interface
placeholder_html = """
<div style='width:100%; height:600px; background:#1a1a24; border-radius:10px; display:flex; align-items:center; justify-content:center; flex-direction:column; color:#94a3b8;'>
    <div style='font-size:48px; margin-bottom:20px;'>üßç</div>
    <div style='font-size:18px;'>Enter text and click Generate</div>
    <div style='font-size:14px; margin-top:10px; opacity:0.7;'>The 3D animation will appear here</div>
</div>
"""

with gr.Blocks(theme=gr.themes.Soft(), title="SOKE - Text to Sign Language") as demo:
    gr.Markdown("""
    # ü§ü SOKE - Text to Sign Language
    
    **Generate sign language animations from text in real-time!**
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            text_input = gr.Textbox(
                label="Text to Translate",
                placeholder="Type your message here...",
                lines=3,
                value="Hello, how are you today?"
            )
            
            language_select = gr.Dropdown(
                label="Target Sign Language",
                choices=[
                    ("üá∫üá∏ American Sign Language (ASL)", "how2sign"),
                    ("üá®üá≥ Chinese Sign Language (CSL)", "csl"),
                    ("üá©üá™ German Sign Language (DGS)", "phoenix")
                ],
                value="how2sign"
            )
            
            generate_btn = gr.Button("üöÄ Generate Sign Language", variant="primary", size="lg")
            
            status_text = gr.Markdown("")
            info_text = gr.Markdown("")
            
            gr.Markdown("""
            ### üí° Try these examples:
            - Hello, how are you?
            - Thank you for your help
            - Nice to meet you!
            - What is your name?
            """)
        
        with gr.Column(scale=2):
            gr.Markdown("### üëÅÔ∏è 3D Visualization")
            # Start with placeholder HTML so column is visible
            viewer = gr.HTML(value=placeholder_html, label="", elem_id="viewer-container")
    
    generate_btn.click(
        fn=text_to_sign,
        inputs=[text_input, language_select],
        outputs=[viewer, status_text, info_text]
    )

# Launch with share=True for public URL
print("\n" + "="*70)
print("üöÄ LAUNCHING GRADIO INTERFACE...")
print("="*70)
print("\n‚è≥ Please wait while Gradio starts...")
print("\nüí° A public URL will be generated that you can share with anyone!")
print("\n‚ú® You should see TWO columns: Input (left) and 3D Viewer (right)")
print("\n" + "="*70)

demo.launch(share=True, debug=False)