# Train Gated Fusion MLP (Colab)\nRuns one model family across `concat,sum_pool,max_pool` × `7,30` and uploads snapshots to S3.\n

### 1) Runtime setup\nSet Colab runtime to **GPU** (`Runtime` -> `Change runtime type` -> `GPU`).\n

In [None]:
!pip -q install --upgrade boto3 pandas numpy pyarrow scikit-learn torch joblib\n

### 2) Point to your repo checkout\n

In [None]:
import os\n\nREPO_DIR = os.environ.get('REPO_DIR', '/content/video-virality-predictor')\nif not os.path.isdir(REPO_DIR):\n    raise FileNotFoundError(f'Repo path not found: {REPO_DIR}')\n%cd $REPO_DIR\n

### 3) Set AWS + training parameters\n

In [None]:
import os\n\nMODEL_FAMILY = 'gated_fusion_mlp'\nS3_BUCKET = 'clipfarm-prod-us-west-2'  # change if needed\nAWS_REGION = 'ca-central-1'            # change if needed\nRUN_ID = ''                             # blank -> auto colab-<model>-<timestamp>\nSNAPSHOT_PREFIX = 'clipfarm/models/snapshots'\nSTRATEGIES = 'concat,sum_pool,max_pool'\nHORIZONS = '7,30'\nSEED = 42\nMAX_EPOCHS = 40\nPATIENCE = 6\nPROJECTOR_DIM = 128\nRANK_METRIC = 'rmse_log'\n\n# Set AWS credentials from Colab secrets or inline for your session\nos.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', '')\nos.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', '')\n

### 4) Run 6 jobs (3 fusion strategies × 2 horizons) for this model\n

In [None]:
import os\nimport subprocess\nimport sys\n\nif not S3_BUCKET:\n    raise ValueError('S3_BUCKET is required')\nif not os.environ.get('AWS_ACCESS_KEY_ID') or not os.environ.get('AWS_SECRET_ACCESS_KEY'):\n    raise ValueError('AWS credentials are required in environment')\n\ncmd = [\n    sys.executable, 'Super_Predict/run_model_colab_matrix.py',\n    '--model_family', MODEL_FAMILY,\n    '--s3_bucket', S3_BUCKET,\n    '--s3_region', AWS_REGION,\n    '--snapshot_prefix', SNAPSHOT_PREFIX,\n    '--strategies', STRATEGIES,\n    '--horizons', HORIZONS,\n    '--seed', str(SEED),\n    '--max_epochs', str(MAX_EPOCHS),\n    '--patience', str(PATIENCE),\n    '--projector_dim', str(PROJECTOR_DIM),\n    '--rank_metric', RANK_METRIC,\n]\nif RUN_ID:\n    cmd.extend(['--run_id', RUN_ID])\n\nprint('Running:', ' '.join(cmd))\nsubprocess.run(cmd, check=True)\n