<a href="https://colab.research.google.com/github/SattamAltwaim/SaSOKE/blob/main/notebooks/5_text_to_sign_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Text-to-Sign Language Inference
Generate sign language from custom text input using SOKE model.


## 1. Setup Environment


In [None]:
# Clone repo if not present
import os
if not os.path.exists('/content/SaSOKE'):
    !git clone https://github.com/SattamAltwaim/SaSOKE.git
%cd /content/SaSOKE

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

drive_data = '/content/drive/MyDrive/GraduationProject/CodeFiles/SaSOKE'
print("✓ Code:", os.getcwd())
print("✓ Data:", drive_data)


In [None]:
# Install dependencies (if needed)
%pip install -q pytorch_lightning torchmetrics omegaconf shortuuid transformers einops rich matplotlib sentencepiece


## 2. Verify GPU


In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ No GPU detected! Go to Runtime → Change runtime type → GPU")


## 3. Enter Your Custom Text


In [None]:
# Enter your text here - you can modify this!
custom_texts = [
    "Hello, how are you today?",
    "Thank you for your help.",
    "I am learning sign language."
]

# Or enter a single text
# custom_texts = ["Your custom text here"]

print("Input texts:")
for i, text in enumerate(custom_texts, 1):
    print(f"{i}. {text}")


## 4. Run Inference on Your Text


In [None]:
import torch
import pytorch_lightning as pl
from mGPT.config import parse_args
from mGPT.models.build_model import build_model
from mGPT.data.build_data import build_data
from mGPT.utils.load_checkpoint import load_pretrained_vae, load_pretrained
from mGPT.utils.logger import create_logger
import pickle
from pathlib import Path
import yaml

# Configure paths
with open('configs/soke.yaml', 'r') as f:
    config = yaml.safe_load(f)

config['ACCELERATOR'] = 'gpu'
config['DEVICE'] = [0]
config['DATASET']['H2S']['ROOT'] = f'{drive_data}/data/How2Sign'
config['DATASET']['H2S']['MEAN_PATH'] = f'{drive_data}/smpl-x/mean.pt'
config['DATASET']['H2S']['STD_PATH'] = f'{drive_data}/smpl-x/std.pt'
config['TRAIN']['PRETRAINED_VAE'] = f'{drive_data}/checkpoints/vae/tokenizer.ckpt'

with open('configs/text_inference.yaml', 'w') as f:
    yaml.dump(config, f)

# Update assets
with open('configs/assets.yaml', 'r') as f:
    assets = yaml.safe_load(f)
assets['METRIC']['TM2T']['t2m_path'] = f'{drive_data}/deps/t2m/t2m/'
with open('configs/assets_inference.yaml', 'w') as f:
    yaml.dump(assets, f)

# Parse config
import sys
sys.argv = ['', '--cfg', 'configs/text_inference.yaml', '--cfg_assets', 'configs/assets_inference.yaml']
cfg = parse_args(phase="test")
cfg.FOLDER = cfg.TEST.FOLDER

# Seed
pl.seed_everything(cfg.SEED_VALUE)

# Build data and model
print("Loading model...")
datamodule = build_data(cfg)
model = build_model(cfg, datamodule)

# Load checkpoints
logger = create_logger(cfg, phase="test")
if cfg.TRAIN.PRETRAINED_VAE:
    load_pretrained_vae(cfg, model, logger)

# Check for trained checkpoint
ckpt_path = f'{drive_data}/experiments/mgpt/SOKE/checkpoints/last.ckpt'
if os.path.exists(ckpt_path):
    print(f"Loading trained checkpoint from {ckpt_path}")
    cfg.TEST.CHECKPOINTS = ckpt_path
    load_pretrained(cfg, model, logger, phase="test")
else:
    print("Using pretrained mBART (no fine-tuned checkpoint found)")

model = model.cuda()
model.eval()

print("✓ Model ready!")


In [None]:
# Generate sign language poses
output_dir = 'text_sign_results'
os.makedirs(output_dir, exist_ok=True)

print(f"\\nGenerating sign language for {len(custom_texts)} text(s)...\\n")

with torch.no_grad():
    for idx, text in enumerate(custom_texts):
        print(f"[{idx+1}/{len(custom_texts)}] Processing: '{text}'")
        
        # Prepare input
        batch = {
            'text': [text]
        }
        
        try:
            # Generate poses
            output = model.t2m_eval(batch)
            
            # Save result
            filename = f"text_{idx+1}.pkl"
            filepath = os.path.join(output_dir, filename)
            
            result = {
                'text': text,
                'prediction': output['joints'].cpu() if 'joints' in output else output,
                'length': len(text.split())
            }
            
            with open(filepath, 'wb') as f:
                pickle.dump(result, f)
            
            print(f"  ✓ Saved to: {filepath}")
            
        except Exception as e:
            print(f"  ✗ Error: {e}")
            import traceback
            traceback.print_exc()
            continue

print(f"\\n✓ Complete! Predictions saved in '{output_dir}'")


## 5. View Results


In [None]:
# List generated files
print("Generated predictions:")
!ls -lh {output_dir}

# Load and display results
pkl_files = sorted([f for f in os.listdir(output_dir) if f.endswith('.pkl')])

for pkl_file in pkl_files:
    filepath = os.path.join(output_dir, pkl_file)
    
    with open(filepath, 'rb') as f:
        result = pickle.load(f)
    
    print(f"\\n{pkl_file}:")
    print(f"  Text: {result['text']}")
    print(f"  Words: {result['length']}")
    if hasattr(result['prediction'], 'shape'):
        print(f"  Pose shape: {result['prediction'].shape}")
        print(f"  Frames: {result['prediction'].shape[0]}")


## 6. Download Results


In [None]:
# Zip results for easy download
!zip -r text_sign_results.zip {output_dir}/

# Download
from google.colab import files
files.download('text_sign_results.zip')

print("✓ Results packaged and ready to download")


## Notes

- **GPU Required**: Make sure you're using a GPU runtime (Runtime → Change runtime type → GPU → T4/V100/A100)
- **First Time**: Run notebook 1 first to download all dependencies to your Google Drive
- **Custom Text**: Simply modify the `custom_texts` list in cell 8 with your own text
- **Output**: Each text generates a `.pkl` file containing predicted sign language poses (3D coordinates)
- **Format**: Poses are in SMPL-X format and can be visualized using 3D animation tools

### Troubleshooting
- **OOM Error**: Reduce text length or batch size
- **Missing files**: Make sure notebook 1 was run successfully to download models
- **Slow generation**: Normal on T4 GPU, faster on V100/A100
