<a href="https://colab.research.google.com/github/Sanal-Live2/SMTPD/blob/main/SMTPD_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SMTPD: Social Media Temporal Popularity Prediction

**Paper:** https://arxiv.org/abs/2503.04446

This notebook trains a multi-modal LSTM model to predict social media popularity over 30 days.

## Features:
- ‚úÖ Works with pre-created subset datasets
- ‚úÖ Automatic checkpoint saving/loading
- ‚úÖ Resume after Colab disconnection
- ‚úÖ GPU monitoring
- ‚úÖ Progress tracking

## Estimated Training Time (based on dataset size):
- 5% subset (~14k samples): 30-45 minutes
- 10% subset (~28k samples): 1-2 hours
- 25% subset (~70k samples): 2.5-3.5 hours
- 50% subset (~141k samples): 5-7 hours
- Full dataset (282k samples): 10-15 hours (needs 2 sessions)

## 1. Environment Setup

In [1]:
# Check GPU availability
!nvidia-smi
import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Fri Nov 21 09:27:03 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   70C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
# Install required packages
!pip install -q transformers langid fasttext-wheel scikit-learn
print("‚úÖ All packages installed!")

[?25l     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/1.9 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m [32m1.9/1.9 MB[0m [31m104.2 MB/s[0m eta [36m0:00:01[0m[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.9/1.9 MB[0m [31m43.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m4.6/4.6 MB[0m [31m63.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2

## 2. Mount Google Drive & Setup Data Paths

**Before running:** Upload your subset folder to Google Drive:
```
MyDrive/SMTPD_data/
‚îú‚îÄ‚îÄ subset_5_percent/          (or subset_10_percent, etc.)
‚îÇ   ‚îú‚îÄ‚îÄ basic_view_pn_5percent.csv
‚îÇ   ‚îî‚îÄ‚îÄ img_yt/
‚îÇ       ‚îî‚îÄ‚îÄ (your images)
‚îú‚îÄ‚îÄ bert_multilingual/         (BERT model - will download if not present)
‚îî‚îÄ‚îÄ checkpoints/               (created automatically)
```

**How to create subset:**
- Use `create_matched_subset.py` locally
- Upload the entire subset folder to Drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

# ========================================
# üìù CONFIGURE YOUR PATHS HERE
# ========================================

# Set your subset folder name (change this to match your uploaded folder)
SUBSET_FOLDER = 'subset_5_percent'  # Change to: subset_10_percent, subset_25_percent, etc.

# Base directory in Google Drive
DATA_DIR = '/content/drive/MyDrive/SMTPD_data'

# Dataset paths (automatically configured based on SUBSET_FOLDER)
SUBSET_DIR = f'{DATA_DIR}/{SUBSET_FOLDER}'
CSV_PATH = f'{SUBSET_DIR}/basic_view_pn_{SUBSET_FOLDER.split("_")[1]}percent.csv'
IMG_DIR = f'{SUBSET_DIR}/img_yt'

# Shared resources
BERT_PATH = f'{DATA_DIR}/bert_multilingual'
CKPT_DIR = f'{DATA_DIR}/checkpoints'

# Create checkpoint directory
!mkdir -p {CKPT_DIR}

print("="*70)
print("üìÇ Data Paths Configuration")
print("="*70)
print(f"Subset folder:  {SUBSET_FOLDER}")
print(f"CSV file:       {CSV_PATH}")
print(f"Images dir:     {IMG_DIR}")
print(f"BERT model:     {BERT_PATH}")
print(f"Checkpoints:    {CKPT_DIR}")
print("="*70)

# Verify files exist
import os
print("\nüîç Verifying files...")
csv_exists = os.path.exists(CSV_PATH)
img_exists = os.path.exists(IMG_DIR)
bert_exists = os.path.exists(BERT_PATH)

print(f"‚úÖ CSV exists: {csv_exists}")
print(f"‚úÖ Images exist: {img_exists}")
print(f"‚úÖ BERT model exists: {bert_exists}")

if csv_exists and img_exists:
    import pandas as pd
    df = pd.read_csv(CSV_PATH)
    csv_samples = len(df)
    img_count = len([f for f in os.listdir(IMG_DIR) if f.endswith('.jpg')])
    print(f"\nüìä Dataset Info:")
    print(f"   CSV rows: {csv_samples:,}")
    print(f"   Images: {img_count:,}")
    print(f"   Match: {'‚úÖ Perfect!' if csv_samples == img_count else '‚ö†Ô∏è Mismatch'}")
else:
    print("\n‚ö†Ô∏è Error: Files not found!")
    print(f"\nMake sure you uploaded '{SUBSET_FOLDER}' folder to:")
    print(f"   {DATA_DIR}/")

if not bert_exists:
    print("\n‚ö†Ô∏è BERT model not found. Will download in next step...")

Mounted at /content/drive
üìÇ Data Paths Configuration
Subset folder:  subset_5_percent
CSV file:       /content/drive/MyDrive/SMTPD_data/subset_5_percent/basic_view_pn_5percent.csv
Images dir:     /content/drive/MyDrive/SMTPD_data/subset_5_percent/img_yt
BERT model:     /content/drive/MyDrive/SMTPD_data/bert_multilingual
Checkpoints:    /content/drive/MyDrive/SMTPD_data/checkpoints

üîç Verifying files...
‚úÖ CSV exists: True
‚úÖ Images exist: True
‚úÖ BERT model exists: False

üìä Dataset Info:
   CSV rows: 14,124
   Images: 14,123
   Match: ‚ö†Ô∏è Mismatch

‚ö†Ô∏è BERT model not found. Will download in next step...


[31mERROR: Could not find a version that satisfies the requirement google-colab (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for google-colab[0m[31m
[0m

In [4]:
# Download BERT model if not present
import os
if not os.path.exists(BERT_PATH):
    print("üì• Downloading BERT multilingual model (~700MB)...")
    print("This will take 3-5 minutes...\n")

    !pip install -q huggingface_hub
    from huggingface_hub import snapshot_download

    snapshot_download(
        repo_id="bert-base-multilingual-cased",
        local_dir=BERT_PATH,
        local_dir_use_symlinks=False
    )
    print("\n‚úÖ BERT model downloaded successfully!")
else:
    print("‚úÖ BERT model already exists, skipping download.")

üì• Downloading BERT multilingual model (~700MB)...
This will take 3-5 minutes...



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/7.10k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

tf_model.h5:   0%|          | 0.00/1.08G [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/714M [00:00<?, ?B/s]

flax_model.msgpack:   0%|          | 0.00/712M [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/445 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]


‚úÖ BERT model downloaded successfully!


## 3. Clone Repository & Setup Code

In [5]:
# Clone the repository
!git clone https://github.com/zhuwei321/SMTPD.git
%cd SMTPD

# List files
!ls -lh *.py | head -15

Cloning into 'SMTPD'...
remote: Enumerating objects: 48, done.[K
remote: Counting objects: 100% (48/48), done.[K
remote: Compressing objects: 100% (47/47), done.[K
remote: Total 48 (delta 21), reused 2 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (48/48), 348.75 KiB | 1.24 MiB/s, done.
Resolving deltas: 100% (21/21), done.
/content/SMTPD
-rw-r--r-- 1 root root 3.1K Nov 21 09:30 cb.py
-rw-r--r-- 1 root root 3.0K Nov 21 09:30 dataset_split.py
-rw-r--r-- 1 root root  12K Nov 21 09:30 fe.py
-rw-r--r-- 1 root root 2.4K Nov 21 09:30 img_downloader.py
-rw-r--r-- 1 root root 7.8K Nov 21 09:30 main_bili.py
-rw-r--r-- 1 root root  14K Nov 21 09:30 main.py
-rw-r--r-- 1 root root 8.8K Nov 21 09:30 rewrite_s.py
-rw-r--r-- 1 root root 3.9K Nov 21 09:30 smp_data.py
-rw-r--r-- 1 root root  25K Nov 21 09:30 smp_model.py
-rw-r--r-- 1 root root 8.0K Nov 21 09:30 tools.py
-rw-r--r-- 1 root root  27K Nov 21 09:30 utils.py


## 4. Modify Code for Colab

Update BERT model paths to use Colab/Drive paths

In [6]:
# Patch smp_model.py to use correct BERT path
with open('smp_model.py', 'r') as f:
    content = f.read()

# Replace BERT paths
content = content.replace("'../bert_multilingual'", f"'{BERT_PATH}'")

with open('smp_model.py', 'w') as f:
    f.write(content)

print("‚úÖ Updated BERT paths in smp_model.py")

‚úÖ Updated BERT paths in smp_model.py


In [8]:
%%writefile main_colab.py
# Modified main.py for Google Colab - Direct training on provided dataset
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import argparse
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from smp_model import *
from smp_data import *
from tqdm import tqdm
import csv
from tools import *
from transformers import logging
import warnings
import random
import logging as log
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser(description='SMTPD Model Trainer')
parser.add_argument('--warm_start_epoch', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--epochs', type=int, default=40)
parser.add_argument('--lr', default=1e-2, type=float)

parser.add_argument('--images_dir', type=str, required=True)
parser.add_argument('--gt_path', type=str, default="0")
parser.add_argument('--data_files', type=str, required=True)
parser.add_argument('--seq_len', type=int, default=29)
parser.add_argument('--ckpt_path', type=str, required=True)
parser.add_argument('--result_file', type=str, default='all_result.csv')
parser.add_argument('--write', type=bool, default=True)
parser.add_argument('--train', type=bool, default=False)
parser.add_argument('--test', type=bool, default=False)
parser.add_argument('--K_fold', type=int, default=0)
parser.add_argument('--use_mlp', type=bool, default=False)
parser.add_argument('--resume_from', type=str, default=None, help='Path to checkpoint to resume from')

# Import CustomLoss and other functions from original main.py
exec(open('main.py').read().split('if __name__')[0])

def load_data(args, K, n):
    """Load data with K-fold split - trains on full provided dataset"""
    random_seed = 23
    random_generator = random.Random(random_seed)
    data_files = args.data_files

    # Load dataset (uses whatever data is provided in the CSV/images)
    data_set = youtube_data_lstm(data_files, args.images_dir, args.gt_path)
    total_size = len(data_set)

    print(f"\nüìä Dataset loaded: {total_size:,} samples")

    # Create indices and shuffle
    indices = list(range(total_size))
    random_generator.shuffle(indices)

    batch_size = args.batch_size
    fold_size = total_size // K

    # K-fold split
    val_start = n * fold_size
    val_end = (n + 1) * fold_size

    val_indices = indices[val_start:val_end]
    test_indices = val_indices
    train_indices = [i for i in indices if i not in val_indices]

    train_set = Subset(data_set, train_indices)
    val_set = Subset(data_set, val_indices)
    test_set = Subset(data_set, test_indices)

    print(f"üìà Train: {len(train_indices):,} | Val: {len(val_indices):,} | Test: {len(test_indices):,}")

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
                             num_workers=args.num_workers, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False,
                           num_workers=args.num_workers, drop_last=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False,
                            num_workers=args.num_workers, drop_last=True)

    return train_loader, val_loader, test_loader

if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    args = parser.parse_args()

    print("="*60)
    print("üöÄ SMTPD Training Starting...")
    print("="*60)
    print(f"Device: {device}")
    print(f"K-fold: {args.K_fold}")
    print(f"Epochs: {args.epochs}")
    print(f"Batch size: {args.batch_size}")
    print(f"Sequence length: {args.seq_len} days")
    print(f"Learning rate: {args.lr}")
    print("="*60)

    # Load data
    train_loader, val_loader, test_loader = load_data(args, 5, args.K_fold)

    # Create model
    if args.use_mlp:
        model = youtube_MLP(args.seq_len, args.batch_size)
    else:
        model = youtube_lstm3(args.seq_len, args.batch_size)

    # Resume from checkpoint if specified
    if args.resume_from and os.path.exists(args.resume_from):
        print(f"\nüìÇ Loading checkpoint: {args.resume_from}")
        model_dict = torch.load(args.resume_from)
        model.load_state_dict(model_dict)
        print("‚úÖ Checkpoint loaded successfully!")
    elif args.test:
        import glob
        model_files = glob.glob(os.path.join(args.ckpt_path, str(args.K_fold) + "*.pth"))
        if model_files:
            model_files = sorted(model_files)[-1]  # Get latest
            model_dict = torch.load(model_files)
            model.load_state_dict(model_dict)
            print(f'‚úÖ Loaded model: {model_files}')

    model = model.to(device)

    if args.train:
        train(args, model, train_loader, val_loader)
    elif args.test:
        test(args, model, test_loader)
    else:
        print("‚ö†Ô∏è Please choose --train=True or --test=True")

print("‚úÖ main_colab.py created!")

Overwriting main_colab.py


## 5. Training Configuration

In [20]:
# ========================================
# üéØ TRAINING CONFIGURATION
# ========================================

# K-fold cross-validation
K_FOLD = 0          # Which fold to use (0-4)

# Model settings
SEQ_LEN = 29        # Days to predict (29 = full 30 days)
USE_MLP = False     # False = LSTM (recommended), True = MLP

# Training hyperparameters
EPOCHS = 40      # Max epochs (early stopping usually stops at ~15-20)
BATCH_SIZE = 64    # Batch size (reduce to 32 if OOM)
LEARNING_RATE = 1e-2  # Learning rate

# Resume from checkpoint (if Colab disconnected)
RESUME_FROM = None  # Set to checkpoint path like: f'{CKPT_DIR}/0-15-0.2345.pth'

# ========================================

print("="*70)
print("üìã Training Configuration")
print("="*70)
print(f"Dataset:        {SUBSET_FOLDER}")
print(f"K-fold:         {K_FOLD}")
print(f"Epochs:         {EPOCHS}")
print(f"Batch size:     {BATCH_SIZE}")
print(f"Sequence len:   {SEQ_LEN} days")
print(f"Learning rate:  {LEARNING_RATE}")
print(f"Model:          {'MLP' if USE_MLP else 'LSTM'}")
if RESUME_FROM:
    print(f"Resume from:    {RESUME_FROM}")
print("="*70)

# Estimate training time based on dataset size
import pandas as pd
import os
if os.path.exists(CSV_PATH):
    df = pd.read_csv(CSV_PATH)
    samples = len(df)
    batches_per_epoch = (samples * 0.8) / BATCH_SIZE  # 80% for training
    time_per_epoch = batches_per_epoch * 0.4 / 60  # ~0.4s per batch
    total_time = time_per_epoch * 15  # Assume early stopping at epoch 15

    print(f"\n‚è±Ô∏è  Estimated Training Time:")
    print(f"   Samples:          {samples:,}")
    print(f"   Batches/epoch:    {int(batches_per_epoch)}")
    print(f"   Time per epoch:   ~{time_per_epoch:.1f} minutes")
    print(f"   Total (est.):     ~{total_time:.1f} hours")
    print(f"   Colab session:    {'‚úÖ Fits in 1 session' if total_time < 11 else '‚ö†Ô∏è Needs 2 sessions'}")
    print("="*70)

üìã Training Configuration
Dataset:        subset_5_percent
K-fold:         0
Epochs:         40
Batch size:     64
Sequence len:   29 days
Learning rate:  0.01
Model:          LSTM

‚è±Ô∏è  Estimated Training Time:
   Samples:          14,124
   Batches/epoch:    176
   Time per epoch:   ~1.2 minutes
   Total (est.):     ~17.7 hours
   Colab session:    ‚ö†Ô∏è Needs 2 sessions


## 6. Start Training

**Important:**
- Checkpoints are saved to Google Drive automatically
- Training stops early if validation doesn't improve for 5 epochs
- Monitor GPU usage in next section
- If Colab disconnects, you can resume from checkpoint

In [None]:
# Build the training command
cmd = f"""python main_colab.py \
    --train=True \
    --K_fold={K_FOLD} \
    --seq_len={SEQ_LEN} \
    --epochs={EPOCHS} \
    --batch_size={BATCH_SIZE} \
    --lr={LEARNING_RATE} \
    --use_mlp={USE_MLP} \
    --images_dir={IMG_DIR} \
    --data_files={CSV_PATH} \
    --ckpt_path={CKPT_DIR}"""

if RESUME_FROM:
    cmd += f" --resume_from={RESUME_FROM}"

print("üöÄ Starting training...\n")
!{cmd}

üöÄ Starting training...

2025-11-21 09:56:16.520599: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763718976.541786    7996 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763718976.548702    7996 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763718976.565753    7996 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763718976.565778    7996 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763718976.565783    7996 computation_placer.cc:

## 7. Monitor Training

Run these cells while training is in progress

In [None]:
# Monitor GPU usage (run this periodically)
!nvidia-smi --query-gpu=utilization.gpu,memory.used,memory.total,temperature.gpu --format=csv

# Check saved checkpoints
import glob
import os
checkpoints = sorted(glob.glob(f'{CKPT_DIR}/*.pth'))
print(f"\nüìÇ Saved checkpoints ({len(checkpoints)}):")
for ckpt in checkpoints[-5:]:  # Show last 5
    size_mb = os.path.getsize(ckpt) / (1024*1024)
    print(f"   {os.path.basename(ckpt):40s} ({size_mb:.1f} MB)")

In [None]:
# View training logs
log_file = f'{CKPT_DIR}/train_{K_FOLD}.log'
if os.path.exists(log_file):
    print(f"üìÑ Last 50 lines of training log:\n")
    !tail -50 {log_file}
else:
    print("‚ö†Ô∏è Log file not created yet. Training may not have started.")

## 8. Resume Training (If Disconnected)

If Colab disconnects:
1. Remount Google Drive (run Section 2)
2. Run cell below to find latest checkpoint
3. Update `RESUME_FROM` in Section 5
4. Re-run training cell (Section 6)

In [None]:
# Find latest checkpoint to resume from
import glob
import os

checkpoints = sorted(glob.glob(f'{CKPT_DIR}/{K_FOLD}-*.pth'))
if checkpoints:
    latest = checkpoints[-1]
    print("‚úÖ Latest checkpoint found!\n")
    print(f"Checkpoint: {latest}")
    print(f"\nTo resume training:")
    print(f"1. Go to Section 5 (Configuration)")
    print(f"2. Set: RESUME_FROM = '{latest}'")
    print(f"3. Re-run Section 6 (Start Training)")
else:
    print("‚ö†Ô∏è No checkpoints found.")
    print("Training may not have started yet or no checkpoints were saved.")

## 9. Test Trained Model

In [None]:
# Test with best checkpoint
import glob
import os

checkpoints = sorted(glob.glob(f'{CKPT_DIR}/{K_FOLD}-*.pth'))
if checkpoints:
    best_ckpt = checkpoints[-1]  # Latest (should be best due to early stopping)
    print(f"üß™ Testing with checkpoint: {os.path.basename(best_ckpt)}\n")

    cmd = f"""python main_colab.py \
        --test=True \
        --K_fold={K_FOLD} \
        --seq_len={SEQ_LEN} \
        --batch_size={BATCH_SIZE} \
        --use_mlp={USE_MLP} \
        --images_dir={IMG_DIR} \
        --data_files={CSV_PATH} \
        --ckpt_path={CKPT_DIR} \
        --resume_from={best_ckpt}"""

    !{cmd}
else:
    print("‚ö†Ô∏è No trained model found. Please train first!")

## 10. Visualize Results

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os

# Load results
result_file = f'{CKPT_DIR}/all_result.csv'
if os.path.exists(result_file):
    df = pd.read_csv(result_file, header=None, names=['video_id', 'prediction', 'ground_truth'])

    # Parse predictions and ground truth
    import ast
    df['pred'] = df['prediction'].apply(ast.literal_eval)
    df['gt'] = df['ground_truth'].apply(ast.literal_eval)

    # Plot first 5 samples
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    axes = axes.flatten()

    for i in range(min(5, len(df))):
        pred = df.iloc[i]['pred']
        gt = df.iloc[i]['gt']

        axes[i].plot(gt, 'b-', label='Ground Truth', linewidth=2)
        axes[i].plot(pred, 'r--', label='Prediction', linewidth=2)
        axes[i].set_title(f'Sample {i+1}')
        axes[i].set_xlabel('Days')
        axes[i].set_ylabel('Popularity')
        axes[i].legend()
        axes[i].grid(True, alpha=0.3)

    # Hide last subplot
    axes[5].axis('off')

    plt.tight_layout()
    plt.savefig(f'{CKPT_DIR}/predictions_plot.png', dpi=150, bbox_inches='tight')
    plt.show()

    print(f"\n‚úÖ Plot saved to: {CKPT_DIR}/predictions_plot.png")
else:
    print("‚ö†Ô∏è No results file found. Run testing first (Section 9)!")

## 11. Download Results

Download trained models and results to your local machine

In [None]:
from google.colab import files
import glob
import os

print("üì• Preparing downloads...\n")

# Download best checkpoint
checkpoints = sorted(glob.glob(f'{CKPT_DIR}/{K_FOLD}-*.pth'))
if checkpoints:
    best_ckpt = checkpoints[-1]
    print(f"Downloading: {os.path.basename(best_ckpt)}")
    files.download(best_ckpt)

# Download results
result_file = f'{CKPT_DIR}/all_result.csv'
if os.path.exists(result_file):
    print(f"Downloading: all_result.csv")
    files.download(result_file)

# Download plot
plot_file = f'{CKPT_DIR}/predictions_plot.png'
if os.path.exists(plot_file):
    print(f"Downloading: predictions_plot.png")
    files.download(plot_file)

# Download log
log_file = f'{CKPT_DIR}/train_{K_FOLD}.log'
if os.path.exists(log_file):
    print(f"Downloading: train_{K_FOLD}.log")
    files.download(log_file)

print("\n‚úÖ Downloads complete!")

In [17]:
# Fix the main.py file to remove verbose argument
with open('main.py', 'r') as f:
    content = f.read()

# Remove verbose=True from ReduceLROnPlateau
content = content.replace(
    'scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=1, verbose=True, min_lr=0)',
    'scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=1, min_lr=0)'
)

with open('main.py', 'w') as f:
    f.write(content)

print("‚úÖ Fixed main.py - removed verbose argument from scheduler")

‚úÖ Fixed main.py - removed verbose argument from scheduler


## üí° Tips & Troubleshooting

### Speed up training:
1. Use smaller subset (create with `create_matched_subset.py`)
2. Reduce prediction days: `SEQ_LEN = 14`
3. Lower epochs: `EPOCHS = 20`

### If you get OOM (Out of Memory):
1. Reduce batch size: `BATCH_SIZE = 32` or `16`
2. Restart runtime: Runtime ‚Üí Restart runtime
3. Use smaller subset

### If Colab disconnects:
1. Don't panic! Checkpoints are saved to Google Drive
2. Remount Drive (Section 2)
3. Find latest checkpoint (Section 8)
4. Set `RESUME_FROM` and restart training

### Training not improving:
1. Check learning rate (try `1e-3` or `5e-3`)
2. Verify data loaded correctly
3. Check training logs for errors

### Data not found:
1. Verify subset folder uploaded to Google Drive
2. Check `SUBSET_FOLDER` name in Section 2
3. Ensure Drive is mounted

### Creating subsets:
```bash
# On your local machine:
python create_matched_subset.py --ratio 0.05  # 5%
python create_matched_subset.py --ratio 0.10  # 10%
python create_matched_subset.py --ratio 0.25  # 25%
```

---

**Paper:** https://arxiv.org/abs/2503.04446  
**GitHub:** https://github.com/zhuwei321/SMTPD

Good luck with your training! üöÄ