# SOKE Stage 2: Train Sign Language Generator
Trains the mBART-based multilingual sign language generator using tokenized poses.


In [None]:
# Clone repo if not present
import os
if not os.path.exists('/content/SaSOKE'):
    !git clone https://github.com/YOUR_USERNAME/SaSOKE.git
%cd /content/SaSOKE

# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

drive_data = '/content/drive/MyDrive/SOKE_data'
print("Code:", os.getcwd())
print("Data:", drive_data)


## Prerequisites
Ensure tokenizer is trained or pretrained checkpoint exists at `checkpoints/vae/tokenizer.ckpt`


In [None]:
# Verify tokenizer checkpoint in Drive
assert os.path.exists(f'{drive_data}/checkpoints/vae/tokenizer.ckpt'), "Tokenizer not found in Drive!"
print("Tokenizer checkpoint found in Drive")


## Configuration Setup


In [None]:
# Update config for Colab/CUDA
import yaml

with open('configs/soke.yaml', 'r') as f:
    config = yaml.safe_load(f)

# GPU settings
config['ACCELERATOR'] = 'gpu'
config['DEVICE'] = [0]

# Point to Drive for data/models
config['DATASET']['H2S']['ROOT'] = f'{drive_data}/data/How2Sign'
config['DATASET']['H2S']['MEAN_PATH'] = f'{drive_data}/smpl-x/mean.pt'
config['DATASET']['H2S']['STD_PATH'] = f'{drive_data}/smpl-x/std.pt'

# Model paths in Drive
config['TRAIN']['PRETRAINED_VAE'] = f'{drive_data}/checkpoints/vae/tokenizer.ckpt'
config['model']['params']['lm_path'] = f'{drive_data}/deps/mbart-h2s-csl-phoenix'

# Training settings
config['TRAIN']['NUM_WORKERS'] = 2
config['TRAIN']['BATCH_SIZE'] = 16

# Save config
with open('configs/soke_colab.yaml', 'w') as f:
    yaml.dump(config, f)

print("Config updated - GitHub code + Drive data")


## Train SOKE Model


In [None]:
# Start training
!python -m train --cfg configs/soke_colab.yaml --nodebug


## Monitor Training


In [None]:
# Load tensorboard
%load_ext tensorboard
%tensorboard --logdir experiments/mgpt/SOKE/


## Test Model
Run inference after training completes.


In [None]:
# Run inference on test set
!python -m test --cfg configs/soke_colab.yaml --task t2m
