In [1]:
%matplotlib widget
import numpy as np
import pandas as pd
import math
import torch
import torch.utils as utils
import scipy.sparse as sp
from torch_geometric.utils import from_scipy_sparse_matrix
from collections import defaultdict

In [2]:
dataset = "pems-bay"
val_rate = 0.20
test_rate = 0
n_his = 12
n_pred = 12
batch_size = 32

end_of_initial_data_index = 29792
data_per_step = 140

dataset_path = f"../../data/{dataset}/vel.csv"
adj_path = f"../../data/{dataset}/adj.npz"

In [3]:
def load_data(len_train, len_val):
    # load dataset from .csv file
    vel = pd.read_csv(dataset_path)

    train = vel[: len_train] # if metr_la (len_train = 23991), get 23991 from dataset for training
    val = vel[len_train: len_train + len_val] # if metr_la(23991 & 5140), get validation dataset from 23991-29131
    test = vel[len_train + len_val:] # if metr_la, get test dataset from 29131 to the final value

    return train, val, test

In [4]:
def data_transform(data, n_his, n_pred):
    n_vertex = data.shape[1] # number of nodes
    len_record = len(data)
    num = len_record - n_his - n_pred # number of sequences

    if num <= 0:
        return None, None
    
    # Init a NumPy array with 0s, representing input data tensor with 4-dimensions
    # if default (for training): [23976 x 1 x 12 x 207] - is 1 number of node features(???)
    x = np.zeros([num, 1, n_his, n_vertex])
    # Init a NumPy array with 0s, representing target data tensor with 2-dimensions
    # if default (for training): [23976 x 207]
    y = np.zeros([num, n_vertex])
    
    # loop over each sequence and change values of 1st dimension for both x and y tensor
    for i in range(num):
        head = i # define start of each sequence
        tail = i + n_his # define end of each sequence
        # data[head: tail] - get data from head to tail (0-12, 1-13, etc.)
        # Example (1st iteration): head = 0, tail = 12, so extract data from 0:12
        # Reshape that 2D matrix into a 3D one with dimensins [1, 12, 207]
        x[i, :, :, :] = data[head: tail].reshape(1, n_his, n_vertex)
        # Change values of 1st dimensions
        y[i] = data[tail + n_pred - 1]

    return torch.Tensor(x), torch.Tensor(y)

In [5]:
def zscore_preprocess_2d_data(train, val, test):
    # dimensions for train, val and test
    num_time_sequence_train, num_nodes_train = train.shape
    num_time_sequence_val, num_nodes_val = val.shape
    num_time_sequence_test, num_nodes_test = test.shape

    # Shape train, val, and test from 2D (time_sequence, num_nodes) to 2D (time_sequence * num_nodes, 1)
    train_shaped = train.reshape(-1, 1)
    val_shaped = val.reshape(-1, 1)
    test_shaped = test.reshape(-1, 1)

    # Reshape train, val and test from 2D (time_sequence * num_nodes, 1) back to 2D (time_sequence, num_nodes)
    train = train_shaped.reshape(num_time_sequence_train, num_nodes_train)
    val = val_shaped.reshape(num_time_sequence_val, num_nodes_val)
    test = test_shaped.reshape(num_time_sequence_test, num_nodes_test)

    return train, val, test

In [6]:
def create_train_iter_for_online(self, epoch, x_train, y_train):
    if epoch == 0:
        inital_x_train = x_train[:self.end_of_initial_data_index]
        inital_y_train = y_train[:self.end_of_initial_data_index]
        train_data = utils.data.TensorDataset(inital_x_train, inital_y_train)
        train_iter = utils.data.DataLoader(dataset=train_data, batch_size=self.batch_size, shuffle=True)

        return train_iter
    else:
        current_train = self.train_dataset[:self.end_of_initial_data_index + (self.data_per_step * (epoch - 1))]

        new_train_datastep = self.train_dataset[
            self.end_of_initial_data_index + (self.data_per_step * (epoch - 1)):
            self.end_of_initial_data_index + (self.data_per_step * (epoch))
        ]

        random_sample_size = (self.batch_size - 1) * self.data_per_step
        # Randomly sample indices from current_train
        if len(current_train) > random_sample_size:
            random_indices = np.random.choice(current_train.shape[0], random_sample_size, replace=False)
        else:
            # If current_train has fewer than the required samples, take all of it
            random_indices = current_train
        new_train_datastep = [self.train_dataset[idx] for idx in range(len(new_train_datastep))]

        new_x_train = x_train[
            self.end_of_initial_data_index + (self.data_per_step * (epoch - 1)):
            self.end_of_initial_data_index + (self.data_per_step * (epoch))
        ]
        sampled_x_train = x_train[random_indices, :]
        new_x_train = torch.cat((sampled_x_train, new_x_train), dim=0)
        new_y_train = y_train[
            self.end_of_initial_data_index + (self.data_per_step * (epoch - 1)):
            self.end_of_initial_data_index + (self.data_per_step * (epoch))
        ]
        sampled_y_train = y_train[random_indices, :]
        new_y_train = torch.cat((sampled_y_train, new_y_train), dim=0)

        train_data = utils.data.TensorDataset(new_x_train, new_y_train)
        train_iter = utils.data.DataLoader(dataset=train_data, batch_size=self.batch_size, shuffle=True)

        return train_iter

