<a href="https://colab.research.google.com/github/SattamAltwaim/SaSOKE/blob/main/notebooks/9_api_server.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/SattamAltwaim/SaSOKE/blob/main/notebooks/9_api_server.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ü§ü SOKE API Server

**Flask API that provides SMPL-X parameters and GLB frames for sign language generation**

### Features
- ‚úÖ REST API endpoint for text-to-sign generation
- ‚úÖ Returns GLB frames (base64 encoded) ready for 3D display
- ‚úÖ CORS enabled for frontend access
- ‚úÖ Works with standalone Apple-style frontend

### Requirements
- **GPU Runtime**: `Runtime ‚Üí Change runtime type ‚Üí GPU`

## Step 1: Setup

In [None]:
# Clone repo and mount Drive
import os
if not os.path.exists('/content/SaSOKE'):
    !git clone https://github.com/SattamAltwaim/SaSOKE.git
%cd /content/SaSOKE

from google.colab import drive
drive.mount('/content/drive')

drive_data = '/content/drive/MyDrive/GraduationProject/CodeFiles/SaSOKE'
print("‚úì Ready!")

In [None]:
# Install dependencies
!pip install -q pytorch_lightning torchmetrics omegaconf shortuuid transformers diffusers einops wandb rich matplotlib
!pip install -q smplx h5py scikit-image spacy ftfy more-itertools natsort tensorboard sentencepiece
!pip install -q flask flask-cors trimesh
print("‚úì Dependencies installed!")

## Step 2: Load Model

In [None]:
# Setup paths
import sys
import yaml
from mGPT.config import parse_args

deps_links = {
    'deps/smpl_models': f'{drive_data}/deps/smpl_models',
    'deps/mbart-h2s-csl-phoenix': f'{drive_data}/deps/mbart-h2s-csl-phoenix',
}

for expected_path, actual_path in deps_links.items():
    if not os.path.exists(expected_path):
        os.makedirs(os.path.dirname(expected_path), exist_ok=True)
        os.symlink(actual_path, expected_path)

with open('configs/soke.yaml', 'r') as f:
    config = yaml.safe_load(f)

config['ACCELERATOR'] = 'gpu'
config['DEVICE'] = [0]
config['DATASET']['H2S']['ROOT'] = f'{drive_data}/data/How2Sign'
config['DATASET']['H2S']['MEAN_PATH'] = f'{drive_data}/smpl-x/mean.pt'
config['DATASET']['H2S']['STD_PATH'] = f'{drive_data}/smpl-x/std.pt'
config['TRAIN']['PRETRAINED_VAE'] = f'{drive_data}/checkpoints/vae/tokenizer.ckpt'

with open('configs/web_inference.yaml', 'w') as f:
    yaml.dump(config, f)

with open('configs/assets.yaml', 'r') as f:
    assets = yaml.safe_load(f)

assets['RENDER']['SMPL_MODEL_PATH'] = 'deps/smpl_models/smpl'
assets['RENDER']['MODEL_PATH'] = 'deps/smpl_models'
assets['METRIC']['TM2T']['t2m_path'] = f'{drive_data}/deps/deps/t2m/t2m/'

with open('configs/assets_web.yaml', 'w') as f:
    yaml.dump(assets, f)

sys.argv = ['', '--cfg', 'configs/web_inference.yaml', '--cfg_assets', 'configs/assets_web.yaml']
cfg = parse_args(phase="test")
cfg.FOLDER = cfg.TEST.FOLDER
print("‚úì Configuration ready!")

In [None]:
# Load model
import torch
import pytorch_lightning as pl
from mGPT.models.build_model import build_model
from mGPT.data.build_data import build_data
from mGPT.utils.load_checkpoint import load_pretrained_vae, load_pretrained
from mGPT.utils.logger import create_logger
from mGPT.utils.human_models import smpl_x, get_coord

