In [None]:
import torch
import os
print("PyTorch has version {}".format(torch.__version__))
import sys
print(sys.version)

PyTorch has version 2.8.0+cu126
3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]


In [None]:
# Install torch geometric
if 'IS_GRADESCOPE_ENV' not in os.environ:
  # Clean uninstall first
  !pip uninstall -y torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib

  torch_version = str(torch.__version__)
  scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
  sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
  !pip install torch-scatter -f $scatter_src
  !pip install torch-sparse -f $sparse_src
  !pip install torch-geometric
  !pip install torch-geometric-temporal
  !pip install ogb

Found existing installation: torch-geometric 2.7.0
Uninstalling torch-geometric-2.7.0:
  Successfully uninstalled torch-geometric-2.7.0
Found existing installation: torch_sparse 0.6.18+pt28cu126
Uninstalling torch_sparse-0.6.18+pt28cu126:
  Successfully uninstalled torch_sparse-0.6.18+pt28cu126
Found existing installation: torch_scatter 2.1.2+pt28cu126
Uninstalling torch_scatter-2.1.2+pt28cu126:
  Successfully uninstalled torch_scatter-2.1.2+pt28cu126
[0mLooking in links: https://pytorch-geometric.com/whl/torch-2.8.0+cu126.html
Collecting torch-scatter
  Using cached https://data.pyg.org/whl/torch-2.8.0%2Bcu126/torch_scatter-2.1.2%2Bpt28cu126-cp312-cp312-linux_x86_64.whl (10.9 MB)
Installing collected packages: torch-scatter
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch-geometric-temporal 0.56.2 requires torch-geometric, which is not installed.
to

In [None]:
import pandas as pd
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
data_path = "/content/drive/MyDrive/CS 224W/synthetic_data.xlsx"
df = pd.read_excel(data_path)
df

Unnamed: 0,fault_radius,event_time,magnitude,rec_interval,latitude,longitude,depth
0,50.0,0.656601,2.008661,,20.0,-122,12
1,50.0,1.106016,2.008661,0.449415,20.0,-122,12
2,50.0,1.596545,2.008661,0.490529,20.0,-122,12
3,50.0,2.044598,2.008661,0.448053,20.0,-122,12
4,50.0,2.475660,2.008661,0.431062,20.0,-122,12
...,...,...,...,...,...,...,...
2179,300.0,13.411939,3.564964,2.492310,38.0,-122,12
2180,300.0,17.263929,3.564964,3.851990,38.0,-122,12
2181,300.0,20.529043,3.564964,3.265114,38.0,-122,12
2182,300.0,22.699272,3.564964,2.170229,38.0,-122,12


DATALOADING FOR EMBEDDING TIME SERIES INTO NODE FEATURES

In [None]:
import os
import numpy as np
import pandas as pd

def haversine(lon1, lat1, lon2, lat2):
    """Compute great-circle distance (km) between two lat/lon points."""
    R = 6371.0
    lon1, lat1, lon2, lat2 = map(np.radians, [lon1, lat1, lon2, lat2])
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = np.sin(dlat/2.0)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon/2.0)**2
    c = 2 * np.arcsin(np.sqrt(a))
    return R * c

# NOTE! change this pointer to the appropriate location in the drive
DATA_FILE = "/content/drive/MyDrive/CS 224W/synthetic_data.xlsx"
DIST_THRESHOLD_KM = 50.0
N_MONTHS = 6

def load_catalog(path):
    """Load earthquake catalog."""
    if not os.path.exists(path):
        raise FileNotFoundError(path)
    df = pd.read_excel(path)
    df.columns = [c.strip().lower() for c in df.columns]
    return df


def detect_columns(df):
    """Detect column names for standard fields."""
    time_col = "event_time"
    lat_col = "latitude"
    lon_col = "longitude"
    mag_col = "magnitude"
    id_col = "fault_radius"
    return time_col, lat_col, lon_col, mag_col, id_col


def build_nodes_and_adj(df, lat_col, lon_col, dist_thresh_km=DIST_THRESHOLD_KM):
    """
    Build adjacency matrix for nodes (fault patches).
    Nodes are connected if distance < dist_thresh_km.
    """
    unique_nodes = df["fault_radius"].unique()
    N = len(unique_nodes)
    lat_vals = df.groupby("fault_radius")[lat_col].mean().values
    lon_vals = df.groupby("fault_radius")[lon_col].mean().values

    D = np.zeros((N, N), dtype=float)
    for i in range(N):
        for j in range(N):
            D[i, j] = haversine(lon_vals[i], lat_vals[i], lon_vals[j], lat_vals[j])

    adj = (D <= dist_thresh_km).astype(float)

    deg = adj.sum(axis=1)
    deg_inv_sqrt = np.diag(1.0 / np.sqrt(deg))
    adj_norm = deg_inv_sqrt @ adj @ deg_inv_sqrt

    return adj, adj_norm, N

