In [None]:
import os
import sys
import tensorflow as tf

# 1. Environment Detection & Drive Mounting
try:
    from google.colab import drive
    IN_COLAB = True
    print("Running in Google Colab. Mounting Drive...")
    drive.mount('/content/drive')
except ImportError:
    IN_COLAB = False
    print("Running Locally.")

# 2. Path Setup
# We need to find the project root to import 'src'
if IN_COLAB:
    # Common path pattern; adjust 'headway-prediction' if your folder name is different on Drive
    SEARCH_PATH = '/content/drive/MyDrive'
    PROJECT_DIR = None
    
    # Simple search for the project folder
    for root, dirs, files in os.walk(SEARCH_PATH):
        if 'src' in dirs and 'requirements.txt' in files:
            PROJECT_DIR = root
            break
            
    if PROJECT_DIR:
        print(f"Project root found at: {PROJECT_DIR}")
        os.chdir(PROJECT_DIR)
        sys.path.append(PROJECT_DIR)
    else:
        raise FileNotFoundError("Could not find project root in Drive. Make sure 'headway-prediction' is synced.")
else:
    # Assume we are in 'notebooks/' and project root is one level up
    PROJECT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
    sys.path.append(PROJECT_DIR)
    print(f"Running locally from: {PROJECT_DIR}")

In [None]:
from tensorflow.keras import mixed_precision

# Enable Mixed Precision to speed up training on T4/A100/V100 GPUs
# This uses float16 for calculations but float32 for variable stability.
# Expected speedup: 2x-3x
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

print(f"Global Mixed Precision Policy set to: {policy.name}")
print("Compute dtype:", policy.compute_dtype)
print("Variable dtype:", policy.variable_dtype)

In [None]:
from src.config import Config
from src.data.dataset import SubwayDataGenerator
from src.models.st_covnet import HeadwayConvLSTM

# 3. DEBUG FLAG
# Set to True for a quick 30-second "Smoke Test"
# Set to False for full production training
DEBUG_MODE = False

print(f"--- DEBUG MODE: {DEBUG_MODE} ---")

# Initialize Config
config = Config()

if DEBUG_MODE:
    print("WARNING: Overriding config for fast smoke test.")
    config.BATCH_SIZE = 32     # Smaller batch for stability
    config.EPOCHS = 1          # Single pass to prove end-to-end flow
    config.FILTERS = 16        # Tiny model for speed
    
    # If using Drive, paths are already set by os.chdir(PROJECT_DIR) + Config logic
    # but we verify them anyway
    print(f"Data Dir: {config.DATA_DIR}")

# 4. Load Data
gen = SubwayDataGenerator(config)
gen.load_data()

# 5. Create Datasets
if DEBUG_MODE:
    # Use a tiny slice of data (e.g., first 1000 samples)
    print("Creating tiny DEBUG dataset...")
    train_ds = gen.make_dataset(start_index=0, end_index=1000)
    val_ds = gen.make_dataset(start_index=1000, end_index=1200)
else:
    # Full data split
    # Requested split: 60% Train, 20% Val, 20% Test
    total_len = len(gen.headway_data)
    train_end = int(total_len * 0.6)
    val_end = int(total_len * 0.8)
    
    print(f"Creating FULL datasets (N={total_len})...")
    print(f"Train: 0-{train_end} | Val: {train_end}-{val_end}")
    
    train_ds = gen.make_dataset(start_index=0, end_index=train_end)
    val_ds = gen.make_dataset(start_index=train_end, end_index=val_end)

print("Datasets ready.")

In [None]:
# 6. Build and Compile Model
from src.metrics import rmse_seconds, mae_seconds

model_builder = HeadwayConvLSTM(config)
model = model_builder.build_model()

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=config.LEARNING_RATE),
    loss='mse',
    metrics=['mae', rmse_seconds, mae_seconds]
)

if DEBUG_MODE:
    model.summary()

In [None]:
# 8. Evaluation
from src.evaluator import Evaluator

print("running evaluation checks...")
evaluator = Evaluator(config)

# 1. Plot Loss Curves
evaluator.plot_loss(history)

# 2. Visualize Predictions
# Using the updated method in src/evaluator.py which now supports the "Micrograph" style
print("Generating spatiotemporal visualization...")
evaluator.plot_spatiotemporal_prediction(model, val_ds, sample_idx=0)