pl.seed_everything(cfg.SEED_VALUE)
cfg.DATASET.WORD_VERTILIZER_PATH = f'{drive_data}/deps/deps/t2m/glove/'

datamodule = build_data(cfg)
model = build_model(cfg, datamodule)

logger = create_logger(cfg, phase="test")
if cfg.TRAIN.PRETRAINED_VAE:
    load_pretrained_vae(cfg, model, logger)

ckpt_path = f'{drive_data}/checkpoints/mGPT/last.ckpt'
if os.path.exists(ckpt_path):
    cfg.TEST.CHECKPOINTS = ckpt_path

    # Get model state dict before loading
    model_state_before = {k: v.clone() for k, v in model.state_dict().items()}

    # Load checkpoint
    load_pretrained(cfg, model, logger, phase="test")

    # Verify weights were loaded by checking if they changed
    checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    checkpoint_keys = set(checkpoint["state_dict"].keys())
    model_keys = set(model.state_dict().keys())

    # Count matching keys
    matching_keys = checkpoint_keys.intersection(model_keys)
    loaded_count = 0
    for key in matching_keys:
        if key in model_state_before:
            if not torch.equal(model.state_dict()[key], model_state_before[key]):
                loaded_count += 1

    print(f"\n‚úì Loaded model checkpoint from {ckpt_path}")
    print(f"  - Checkpoint keys: {len(checkpoint_keys)}")
    print(f"  - Model keys: {len(model_keys)}")
    print(f"  - Matching keys: {len(matching_keys)}")
    print(f"  - Weights updated: {loaded_count}")
    print(f"  - Note: 'Weights not loaded' messages above are normal - they show which weights don't match the current model structure.")
else:
    print(f"‚ö†Ô∏è  Checkpoint not found at {ckpt_path}")
    print(f"   Please ensure the checkpoint exists at this path.")

model = model.cuda()
model.eval()

mean = datamodule.hparams.mean.cuda()
std = datamodule.hparams.std.cuda()

print("\n‚úÖ Model loaded and ready!")

In [None]:
# Patch model.forward to handle motion concatenation with proper padding
import torch.nn.functional as F

original_forward = model.forward