In [7]:
def detect_congestion_alpha_propagation(
    v,                 # speeds [T, N] (numpy array or torch.Tensor)
    alpha=0.5,
    vhat=None,         # per-sensor medians [N]; if None, computed from v
    use_nanmedian=True
):
    # to numpy
    try:
        import torch
        if isinstance(v, torch.Tensor):
            v = v.detach().cpu().numpy()
    except ImportError:
        pass

    v = np.asarray(v, dtype=np.float32)
    if v.ndim == 1:
        v = v[:, None]
    elif v.ndim > 2:
        T = v.shape[0]
        v = v.reshape(T, -1)

    T, N = v.shape

    # per-sensor medians
    if vhat is None:
        vhat = (np.nanmedian(v, axis=0) if use_nanmedian else np.median(np.nan_to_num(v, nan=0.0), axis=0))
    else:
        vhat = np.asarray(vhat, dtype=np.float32)
        assert vhat.shape == (N,), f"vhat must be shape (N,), got {vhat.shape}"

    thr      = alpha * vhat            # [N]
    thr_row  = thr[None, :]            # [1, N]

    # Rule 1: local threshold
    base = (v < thr_row)               # [T, N]

    # Rule 2: spatial continuity (both neighbors under their own thresholds)
    spatial = np.zeros_like(base, dtype=bool)
    if N >= 3:
        left  = v[:, :-2] < thr_row[:, :-2]
        right = v[:,  2:] < thr_row[:,  2:]
        spatial[:, 1:-1] = left & right

    # Rule 3: temporal continuity (both time neighbors under the same sensor’s threshold)
    temporal = np.zeros_like(base, dtype=bool)
    if T >= 3:
        prev_t = v[:-2, :] < thr_row
        next_t = v[ 2:,  :] < thr_row
        temporal[1:-1, :] = prev_t & next_t

    congested = base | spatial | temporal
    return congested, vhat


In [8]:
def evaluate_congestion_alpha_with_oracle(
    data_iter,
    alpha=0.5,
    vhat=None,
    oracle_mode_type="perfect",  # "perfect" | "worst"
    perfect_noise_abs=3.06,
    side_margin=0.1,
):
    import numpy as np
    sum_abs_error = 0.0
    sum_sq_error  = 0.0
    sum_gt        = 0.0
    total_preds   = 0

    TP = FP = TN = FN = 0
    total_points = 0
    total_gt_cong = 0
    total_pred_cong = 0

    used_vhat = vhat

    for item in data_iter:
        y = item[-1] if isinstance(item, (tuple, list)) else item
        try:
            import torch
            if isinstance(y, torch.Tensor):
                y_np = y.detach().cpu().numpy()
            else:
                y_np = np.asarray(y)
        except ImportError:
            y_np = np.asarray(y)
        y_np = np.asarray(y_np, dtype=np.float32)

        if y_np.ndim == 1:
            y_np = y_np[:, None]
        elif y_np.ndim > 2:
            T = y_np.shape[0]
            y_np = y_np.reshape(T, -1)
        T, N = y_np.shape

        gt_mask, used_vhat = detect_congestion_alpha_propagation(y_np, alpha=alpha, vhat=used_vhat)

        thr      = alpha * used_vhat
        thr_row  = thr[None, :]
        thr_full = np.broadcast_to(thr_row, y_np.shape)

        if oracle_mode_type == "perfect":
            base_gt = (y_np < thr_full)
            noise = np.random.uniform(-perfect_noise_abs, perfect_noise_abs, size=y_np.shape).astype(np.float32)
            y_pred = y_np + noise

            if np.any(base_gt):
                target_lo = thr_full - side_margin
                y_pred[base_gt] = np.minimum(y_pred[base_gt], target_lo[base_gt])

            if np.any(~base_gt):
                target_hi = thr_full + side_margin
                y_pred[~base_gt] = np.maximum(y_pred[~base_gt], target_hi[~base_gt])

        elif oracle_mode_type == "worst":
            y_pred = y_np.copy()
            base_gt = (y_np < thr_full)
            if np.any(base_gt):
                target_hi = thr_full + side_margin
                y_pred[base_gt] = target_hi[base_gt]
        else:
            raise ValueError(f"Unknown oracle_mode_type: {oracle_mode_type}")

        d = np.abs(y_np - y_pred)
        sum_abs_error += float(np.sum(d))
        sum_sq_error  += float(np.sum(d ** 2))
        sum_gt        += float(np.sum(y_np))
        total_preds   += d.size

        pred_mask, _ = detect_congestion_alpha_propagation(y_pred, alpha=alpha, vhat=used_vhat)

        gt = gt_mask.reshape(-1)
        pr = pred_mask.reshape(-1)

        tp = int(np.sum(gt & pr))
        tn = int(np.sum(~gt & ~pr))
        fp = int(np.sum(~gt & pr))
        fn = int(np.sum(gt & ~pr))

        TP += tp; TN += tn; FP += fp; FN += fn
        total_points   += gt.size
        total_gt_cong  += int(np.sum(gt))
        total_pred_cong+= int(np.sum(pr))

    MAE   = (sum_abs_error / total_preds) if total_preds > 0 else 0.0
    RMSE  = (np.sqrt(sum_sq_error / total_preds) if total_preds > 0 else 0.0)
    WMAPE = (sum_abs_error / sum_gt) if sum_gt != 0 else 0.0

    precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    recall    = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    f1        = 2*precision*recall/(precision+recall) if (precision+recall) > 0 else 0.0
    iou       = TP / (TP + FP + FN) if (TP + FP + FN) > 0 else 0.0
    accuracy  = (TP + TN) / total_points if total_points > 0 else 0.0
    cong_rate = total_gt_cong / total_points if total_points > 0 else 0.0

    return dict(
        # composed
        MAE=MAE, RMSE=RMSE, WMAPE=WMAPE,
        TP=TP, FP=FP, TN=TN, FN=FN,
        precision=precision, recall=recall, f1=f1, iou=iou, accuracy=accuracy,
        total_points=total_points,
        gt_congested=total_gt_cong,
        pred_congested=total_pred_cong,
        congestion_rate=cong_rate,
        alpha=alpha,
        oracle_mode_type=oracle_mode_type,
        # raw aggregates for global composition
        _sum_abs_error=sum_abs_error,
        _sum_sq_error=sum_sq_error,
        _sum_gt=sum_gt,
        _total_preds=total_preds,
    ), used_vhat


