In [None]:
import os
import sys
import tensorflow as tf

# 1. Environment Detection & Drive Mounting
try:
    from google.colab import drive
    IN_COLAB = True
    print("Running in Google Colab. Mounting Drive...")
    drive.mount('/content/drive')
except ImportError:
    IN_COLAB = False
    print("Running Locally.")

# 2. Path Setup
# We need to find the project root to import 'src'
if IN_COLAB:
    # Common path pattern; adjust 'headway-prediction' if your folder name is different on Drive
    SEARCH_PATH = '/content/drive/MyDrive'
    PROJECT_DIR = None
    
    # Simple search for the project folder
    for root, dirs, files in os.walk(SEARCH_PATH):
        if 'src' in dirs and 'requirements.txt' in files:
            PROJECT_DIR = root
            break
            
    if PROJECT_DIR:
        print(f"Project root found at: {PROJECT_DIR}")
        os.chdir(PROJECT_DIR)
        sys.path.append(PROJECT_DIR)
    else:
        raise FileNotFoundError("Could not find project root in Drive. Make sure 'headway-prediction' is synced.")
else:
    # Assume we are in 'notebooks/' and project root is one level up
    PROJECT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
    sys.path.append(PROJECT_DIR)
    print(f"Running locally from: {PROJECT_DIR}")

In [None]:
from src.config import Config
from src.data.dataset import SubwayDataGenerator
from src.models.st_covnet import HeadwayConvLSTM

# 3. DEBUG FLAG
# Set to True for a quick 30-second "Smoke Test"
# Set to False for full production training
DEBUG_MODE = True

print(f"--- DEBUG MODE: {DEBUG_MODE} ---")

# Initialize Config
config = Config()

if DEBUG_MODE:
    print("WARNING: Overriding config for fast smoke test.")
    config.BATCH_SIZE = 32     # Smaller batch for stability
    config.EPOCHS = 1          # Single pass to prove end-to-end flow
    config.FILTERS = 16        # Tiny model for speed
    
    # If using Drive, paths are already set by os.chdir(PROJECT_DIR) + Config logic
    # but we verify them anyway
    print(f"Data Dir: {config.DATA_DIR}")

# 4. Load Data
gen = SubwayDataGenerator(config)
gen.load_data()

# 5. Create Datasets
if DEBUG_MODE:
    # Use a tiny slice of data (e.g., first 1000 samples)
    print("Creating tiny DEBUG dataset...")
    train_ds = gen.make_dataset(start_index=0, end_index=1000)
    val_ds = gen.make_dataset(start_index=1000, end_index=1200)
else:
    # Full data split
    # Standard split: 70% Train, 15% Val, 15% Test
    total_len = len(gen.headway_data)
    train_end = int(total_len * 0.7)
    val_end = int(total_len * 0.85)
    
    print(f"Creating FULL datasets (N={total_len})...")
    train_ds = gen.make_dataset(start_index=0, end_index=train_end)
    val_ds = gen.make_dataset(start_index=train_end, end_index=val_end)

print("Datasets ready.")

In [None]:
# 6. Build and Compile Model
model_builder = HeadwayConvLSTM(config)
model = model_builder.build_model()

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=config.LEARNING_RATE),
    loss='mse',
    metrics=['mae']
)

if DEBUG_MODE:
    model.summary()

# 7. Run Training
print(f"Starting Training (Epochs: {config.EPOCHS})...")

history = model.fit(
    train_ds,
    epochs=config.EPOCHS,
    validation_data=val_ds,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
    ]
)

print("Training cycle complete.")

In [None]:
# 8. Evaluation Smoke Test
from src.evaluator import Evaluator
import matplotlib.pyplot as plt

print("running evaluation checks...")
evaluator = Evaluator(config)

# 1. Plot Loss Curves (even if short)
evaluator.plot_loss(history)

# 2. Visualize Predictions
# We grab the validation dataset to see how well it's generalizing (or over fitting)
print("Generating spatiotemporal visualization...")

# We will override the default plot_spatiotemporal_prediction method with a custom one here
# to match the micrograph style requested (similar to the abstract images)

def plot_micrograph_style(model, dataset, sample_idx=0, direction=0):
    """
    Plots predictions in a style similar to the abstract:
    (a) Actual, (b) Predicted
    Using a simpler, cleaner aesthetic.
    """
    # 1. Fetch a single batch
    for inputs, targets in dataset.take(1):
        preds = model.predict(inputs, verbose=0)
        
        # Extract sample and convert to seconds (multiply by 30 mins * 60 sec/min?) 
        # Or just keep as minutes. Abstract says "Headway (seconds)", but our data is 0-30 mins normalized?
        # Let's assume denormalizing by * 30 gives minutes. Multipling by 60 gives seconds.
        # Let's stick to minutes for now as per our config, or seconds if you prefer.
        # Actually, let's keep it consistent with the previous plot: Minutes (0-30).
        
        y_true = targets[sample_idx, :, :, direction, 0].numpy().T * 30
        y_pred = preds[sample_idx, :, :, direction, 0].T * 30
        
        # Setup plot
        fig, axes = plt.subplots(1, 2, figsize=(12, 5), dpi=150)
        
        # Common settings
        cmap = 'RdYlGn_r' # Red-Yellow-Green reversed (Green=Low Headway, Red=High)
        vmin, vmax = 0, 30
        
        # Plot A: Actual
        im1 = axes[0].imshow(y_true, aspect='auto', cmap=cmap, origin='lower', vmin=vmin, vmax=vmax, interpolation='nearest')
        axes[0].set_title("(a) Actual", fontsize=12, pad=10)
        axes[0].set_xlabel("Time (Future Steps)", fontsize=10)
        axes[0].set_ylabel("Station Index", fontsize=10)
        axes[0].grid(False) # Turn off grid for cleaner look
        
        # Plot B: Predicted
        im2 = axes[1].imshow(y_pred, aspect='auto', cmap=cmap, origin='lower', vmin=vmin, vmax=vmax, interpolation='nearest')
        axes[1].set_title("(b) Predicted", fontsize=12, pad=10)
        axes[1].set_xlabel("Time (Future Steps)", fontsize=10)
        # axes[1].set_ylabel("Station Index") # Hide Y label for second plot
        axes[1].set_yticks([]) 
        axes[1].grid(False)

        # Add single colorbar
        cbar = fig.colorbar(im2, ax=axes.ravel().tolist(), pad=0.02, aspect=30)
        cbar.set_label('Headway (min)', rotation=270, labelpad=15)
        
        plt.tight_layout()
        plt.show()
        break

plot_micrograph_style(model, val_ds, sample_idx=0)