def patched_forward(self, batch, task="t2m"):
    texts = batch["text"]
    lengths_ref = batch["length"]
    src = batch.get("src", ["how2sign"] * len(texts))
    name = batch.get("name", [None] * len(texts))

    gen_output = self.lm.generate_direct(texts, do_sample=True, src=src, name=name, max_length=400, num_beams=1)
    outputs = gen_output['outputs_tokens']
    outputs_hand = gen_output.get('outputs_tokens_hand', None)
    outputs_rhand = gen_output.get('outputs_tokens_rhand', None)

    feats_rst_lst = []
    lengths = []
    max_len = 0

    for i in range(len(texts)):
        # Get device from model parameters
        device = next(self.parameters()).device

        # Ensure outputs[i] is a tensor and clamp
        if isinstance(outputs[i], list):
            outputs[i] = torch.tensor(outputs[i], dtype=torch.long, device=device)
        else:
            outputs[i] = outputs[i].to(device)
        outputs[i] = torch.clamp(outputs[i], 0, self.vae.code_num - 1)

        # Decode body motion
        if len(outputs[i]) > 1:
            motion = self.vae.decode(outputs[i])
        else:
            motion = torch.zeros(1, 1, self.vae.nfeats, device=device)

        # Decode and pad hand motions to match body motion length
        if outputs_hand is not None and hasattr(self, 'hand_vae') and self.hand_vae is not None:
            if isinstance(outputs_hand[i], list):
                outputs_hand[i] = torch.tensor(outputs_hand[i], dtype=torch.long, device=device)
            else:
                outputs_hand[i] = outputs_hand[i].to(device)
            outputs_hand[i] = torch.clamp(outputs_hand[i], 0, self.hand_vae.code_num - 1)

            if len(outputs_hand[i]) > 1:
                motion_lhand = self.hand_vae.decode(outputs_hand[i])
                # Ensure both have batch dimension
                if motion_lhand.dim() == 2:
                    motion_lhand = motion_lhand.unsqueeze(0)
                if motion.dim() == 2:
                    motion = motion.unsqueeze(0)
                # Pad to match body motion temporal length
                if motion_lhand.shape[1] != motion.shape[1]:
                    if motion_lhand.shape[1] < motion.shape[1]:
                        # Pad: (left, right, top, bottom) for 2D, or (left, right, front, back, top, bottom) for 3D
                        # For (B, T, C): pad (0, 0, 0, 0, 0, T_diff)
                        pad_amount = motion.shape[1] - motion_lhand.shape[1]
                        motion_lhand = F.pad(motion_lhand, (0, 0, 0, pad_amount), mode='replicate')
                    else:
                        pad_amount = motion_lhand.shape[1] - motion.shape[1]
                        motion = F.pad(motion, (0, 0, 0, pad_amount), mode='replicate')
                motion = torch.cat([motion, motion_lhand], dim=-1)

        if outputs_rhand is not None and hasattr(self, 'rhand_vae') and self.rhand_vae is not None:
            if isinstance(outputs_rhand[i], list):
                outputs_rhand[i] = torch.tensor(outputs_rhand[i], dtype=torch.long, device=device)
            else:
                outputs_rhand[i] = outputs_rhand[i].to(device)
            outputs_rhand[i] = torch.clamp(outputs_rhand[i], 0, self.rhand_vae.code_num - 1)

            if len(outputs_rhand[i]) > 1:
                motion_rhand = self.rhand_vae.decode(outputs_rhand[i])
                # Ensure both have batch dimension
                if motion_rhand.dim() == 2:
                    motion_rhand = motion_rhand.unsqueeze(0)
                if motion.dim() == 2:
                    motion = motion.unsqueeze(0)
                # Pad to match current motion temporal length
                if motion_rhand.shape[1] != motion.shape[1]:
                    if motion_rhand.shape[1] < motion.shape[1]:
                        pad_amount = motion.shape[1] - motion_rhand.shape[1]
                        motion_rhand = F.pad(motion_rhand, (0, 0, 0, pad_amount), mode='replicate')
                    else:
                        pad_amount = motion_rhand.shape[1] - motion.shape[1]
                        motion = F.pad(motion, (0, 0, 0, pad_amount), mode='replicate')
                motion = torch.cat([motion, motion_rhand], dim=-1)

        lengths.append(motion.shape[1])
        if motion.shape[1] > max_len:
            max_len = motion.shape[1]
        feats_rst_lst.append(motion)

    # Pad and concatenate all motions
    device = next(self.parameters()).device
    feats_rst = torch.zeros((len(feats_rst_lst), max_len, motion.shape[-1])).to(device)
    for i in range(len(feats_rst_lst)):
        feats_rst[i, :feats_rst_lst[i].shape[1], ...] = feats_rst_lst[i]

    # Recover joints for evaluation
    joints_rst = self.feats2joints(feats_rst)

    return {"feats": feats_rst, "joints": joints_rst, "length": lengths, "texts": gen_output.get('cleaned_text', texts)}

# Replace forward method
model.forward = patched_forward.__get__(model, type(model))
print("‚úì Model forward method patched to handle motion padding")


## Step 3: Create API Functions

In [None]:
import numpy as np
import time
import json
import trimesh
import base64
from io import BytesIO