if __name__ == "__main__":
    df = load_catalog(DATA_FILE)
    time_col, lat_col, lon_col, mag_col, id_col = detect_columns(df)
    adj, adj_norm, N = build_nodes_and_adj(df, lat_col, lon_col)

In [None]:
N_MONTHS=10

def years_to_datetimes(years, ref_year=2000):
    start = pd.Timestamp(f"{int(ref_year)}-01-01")
    return [start + pd.DateOffset(months=int(y*12)) for y in years]

def build_monthly_samples(df, time_col, mag_col, node_id_col="fault_radius", n_months=N_MONTHS):
    df = df.copy()
    if np.issubdtype(df[time_col].dtype, np.number):
        df["time_dt"] = years_to_datetimes(df[time_col].values)
    else:
        df["time_dt"] = pd.to_datetime(df[time_col], errors="coerce")
    df["month"] = df["time_dt"].dt.to_period("M").dt.to_timestamp()
    months = pd.date_range(start=df["month"].min(), end=df["month"].max(), freq="MS")
    #print(months)
    node_list = sorted(df[node_id_col].unique())
    N_nodes = len(node_list)
    counts = np.zeros((N_nodes, len(months)))
    mags = np.zeros((N_nodes, len(months)))

    for i,nid in enumerate(node_list):
        sub = df[df[node_id_col]==nid]
        cnt = sub.groupby("month").size()
        #import code; code.interact(local = locals())
        for m,v in cnt.items():
            if m in months:
                counts[i, months.get_loc(m)] = v
        mmean = sub.groupby("month")[mag_col].mean()
        for m,v in mmean.items():
            if m in months:
                mags[i, months.get_loc(m)] = v

    samples = []
    for t in range(n_months, len(months)-1):
        for i in range(N_nodes):
            hist = counts[i, t-n_months:t]
            hist_bin = (hist>0).astype(float)
            sum_count = hist.sum()
            mean_mag = mags[i, t-n_months:t].mean()
            last_event = next((b for b in range(1,n_months+1) if counts[i, t-b]>0), -1)
            features = np.concatenate([hist_bin, [sum_count, mean_mag, last_event]])
            target = int(counts[i, t]>0)
            samples.append({
                "month_idx": t,
                "node_idx": i,
                "features": features,
                "target": target
            })
    samples_by_month = {}
    node_feat_template = np.zeros((N_nodes, n_months+3))
    for s in samples:
        samples_by_month.setdefault(s["month_idx"], []).append(s)
    months_sorted = sorted(samples_by_month.keys())

    feat_dim = node_feat_template.shape[1]
    X_months, y_months = [], []
    for t in months_sorted:
        X_t = node_feat_template.copy()
        y_t = np.zeros(N_nodes)
        for s in samples_by_month[t]:
            # print(s)
            # return None, None
            X_t[s["node_idx"]] = s["features"]
            y_t[s["node_idx"]] = s["target"]
        X_months.append(X_t)
        y_months.append(y_t)

    return X_months, y_months


In [None]:
X, y = build_monthly_samples(df, "event_time", "magnitude")

# X has dimensions (total_samples, num_nodes, features) where features = (N_MONTHS, number of events in context, average magnitude, last event) NOTE: we probably want to replace features with slip
# y has dimensions (total_samples, num_nodes), where each node label is the value at the next month (1 or 0 depending on event/no event)

In [None]:
TRAIN_SPLIT = .7
VAL_SPLIT = .2

# TECHNICALLY THERE IS SOME DATA LEAKAGE BETWEEN SPLITS. will deal with this later
train_idx = int(len(X) * TRAIN_SPLIT)
val_idx = train_idx + int(len(X) * .2)
X_train = X[:train_idx]
X_val = X[train_idx:val_idx]
X_test = X[val_idx:]

y_train = y[:train_idx]
y_val = y[train_idx:val_idx]
y_test = y[val_idx:]

DATALOADING FOR TEMPORAL SNAPSHOT GRAPH

In [None]:
import math
group_by_nodes = df.groupby("fault_radius")
num_nodes = len(group_by_nodes)
latest_time = math.ceil(max(group_by_nodes.aggregate("max")["event_time"])) * 12
fault_radii = list(group_by_nodes.count().index)

node_to_event_labels = np.empty((num_nodes, latest_time))
i = 0
for node, group_df in group_by_nodes:
  event_labels = np.zeros(latest_time)
  event_times_months = group_df["event_time"] * 12
  for event_time in event_times_months:
    event_labels[math.floor(event_time)] = 1
  node_to_event_labels[i] = event_labels
  i += 1