In [9]:
def aggregate_congestion_oracle_metrics(stats_list):
    import numpy as np

    # Speed-error aggregates
    total_sum_abs = sum(d.get("_sum_abs_error", 0.0) for d in stats_list)
    total_sum_sq  = sum(d.get("_sum_sq_error",  0.0) for d in stats_list)
    total_sum_gt  = sum(d.get("_sum_gt",        0.0) for d in stats_list)
    total_preds   = sum(d.get("_total_preds",     0) for d in stats_list)

    MAE  = (total_sum_abs / total_preds) if total_preds > 0 else 0.0
    RMSE = (np.sqrt(total_sum_sq / total_preds) if total_preds > 0 else 0.0)
    WMAPE = (total_sum_abs / total_sum_gt) if total_sum_gt != 0 else 0.0

    # Classification/confusion aggregates
    TP = sum(d["TP"] for d in stats_list)
    FP = sum(d["FP"] for d in stats_list)
    TN = sum(d["TN"] for d in stats_list)
    FN = sum(d["FN"] for d in stats_list)
    total_points = sum(d["total_points"] for d in stats_list)
    gt_cong = sum(d["gt_congested"] for d in stats_list)
    pred_cong = sum(d["pred_congested"] for d in stats_list)

    precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    recall    = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    f1        = 2*precision*recall/(precision+recall) if (precision+recall) > 0 else 0.0
    iou       = TP / (TP + FP + FN) if (TP + FP + FN) > 0 else 0.0
    accuracy  = (TP + TN) / total_points if total_points > 0 else 0.0
    cong_rate = gt_cong / total_points if total_points > 0 else 0.0

    return dict(
        # speed errors (composed like SCSR)
        MAE=MAE, RMSE=RMSE, WMAPE=WMAPE,
        # detection metrics
        precision=precision, recall=recall, f1=f1, iou=iou, accuracy=accuracy,
        congestion_rate=cong_rate,
        # bookkeeping
        total_points=total_points, gt_congested=gt_cong, pred_congested=pred_cong,
        total_preds=total_preds,
    )