def feats_to_smplx_dict(features, mean_tensor, std_tensor):
    """Convert 133-dim features to SMPL-X parameters dictionary"""
    features = features * std_tensor + mean_tensor
    T = features.shape[0]
    zero_pose = torch.zeros(T, 36).to(features)
    features_full = torch.cat([zero_pose, features], dim=-1)  # (T, 169)

    # Extract SMPL-X parameters as dictionary
    smplx_params = {
        'root_pose': features_full[:, 0:3].cpu().numpy().tolist(),
        'body_pose': features_full[:, 3:66].cpu().numpy().tolist(),
        'lhand_pose': features_full[:, 66:111].cpu().numpy().tolist(),
        'rhand_pose': features_full[:, 111:156].cpu().numpy().tolist(),
        'jaw_pose': features_full[:, 156:159].cpu().numpy().tolist(),
        'expression': features_full[:, 159:169].cpu().numpy().tolist(),
    }
    return smplx_params

def smplx_params_to_glb_frames(smplx_params_dict, num_frames):
    """Convert SMPL-X parameters to GLB frames (base64 encoded)"""
    # Create shape parameter
    shape_param = torch.tensor([[-0.07284723, 0.1795129, -0.27608207, 0.135155, 0.10748172,
                                 0.16037364, -0.01616933, -0.03450319, 0.01369138, 0.01108842]],
                               device=mean.device, dtype=torch.float32)

    glb_frames = []

    for i in range(num_frames):
        # Convert lists back to tensors
        root_pose = torch.tensor([smplx_params_dict['root_pose'][i]], dtype=torch.float32, device=mean.device)
        body_pose = torch.tensor([smplx_params_dict['body_pose'][i]], dtype=torch.float32, device=mean.device)
        lhand_pose = torch.tensor([smplx_params_dict['lhand_pose'][i]], dtype=torch.float32, device=mean.device)
        rhand_pose = torch.tensor([smplx_params_dict['rhand_pose'][i]], dtype=torch.float32, device=mean.device)
        jaw_pose = torch.tensor([smplx_params_dict['jaw_pose'][i]], dtype=torch.float32, device=mean.device)
        expression = torch.tensor([smplx_params_dict['expression'][i]], dtype=torch.float32, device=mean.device)

        # Generate mesh
        with torch.no_grad():
            vertices, _ = get_coord(
                root_pose=root_pose,
                body_pose=body_pose,
                lhand_pose=lhand_pose,
                rhand_pose=rhand_pose,
                jaw_pose=jaw_pose,
                shape=shape_param,
                expr=expression
            )

        # Create trimesh with WHITE color
        mesh = trimesh.Trimesh(
            vertices=vertices[0].cpu().numpy(),
            faces=smpl_x.face,
            process=False
        )
        mesh.visual.vertex_colors = np.array([[255, 255, 255, 255]] * len(mesh.vertices))

        # Export to GLB and encode to base64
        glb_buffer = BytesIO()
        mesh.export(file_obj=glb_buffer, file_type='glb')
        glb_data = base64.b64encode(glb_buffer.getvalue()).decode('utf-8')
        glb_frames.append(glb_data)

    return glb_frames

def generate_smplx_params(text, lang_token):
    """Generate SMPL-X parameters from text and language token"""
    if not text.strip():
        return None, "‚ö†Ô∏è Please enter some text"

    try:
        batch = {'text': [text], 'length': [0], 'src': [lang_token]}

        with torch.no_grad():
            output = model.forward(batch, task="t2m")

        feats = output['feats'][0] if 'feats' in output else None

        if feats is None:
            return None, "‚ùå Generation failed - no features produced"

        smplx_params = feats_to_smplx_dict(feats, mean, std)

        return smplx_params, None

    except Exception as e:
        import traceback
        traceback.print_exc()
        return None, f"‚ùå Error: {str(e)}"

print("‚úì API functions ready!")

## Step 4: Create Flask API Server

In [None]:
# Create Flask API
from flask import Flask, request, jsonify
from flask_cors import CORS
from threading import Thread
import socket

app = Flask(__name__)
CORS(app)  # Enable CORS for frontend access