node_to_event_labels

array([[0., 0., 0., ..., 0., 0., 1.],
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 1., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [None]:
from torch_geometric.data import HeteroData

hetero_data = HeteroData()
hetero_data["earthquake_source"].x = torch.asarray(fault_radii * latest_time)
hetero_data["earthquake_source"].y = node_to_event_labels.flatten(order='F')
hetero_data["earthquake_source"].t = torch.arange(latest_time).repeat_interleave(N)

edge_index_spatial = []
for t in range(latest_time):
  src = torch.arange(num_nodes - 1) + t * num_nodes
  dst = src + 1
  edge_index_spatial.append(torch.stack([src, dst]))

edge_index_spatial = torch.cat(edge_index_spatial, dim=1)
hetero_data['earthquake_source', 'spatial', 'earthquake_source'].edge_index = edge_index_spatial

edge_index_temporal = []
for t in range(latest_time - 1):
  src = torch.arange(num_nodes) + t * num_nodes
  dst = src + num_nodes
  edge_index_temporal.append(torch.stack([src, dst]))

edge_index_temporal = torch.cat(edge_index_temporal, dim=1)
hetero_data['earthquake_source', 'temporal', 'earthquake_source'].edge_index = edge_index_temporal

hetero_data

HeteroData(
  earthquake_source={
    x=[36000],
    y=[36000],
    t=[36000],
  },
  (earthquake_source, spatial, earthquake_source)={ edge_index=[2, 35640] },
  (earthquake_source, temporal, earthquake_source)={ edge_index=[2, 35900] }
)

In [None]:
from torch_geometric.utils import subgraph
def get_time_window_subgraph(start_time, context_length):
  mask = (hetero_data["earthquake_source"].t >= start_time) & (hetero_data["earthquake_source"].t < (start_time + context_length + 1))
  node_idx = torch.nonzero(mask).view(-1)
  subgraph_sample = HeteroData()
  subgraph_sample['earthquake_source'].x = hetero_data['earthquake_source'].x[node_idx]
  subgraph_sample['earthquake_source'].y = hetero_data['earthquake_source'].y[node_idx]
  subgraph_sample['earthquake_source'].t = hetero_data['earthquake_source'].t[node_idx]
  subgraph_sample['t_predict'] = start_time + context_length + 1
  for edge_type in hetero_data.edge_types:
    edge_index, edge_mask = subgraph(node_idx, hetero_data[edge_type].edge_index, relabel_nodes=True)
    subgraph_sample[edge_type].edge_index = edge_index
  return subgraph_sample

In [None]:
from torch_geometric.loader import DataLoader
CONTEXT_LENGTH = 6
all_samples = [get_time_window_subgraph(start_time, CONTEXT_LENGTH) for start_time in range(latest_time - CONTEXT_LENGTH + 1)]
TRAIN_INDEX_END = int(len(all_samples) * TRAIN_SPLIT)
VAL_INDEX_END = TRAIN_INDEX_END + int(len(all_samples) * VAL_SPLIT)

BATCH_SIZE = 8

train_loader = DataLoader(all_samples[:TRAIN_INDEX_END], batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(all_samples[TRAIN_INDEX_END:VAL_INDEX_END], batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(all_samples[VAL_INDEX_END:], batch_size=BATCH_SIZE, shuffle=True)

In [None]:
for batch in train_loader:
  print(batch)

HeteroDataBatch(
  t_predict=[8],
  earthquake_source={
    x=[5600],
    y=[8],
    t=[5600],
    batch=[5600],
    ptr=[9],
  },
  (earthquake_source, spatial, earthquake_source)={ edge_index=[2, 5544] },
  (earthquake_source, temporal, earthquake_source)={ edge_index=[2, 4800] }
)
HeteroDataBatch(
  t_predict=[8],
  earthquake_source={
    x=[5600],
    y=[8],
    t=[5600],
    batch=[5600],
    ptr=[9],
  },
  (earthquake_source, spatial, earthquake_source)={ edge_index=[2, 5544] },
  (earthquake_source, temporal, earthquake_source)={ edge_index=[2, 4800] }
)
HeteroDataBatch(
  t_predict=[8],
  earthquake_source={
    x=[5600],
    y=[8],
    t=[5600],
    batch=[5600],
    ptr=[9],
  },
  (earthquake_source, spatial, earthquake_source)={ edge_index=[2, 5544] },
  (earthquake_source, temporal, earthquake_source)={ edge_index=[2, 4800] }
)
HeteroDataBatch(
  t_predict=[8],
  earthquake_source={
    x=[5600],
    y=[8],
    t=[5600],
    batch=[5600],
    ptr=[9],
  },
  (earthquake_