# Task 1: Baseline TPP Models on Real-World Datasets

This notebook runs all 7 baseline TPP models from EasyTPP on 5 real-world datasets.

## Models
| Model | hidden_size | time_emb_size | num_layers | num_heads |
|-------|-------------|---------------|------------|------------|
| RMTPP | 32 | 16 | 2 | - |
| NHP | 64 | 16 | 2 | - |
| FullyNN | 32 | 16 | 2 | - |
| SAHP | 32 | 16 | 2 | 2 |
| THP | 64 | 16 | 2 | 2 |
| IntensityFree | 32 | 16 | 2 | - |
| AttNHP | 32 | 16 | 1 | 2 |

## Datasets
- Taxi, Amazon, Taobao, StackOverflow, Retweet

## Patches Applied
1. **FullyNN Gradient Fix**: Enables gradient computation during evaluation
2. **IntensityFree Median Fix**: Uses analytical median instead of unstable sampling
3. **Memory Optimization**: Aggressive GPU clearing for memory-intensive models

## 1. Package Installation

In [None]:
import os
import sys

# Set PyTorch memory allocation config BEFORE importing torch
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

!pip install easy-tpp lightning pytorch-lightning hydra-core omegaconf torchmetrics stribor -q

print("✓ Core dependencies installed")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.9/44.9 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m99.2/99.2 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m846.0/846.0 kB[0m [31m57.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.5/849.5 kB[0m [31m57.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.5/154.5 kB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m65.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.7/57.7 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25h✓ Core dependencies installed


In [None]:
import yaml
import pandas as pd
import numpy as np
import torch
import gc
import time
import traceback
import shutil
from datetime import datetime
from easy_tpp.config_factory import Config
from easy_tpp.runner import Runner
from google.colab import drive

print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

PyTorch version: 2.9.0+cu126
NumPy version: 2.0.2
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB
GPU Memory: 85.2 GB


## 2. Apply Monkey Patches

### Patch 1: FullyNN Gradient Fix
FullyNN uses `torch.autograd.grad()` to compute intensity derivatives, but during validation/evaluation, EasyTPP runs under `torch.no_grad()` context which disables gradient tracking.

### Patch 2: IntensityFree Median Fix
IntensityFree uses sampling from LogNormalMixtureDistribution for time prediction. Log-normal distributions have heavy right tails, causing single samples to be extreme outliers. We replace sampling with analytical median.

In [None]:
# ==============================================================================
# PATCH 1: FULLYNN GRADIENT FIX
# ==============================================================================

from easy_tpp.model.torch_model import torch_fullynn

def patched_compute_intensities_at_sample_times(self, time_seqs, time_delta_seqs, type_seqs, sample_dtimes, **kwargs):
    """Patched version that enables gradients during intensity computation."""
    compute_last_step_only = kwargs.get('compute_last_step_only', False)

    with torch.enable_grad():
        hidden_states = self.forward(
            time_seqs=time_seqs,
            time_delta_seqs=time_delta_seqs,
            type_seqs=type_seqs,
        )

        num_samples = sample_dtimes.size()[-1]
        batch_size, seq_len, hidden_size = hidden_states.shape

        hidden_states_ = hidden_states[..., None, :].expand(batch_size, seq_len, num_samples, hidden_size)
        sample_dtimes_grad = sample_dtimes.clone().detach().requires_grad_(True)

        _, derivative_integral_lambda = self.layer_intensity.forward(
            hidden_states=hidden_states_,
            time_delta_seqs=sample_dtimes_grad,
        )

    derivative_integral_lambda = derivative_integral_lambda.detach()

    if compute_last_step_only:
        return derivative_integral_lambda[:, -1:, :, :]
    return derivative_integral_lambda

torch_fullynn.FullyNN.compute_intensities_at_sample_times = patched_compute_intensities_at_sample_times
print("✓ FullyNN patched for gradient computation during evaluation")

✓ FullyNN patched for gradient computation during evaluation


In [None]:
# ==============================================================================
# PATCH 2: INTENSITYFREE MEDIAN FIX - Distribution Class
# ==============================================================================

import torch.distributions as D
from torch.distributions import TransformedDistribution
from easy_tpp.model.torch_model import torch_intensity_free

Normal = torch_intensity_free.Normal
MixtureSameFamily = torch_intensity_free.MixtureSameFamily
clamp_preserve_gradients = torch_intensity_free.clamp_preserve_gradients


class FixedLogNormalMixtureDistribution(TransformedDistribution):
    """LogNormalMixtureDistribution with analytical median for stable predictions.

    Fixes:
    1. numpy.float32 scalars causing .sign() errors
    2. Implements analytical median for stable RMSE computation
    """

    def __init__(self, locs, log_scales, log_weights, mean_log_inter_time, std_log_inter_time, validate_args=None):
        mixture_dist = D.Categorical(logits=log_weights)
        component_dist = Normal(loc=locs, scale=log_scales.exp())
        GMM = MixtureSameFamily(mixture_dist, component_dist)

        # Convert numpy scalars to Python floats
        if hasattr(mean_log_inter_time, 'item'):
            mean_log_inter_time = mean_log_inter_time.item()
        else:
            mean_log_inter_time = float(mean_log_inter_time)

        if hasattr(std_log_inter_time, 'item'):
            std_log_inter_time = std_log_inter_time.item()
        else:
            std_log_inter_time = float(std_log_inter_time)

        self.mean_log_inter_time = mean_log_inter_time
        self.std_log_inter_time = std_log_inter_time

        # Store for analytical median computation
        self._locs = locs
        self._log_scales = log_scales
        self._log_weights = log_weights

        if mean_log_inter_time == 0.0 and std_log_inter_time == 1.0:
            transforms = []
        else:
            transforms = [D.AffineTransform(loc=mean_log_inter_time, scale=std_log_inter_time)]
        transforms.append(D.ExpTransform())

        self.transforms = transforms
        sign = 1
        for transform in self.transforms:
            sign = sign * transform.sign
        self.sign = int(sign)

        super().__init__(GMM, transforms, validate_args=validate_args)

    @property
    def median(self):
        """Analytical median: Σ w_i * exp(μ_i)

        For LogNormal(μ, σ), median = exp(μ) (independent of σ!)
        For mixture, we use weighted combination of component medians.
        """
        a = self.mean_log_inter_time
        b = self.std_log_inter_time

        mu = self._locs
        weights = torch.softmax(self._log_weights, dim=-1)

        # Transformed μ after AffineTransform: μ' = a + b*μ
        mu_transformed = a + b * mu

        # Median of each LogNormal component: exp(μ')
        component_medians = torch.exp(mu_transformed)

        # Weighted combination
        return (weights * component_medians).sum(dim=-1)

    def log_cdf(self, x):
        for transform in self.transforms[::-1]:
            x = transform.inv(x)
        if self._validate_args:
            self.base_dist._validate_sample(x)
        if self.sign == 1:
            return self.base_dist.log_cdf(x)
        else:
            return self.base_dist.log_survival_function(x)

    def log_survival_function(self, x):
        for transform in self.transforms[::-1]:
            x = transform.inv(x)
        if self._validate_args:
            self.base_dist._validate_sample(x)
        if self.sign == 1:
            return self.base_dist.log_survival_function(x)
        else:
            return self.base_dist.log_cdf(x)


torch_intensity_free.LogNormalMixtureDistribution = FixedLogNormalMixtureDistribution
print("✓ LogNormalMixtureDistribution patched with median property")

✓ LogNormalMixtureDistribution patched with median property


In [None]:
# ==============================================================================
# PATCH 2b: INTENSITYFREE PREDICTION METHOD PATCH
# ==============================================================================

def patched_predict_one_step_at_every_event(self, batch):
    """Patched version using analytical median instead of sampling."""
    time_seq, time_delta_seq, event_seq, batch_non_pad_mask, _ = batch

    # Remove the last event
    time_seq, time_delta_seq, event_seq = time_seq[:, :-1], time_delta_seq[:, :-1], event_seq[:, :-1]

    # [batch_size, seq_len, hidden_size]
    context = self.forward(time_delta_seq, event_seq)

    # [batch_size, seq_len, 3 * num_mix_components]
    raw_params = self.linear(context)
    locs = raw_params[..., :self.num_mix_components]
    log_scales = raw_params[..., self.num_mix_components: (2 * self.num_mix_components)]
    log_weights = raw_params[..., (2 * self.num_mix_components):]

    log_scales = clamp_preserve_gradients(log_scales, -5.0, 3.0)
    log_weights = torch.log_softmax(log_weights, dim=-1)

    inter_time_dist = FixedLogNormalMixtureDistribution(
        locs=locs,
        log_scales=log_scales,
        log_weights=log_weights,
        mean_log_inter_time=self.mean_log_inter_time,
        std_log_inter_time=self.std_log_inter_time
    )

    # Use analytical median instead of sample().mean()
    dtimes_pred = inter_time_dist.median

    # [batch_size, seq_len, num_marks]
    mark_logits = torch.log_softmax(self.mark_linear(context), dim=-1)
    types_pred = torch.argmax(mark_logits, dim=-1)

    return dtimes_pred, types_pred


from easy_tpp.model.torch_model.torch_intensity_free import IntensityFree
IntensityFree.predict_one_step_at_every_event = patched_predict_one_step_at_every_event
print("✓ IntensityFree.predict_one_step_at_every_event patched to use median")

print("\n" + "="*70)
print("ALL PATCHES APPLIED SUCCESSFULLY")
print("="*70)

✓ IntensityFree.predict_one_step_at_every_event patched to use median

ALL PATCHES APPLIED SUCCESSFULLY


## 3. Mount Google Drive & Setup Directories

In [None]:
drive.mount('/content/drive')

BASE_DIR = '/content/drive/MyDrive/Colab Notebooks/MilestoneFall2025'
DATASET_DIR = os.path.join(BASE_DIR, 'Datasets')
CHECKPOINT_DIR = os.path.join(BASE_DIR, 'checkpoints')
RESULTS_DIR = os.path.join(BASE_DIR, 'results/Task1')
CONFIG_DIR = os.path.join(BASE_DIR, 'configs')

for d in [RESULTS_DIR, CONFIG_DIR, CHECKPOINT_DIR]:
    os.makedirs(d, exist_ok=True)

print(f"✓ Base Directory: {BASE_DIR}")
print(f"✓ Dataset Directory: {DATASET_DIR}")
print(f"✓ Checkpoint Directory: {CHECKPOINT_DIR}")
print(f"✓ Results Directory: {RESULTS_DIR}")

Mounted at /content/drive
✓ Base Directory: /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025
✓ Dataset Directory: /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/Datasets
✓ Checkpoint Directory: /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints
✓ Results Directory: /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/results/Task1


## 4. Configuration Dictionaries

### 4.1 Dataset Configuration

In [None]:
# Data specifications for each dataset
data_spec_dict = {
    "taxi": {
        "data_format": "pkl",
        "train_dir": os.path.join(DATASET_DIR, "taxi", "train.pkl"),
        "valid_dir": os.path.join(DATASET_DIR, "taxi", "dev.pkl"),
        "test_dir": os.path.join(DATASET_DIR, "taxi", "test.pkl"),
        "data_specs": {
            "num_event_types": 10,
            "pad_token_id": 10,
            "padding_side": "right",
            "max_seq_len": 100,
            "strict_pad_leng": True
        }
    },
    "amazon": {
        "data_format": "pkl",
        "train_dir": os.path.join(DATASET_DIR, "amazon", "train.pkl"),
        "valid_dir": os.path.join(DATASET_DIR, "amazon", "dev.pkl"),
        "test_dir": os.path.join(DATASET_DIR, "amazon", "test.pkl"),
        "data_specs": {
            "num_event_types": 16,
            "pad_token_id": 16,
            "padding_side": "right",
            "max_seq_len": 100,
            "strict_pad_leng": True
        }
    },
    "taobao": {
        "data_format": "pkl",
        "train_dir": os.path.join(DATASET_DIR, "taobao", "train.pkl"),
        "valid_dir": os.path.join(DATASET_DIR, "taobao", "dev.pkl"),
        "test_dir": os.path.join(DATASET_DIR, "taobao", "test.pkl"),
        "data_specs": {
            "num_event_types": 17,
            "pad_token_id": 17,
            "padding_side": "right",
            "max_seq_len": 150,
            "strict_pad_leng": True
        }
    },
    "stackoverflow": {
        "data_format": "pkl",
        "train_dir": os.path.join(DATASET_DIR, "stackoverflow", "train.pkl"),
        "valid_dir": os.path.join(DATASET_DIR, "stackoverflow", "dev.pkl"),
        "test_dir": os.path.join(DATASET_DIR, "stackoverflow", "test.pkl"),
        "data_specs": {
            "num_event_types": 22,
            "pad_token_id": 22,
            "padding_side": "right",
            "max_seq_len": 100,
            "strict_pad_leng": True
        }
    },
    "retweet": {
        "data_format": "pkl",
        "train_dir": os.path.join(DATASET_DIR, "retweet", "train.pkl"),
        "valid_dir": os.path.join(DATASET_DIR, "retweet", "dev.pkl"),
        "test_dir": os.path.join(DATASET_DIR, "retweet", "test.pkl"),
        "data_specs": {
            "num_event_types": 3,
            "pad_token_id": 3,
            "padding_side": "right",
            "max_seq_len": 100,
            "strict_pad_leng": True
        }
    }
}

print("✓ Data specifications loaded for 5 datasets:")
for ds in data_spec_dict:
    print(f"  - {ds}")

✓ Data specifications loaded for 5 datasets:
  - taxi
  - amazon
  - taobao
  - stackoverflow
  - retweet


### 4.2 Model Configuration (Corrected Hyperparameters)

| Model | hidden_size | time_emb_size | num_layers | num_heads |
|-------|-------------|---------------|------------|------------|
| RMTPP | 32 | 16 | 2 | - |
| NHP | 64 | 16 | 2 | - |
| FullyNN | 32 | 16 | 2 | - |
| SAHP | 32 | 16 | 2 | 2 |
| THP | 64 | 16 | 2 | 2 |
| IntensityFree | 32 | 16 | 2 | - |
| AttNHP | 32 | 16 | 1 | 2 |

In [None]:
# ==============================================================================
# MODEL SPECIFICATIONS
# ==============================================================================

# Common thinning parameters
default_thinning = {
    "num_seq": 10,
    "num_sample": 1,
    "num_exp": 500,
    "look_ahead_time": 10,
    "patience_counter": 5,
    "over_sample_rate": 5,
    "num_samples_boundary": 5,
    "dtime_max": 5
}

# Memory-optimized thinning for OOM-prone models
memory_optimized_thinning = {
    "num_seq": 10,
    "num_sample": 1,
    "num_exp": 100,  # Reduced from 500
    "look_ahead_time": 10,
    "patience_counter": 5,
    "over_sample_rate": 5,
    "num_samples_boundary": 5,
    "dtime_max": 5
}

model_spec_dict = {
    # RMTPP: hidden=32, time_emb=16, layers=2
    "RMTPP": {
        "model_id": "RMTPP",
        "hidden_size": 32,
        "time_embed_size": 16,
        "num_layers": 2,
        "dropout": 0.0,
        "use_ln": False,
        "loss_integral_num_sample_per_step": 20,
        "mc_num_sample_per_step": 20,
        "seed": 2019,
        "thinning": default_thinning.copy()
    },

    # NHP: hidden=64, time_emb=16, layers=2
    "NHP": {
        "model_id": "NHP",
        "hidden_size": 64,
        "time_embed_size": 16,
        "num_layers": 2,
        "dropout": 0.0,
        "use_ln": False,
        "loss_integral_num_sample_per_step": 20,
        "mc_num_sample_per_step": 20,
        "seed": 2019,
        "thinning": default_thinning.copy()
    },

    # FullyNN: hidden=32, time_emb=16, layers=2
    "FullyNN": {
        "model_id": "FullyNN",
        "hidden_size": 32,
        "time_embed_size": 16,
        "num_layers": 2,
        "dropout": 0.0,
        "use_ln": False,
        "seed": 2019,
        "model_specs": {
            "num_mlp_layers": 3,
            "proper_marked_intensities": True
        },
        "thinning": memory_optimized_thinning.copy()  # Memory-optimized
    },

    # SAHP: hidden=32, time_emb=16, layers=2, heads=2
    "SAHP": {
        "model_id": "SAHP",
        "hidden_size": 32,
        "time_embed_size": 16,
        "num_layers": 2,
        "num_heads": 2,
        "mc_num_sample_per_step": 20,
        "sharing_param_layer": False,
        "loss_integral_num_sample_per_step": 20,
        "dropout": 0.0,
        "use_ln": False,
        "seed": 2019,
        "thinning": default_thinning.copy()
    },

    # THP: hidden=64, time_emb=16, layers=2, heads=2
    "THP": {
        "model_id": "THP",
        "hidden_size": 64,
        "time_embed_size": 16,
        "num_layers": 2,
        "num_heads": 2,
        "mc_num_sample_per_step": 20,
        "sharing_param_layer": False,
        "loss_integral_num_sample_per_step": 20,
        "dropout": 0.0,
        "use_ln": False,
        "seed": 2019,
        "thinning": default_thinning.copy()
    },

    # IntensityFree: hidden=32, time_emb=16, layers=2
    "IntensityFree": {
        "model_id": "IntensityFree",
        "hidden_size": 32,
        "time_embed_size": 16,
        "num_layers": 2,
        "dropout": 0.0,
        "use_ln": False,
        "seed": 2019,
        "sharing_param_layer": False,
        "loss_integral_num_sample_per_step": 20,
        "mc_num_sample_per_step": 20,
        "num_mix_components": 3,
        "model_specs": {
            "num_mix_components": 3
        },
        "thinning": {
            **default_thinning,
            "num_step_gen": 10
        }
    },

    # AttNHP: hidden=32, time_emb=16, layers=1, heads=2
    "AttNHP": {
        "model_id": "AttNHP",
        "hidden_size": 32,
        "time_embed_size": 16,
        "num_layers": 1,  # NOTE: Only 1 layer per paper
        "num_heads": 2,
        "mc_num_sample_per_step": 20,
        "sharing_param_layer": False,
        "loss_integral_num_sample_per_step": 20,
        "dropout": 0.0,
        "use_ln": False,
        "seed": 2019,
        "thinning": memory_optimized_thinning.copy()  # Memory-optimized
    }
}

print("✓ Model specifications loaded for 7 models:")
print("\n" + "="*70)
print(f"{'Model':<15} {'hidden_size':<12} {'time_emb':<10} {'layers':<8} {'heads':<8}")
print("-"*70)
for model_id, spec in model_spec_dict.items():
    h = spec.get('hidden_size', '-')
    t = spec.get('time_embed_size', '-')
    l = spec.get('num_layers', '-')
    heads = spec.get('num_heads', '-')
    print(f"{model_id:<15} {h:<12} {t:<10} {l:<8} {heads:<8}")
print("="*70)

✓ Model specifications loaded for 7 models (CORRECTED HYPERPARAMETERS):

Model           hidden_size  time_emb   layers   heads   
----------------------------------------------------------------------
RMTPP           32           16         2        -       
NHP             64           16         2        -       
FullyNN         32           16         2        -       
SAHP            32           16         2        2       
THP             64           16         2        2       
IntensityFree   32           16         2        -       
AttNHP          32           16         1        2       


### 4.3 Training Configuration

In [None]:
# Default trainer configuration
trainer_config_default = {
    "batch_size": 256,
    "max_epoch": 200,
    "shuffle": False,
    "optimizer": "adam",
    "learning_rate": 1e-3,
    "valid_freq": 1,
    "use_tfb": False,
    "metrics": ["acc", "rmse"],
    "seed": 2019,
    "gpu": 0
}

# Memory-optimized trainer config for OOM-prone combinations
trainer_config_memory_optimized = {
    "batch_size": 32,  # Reduced from 256
    "max_epoch": 200,
    "shuffle": False,
    "optimizer": "adam",
    "learning_rate": 1e-3,
    "valid_freq": 1,
    "use_tfb": False,
    "metrics": ["acc", "rmse"],
    "seed": 2019,
    "gpu": 0
}

# Define which model/dataset combinations need memory optimization
MEMORY_OPTIMIZED_COMBINATIONS = [
    ("FullyNN", "amazon"),
    ("FullyNN", "taobao"),
    ("FullyNN", "stackoverflow"),
    ("AttNHP", "amazon"),
    ("AttNHP", "taobao"),
    ("AttNHP", "stackoverflow"),
]

def get_trainer_config(model_id, data_id):
    """Get appropriate trainer config based on model/dataset combination."""
    if (model_id, data_id) in MEMORY_OPTIMIZED_COMBINATIONS:
        return trainer_config_memory_optimized.copy()
    return trainer_config_default.copy()

print("✓ Trainer configurations loaded")
print(f"  - Default batch size: {trainer_config_default['batch_size']}")
print(f"  - Memory-optimized batch size: {trainer_config_memory_optimized['batch_size']}")
print(f"  - Max epochs: {trainer_config_default['max_epoch']}")
print(f"  - Memory-optimized combinations: {len(MEMORY_OPTIMIZED_COMBINATIONS)}")

✓ Trainer configurations loaded
  - Default batch size: 256
  - Memory-optimized batch size: 32
  - Max epochs: 200
  - Memory-optimized combinations: 6


## 5. Helper Functions

In [None]:
def create_experiment_config(model_id, data_id, data_spec, model_spec, trainer_cfg):
    """Create a complete experiment configuration."""
    experiment_id = f"{model_id}_{data_id}_train"

    config = {
        "pipeline_config_id": "runner_config",
        "data": {
            data_id: data_spec
        },
        experiment_id: {
            "base_config": {
                "stage": "train",
                "backend": "torch",
                "dataset_id": data_id,
                "runner_id": "std_tpp",
                "model_id": model_id,
                "base_dir": CHECKPOINT_DIR
            },
            "trainer_config": trainer_cfg,
            "model_config": model_spec
        }
    }

    return config, experiment_id


def extract_results_from_logs(log_path):
    """Extract best results from training logs."""
    try:
        if os.path.isfile(log_path):
            log_file = log_path
        elif os.path.isdir(log_path):
            log_files = [f for f in os.listdir(log_path) if f.endswith('.log')]
            if not log_files:
                return None
            log_file = os.path.join(log_path, log_files[0])
        else:
            print(f"  Warning: Log path does not exist: {log_path}")
            return None

        with open(log_file, 'r') as f:
            lines = f.readlines()

        best_ll = None
        best_acc = None
        best_rmse = None

        for line in lines:
            if 'test loglike is' in line:
                try:
                    parts = line.split(',')
                    ll_part = [p for p in parts if 'test loglike is' in p][0]
                    best_ll = float(ll_part.split('test loglike is')[1].strip())

                    acc_parts = [p for p in parts if 'acc is' in p]
                    if acc_parts:
                        best_acc = float(acc_parts[0].split('acc is')[1].strip())

                    rmse_parts = [p for p in parts if 'rmse is' in p]
                    if rmse_parts:
                        best_rmse = float(rmse_parts[0].split('rmse is')[1].strip())
                except (IndexError, ValueError) as e:
                    continue

        if best_ll is None:
            return None

        return {
            'log_likelihood': best_ll,
            'accuracy': best_acc,
            'rmse': best_rmse
        }
    except Exception as e:
        print(f"  Warning: Could not extract results from logs: {e}")
        return None


def find_latest_checkpoint_dir(base_dir, model_id, data_id):
    """Find the most recent checkpoint directory for an experiment."""
    try:
        target_yaml = f"{model_id}_{data_id}_train_output.yaml"

        matching_dirs = []
        for d in os.listdir(base_dir):
            dir_path = os.path.join(base_dir, d)
            if os.path.isdir(dir_path):
                yaml_path = os.path.join(dir_path, target_yaml)
                if os.path.exists(yaml_path):
                    matching_dirs.append(dir_path)

        if not matching_dirs:
            return None

        latest = max(matching_dirs, key=os.path.getmtime)
        return latest
    except Exception as e:
        print(f"  Warning: Error finding checkpoint dir: {e}")
        return None


def clear_gpu_memory():
    """Aggressively clear GPU memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()


print("✓ Helper functions loaded")

✓ Helper functions loaded


In [None]:
def run_single_experiment(model_id, data_id, results_list):
    """
    Run a single model-dataset experiment.

    Returns the result dictionary.
    """
    experiment_name = f"{model_id}_{data_id}"
    start_time = time.time()

    print(f"\n{'='*70}")
    print(f"Running {model_id} on {data_id}")
    print(f"{'='*70}")

    try:
        # Get appropriate configs
        data_spec = data_spec_dict[data_id]
        model_spec = model_spec_dict[model_id]
        trainer_cfg = get_trainer_config(model_id, data_id)

        # Log if using memory-optimized settings
        if (model_id, data_id) in MEMORY_OPTIMIZED_COMBINATIONS:
            print(f"  ⚡ Using memory-optimized settings (batch_size={trainer_cfg['batch_size']})")

        # Create config
        config_dict, experiment_id = create_experiment_config(
            model_id=model_id,
            data_id=data_id,
            data_spec=data_spec,
            model_spec=model_spec,
            trainer_cfg=trainer_cfg
        )

        # Save config
        config_path = os.path.join(CONFIG_DIR, f"{experiment_name}.yaml")
        with open(config_path, 'w') as f:
            yaml.dump(config_dict, f)

        # Build and run
        config = Config.build_from_yaml_file(config_path, experiment_id=experiment_id)
        runner = Runner.build_from_config(config)
        runner.run()

        # Find checkpoint directory
        checkpoint_dir = find_latest_checkpoint_dir(CHECKPOINT_DIR, model_id, data_id)

        # Extract results
        if checkpoint_dir:
            log_dir = os.path.join(checkpoint_dir, 'log')
            metrics = extract_results_from_logs(log_dir)

            if metrics:
                result = {
                    'model': model_id,
                    'dataset': data_id,
                    'log_likelihood': metrics['log_likelihood'],
                    'accuracy': metrics['accuracy'],
                    'rmse': metrics['rmse'],
                    'status': 'success',
                    'time_seconds': time.time() - start_time
                }
                print(f"\n✓ {model_id} on {data_id} completed successfully!")
                print(f"  - Log-Likelihood: {metrics['log_likelihood']:.4f}")
                if metrics['accuracy']:
                    print(f"  - Accuracy: {metrics['accuracy']:.4f}")
                if metrics['rmse']:
                    print(f"  - RMSE: {metrics['rmse']:.4f}")
            else:
                result = {
                    'model': model_id,
                    'dataset': data_id,
                    'log_likelihood': None,
                    'accuracy': None,
                    'rmse': None,
                    'status': 'completed_no_metrics',
                    'time_seconds': time.time() - start_time
                }
                print(f"⚠ {model_id} on {data_id} completed but no metrics found")
        else:
            result = {
                'model': model_id,
                'dataset': data_id,
                'log_likelihood': None,
                'accuracy': None,
                'rmse': None,
                'status': 'no_checkpoint',
                'time_seconds': time.time() - start_time
            }
            print(f"⚠ {model_id} on {data_id} - no checkpoint directory found")

        # Clean up
        del runner
        del config

    except Exception as e:
        result = {
            'model': model_id,
            'dataset': data_id,
            'log_likelihood': None,
            'accuracy': None,
            'rmse': None,
            'status': 'failed',
            'error': str(e),
            'time_seconds': time.time() - start_time
        }
        print(f"\n✗ Error running {model_id} on {data_id}: {str(e)}")
        traceback.print_exc()

    # Append result
    results_list.append(result)

    # Save intermediate results
    intermediate_df = pd.DataFrame(results_list)
    intermediate_df.to_csv(os.path.join(RESULTS_DIR, 'intermediate_results.csv'), index=False)

    # Clear GPU memory
    clear_gpu_memory()

    return result


print("✓ Experiment runner function loaded")

✓ Experiment runner function loaded


## 6. Run Experiments

Running all 35 experiments (7 models × 5 datasets).

In [None]:
# Initialize results storage
results_list = []
total_experiments = len(model_spec_dict) * len(data_spec_dict)

print(f"Starting {total_experiments} experiments...")
print(f"Running {len(model_spec_dict)} models on {len(data_spec_dict)} datasets")
print(f"Results will be saved incrementally to: {RESULTS_DIR}")
print("="*70)

Starting 35 experiments...
Running 7 models on 5 datasets
Results will be saved incrementally to: /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/results/Task1


In [None]:
# Run RMTPP on all datasets
print("\n" + "#"*70)
print("# MODEL 1/7: RMTPP")
print("#"*70)

for data_id in data_spec_dict.keys():
    run_single_experiment("RMTPP", data_id, results_list)


######################################################################
# MODEL 1/7: RMTPP
######################################################################

Running RMTPP on taxi
[31;1m2026-01-01 02:18:31,507 - config.py[pid:2009;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig[0m
[31;1m2026-01-01 02:18:33,446 - runner_config.py[pid:2009;line:140:update_config] - CRITICAL: train model RMTPP using GPU with torch backend[0m
[38;20m2026-01-01 02:18:33,460 - runner_config.py[pid:2009;line:35:__init__] - INFO: Save the config to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-021831/RMTPP_taxi_train_output.yaml[0m
[38;20m2026-01-01 02:18:33,482 - base_runner.py[pid:2009;line:176:save_log] - INFO: Save the log to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-021831/log[0m
0.22442521993973832 0.29228809611195583
min_dt: 0.0002777777777777657
max_dt

In [None]:
# Run NHP on all datasets
print("\n" + "#"*70)
print("# MODEL 2/7: NHP")
print("#"*70)

for data_id in data_spec_dict.keys():
    run_single_experiment("NHP", data_id, results_list)


######################################################################
# MODEL 2/7: NHP
######################################################################

Running NHP on taxi
[31;1m2026-01-01 02:25:10,326 - config.py[pid:2009;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig[0m
[31;1m2026-01-01 02:25:10,332 - runner_config.py[pid:2009;line:140:update_config] - CRITICAL: train model NHP using GPU with torch backend[0m
[38;20m2026-01-01 02:25:10,343 - runner_config.py[pid:2009;line:35:__init__] - INFO: Save the config to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-022510/NHP_taxi_train_output.yaml[0m
[38;20m2026-01-01 02:25:10,346 - base_runner.py[pid:2009;line:176:save_log] - INFO: Save the log to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-022510/log[0m
0.22442521993973832 0.29228809611195583
min_dt: 0.0002777777777777657
max_dt: 5.7213

In [None]:
# Run FullyNN on all datasets
# NOTE: Uses memory-optimized settings for amazon, taobao, stackoverflow
print("\n" + "#"*70)
print("# MODEL 3/7: FullyNN (with gradient patch)")
print("#"*70)

for data_id in data_spec_dict.keys():
    run_single_experiment("FullyNN", data_id, results_list)


######################################################################
# MODEL 3/7: FullyNN (with gradient patch)
######################################################################

Running FullyNN on taxi
[31;1m2026-01-01 03:16:44,207 - config.py[pid:2009;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig[0m
[31;1m2026-01-01 03:16:44,214 - runner_config.py[pid:2009;line:140:update_config] - CRITICAL: train model FullyNN using GPU with torch backend[0m
[38;20m2026-01-01 03:16:44,226 - runner_config.py[pid:2009;line:35:__init__] - INFO: Save the config to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-031644/FullyNN_taxi_train_output.yaml[0m
[38;20m2026-01-01 03:16:44,229 - base_runner.py[pid:2009;line:176:save_log] - INFO: Save the log to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-031644/log[0m
0.22442521993973832 0.29228809611195583
min_dt

In [None]:
# Run SAHP on all datasets
print("\n" + "#"*70)
print("# MODEL 4/7: SAHP")
print("#"*70)

for data_id in data_spec_dict.keys():
    run_single_experiment("SAHP", data_id, results_list)


######################################################################
# MODEL 4/7: SAHP
######################################################################

Running SAHP on taxi
[31;1m2026-01-01 04:19:19,004 - config.py[pid:2009;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig[0m
[31;1m2026-01-01 04:19:19,011 - runner_config.py[pid:2009;line:140:update_config] - CRITICAL: train model SAHP using GPU with torch backend[0m
[38;20m2026-01-01 04:19:19,024 - runner_config.py[pid:2009;line:35:__init__] - INFO: Save the config to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-041919/SAHP_taxi_train_output.yaml[0m
[38;20m2026-01-01 04:19:19,028 - base_runner.py[pid:2009;line:176:save_log] - INFO: Save the log to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-041919/log[0m
0.22442521993973832 0.29228809611195583
min_dt: 0.0002777777777777657
max_dt: 5.

In [None]:
# Run THP on all datasets
print("\n" + "#"*70)
print("# MODEL 5/7: THP")
print("#"*70)

for data_id in data_spec_dict.keys():
    run_single_experiment("THP", data_id, results_list)


######################################################################
# MODEL 5/7: THP
######################################################################

Running THP on taxi
[31;1m2026-01-01 04:26:28,377 - config.py[pid:2009;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig[0m
[31;1m2026-01-01 04:26:28,384 - runner_config.py[pid:2009;line:140:update_config] - CRITICAL: train model THP using GPU with torch backend[0m
[38;20m2026-01-01 04:26:28,396 - runner_config.py[pid:2009;line:35:__init__] - INFO: Save the config to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-042628/THP_taxi_train_output.yaml[0m
[38;20m2026-01-01 04:26:28,400 - base_runner.py[pid:2009;line:176:save_log] - INFO: Save the log to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-042628/log[0m
0.22442521993973832 0.29228809611195583
min_dt: 0.0002777777777777657
max_dt: 5.7213

In [None]:
# Run IntensityFree on all datasets
# NOTE: Uses median patch for stable RMSE
print("\n" + "#"*70)
print("# MODEL 6/7: IntensityFree (with median patch)")
print("#"*70)

for data_id in data_spec_dict.keys():
    run_single_experiment("IntensityFree", data_id, results_list)


######################################################################
# MODEL 6/7: IntensityFree (with median patch)
######################################################################

Running IntensityFree on taxi
[31;1m2026-01-01 04:33:34,402 - config.py[pid:2009;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig[0m
[31;1m2026-01-01 04:33:34,418 - runner_config.py[pid:2009;line:140:update_config] - CRITICAL: train model IntensityFree using GPU with torch backend[0m
[38;20m2026-01-01 04:33:34,433 - runner_config.py[pid:2009;line:35:__init__] - INFO: Save the config to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-043334/IntensityFree_taxi_train_output.yaml[0m
[38;20m2026-01-01 04:33:34,437 - base_runner.py[pid:2009;line:176:save_log] - INFO: Save the log to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-043334/log[0m
0.22442521993973832 0.29

In [None]:
# Run AttNHP on all datasets
# NOTE: Uses memory-optimized settings for amazon, taobao, stackoverflow
print("\n" + "#"*70)
print("# MODEL 7/7: AttNHP (memory-intensive)")
print("#"*70)

for data_id in data_spec_dict.keys():
    run_single_experiment("AttNHP", data_id, results_list)


######################################################################
# MODEL 7/7: AttNHP (memory-intensive)
######################################################################

Running AttNHP on taxi
[31;1m2026-01-01 04:38:39,462 - config.py[pid:2009;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class RunnerConfig[0m
[31;1m2026-01-01 04:38:39,469 - runner_config.py[pid:2009;line:140:update_config] - CRITICAL: train model AttNHP using GPU with torch backend[0m
[38;20m2026-01-01 04:38:39,493 - runner_config.py[pid:2009;line:35:__init__] - INFO: Save the config to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-043839/AttNHP_taxi_train_output.yaml[0m
[38;20m2026-01-01 04:38:39,498 - base_runner.py[pid:2009;line:176:save_log] - INFO: Save the log to /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/checkpoints/2009_137863975707264_260101-043839/log[0m
0.22442521993973832 0.29228809611195583
min_dt: 0.000

## 7. Results Summary & Formatting

In [None]:
# ==============================================================================
# FORMAT RESULTS FOR REPORT
# ==============================================================================
# Format: RMSE / Type Error Rate %
# Matches Table 1 style from EasyTPP paper
# ==============================================================================

def format_final_results(results_list):
    """Format results into paper-style table."""

    # Create DataFrame
    df = pd.DataFrame(results_list)

    # Calculate type error rate: 1 - accuracy
    df['type_error_rate'] = df['accuracy'].apply(
        lambda x: (1 - x) * 100 if x is not None and x > 0 else None
    )

    # Format cell values as "RMSE/TypeErr%"
    def format_cell(row):
        if row['status'] != 'success':
            return 'NaN'
        if row['rmse'] is None or row['type_error_rate'] is None:
            return 'NaN'
        return f"{row['rmse']:.3f}/{row['type_error_rate']:.1f}%"

    df['formatted'] = df.apply(format_cell, axis=1)

    # Pivot to create table
    pivot_df = df.pivot(
        index='model',
        columns='dataset',
        values='formatted'
    )

    # Reorder columns
    col_order = ['amazon', 'retweet', 'stackoverflow', 'taobao', 'taxi']
    pivot_df = pivot_df[[c for c in col_order if c in pivot_df.columns]]

    # Reorder rows
    row_order = ['RMTPP', 'NHP', 'FullyNN', 'SAHP', 'THP', 'IntensityFree', 'AttNHP']
    pivot_df = pivot_df.reindex([r for r in row_order if r in pivot_df.index])

    return pivot_df


# Format and display
if results_list:
    final_df = format_final_results(results_list)

    print("="*80)
    print("FINAL RESULTS TABLE (Format: RMSE / Type Error Rate %)")
    print("="*80)
    display(final_df)

    # Save to CSV
    final_df.to_csv(os.path.join(RESULTS_DIR, 'Task1_results_final.csv'))
    print(f"\n✓ Results saved to: {RESULTS_DIR}/Task1_results_final.csv")

    # Print markdown version
    print("\n" + "="*80)
    print("MARKDOWN TABLE (for report)")
    print("="*80)
    final_df_reset = final_df.reset_index()
    print(final_df_reset.to_markdown(index=False))
else:
    print("⚠ No results to format")

FINAL RESULTS TABLE (Format: RMSE / Type Error Rate %)


dataset,amazon,retweet,stackoverflow,taobao,taxi
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
RMTPP,0.461/67.0%,25.648/44.7%,1.417/57.5%,0.267/56.4%,0.396/8.8%
NHP,0.520/70.1%,22.953/39.9%,1.417/55.4%,0.844/41.3%,0.449/8.5%
FullyNN,2.884/70.2%,19.457/46.0%,1.705/58.7%,1.987/56.4%,0.883/58.0%
SAHP,0.549/68.1%,21.796/40.8%,1.271/55.9%,0.232/46.3%,0.344/8.9%
THP,0.468/65.8%,25.993/40.7%,1.403/54.6%,0.342/44.0%,0.370/8.9%
IntensityFree,0.329/65.0%,17.133/39.9%,1.044/54.6%,0.193/39.0%,0.301/8.4%
AttNHP,8711.428/84.1%,19.787/40.5%,1.309/64.4%,0.132/56.4%,0.386/11.6%



✓ Results saved to: /content/drive/MyDrive/Colab Notebooks/MilestoneFall2025/results/Task1/Task1_results_final.csv

MARKDOWN TABLE (for report)
| model         | amazon         | retweet      | stackoverflow   | taobao      | taxi        |
|:--------------|:---------------|:-------------|:----------------|:------------|:------------|
| RMTPP         | 0.461/67.0%    | 25.648/44.7% | 1.417/57.5%     | 0.267/56.4% | 0.396/8.8%  |
| NHP           | 0.520/70.1%    | 22.953/39.9% | 1.417/55.4%     | 0.844/41.3% | 0.449/8.5%  |
| FullyNN       | 2.884/70.2%    | 19.457/46.0% | 1.705/58.7%     | 1.987/56.4% | 0.883/58.0% |
| SAHP          | 0.549/68.1%    | 21.796/40.8% | 1.271/55.9%     | 0.232/46.3% | 0.344/8.9%  |
| THP           | 0.468/65.8%    | 25.993/40.7% | 1.403/54.6%     | 0.342/44.0% | 0.370/8.9%  |
| IntensityFree | 0.329/65.0%    | 17.133/39.9% | 1.044/54.6%     | 0.193/39.0% | 0.301/8.4%  |
| AttNHP        | 8711.428/84.1% | 19.787/40.5% | 1.309/64.4%     | 0.132/56.4% | 0.386

In [None]:
# ==============================================================================
# EXPERIMENT SUMMARY STATISTICS
# ==============================================================================

if results_list:
    df = pd.DataFrame(results_list)

    print("="*70)
    print("EXPERIMENT SUMMARY")
    print("="*70)

    total = len(df)
    successful = len(df[df['status'] == 'success'])
    failed = len(df[df['status'] == 'failed'])
    other = total - successful - failed

    print(f"Total experiments: {total}")
    print(f"  ✓ Successful: {successful}")
    print(f"  ✗ Failed: {failed}")
    if other > 0:
        print(f"  ? Other: {other}")

    # Time summary
    total_time = df['time_seconds'].sum()
    print(f"\nTotal runtime: {total_time/3600:.2f} hours")

    # Per-model summary
    print("\n" + "-"*70)
    print("Per-Model Summary:")
    print("-"*70)

    for model in ['RMTPP', 'NHP', 'FullyNN', 'SAHP', 'THP', 'IntensityFree', 'AttNHP']:
        model_df = df[df['model'] == model]
        if len(model_df) > 0:
            success = len(model_df[model_df['status'] == 'success'])
            model_time = model_df['time_seconds'].sum()
            print(f"  {model}: {success}/5 successful ({model_time/60:.1f} min)")

    # Check for any issues
    if failed > 0:
        print("\n" + "!"*70)
        print("FAILED EXPERIMENTS:")
        print("!"*70)
        for _, row in df[df['status'] == 'failed'].iterrows():
            print(f"  - {row['model']} on {row['dataset']}: {row.get('error', 'Unknown error')}")

EXPERIMENT SUMMARY
Total experiments: 35
  ✓ Successful: 35
  ✗ Failed: 0

Total runtime: 3.46 hours

----------------------------------------------------------------------
Per-Model Summary:
----------------------------------------------------------------------
  RMTPP: 5/5 successful (6.6 min)
  NHP: 5/5 successful (51.5 min)
  FullyNN: 5/5 successful (62.5 min)
  SAHP: 5/5 successful (7.1 min)
  THP: 5/5 successful (7.1 min)
  IntensityFree: 5/5 successful (5.1 min)
  AttNHP: 5/5 successful (67.5 min)


## 8. Verification & Quality Check

In [None]:
# ==============================================================================
# VERIFICATION - Check all results are present and reasonable
# ==============================================================================

def verify_results(results_list):
    """Verify that all experiments completed and results are reasonable."""

    df = pd.DataFrame(results_list)

    print("="*70)
    print("VERIFICATION CHECKS")
    print("="*70)

    issues = []

    # Check 1: All model-dataset combinations present
    expected_models = ['RMTPP', 'NHP', 'FullyNN', 'SAHP', 'THP', 'IntensityFree', 'AttNHP']
    expected_datasets = ['taxi', 'amazon', 'taobao', 'stackoverflow', 'retweet']

    for model in expected_models:
        for dataset in expected_datasets:
            mask = (df['model'] == model) & (df['dataset'] == dataset)
            if not mask.any():
                issues.append(f"Missing: {model} on {dataset}")

    if issues:
        print("\n⚠ Missing experiments:")
        for issue in issues:
            print(f"  - {issue}")
    else:
        print("✓ All 35 model-dataset combinations present")

    # Check 2: RMSE values are reasonable
    print("\nRMSE Sanity Check:")
    success_df = df[df['status'] == 'success']

    for _, row in success_df.iterrows():
        rmse = row['rmse']
        if rmse is not None:
            if rmse > 1000:
                print(f"  ⚠ High RMSE: {row['model']}/{row['dataset']} = {rmse:.2f}")
            elif rmse < 0:
                print(f"  ✗ Negative RMSE: {row['model']}/{row['dataset']} = {rmse:.2f}")

    if not any(success_df['rmse'] > 1000):
        print("  ✓ All RMSE values are reasonable (< 1000)")

    # Check 3: Accuracy values are in [0, 1]
    print("\nAccuracy Sanity Check:")
    acc_issues = False
    for _, row in success_df.iterrows():
        acc = row['accuracy']
        if acc is not None:
            if acc < 0 or acc > 1:
                print(f"  ⚠ Invalid accuracy: {row['model']}/{row['dataset']} = {acc:.4f}")
                acc_issues = True

    if not acc_issues:
        print("  ✓ All accuracy values in valid range [0, 1]")

    print("\n" + "="*70)


if results_list:
    verify_results(results_list)

VERIFICATION CHECKS
✓ All 35 model-dataset combinations present

RMSE Sanity Check:
  ⚠ High RMSE: AttNHP/amazon = 8711.43

Accuracy Sanity Check:
  ✓ All accuracy values in valid range [0, 1]



In [None]:
# ==============================================================================
# AUTO-DISCONNECT RUNTIME
# ==============================================================================
# Disconnects the Colab runtime to save compute units after experiments complete.
# Add this as the last cell in your notebook.
# ==============================================================================

import time

# Optional: Wait a few seconds to ensure all files are saved to Drive
print("Waiting 30 seconds to ensure all files are synced to Google Drive...")
time.sleep(30)

print("✓ All experiments complete. Disconnecting runtime to save compute units...")

# Method 1: Using Colab's runtime API (preferred)
try:
    from google.colab import runtime
    runtime.unassign()
except:
    pass

# Method 2: JavaScript fallback (if Method 1 fails)
try:
    from IPython.display import Javascript
    display(Javascript('google.colab.kernel.disconnect()'))
except:
    pass

print("Runtime disconnected.")

Waiting 30 seconds to ensure all files are synced to Google Drive...