@app.route('/api/generate', methods=['POST'])
def api_generate():
    """API endpoint that takes lang_token and text, returns GLB frames"""
    try:
        data = request.get_json()
        text = data.get('text', '').strip()
        lang_token = data.get('lang_token', 'how2sign')

        if not text:
            return jsonify({'error': 'Text is required'}), 400

        # Generate SMPL-X parameters
        smplx_params, error = generate_smplx_params(text, lang_token)

        if error:
            return jsonify({'error': error}), 500

        num_frames = len(smplx_params['body_pose'])

        # Convert to GLB frames
        glb_frames = smplx_params_to_glb_frames(smplx_params, num_frames)

        # Return GLB frames (base64 encoded)
        return jsonify({
            'success': True,
            'glb_frames': glb_frames,
            'num_frames': num_frames,
            'text': text,
            'lang_token': lang_token
        })

    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500

@app.route('/api/health', methods=['GET'])
def health():
    """Health check endpoint"""
    return jsonify({'status': 'ok', 'message': 'API is running'})

def get_free_port():
    """Get a free port"""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        s.listen(1)
        port = s.getsockname()[1]
    return port

port = get_free_port()

def run_flask():
    app.run(host='0.0.0.0', port=port, debug=False, use_reloader=False)

flask_thread = Thread(target=run_flask, daemon=True)
flask_thread.start()

# Wait a moment for Flask to start
time.sleep(2)

print(f"\n‚úÖ Flask API started!")
print(f"üì° API URL: http://localhost:{port}/api/generate")
print(f"üíö Health check: http://localhost:{port}/api/health")
print(f"\nüìù API Usage:")
print(f"  POST /api/generate")
print(f"  Body: {{'text': 'Hello world', 'lang_token': 'how2sign'}}")
print(f"  Response: {{'success': True, 'glb_frames': [...], 'num_frames': N}}")
print(f"\n‚ö†Ô∏è  To expose this API publicly, use ngrok:")
print(f"  !pip install pyngrok")
print(f"  from pyngrok import ngrok")
print(f"  public_url = ngrok.connect({port})")
print(f"  print(f'Public URL: {{public_url}}')")

## Step 5: Expose API Publicly (Optional)

In [None]:
# Expose API using ngrok (for frontend access)
!pip install -q pyngrok

from pyngrok import ngrok
import getpass

# Setup ngrok authentication
print("\nüîê ngrok Authentication Setup")
print("=" * 60)
print("To use ngrok, you need an authtoken.")
print("1. Sign up at: https://dashboard.ngrok.com/signup")
print("2. Get your authtoken: https://dashboard.ngrok.com/get-started/your-authtoken")
print("=" * 60)

# Option 1: Use environment variable (if set)
ngrok_token = os.environ.get('NGROK_AUTHTOKEN', None)

# Option 2: Prompt user to enter token
if not ngrok_token:
    print("\nEnter your ngrok authtoken (or press Enter to skip):")
    user_token = getpass.getpass("ngrok authtoken: ").strip()
    if user_token:
        ngrok_token = user_token
        os.environ['NGROK_AUTHTOKEN'] = ngrok_token

if ngrok_token:
    ngrok.set_auth_token(ngrok_token)
    print("\n‚úì ngrok authtoken configured")

    # Create public tunnel
    try:
        public_url = ngrok.connect(port)
        print(f"\nüåê Public API URL: {public_url}")
        print(f"\nüìã Update your frontend with this URL:")
        print(f"   const API_URL = '{public_url}/api/generate';")
        print(f"\n‚ö†Ô∏è  This URL will expire when the Colab session ends!")
    except Exception as e:
        print(f"\n‚ùå Error creating ngrok tunnel: {e}")
        print("\nYou can still use the API locally at:")
        print(f"   http://localhost:{port}/api/generate")
else:
    print("\n‚ö†Ô∏è  No ngrok authtoken provided. Skipping public tunnel.")
    print("\nYou can still use the API locally at:")
    print(f"   http://localhost:{port}/api/generate")
    print("\nTo expose publicly, run this cell again with your authtoken.")