In [10]:
def evaluate_cloudlet_pyg_new_metric_analysis(
    data_iter,
    big_err_threshold=20.0,
    change_window=12,
    change_delta=20.0,
    change_tolerance=10.0,
    cooldown=None,
    oracle_mode_type="perfect",
):
    if cooldown is None:
        cooldown = max(1, change_window // 2)

    # Aggregates (so we can compose across many calls)
    sum_abs_error = 0.0
    sum_sq_error  = 0.0
    sum_gt        = 0.0
    total_preds   = 0
    big_err_count = 0

    total_jam_events = 0
    total_jam_hits   = 0
    total_rec_events = 0
    total_rec_hits   = 0

    PERFECT_NOISE_ABS = min(3.06, change_tolerance * 0.49)  # always within tolerance

    with torch.no_grad():
        for item in data_iter:
            # accept (x, y) or just y
            y = item[-1] if isinstance(item, (tuple, list)) else item
            y_np = y.detach().cpu().numpy() if isinstance(y, torch.Tensor) else np.asarray(y)
            y_np = np.asarray(y_np, dtype=np.float32)

            # coerce to [T, N]
            if y_np.ndim == 1:
                y_np = y_np[:, None]
            elif y_np.ndim > 2:
                T = y_np.shape[0]
                y_np = y_np.reshape(T, -1)

            T, N = y_np.shape

            # ---- Build predictions
            if oracle_mode_type == "perfect":
                # small noise strictly within tolerance -> always "hit" at event endpoints
                noise = np.random.uniform(-PERFECT_NOISE_ABS, PERFECT_NOISE_ABS, size=y_np.shape).astype(np.float32)
                y_pred_np = y_np + noise

            elif oracle_mode_type == "worst":
                y_pred_np = y_np.copy()
                cool = np.zeros(N, dtype=int)
                for t in range(1, T):
                    cool = np.maximum(0, cool - 1)
                    w_start = max(0, t - change_window)
                    past = y_np[w_start:t, :]
                    if past.size == 0:
                        continue
                    cur = y_np[t, :]
                    jam_best = np.max(past - cur[None, :], axis=0)
                    rec_best = np.max(cur[None, :] - past, axis=0)
                    jam_mask = (jam_best >= change_delta) & (cool == 0)
                    rec_mask = (rec_best >= change_delta) & (cool == 0)
                    if not (jam_mask.any() or rec_mask.any()):
                        continue
                    idx_all = np.where(jam_mask | rec_mask)[0]
                    if idx_all.size > 0:
                        bump = change_tolerance + 1.0
                        cur_sel = cur[idx_all]
                        push_down = cur_sel > np.median(cur_sel)
                        y_pred_np[t, idx_all[push_down]]  = cur_sel[push_down]  - bump
                        y_pred_np[t, idx_all[~push_down]] = cur_sel[~push_down] + bump
                        cool[idx_all] = np.maximum(cool[idx_all], cooldown)
            else:
                raise ValueError(f"Unknown oracle_mode_type: {oracle_mode_type}")

            # ---- Base metrics
            d = np.abs(y_np - y_pred_np)
            sum_abs_error += float(np.sum(d))
            sum_sq_error  += float(np.sum(d ** 2))
            sum_gt        += float(np.sum(y_np))
            big_err_count += int(np.sum(d >= big_err_threshold))
            total_preds   += d.size

            # ---- Event detection + hits
            cool = np.zeros(N, dtype=int)
            for t in range(1, T):
                cool = np.maximum(0, cool - 1)
                w_start = max(0, t - change_window)
                past = y_np[w_start:t, :]
                if past.size == 0:
                    continue

                cur = y_np[t, :]
                pred_cur = y_pred_np[t, :]

                jam_best = np.max(past - cur[None, :], axis=0)
                rec_best = np.max(cur[None, :] - past, axis=0)

                jam_mask = (jam_best >= change_delta) & (cool == 0)
                rec_mask = (rec_best >= change_delta) & (cool == 0)

                if not (jam_mask.any() or rec_mask.any()):
                    continue

                abs_err = np.abs(pred_cur - cur)

                if jam_mask.any():
                    idx = np.where(jam_mask)[0]
                    total_jam_events += idx.size
                    total_jam_hits   += int(np.sum(abs_err[idx] <= change_tolerance))
                    cool[idx] = np.maximum(cool[idx], cooldown)

                if rec_mask.any():
                    idx = np.where(rec_mask)[0]
                    total_rec_events += idx.size
                    total_rec_hits   += int(np.sum(abs_err[idx] <= change_tolerance))
                    cool[idx] = np.maximum(cool[idx], cooldown)

    # Derived metrics for this call
    MAE  = (sum_abs_error / total_preds) if total_preds > 0 else 0.0
    RMSE = (np.sqrt(sum_sq_error / total_preds) if total_preds > 0 else 0.0)
    WMAPE = (sum_abs_error / sum_gt) if sum_gt != 0 else 0.0
    BIG_ERR_RATE = (big_err_count / total_preds) if total_preds > 0 else 0.0

    sudden_events = total_jam_events + total_rec_events
    sudden_hits   = total_jam_hits   + total_rec_hits
    SUDDEN_EVENT_RATE = (sudden_hits / sudden_events) if sudden_events > 0 else 0.0
    JAM_EVENT_RATE    = (total_jam_hits / total_jam_events) if total_jam_events > 0 else 0.0
    REC_EVENT_RATE    = (total_rec_hits / total_rec_events) if total_rec_events > 0 else 0.0

    # Return both the composed metrics and the raw aggregates (to combine across epochs)
    return dict(
        # composed
        MAE=MAE, RMSE=RMSE, WMAPE=WMAPE,
        big_err_count=big_err_count, big_err_rate=BIG_ERR_RATE,
        sudden_event_count=sudden_events, sudden_event_hits=sudden_hits, sudden_event_rate=SUDDEN_EVENT_RATE,
        jam_event_count=total_jam_events, jam_event_hits=total_jam_hits, jam_event_rate=JAM_EVENT_RATE,
        rec_event_count=total_rec_events, rec_event_hits=total_rec_hits, rec_event_rate=REC_EVENT_RATE,
        # aggregates
        _sum_abs_error=sum_abs_error,
        _sum_sq_error=sum_sq_error,
        _sum_gt=sum_gt,
        _total_preds=total_preds,
    )

In [11]:
def aggregate_final_metrics(aggregate_list):
    # aggregate_list is a list of dicts returned by evaluate_cloudlet_pyg_new_metric_analysis_v2
    total_sum_abs = sum(d["_sum_abs_error"] for d in aggregate_list)
    total_sum_sq  = sum(d["_sum_sq_error"]  for d in aggregate_list)
    total_sum_gt  = sum(d["_sum_gt"]        for d in aggregate_list)
    total_preds   = sum(d["_total_preds"]   for d in aggregate_list)

    total_jam_events = sum(d["jam_event_count"] for d in aggregate_list)
    total_jam_hits   = sum(d["jam_event_hits"]  for d in aggregate_list)
    total_rec_events = sum(d["rec_event_count"] for d in aggregate_list)
    total_rec_hits   = sum(d["rec_event_hits"]  for d in aggregate_list)

    big_err_count    = sum(d["big_err_count"]   for d in aggregate_list)

    MAE  = (total_sum_abs / total_preds) if total_preds > 0 else 0.0
    RMSE = (np.sqrt(total_sum_sq / total_preds) if total_preds > 0 else 0.0)
    WMAPE = (total_sum_abs / total_sum_gt) if total_sum_gt != 0 else 0.0
    BIG_ERR_RATE = (big_err_count / total_preds) if total_preds > 0 else 0.0

    sudden_events = total_jam_events + total_rec_events
    sudden_hits   = total_jam_hits   + total_rec_hits
    SUDDEN_EVENT_RATE = (sudden_hits / sudden_events) if sudden_events > 0 else 0.0
    JAM_EVENT_RATE    = (total_jam_hits / total_jam_events) if total_jam_events > 0 else 0.0
    REC_EVENT_RATE    = (total_rec_hits / total_rec_events) if total_rec_events > 0 else 0.0

    return dict(
        MAE=MAE, RMSE=RMSE, WMAPE=WMAPE, BIG_ERR_RATE=BIG_ERR_RATE,
        SUDDEN_EVENT_RATE=SUDDEN_EVENT_RATE,
        JAM_EVENT_RATE=JAM_EVENT_RATE,
        REC_EVENT_RATE=REC_EVENT_RATE,
        total_preds=total_preds, sudden_events=sudden_events
    )

In [18]:
def detect_congestion_alpha_propagation_combined(
        v,                      # [T, N] speeds (np or torch)
        edge_index,             # [2, E] torch.LongTensor or np.array of edges (0-based)
        alpha=0.5,
        vhat=None,
        use_nanmedian=True,
        require_both_when_deg2=True,   # if deg(i)==2, require both neighbors; else use ">=2 neighbors" rule
    ):
    # to numpy
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().numpy()
    v = np.asarray(v, dtype=np.float32)
    if v.ndim == 1:
        v = v[:, None]
    elif v.ndim > 2:
        v = v.reshape(v.shape[0], -1)
    T, N = v.shape

    # medians
    if vhat is None:
        vhat = (np.nanmedian(v, axis=0) if use_nanmedian else np.median(np.nan_to_num(v, nan=0.0), axis=0))
    else:
        vhat = np.asarray(vhat, dtype=np.float32)
        assert vhat.shape == (N,), f"vhat must be (N,), got {vhat.shape}"

    thr = alpha * vhat
    thr_full = thr[None, :].repeat(T, axis=0)
    base = (v < thr_full)  # Rule 1

    # build undirected adjacency lists
    if isinstance(edge_index, torch.Tensor):
        ei = edge_index.detach().cpu().numpy()
    else:
        ei = np.asarray(edge_index)
    assert ei.shape[0] == 2, "edge_index must be shape [2, E]"
    nbrs = defaultdict(list)
    for u, w in ei.T:
        nbrs[int(u)].append(int(w))
        nbrs[int(w)].append(int(u))  # undirected for “immediate neighbors”

    # Rule 2 graph-aware spatial continuity
    spatial = np.zeros_like(base, dtype=bool)
    for i in range(N):
        neigh = nbrs.get(i, [])
        deg = len(neigh)
        if deg < 2:
            continue
        # congested neighbors at time t
        # base[:, neigh] shape [T, deg] -> count how many neighbors are below threshold
        cong_neigh_count = np.sum(base[:, neigh], axis=1)  # [T]

        if require_both_when_deg2 and deg == 2:
            spatial[:, i] = (cong_neigh_count == 2)
        else:
            spatial[:, i] = (cong_neigh_count >= 2)

    # Rule 3 temporal continuity (same as before)
    temporal = np.zeros_like(base, dtype=bool)
    if T >= 3:
        prev_t = base[:-2, :]
        next_t = base[ 2:, :]
        temporal[1:-1, :] = prev_t & next_t

    congested = base | spatial | temporal
    return congested, vhat

def evaluate_oracle_scsr_and_alpha_combined(
    data_iter,
    edge_index,
    *,
    # SCSR params
    big_err_threshold=20.0,
    change_window=12,
    change_delta=20.0,
    change_tolerance=10.0,
    cooldown=None,
    # Oracle choice for BOTH SCSR + α-propagation
    oracle_mode_type="perfect",  # "perfect" | "worst"
    perfect_noise_abs=None,      # default will be min(3.06, change_tolerance*0.49)
    # α-propagation params
    alpha=0.5,
    vhat=None,                   # pass persistent medians
):
    """
    Builds y_pred using **SCSR oracle** logic. Then:
      - computes base errors & SCSR event hits,
      - runs α-based congestion on GT vs y_pred, and returns confusion metrics too.

    Returns (metrics_dict, used_vhat).
    metrics_dict includes BOTH SCSR & α metrics and raw aggregates for global composition.
    """
    if cooldown is None:
        cooldown = max(1, change_window // 2)

    if perfect_noise_abs is None:
        perfect_noise_abs = min(3.06, change_tolerance * 0.49)  # ensures SCSR hits at endpoints

    # ---------- global aggregates ----------
    # base errors
    sum_abs_error = 0.0
    sum_sq_error  = 0.0
    sum_gt        = 0.0
    total_preds   = 0
    big_err_count = 0

    # SCSR events
    total_jam_events = 0
    total_jam_hits   = 0
    total_rec_events = 0
    total_rec_hits   = 0

    # α-propagation confusion
    TP = FP = TN = FN = 0
    total_points = 0
    total_gt_cong = 0
    total_pred_cong = 0

    used_vhat = vhat

    with torch.no_grad():
        for item in data_iter:
            # accept (x, y) or just y
            y = item[-1] if isinstance(item, (tuple, list)) else item
            y_np = y.detach().cpu().numpy() if isinstance(y, torch.Tensor) else np.asarray(y)
            y_np = np.asarray(y_np, dtype=np.float32)

            # shape -> [T, N]
            if y_np.ndim == 1:
                y_np = y_np[:, None]
            elif y_np.ndim > 2:
                T = y_np.shape[0]
                y_np = y_np.reshape(T, -1)

            T, N = y_np.shape

            # ---------- Build y_pred via SCSR oracle ----------
            if oracle_mode_type == "perfect":
                # Add small noise, but keep always within change_tolerance around GT so SCSR hits are guaranteed.
                noise = np.random.uniform(-perfect_noise_abs, perfect_noise_abs, size=y_np.shape).astype(np.float32)
                y_pred_np = y_np + noise

            elif oracle_mode_type == "worst":
                # Copy GT, but at SCSR event endpoints force error > change_tolerance
                y_pred_np = y_np.copy()
                cool = np.zeros(N, dtype=int)
                for t in range(1, T):
                    cool = np.maximum(0, cool - 1)
                    w_start = max(0, t - change_window)
                    past = y_np[w_start:t, :]
                    if past.size == 0:
                        continue
                    cur = y_np[t, :]
                    jam_best = np.max(past - cur[None, :], axis=0)
                    rec_best = np.max(cur[None, :] - past, axis=0)
                    jam_mask = (jam_best >= change_delta) & (cool == 0)
                    rec_mask = (rec_best >= change_delta) & (cool == 0)
                    if not (jam_mask.any() or rec_mask.any()):
                        continue
                    idx_all = np.where(jam_mask | rec_mask)[0]
                    if idx_all.size > 0:
                        bump = change_tolerance + 1.0
                        cur_sel = cur[idx_all]
                        push_down = cur_sel > np.median(cur_sel)
                        y_pred_np[t, idx_all[push_down]]  = cur_sel[push_down]  - bump
                        y_pred_np[t, idx_all[~push_down]] = cur_sel[~push_down] + bump
                        cool[idx_all] = np.maximum(cool[idx_all], cooldown)
            else:
                raise ValueError(f"Unknown oracle_mode_type: {oracle_mode_type}")

            # ---------- Base errors ----------
            d = np.abs(y_np - y_pred_np)
            sum_abs_error += float(np.sum(d))
            sum_sq_error  += float(np.sum(d ** 2))
            sum_gt        += float(np.sum(y_np))
            big_err_count += int(np.sum(d >= big_err_threshold))
            total_preds   += d.size

            # ---------- SCSR event detection + hits ----------
            cool = np.zeros(N, dtype=int)
            for t in range(1, T):
                cool = np.maximum(0, cool - 1)
                w_start = max(0, t - change_window)
                past = y_np[w_start:t, :]
                if past.size == 0:
                    continue

                cur = y_np[t, :]
                pred_cur = y_pred_np[t, :]

                jam_best = np.max(past - cur[None, :], axis=0)
                rec_best = np.max(cur[None, :] - past, axis=0)

                jam_mask = (jam_best >= change_delta) & (cool == 0)
                rec_mask = (rec_best >= change_delta) & (cool == 0)
                if not (jam_mask.any() or rec_mask.any()):
                    continue

                abs_err = np.abs(pred_cur - cur)

                if jam_mask.any():
                    idx = np.where(jam_mask)[0]
                    total_jam_events += idx.size
                    total_jam_hits   += int(np.sum(abs_err[idx] <= change_tolerance))
                    cool[idx] = np.maximum(cool[idx], cooldown)

                if rec_mask.any():
                    idx = np.where(rec_mask)[0]
                    total_rec_events += idx.size
                    total_rec_hits   += int(np.sum(abs_err[idx] <= change_tolerance))
                    cool[idx] = np.maximum(cool[idx], cooldown)

            # ---------- α-based congestion: GT vs y_pred ----------
            gt_mask, used_vhat = detect_congestion_alpha_propagation_combined(y_np, edge_index, alpha=alpha, vhat=used_vhat)
            pr_mask, _         = detect_congestion_alpha_propagation_combined(y_pred_np, edge_index, alpha=alpha, vhat=used_vhat)

            gt = gt_mask.reshape(-1)
            pr = pr_mask.reshape(-1)

            tp = int(np.sum(gt & pr))
            tn = int(np.sum(~gt & ~pr))
            fp = int(np.sum(~gt & pr))
            fn = int(np.sum(gt & ~pr))

            TP += tp; TN += tn; FP += fp; FN += fn
            total_points   += gt.size
            total_gt_cong  += int(np.sum(gt))
            total_pred_cong+= int(np.sum(pr))

    # ---------- composed metrics ----------
    # Base errors
    MAE   = (sum_abs_error / total_preds) if total_preds > 0 else 0.0
    RMSE  = (np.sqrt(sum_sq_error / total_preds) if total_preds > 0 else 0.0)
    WMAPE = (sum_abs_error / sum_gt) if sum_gt != 0 else 0.0
    BIG_ERR_RATE = (big_err_count / total_preds) if total_preds > 0 else 0.0

    # SCSR
    sudden_events = total_jam_events + total_rec_events
    sudden_hits   = total_jam_hits   + total_rec_hits
    SUDDEN_EVENT_RATE = (sudden_hits / sudden_events) if sudden_events > 0 else 0.0
    JAM_EVENT_RATE    = (total_jam_hits / total_jam_events) if total_jam_events > 0 else 0.0
    REC_EVENT_RATE    = (total_rec_hits / total_rec_events) if total_rec_events > 0 else 0.0

    # α-propagation detection metrics
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    recall    = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    f1        = 2*precision*recall/(precision+recall) if (precision+recall) > 0 else 0.0
    iou       = TP / (TP + FP + FN) if (TP + FP + FN) > 0 else 0.0
    accuracy  = (TP + TN) / total_points if total_points > 0 else 0.0
    cong_rate = total_gt_cong / total_points if total_points > 0 else 0.0

    return dict(
        # ---- base errors (composed)
        MAE=MAE, RMSE=RMSE, WMAPE=WMAPE, BIG_ERR_RATE=BIG_ERR_RATE,

        # ---- SCSR composed
        sudden_event_count=sudden_events,
        sudden_event_hits=sudden_hits,
        sudden_event_rate=SUDDEN_EVENT_RATE,
        jam_event_count=total_jam_events, jam_event_hits=total_jam_hits, jam_event_rate=JAM_EVENT_RATE,
        rec_event_count=total_rec_events, rec_event_hits=total_rec_hits, rec_event_rate=REC_EVENT_RATE,

        # ---- α-propagation composed
        TP=TP, FP=FP, TN=TN, FN=FN,
        precision=precision, recall=recall, f1=f1, iou=iou, accuracy=accuracy,
        gt_congested=total_gt_cong, pred_congested=total_pred_cong,
        congestion_rate=cong_rate,
        alpha=alpha,
        oracle_mode_type=oracle_mode_type,

        # ---- raw aggregates for global composition
        _sum_abs_error=sum_abs_error,
        _sum_sq_error=sum_sq_error,
        _sum_gt=sum_gt,
        _total_preds=total_preds,
    ), used_vhat

In [13]:
def aggregate_combined(stats_list):
    import numpy as np

    # base errors
    total_sum_abs = sum(d["_sum_abs_error"] for d in stats_list)
    total_sum_sq  = sum(d["_sum_sq_error"]  for d in stats_list)
    total_sum_gt  = sum(d["_sum_gt"]        for d in stats_list)
    total_preds   = sum(d["_total_preds"]   for d in stats_list)

    MAE   = (total_sum_abs / total_preds) if total_preds > 0 else 0.0
    RMSE  = (np.sqrt(total_sum_sq / total_preds) if total_preds > 0 else 0.0)
    WMAPE = (total_sum_abs / total_sum_gt) if total_sum_gt != 0 else 0.0

    # SCSR
    jam_events = sum(d["jam_event_count"] for d in stats_list)
    jam_hits   = sum(d["jam_event_hits"]  for d in stats_list)
    rec_events = sum(d["rec_event_count"] for d in stats_list)
    rec_hits   = sum(d["rec_event_hits"]  for d in stats_list)
    sudden_events = jam_events + rec_events
    sudden_hits   = jam_hits   + rec_hits

    SUDDEN_EVENT_RATE = (sudden_hits / sudden_events) if sudden_events > 0 else 0.0
    JAM_EVENT_RATE    = (jam_hits / jam_events) if jam_events > 0 else 0.0
    REC_EVENT_RATE    = (rec_hits / rec_events) if rec_events > 0 else 0.0

    # α-propagation confusion
    TP = sum(d["TP"] for d in stats_list)
    FP = sum(d["FP"] for d in stats_list)
    TN = sum(d["TN"] for d in stats_list)
    FN = sum(d["FN"] for d in stats_list)
    total_points = TP + FP + TN + FN
    gt_cong = sum(d["gt_congested"] for d in stats_list)
    pred_cong = sum(d["pred_congested"] for d in stats_list)

    precision = TP / (TP + FP) if (TP + FP) > 0 else 0.0
    recall    = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    f1        = 2*precision*recall/(precision+recall) if (precision+recall) > 0 else 0.0
    iou       = TP / (TP + FP + FN) if (TP + FP + FN) > 0 else 0.0
    accuracy  = (TP + TN) / total_points if total_points > 0 else 0.0
    cong_rate = gt_cong / total_points if total_points > 0 else 0.0

    return dict(
        # base errors
        MAE=MAE, RMSE=RMSE, WMAPE=WMAPE,
        total_preds=total_preds,

        # SCSR
        SUDDEN_EVENT_RATE=SUDDEN_EVENT_RATE,
        JAM_EVENT_RATE=JAM_EVENT_RATE,
        REC_EVENT_RATE=REC_EVENT_RATE,
        sudden_events=sudden_events,

        # α-propagation
        precision=precision, recall=recall, f1=f1, iou=iou, accuracy=accuracy,
        congestion_rate=cong_rate,
        gt_congested=gt_cong, pred_congested=pred_cong,
        total_points=total_points
    )


In [15]:
def online_train(epoch, edge_index, x_train, y_train, val_iter, num_epochs, oracle_mode_type, agg, used_vhat, agg_stats, agg_combined):
    if epoch < num_epochs:
        # Create masked validation dataset
        new_x_val = x_train[
            end_of_initial_data_index + (data_per_step * (epoch)):
            end_of_initial_data_index + (data_per_step * (epoch + 1))
            ]
        new_y_val = y_train[
            end_of_initial_data_index + (data_per_step * (epoch)):
            end_of_initial_data_index + (data_per_step * (epoch + 1))
        ]
        new_val_data = utils.data.TensorDataset(new_x_val, new_y_val)
        new_val_iter = utils.data.DataLoader(dataset=new_val_data, batch_size=batch_size, shuffle=False)

        # res = evaluate_cloudlet_pyg_new_metric_analysis(
        #     new_val_iter, oracle_mode_type=oracle_mode_type
        # )
        # stats_cong, used_vhat = evaluate_congestion_alpha_with_oracle(
        #     new_val_iter,
        #     alpha=0.5, vhat=used_vhat,
        #     oracle_mode_type=oracle_mode_type,
        # )
        res_combined, used_vhat = evaluate_oracle_scsr_and_alpha_combined(
            new_val_iter,
            edge_index=edge_index,
            oracle_mode_type=oracle_mode_type,   # or "perfect"
            change_window=12, change_delta=20.0, change_tolerance=10.0,
            alpha=0.5, vhat=used_vhat
        )

        # agg_stats.append(stats_cong)
        # agg.append(res)
        agg_combined.append(res_combined)
    elif epoch == num_epochs:
        res_combined, used_vhat = evaluate_oracle_scsr_and_alpha_combined(
            val_iter,
            edge_index=edge_index,
            oracle_mode_type=oracle_mode_type,   # or "perfect"
            change_window=12, change_delta=20.0, change_tolerance=10.0,
            alpha=0.5, vhat=used_vhat
        )
        agg_combined.append(res_combined)
        # res = evaluate_cloudlet_pyg_new_metric_analysis(
        #     val_iter, oracle_mode_type=oracle_mode_type
        # )
        # stats_cong, used_vhat = evaluate_congestion_alpha_with_oracle(
        #     val_iter,
        #     alpha=0.5, vhat=used_vhat,
        #     oracle_mode_type=oracle_mode_type,
        # )

        # agg_stats.append(stats_cong)
        # agg.append(res)

In [16]:
adj = sp.load_npz(adj_path)
adj = adj.tocsc()
edge_index, _ = from_scipy_sparse_matrix(adj)

data_col = pd.read_csv(dataset_path).shape[0]

len_val = int(math.floor(data_col * val_rate))
len_test = int(math.floor(data_col * test_rate))
len_train = int(data_col - len_val - len_test)

if ((len_train - end_of_initial_data_index) % data_per_step == 0):
    print("End index and data step correctly selected!")
else:
    print(f"End index and data step WRONGLY selected: {(len_train - end_of_initial_data_index) % data_per_step}")
    exit(1)

train, val, test = load_data(len_train, len_val)
train, val, test = zscore_preprocess_2d_data(train.values, val.values, test.values)

x_train, y_train = data_transform(train, n_his, n_pred)
x_val, y_val = data_transform(val, n_his, n_pred)
# x_test, y_test = data_transform(test, n_his, n_pred)

train_data = utils.data.TensorDataset(x_train, y_train)
train_iter = utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
val_data = utils.data.TensorDataset(x_val, y_val)
val_iter = utils.data.DataLoader(dataset=val_data, batch_size=batch_size, shuffle=False)
# test_data = utils.data.TensorDataset(x_test, y_test)
# test_iter = utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)

End index and data step correctly selected!


In [19]:
agg = []
agg_stats = []
agg_combined = []
used_vhat = None
num_epochs = int((len_train - end_of_initial_data_index) / data_per_step)

# for epoch in range(0, num_epochs + 1): 
#     online_train(epoch, x_train, y_train, val_iter, num_epochs, oracle_mode_type="perfect", agg=agg)

for epoch in range(0, num_epochs + 1): 
    online_train(epoch, edge_index, x_train, y_train, val_iter, num_epochs, oracle_mode_type="worst", agg=agg, used_vhat=used_vhat, agg_stats=agg_stats, agg_combined=agg_combined)

final_metrics = aggregate_final_metrics(agg)
print("FINAL MAE:", final_metrics["MAE"])
print("FINAL RMSE:", final_metrics["RMSE"])
print("FINAL WMAPE:", final_metrics["WMAPE"])
print("FINAL SUDDEN_EVENT_RATE:", final_metrics["SUDDEN_EVENT_RATE"])

final = aggregate_congestion_oracle_metrics(agg_stats)
print("\n======== α-based Congestion — FINAL ========")
print(f"MAE:   {final['MAE']:.6f}")
print(f"RMSE:  {final['RMSE']:.6f}")
print(f"WMAPE: {final['WMAPE']:.6f}")
print(f"Recall (hit rate): {final['recall']:.6f}")
print(f"Precision:         {final['precision']:.6f}")
print(f"F1:                {final['f1']:.6f}")
print(f"IoU:               {final['iou']:.6f}")
print(f"Accuracy:          {final['accuracy']:.6f}")
print(f"GT congestion rate:{final['congestion_rate']:.6f}")
print(f"Total points:      {final['total_points']}")
print(f"Total preds:       {final['total_preds']}")

final_combined = aggregate_combined(agg_combined)
print("=== Combined (shared y_pred) ===")
print(f"MAE={final_combined['MAE']:.6f}  RMSE={final_combined['RMSE']:.6f}  WMAPE={final_combined['WMAPE']:.6f}")
print(f"SCSR:  sudden_event_rate={final_combined['SUDDEN_EVENT_RATE']:.6f}  "
      f"jam_rate={final_combined['JAM_EVENT_RATE']:.6f}  rec_rate={final_combined['REC_EVENT_RATE']:.6f}")
print(f"ALPHA: recall={final_combined['recall']:.6f} precision={final_combined['precision']:.6f} "
      f"F1={final_combined['f1']:.6f} IoU={final_combined['iou']:.6f} accuracy={final_combined['accuracy']:.6f}")
print(f"ALPHA: gt_congestion_rate={final_combined['congestion_rate']:.6f}  "
      f"gt_cong={final_combined['gt_congested']}  pred_cong={final_combined['pred_congested']}")


FINAL MAE: 0.0
FINAL RMSE: 0.0
FINAL WMAPE: 0.0
FINAL SUDDEN_EVENT_RATE: 0.0

MAE:   0.000000
RMSE:  0.000000
WMAPE: 0.000000
Recall (hit rate): 0.000000
Precision:         0.000000
F1:                0.000000
IoU:               0.000000
Accuracy:          0.000000
GT congestion rate:0.000000
Total points:      0
Total preds:       0
=== Combined (shared y_pred) ===
MAE=0.088944  RMSE=0.989131  WMAPE=0.001423
SCSR:  sudden_event_rate=0.000000  jam_rate=0.000000  rec_rate=0.000000
ALPHA: recall=0.990698 precision=0.989271 F1=0.989984 IoU=0.980166 accuracy=0.993063
ALPHA: gt_congestion_rate=0.346045  gt_cong=2505152  pred_cong=2508766


In [20]:
agg = []
used_vhat = None
agg_stats = []
agg_combined = []
num_epochs = int((len_train - end_of_initial_data_index) / data_per_step)

for epoch in range(0, num_epochs + 1): 
    online_train(epoch, edge_index, x_train, y_train, val_iter, num_epochs, oracle_mode_type="perfect", agg=agg, used_vhat=used_vhat, agg_stats=agg_stats, agg_combined=agg_combined)

final_metrics = aggregate_final_metrics(agg)
print("FINAL MAE:", final_metrics["MAE"])
print("FINAL RMSE:", final_metrics["RMSE"])
print("FINAL WMAPE:", final_metrics["WMAPE"])
print("FINAL SUDDEN_EVENT_RATE:", final_metrics["SUDDEN_EVENT_RATE"])

final = aggregate_congestion_oracle_metrics(agg_stats)
print("\n======== α-based Congestion — FINAL ========")
print(f"MAE:   {final['MAE']:.6f}")
print(f"RMSE:  {final['RMSE']:.6f}")
print(f"WMAPE: {final['WMAPE']:.6f}")
print(f"Recall (hit rate): {final['recall']:.6f}")
print(f"Precision:         {final['precision']:.6f}")
print(f"F1:                {final['f1']:.6f}")
print(f"IoU:               {final['iou']:.6f}")
print(f"Accuracy:          {final['accuracy']:.6f}")
print(f"GT congestion rate:{final['congestion_rate']:.6f}")
print(f"Total points:      {final['total_points']}")
print(f"Total preds:       {final['total_preds']}")

final_combined = aggregate_combined(agg_combined)
print("=== Combined (shared y_pred) ===")
print(f"MAE={final_combined['MAE']:.6f}  RMSE={final_combined['RMSE']:.6f}  WMAPE={final_combined['WMAPE']:.6f}")
print(f"SCSR:  sudden_event_rate={final_combined['SUDDEN_EVENT_RATE']:.6f}  "
      f"jam_rate={final_combined['JAM_EVENT_RATE']:.6f}  rec_rate={final_combined['REC_EVENT_RATE']:.6f}")
print(f"ALPHA: recall={final_combined['recall']:.6f} precision={final_combined['precision']:.6f} "
      f"F1={final_combined['f1']:.6f} IoU={final_combined['iou']:.6f} accuracy={final_combined['accuracy']:.6f}")
print(f"ALPHA: gt_congestion_rate={final_combined['congestion_rate']:.6f}  "
      f"gt_cong={final_combined['gt_congested']}  pred_cong={final_combined['pred_congested']}")


FINAL MAE: 0.0
FINAL RMSE: 0.0
FINAL WMAPE: 0.0
FINAL SUDDEN_EVENT_RATE: 0.0

MAE:   0.000000
RMSE:  0.000000
WMAPE: 0.000000
Recall (hit rate): 0.000000
Precision:         0.000000
F1:                0.000000
IoU:               0.000000
Accuracy:          0.000000
GT congestion rate:0.000000
Total points:      0
Total preds:       0
=== Combined (shared y_pred) ===
MAE=1.530372  RMSE=1.767036  WMAPE=0.024488
SCSR:  sudden_event_rate=1.000000  jam_rate=1.000000  rec_rate=1.000000
ALPHA: recall=0.984189 precision=0.977885 F1=0.981027 IoU=0.962761 accuracy=0.986827
ALPHA: gt_congestion_rate=0.346045  gt_cong=2505152  pred_cong=2521302
