This notebook is the single source for all datasets.
Choose your dataset in the next cell. For **Something-Something** we use
`chunk_size=100`, `meta_lr=0.01`, and `ReduceLROnPlateau`.
For **HMDB/Kinetics/UCF** we use `chunk_size=300`, `meta_lr=0.001`, and `StepLR`.

Import require libraries

In [None]:
from pathlib import Path
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TransformerConv
from torch_geometric.data import Data
from collections import defaultdict
import numpy as np
import os
import random
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from torch.nn.utils import clip_grad_norm_
from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix, precision_recall_fscore_support
import time
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import copy  # for KD teacher

try:
    import psutil
except ImportError:
    psutil = None

try:
    np.float
except AttributeError:
    np.float = float
    np.int = int
    np.bool = bool
    np.object = object
    np.str = str
    np.long = int

from river.drift import ADWIN, PageHinkley
from sklearn.metrics.pairwise import cosine_similarity

try:
    from sklearn_extra.cluster import KMedoids
except ImportError:
    KMedoids = None

In [None]:
# Choose one: "HMDB", "Kinetics", "UCF", "Something-Something"
DATASET = "UCF"  # change name of the dataset 

if DATASET == "Something-Something":
    CHUNK_SIZE = 100
    META_LR    = 1e-2
    SCHEDULER  = "plateau"   # ReduceLROnPlateau (uses val F1)
else:
    CHUNK_SIZE = 300
    META_LR    = 1e-3
    SCHEDULER  = "step"      # StepLR

print(f"Dataset={DATASET} | chunk={CHUNK_SIZE} | meta_lr={META_LR} | scheduler={SCHEDULER}")


Memory and latency helpers functions

In [None]:
def get_memory_usage():
    if psutil is not None:
        process = psutil.Process(os.getpid())
        return process.memory_info().rss / (1024.0 * 1024.0)
    else:
        return None

def get_cuda_memory_usage():
    if torch.cuda.is_available():
        torch.cuda.synchronize()
        return torch.cuda.memory_allocated() / (1024.0 * 1024.0)
    else:
        return None

def plot_memory_usage(cpu_mem, gpu_mem=None):
    plt.figure(figsize=(10,4))
    plt.plot(cpu_mem, marker='o', label="CPU RAM (MB)")
    if gpu_mem is not None:
        plt.plot(gpu_mem, marker='s', label="GPU VRAM (MB)")
    plt.title("Memory Usage per Test Chunk")
    plt.xlabel("Test Chunk")
    plt.ylabel("Memory (MB)")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


Load STKG data from your local

In [None]:
def load_data_from_files(): # give full path names of files in de Data folder
    node_features = []
    with open('.../node_features.txt','r') as f:
        for line in f:
            node_features.append([float(x) for x in line.strip().split()[1:]])
    x = torch.tensor(node_features, dtype=torch.float)

    edge_index = []
    with open('.../edges.txt','r') as f:
        for line in f:
            edge_index.append([int(x) for x in line.strip().split()])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    edge_attr = []
    with open('.../edge_features.txt','r') as f:
        for line in f:
            edge_attr.append([float(x) for x in line.strip().split()[2:]])
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)

    y = []
    with open('.../node_labels.txt','r') as f:
        for line in f:
            y.append(int(line.strip().split()[1]))
    y = torch.tensor(y, dtype=torch.long)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

The definition of continual learning metrics

In [None]:
class CLMetrics:
    def __init__(self, num_chunks):
        self.best_acc_per_chunk = [0. for _ in range(num_chunks)]
        self.last_acc_per_chunk = [0. for _ in range(num_chunks)]
        self.first_acc_per_chunk = [0. for _ in range(num_chunks)]
        self.seen = [False for _ in range(num_chunks)]
        self.chunk_history = [[] for _ in range(num_chunks)]
# Adaptation epochs with highest accuracy
        self.adaptation_epochs = [None for _ in range(num_chunks)]

    def update(self, chunk_idx, acc, epoch):
        self.chunk_history[chunk_idx].append(acc)
        if acc > self.best_acc_per_chunk[chunk_idx]:
            self.best_acc_per_chunk[chunk_idx] = acc
            self.adaptation_epochs[chunk_idx] = epoch
        if not self.seen[chunk_idx]:
            self.first_acc_per_chunk[chunk_idx] = acc
            self.seen[chunk_idx] = True
        self.last_acc_per_chunk[chunk_idx] = acc

    def average_forgetting(self): # Average of each chunk (best - last)
        diffs = [b - l for b, l in zip(self.best_acc_per_chunk, self.last_acc_per_chunk)]
        return np.mean(diffs)

    def average_forgetting_std(self):
        diffs = [b - l for b, l in zip(self.best_acc_per_chunk, self.last_acc_per_chunk)]
        return np.std(diffs)

    def adaptation_speed(self):
        valid = [e for e in self.adaptation_epochs if e is not None]
        return np.mean(valid) if valid else None

    def adaptation_speed_std(self):
        valid = [e for e in self.adaptation_epochs if e is not None]
        return np.std(valid) if valid else None

    def report(self):
        print("\n=== Continual Learning Metrics ===")
        avg_forget = self.average_forgetting()
        std_forget = self.average_forgetting_std()
        avg_adapt = self.adaptation_speed()
        std_adapt = self.adaptation_speed_std()
        print(f" Average Forgetting: {avg_forget:.4f} ± {std_forget:.4f}")
        if avg_adapt is not None:
            print(f" Adaptation Speed (epoch of peak acc): {avg_adapt:.4f} ± {std_adapt:.4f}")
        else:
            print(f" Adaptation Speed (epoch of peak acc): None")

Definition of drift function

In [None]:
def detect_concept_drift(acc_history, window=5, threshold=2.5):
    drifts = []
    for i in range(len(acc_history)):
        if i < window:
            drifts.append(False)
        else:
            window_mean = np.mean(acc_history[i - window:i])
            window_std = np.std(acc_history[i - window:i]) + 1e-8
            if acc_history[i] < window_mean - threshold * window_std:
                drifts.append(True)
            else:
                drifts.append(False)
    return drifts

def detect_drift_adwin(acc_history, delta=0.1):
    adwin = ADWIN(delta=delta)
    drifts = []
    for acc in acc_history:
        adwin.update(acc)
        drifts.append(adwin.drift_detected)
    return drifts

def detect_drift_pagehinkley(acc_history, threshold=50, alpha=0.999, min_instances=5):
    ph = PageHinkley(threshold=threshold, alpha=alpha, min_instances=min_instances)
    drifts = []
    for acc in acc_history:
        ph.update(acc)
        drifts.append(ph.drift_detected)
    return drifts


Definition of forgetting rate function

In [None]:
chunk_acc_history = defaultdict(list)
def windowed_forgetting_rate(chunk_id, current_acc, window_size=5):
    history = chunk_acc_history[chunk_id]
    if not history:
        chunk_acc_history[chunk_id].append(current_acc)
        return 0.0
    window = history[-window_size:]
    forget = sum(max(0, prev - current_acc) for prev in window) / len(window)
    chunk_acc_history[chunk_id].append(current_acc)
    return forget

Definition of plots for drifts

In [None]:

def plot_adwin_drift(accs, drift_flags):
    x = np.arange(len(accs))
    plt.figure(figsize=(12,5))
    plt.plot(x, accs, label='Chunk Accuracy')
    drift_points = [i for i, flag in enumerate(drift_flags) if flag]
    plt.scatter(drift_points, [accs[i] for i in drift_points], color='red', label='ADWIN Drift Detected', zorder=5)
    plt.xlabel("Test Chunk")
    plt.ylabel("Accuracy")
    plt.title("ADWIN Concept Drift Detection")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_drift_detection(values, drift_flags, metric_name="Accuracy"):
    x = np.arange(len(values))
    plt.figure(figsize=(12,5))
    plt.plot(x, values, label=f"Chunk {metric_name}")
    drift_points = [i for i, flag in enumerate(drift_flags) if flag]
    plt.scatter(drift_points, [values[i] for i in drift_points], color='red', label='Drift Detected', zorder=5)
    plt.xlabel("Test Chunk")
    plt.ylabel(metric_name)
    plt.title(f"Z-Score Concept Drift Detection ({metric_name})")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


Focal loss definition handling class imbalance 

In [None]:

class ClasswiseFocalLoss(torch.nn.Module):
    def __init__(self, alpha=None, gamma=None, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        CE_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-CE_loss)
        gamma_t = self.gamma[targets] if self.gamma is not None else 2.0
        alpha_t = self.alpha[targets] if self.alpha is not None else 1.0
        focal_loss = alpha_t * (1 - pt) ** gamma_t * CE_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

def get_classwise_focal_loss(num_classes, device):
    gamma_per_class = torch.tensor([2.0] * num_classes, dtype=torch.float).to(device)
    alpha_per_class = torch.tensor([1.0] * num_classes, dtype=torch.float).to(device)
    return ClasswiseFocalLoss(alpha=alpha_per_class, gamma=gamma_per_class)


Definition of confusion matrix before training

In [None]:

def plot_confusion_matrix(y_true, y_pred, num_classes):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=list(range(num_classes)),
                yticklabels=list(range(num_classes)))
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.show()

Preparing data for continuous learning real time simulation

In [None]:
def shuffle_and_split_data_into_time_chunks(data, timestamps, chunk_size=500, overlap_ratio=0.1):
    overlap = int(chunk_size * overlap_ratio)
    step_size = chunk_size - overlap
    max_time = timestamps.max().item()
    num_chunks = int((max_time - chunk_size) // step_size) + 1
    chunks = []
    for i in range(num_chunks):
        start_time = i * step_size
        end_time = start_time + chunk_size
        node_mask = (timestamps >= start_time) & (timestamps < end_time)
        selected_nodes = torch.nonzero(node_mask, as_tuple=False).squeeze()
        if selected_nodes.numel() == 0:
            continue
        if selected_nodes.dim() == 0:
            selected_nodes = selected_nodes.unsqueeze(0)
        selected_nodes = selected_nodes.tolist()

        node_id_map = {old_id: new_id for new_id, old_id in enumerate(selected_nodes)}
        chunk_x = data.x[selected_nodes]
        chunk_y = data.y[selected_nodes]

        edge_mask = [(src in selected_nodes and dst in selected_nodes)
                     for src, dst in data.edge_index.t().tolist()]
        edge_mask = torch.tensor(edge_mask, dtype=torch.bool)
        chunk_edge_index = data.edge_index[:, edge_mask]

        chunk_edge_attr = None
        if data.edge_attr is not None and data.edge_attr.size(0) == data.edge_index.size(1):
            chunk_edge_attr = data.edge_attr[edge_mask]

        if chunk_edge_index.numel() > 0:
            reindexed_edges = [[node_id_map[e[0]], node_id_map[e[1]]] 
                               for e in chunk_edge_index.t().tolist()]
            chunk_edge_index = torch.tensor(reindexed_edges).t().long()
        else:
            chunk_edge_index = torch.empty((2, 0), dtype=torch.long)

        chunk_data = Data(x=chunk_x, y=chunk_y, edge_index=chunk_edge_index)
        if chunk_edge_attr is not None:
            chunk_data.edge_attr = chunk_edge_attr.clone().detach()

        chunks.append(chunk_data)
        print(f"Chunk {i} created ({start_time}-{end_time}) with {len(selected_nodes)} nodes.")
    return chunks

def split_chunks_time_aware(chunks, timestamps, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-5
    chunk_time_means = []
    node_start = 0
    for chunk in chunks:
        num_nodes = chunk.x.shape[0]
        chunk_timestamps = timestamps[node_start:node_start + num_nodes]
        chunk_time_mean = torch.mean(chunk_timestamps.float()).item()
        chunk_time_means.append((chunk_time_mean, chunk))
        node_start += num_nodes
    chunk_time_means.sort(key=lambda x: x[0])
    sorted_chunks = [c for _, c in chunk_time_means]
    total = len(sorted_chunks)
    train_end = int(total * train_ratio)
    val_end = int(total * (train_ratio + val_ratio))
    train_chunks = sorted_chunks[:train_end]
    val_chunks = sorted_chunks[train_end:val_end]
    test_chunks = sorted_chunks[val_end:]
    return train_chunks, val_chunks, test_chunks

def generate_balanced_timestamps(y):
    indices = torch.arange(len(y))
    shuffled_indices = indices[torch.randperm(len(indices))]
    timestamps = torch.zeros_like(y, dtype=torch.float)
    timestamps[shuffled_indices] = torch.arange(len(y)).float()
    return timestamps


Function and class definition of cache cleaning during continuous learning

In [None]:

def kmedoids_clean(embeddings, max_size=300):
    if KMedoids is None:
        return np.arange(len(embeddings))
    N = len(embeddings)
    if N <= max_size:
        return np.arange(N)
    k = max_size // 3
    if k < 1:
        k = 1
    kmed = KMedoids(n_clusters=k, method='pam', metric='euclidean',
                    init='k-medoids++', random_state=42)
    kmed.fit(embeddings)
    centers = kmed.medoid_indices_
    return centers


class TGNMemory:
    def __init__(self, hidden_dim=128):
        self.memory_dict = {}
        self.hidden_dim = hidden_dim

    def initialize_node(self, node_id):
        if node_id not in self.memory_dict:
            self.memory_dict[node_id] = torch.zeros(self.hidden_dim) # Create a zero vector for each node

    def get_memory(self, node_id):
        return self.memory_dict.get(node_id, torch.zeros(self.hidden_dim))

    def update_memory(self, node_ids, new_states): # After each chunk in the training cycle, we will take the embedding1 output and update the memory
        for i, nid in enumerate(node_ids):
            nid = int(nid.item())
            self.memory_dict[nid] = new_states[i].detach().cpu().clone()

    def clear_memory(self):
        self.memory_dict = {}

HybridSelectiveReplayBuffer definition for catastrophic forgetting mitigation


In [None]:

class HybridSelectiveReplayBuffer:
    def __init__(self, max_size=500, sampling_mode='proportional'):
        self.buffer = []
        self.max_size = max_size
        self.sampling_mode = sampling_mode

    def add(self, chunk, model):
        model.eval()
        with torch.no_grad():
            logits, _ = model(chunk.x, chunk.edge_index, edge_attr=chunk.edge_attr)
            probs = torch.softmax(logits, dim=1)
            uncertainty = torch.var(probs, dim=1).mean().item()
            embedding = torch.mean(chunk.x, dim=0).cpu().numpy()
            class_dist = torch.bincount(chunk.y, minlength=probs.size(1)).float()
            dominant_class = torch.argmax(class_dist).item()
        self.buffer.append((chunk.clone().detach(), uncertainty, embedding, dominant_class))
        if len(self.buffer) > self.max_size:
            self.buffer.pop(0)

    def sample(self, batch_size):
        from collections import defaultdict
        if len(self.buffer) == 0:
            return []
        class_map = defaultdict(list)
        for idx, item in enumerate(self.buffer):
            class_map[item[3]].append(idx)
        total_len = len(self.buffer)
        classes = list(class_map.keys())
        proportions = {c: len(class_map[c]) / total_len for c in classes}
        selected_idxs = []
        for c in classes:
            n = int(round(proportions[c] * batch_size))
            selected_idxs.extend(class_map[c][:n])
        if len(selected_idxs) < batch_size:
            leftover = [i for i in range(len(self.buffer)) if i not in selected_idxs]
            leftover_sorted = sorted(leftover, key=lambda i: self.buffer[i][1], reverse=True)
            needed = batch_size - len(selected_idxs)
            selected_idxs.extend(leftover_sorted[:needed])
        return [self.buffer[i][0] for i in selected_idxs]

    def clean(self, max_size=300):
        old_len = len(self.buffer)
        if old_len <= max_size:
            return
        embeddings = np.array([item[2] for item in self.buffer])
        keep_indices = kmedoids_clean(embeddings, max_size=max_size)
        self.buffer = [self.buffer[idx] for idx in keep_indices]
        print(f"Buffer cleaning with K-Medoids: {old_len} -> {len(self.buffer)}")

Self-attention to maintain high-level semantic consistency

In [None]:
class STSelfAttention(torch.nn.Module):
    def __init__(self, hidden_channels, time_embed_dim=32, original_edge_dim=0):
        super().__init__()
        self.time_encoder = torch.nn.Linear(1, time_embed_dim)
        self.combined_edge_dim = time_embed_dim + original_edge_dim
        self.attn = TransformerConv(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            heads=4,
            edge_dim=self.combined_edge_dim
        )
        self.norm = torch.nn.LayerNorm(hidden_channels * 4)
        self.linear = torch.nn.Linear(hidden_channels * 4, hidden_channels)
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, x, edge_index, edge_attr=None, timestamps=None):
        E = edge_index.size(1)
        if E == 0:
            return x
        if timestamps is not None and E > 0:
            time_diff = torch.abs(timestamps[edge_index[0]] - timestamps[edge_index[1]]).unsqueeze(1)
            time_embed = self.time_encoder(time_diff)
        else:
            time_embed = None

        if edge_attr is None:
            combined_edge_attr = time_embed
        else:
            if time_embed is not None:
                combined_edge_attr = torch.cat([edge_attr, time_embed], dim=1)
            else:
                combined_edge_attr = edge_attr

        out = self.attn(x, edge_index, combined_edge_attr)
        out = self.norm(out)
        out = F.relu(out)
        out = self.linear(out)
        return self.dropout(out)

This layer integrates structural and temporal information

In [None]:

class TemporalEmbeddingLayer(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, time_embed_dim=32):
        super().__init__()
        self.gcn = GCNConv(in_channels - 1, hidden_channels) # time property in the last column
        self.time_encoder = torch.nn.Sequential(
            torch.nn.Linear(1, time_embed_dim),
            torch.nn.ReLU(),
            torch.nn.LayerNorm(time_embed_dim)
        )
        self.feature_fusion = torch.nn.Linear(hidden_channels + time_embed_dim, hidden_channels)
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, x, edge_index):
        features, time_feat = x[:, :-1], x[:, -1:]
        features = F.relu(self.gcn(features, edge_index))
        time_embed = self.time_encoder(time_feat)
        combined = torch.cat([features, time_embed], dim=1)
        return self.dropout(self.feature_fusion(combined))

Episodic Graph Pattern Memory preserves historical context

In [None]:
class EGPM(torch.nn.Module): 
    def __init__(self, hidden_channels):
        super().__init__()
        self.gru = torch.nn.GRU(hidden_channels, hidden_channels, batch_first=True)

    def forward(self, x, hidden_state=None):
        # x: [n_nodes, hidden_channels]
        x, hidden_state = self.gru(x.unsqueeze(0), hidden_state)
        return x.squeeze(0), hidden_state


It is for faster adaptation to dynamic data distributions and unforeseen concept drifts

In [None]:

class AdaptiveMetaLearning(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.linear = torch.nn.Linear(hidden_channels, hidden_channels)
        self.norm = torch.nn.LayerNorm(hidden_channels)
        self.dropout = torch.nn.Dropout(0.5)

    def forward(self, x):
        return self.dropout(F.relu(self.norm(self.linear(x))))


def meta_update(model, inner_params, meta_lr=0.001):
    with torch.no_grad():
        for (n, p), (n_i, p_i) in zip(model.named_parameters(), inner_params):
            if p.grad is not None:
                p.data = p.data - meta_lr * (p.data - p_i.data)


MAIN MODEL STRUCTURE

In [None]:

class CAST_GRN(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels,
        num_classes,
        time_embed_dim=32,
        original_edge_dim=0,
        memory_module: TGNMemory = None
    ):
        super().__init__()
        self.time_norm = torch.nn.LayerNorm(1)
        self.embedding1 = TemporalEmbeddingLayer(in_channels, hidden_channels, time_embed_dim)
        self.embedding2 = TemporalEmbeddingLayer(hidden_channels, hidden_channels, time_embed_dim)
        self.egpm = EGPM(hidden_channels)
        self.attn = STSelfAttention(
            hidden_channels,
            time_embed_dim,
            original_edge_dim=original_edge_dim
        )
        self.meta = AdaptiveMetaLearning(hidden_channels)
        self.classifier = torch.nn.Linear(hidden_channels, num_classes)

        self.memory_module = memory_module  #External memory_module reference
        self.hidden_channels = hidden_channels

    def forward(self, x, edge_index, edge_attr=None, hidden_state=None):
        
        time_feat = x[:, -1:].clone()  # Time normalization and embedding
        x = torch.cat([x[:, :-1], self.time_norm(time_feat)], dim=1)
        x = self.embedding1(x, edge_index)
        x = self.embedding2(x, edge_index)

        x, hidden_state = self.egpm(x, hidden_state) # Temporal sorting with EGPM (GRU)

        if self.memory_module is not None: #TGNMemeory usage step
            n_nodes = x.size(0)
            mem_list = []
            for node_id in range(n_nodes):
                mem_vec = self.memory_module.get_memory(node_id)  # CPU’da
                mem_list.append(mem_vec)
            mem_tensor = torch.stack(mem_list, dim=0).to(x.device)
            x = x + mem_tensor

        # Self-Attention layer (edge + time embedding)
        x = self.attn(x, edge_index, edge_attr=edge_attr, timestamps=time_feat.squeeze())
        x = self.meta(x)

        out = self.classifier(x)
        return out, hidden_state

For indicating the importance of parameter

In [None]:

def compute_fisher_information(model, val_chunks, optimizer, criterion):
    fisher = {n: torch.zeros_like(p) for n, p in model.named_parameters()}
    model.train()
    for chunk in val_chunks:
        optimizer.zero_grad()
        logits, _ = model(chunk.x, chunk.edge_index, edge_attr=chunk.edge_attr)
        loss = criterion(logits, chunk.y)
        loss.backward()
        for n, p in model.named_parameters():
            if p.grad is not None:
                fisher[n] += p.grad.pow(2)
    for n in fisher:
        fisher[n] /= len(val_chunks)
    return fisher

EWC restricts significant changes to model parameters

In [None]:
def ewc_loss(model, fisher, prev_params, lambda_ewc=1.0):
    return lambda_ewc * sum(
        (fisher[n] * (p - prev_params[n]).pow(2)).sum()
        for n, p in model.named_parameters()
    )

To observe class distribution

In [None]:

def print_label_statistics(chunks, split_name="train"):
    print(f"\n Label statistics for {split_name} chunks:")
    for i, chunk in enumerate(chunks, 1):
        counts = defaultdict(int)
        for label in chunk.y.tolist():
            counts[label] += 1
        stats = ', '.join(f"{label}: {counts[label]}" for label in sorted(counts))
        print(f"  Chunk {i}: {stats}")

Initialize model weights

In [None]:

def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d)):
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)


Definition of plot functions to observe performance of model through different metrics

In [None]:
def plot_training_progress(epoch_accuracies, epoch_f1_scores,
                           epoch_forget_rates, epoch_times, window_forgets,
                           all_epoch_acc_dist):
    epochs = list(range(1, len(epoch_accuracies) + 1))
    plt.figure(figsize=(18, 10))

    plt.subplot(2, 3, 1)
    plt.plot(epochs, epoch_accuracies, marker='o')
    plt.title('Accuracy over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)

    plt.subplot(2, 3, 2)
    plt.plot(epochs, epoch_f1_scores, marker='s', color='orange')
    plt.title('F1 Score over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.grid(True)

    plt.subplot(2, 3, 3)
    plt.plot(epochs, epoch_forget_rates, marker='^', color='red')
    plt.title('Average Forgetting over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Forgetting')
    plt.grid(True)

    plt.subplot(2, 3, 4)
    plt.plot(epochs, epoch_times, marker='d', color='green')
    plt.title('Training Time per Epoch (s)')
    plt.xlabel('Epoch')
    plt.ylabel('Seconds')
    plt.grid(True)

    plt.subplot(2, 3, 5)
    plt.plot(epochs, window_forgets, marker='*', color='purple')
    plt.title('Windowed Forgetting Rate')
    plt.xlabel('Epoch')
    plt.ylabel('Window Forget')
    plt.grid(True)

    plt.subplot(2, 3, 6)
    sns.boxplot(data=all_epoch_acc_dist)
    plt.title('Chunk Accuracy Distribution per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.xticks(ticks=range(len(all_epoch_acc_dist)), 
               labels=range(1, len(all_epoch_acc_dist) + 1))
    plt.tight_layout()
    plt.show()


def plot_temporal_coherence_trend(coherence_scores):
    plt.figure(figsize=(8, 4))
    plt.plot(coherence_scores, marker='o')
    plt.title("Temporal Coherence per Test Chunk")
    plt.xlabel("Test Chunk")
    plt.ylabel("Coherence Score")
    plt.ylim(0, 1.05)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_forgetting_curve(cl_metrics):
    plt.figure(figsize=(10, 4))
    plt.plot(cl_metrics.best_acc_per_chunk, label="Best Accuracy per Chunk",
             linewidth=2, color='royalblue')
    plt.plot(cl_metrics.last_acc_per_chunk, label="Last Accuracy per Chunk",
             linewidth=2, color='orange')
    plt.title("Forgetting Curve (Chunk Accuracy)")
    plt.xlabel("Chunk")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

To achieve optimal balance between rapid adaptation and long-term knowledge retention

In [None]:
def knowledge_distillation_loss(student_logits, teacher_logits, alpha=0.3, temperature=1.0):
    kd = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=1),
        F.softmax(teacher_logits / temperature, dim=1),
        reduction='batchmean'
    ) * (temperature ** 2)
    return alpha * kd

In [None]:

def train_model(
    chunks,
    val_chunks,
    model,
    optimizer,
    criterion,
    buffer,
    fisher,
    prev_params,
    num_epochs=10,
    use_meta_update=False,
    meta_lr=0.001,
    use_tgn_memory=False,
    use_kd=False,
    kd_alpha=0.3,
    kd_temperature=1.0
):
    if SCHEDULER == "step":
        scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
        def scheduler_step(val_f1):
            scheduler.step()
    else:
        scheduler = ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)
        def scheduler_step(val_f1):
            scheduler.step(val_f1)


    accs, f1s, forgets, times = [], [], [], []
    val_AC, val_F = [], []
    window_forgets = []
    epoch_chunk_accuracies = []
    per_class_f1_history = defaultdict(lambda: [])
    replay_class_counter = defaultdict(int)
    best_val_f1 = 0.0
    best_model_state = None

    num_chunks = len(chunks)
    cl_metrics = CLMetrics(num_chunks)

    if use_tgn_memory:
        for chunk in chunks:
            n_nodes = chunk.x.size(0)
            for nid in range(n_nodes):
                model.memory_module.initialize_node(nid)

    teacher_model = None
    if use_kd:
        teacher_model = copy.deepcopy(model)
        teacher_model.eval()

    for epoch in range(num_epochs):
        model.train()
        epoch_accs = []
        epoch_f1s_local = []
        start = time.time()

        if use_meta_update:
            saved_params = [(n, p.clone().detach()) for n, p in model.named_parameters()]

        for i, chunk in enumerate(chunks):
            optimizer.zero_grad()
            logits, hidden_state = model(chunk.x, chunk.edge_index, edge_attr=chunk.edge_attr)
            loss = criterion(logits, chunk.y)

            # EWC Loss after first epoch
            if epoch > 0:
                loss += ewc_loss(model, fisher, prev_params)

            if epoch > 1 and len(buffer.buffer) > 5: #taking samples from replay
                replay_samples = min(5 + epoch, 40)
                for replay_data in buffer.sample(replay_samples):
                    rlogits, _ = model(replay_data.x, replay_data.edge_index, edge_attr=replay_data.edge_attr)
                    loss += criterion(rlogits, replay_data.y)
                    for lbl in replay_data.y.tolist():
                        replay_class_counter[lbl] += 1

            # Knowledge Distillation
            if use_kd and teacher_model is not None:
                with torch.no_grad():
                    tlogits, _ = teacher_model(chunk.x, chunk.edge_index, edge_attr=chunk.edge_attr)
                kd_loss_val = knowledge_distillation_loss(
                    student_logits=logits,
                    teacher_logits=tlogits,
                    alpha=kd_alpha,
                    temperature=kd_temperature
                )
                loss += kd_loss_val

            loss.backward()
            clip_grad_norm_(model.parameters(), 0.8)
            optimizer.step()

            # update TGNMemory
            if use_tgn_memory:
                node_ids = torch.arange(chunk.x.size(0))
                with torch.no_grad():
                    mem_states = model.embedding1(chunk.x, chunk.edge_index)
                model.memory_module.update_memory(node_ids, mem_states)

            buffer.add(chunk, model)

            preds = logits.argmax(dim=1)
            acc = accuracy_score(chunk.y.cpu().numpy(), preds.cpu().numpy())
            f1_ = f1_score(chunk.y.cpu().numpy(), preds.cpu().numpy(), average='weighted')

            # CLMetrics update (best/last accuracy)
            cl_metrics.update(i, acc, epoch)

            chunk_true = chunk.y.cpu().numpy()
            chunk_pred = preds.cpu().numpy()
            n_classes = int(chunk.y.max().item()) + 1
            prec, rec, f1_scores, _ = precision_recall_fscore_support(
                chunk_true, chunk_pred, zero_division=0, labels=range(n_classes)
            )
            for class_id in range(n_classes):
                per_class_f1_history[class_id].append(f1_scores[class_id])

            epoch_accs.append(acc)
            epoch_f1s_local.append(f1_)

            print(f"Epoch {epoch+1} Chunk {i+1}/{len(chunks)}"
                  f" | Loss: {loss.item():.4f}"
                  f" | Acc: {acc:.4f}"
                  f" | F1: {f1_:.4f}")

        epoch_chunk_accuracies.append(epoch_accs)
        pass  # scheduler step moved after validation
        times.append(time.time() - start)

        avg_acc = np.mean(epoch_accs)
        avg_f1 = np.mean(epoch_f1s_local)

        current_avg_forget = cl_metrics.average_forgetting()
        accs.append(avg_acc)
        f1s.append(avg_f1)
        forgets.append(current_avg_forget)

        # Windowed forgetting
        window_vals = []
        for idx_ in range(len(epoch_accs)):
            if idx_ >= 3:
                max_past = max(epoch_accs[idx_-3: idx_])
                window_vals.append(max(0, max_past - epoch_accs[idx_]))
        avg_window_forget = np.mean(window_vals) if window_vals else 0
        window_forgets.append(avg_window_forget)

        # Validation
        val_accs_epoch = []
        val_f1s_epoch = []
        model.eval()
        for val_chunk in val_chunks:
            with torch.no_grad():
                val_out, _ = model(val_chunk.x, val_chunk.edge_index, edge_attr=val_chunk.edge_attr)
                preds_val = val_out.argmax(dim=1)
            vacc = accuracy_score(val_chunk.y.cpu().numpy(), preds_val.cpu().numpy())
            vf1 = f1_score(val_chunk.y.cpu().numpy(), preds_val.cpu().numpy(), average='weighted')
            val_accs_epoch.append(vacc)
            val_f1s_epoch.append(vf1)
        val_f1_ = np.mean(val_f1s_epoch)
        scheduler_step(val_f1_)
        val_acc_ = np.mean(val_accs_epoch)
        val_AC.append(val_acc_)
        val_F.append(val_f1_)

        if use_meta_update:
            with torch.no_grad():
                new_params = [(n, p.clone().detach()) for n, p in model.named_parameters()]
                meta_update(model, new_params, meta_lr=meta_lr)

        print(f"\n Epoch {epoch+1} Summary:"
              f" Time: {times[-1]:.2f}s"
              f" | Avg Train Acc: {avg_acc:.4f}"
              f" | Avg Train F1: {avg_f1:.4f}"
              f" | Avg Forgetting: {current_avg_forget:.4f}"
              f" | Windowed Forget: {avg_window_forget:.4f}"
              f" | Val Acc: {val_acc_:.4f}"
              f" | Val F1: {val_f1_:.4f}\n")

        if use_kd and teacher_model is not None:
            teacher_model.load_state_dict(model.state_dict())
            teacher_model.eval()

        buffer.clean(max_size=300)

        if val_f1_ > best_val_f1:
            best_val_f1 = val_f1_
            best_model_state = model.state_dict()

    return (
        accs, f1s, forgets, times, best_model_state,
        val_AC, val_F, window_forgets,
        per_class_f1_history, replay_class_counter,
        epoch_chunk_accuracies, cl_metrics
    )

To measure prediction consistency across sequential inputs

In [None]:

def temporal_coherence(preds, timestamps):
    if len(preds) <= 1:
        return 1.0
    sorted_indices = np.argsort(timestamps)
    sorted_preds = np.array(preds)[sorted_indices]
    transitions = np.sum(sorted_preds[1:] != sorted_preds[:-1])
    max_transitions = len(sorted_preds) - 1
    return 1.0 - (transitions / max_transitions)

In [None]:
def evaluate_model(model, data, verbose=False, log_memory=True):
    model.eval()
    device = next(model.parameters()).device
    data = data.to(device)
    with torch.no_grad():
        start = time.time()
        ram_start = get_memory_usage() if log_memory else None
        cuda_mem_start = get_cuda_memory_usage() if log_memory else None

        logits, _ = model(data.x, data.edge_index, edge_attr=data.edge_attr)
        inf_time = time.time() - start

        ram_end = get_memory_usage() if log_memory else None
        cuda_mem_end = get_cuda_memory_usage() if log_memory else None

        preds = logits.argmax(dim=1).cpu()
        acc = accuracy_score(data.y.cpu().numpy(), preds.numpy())
        f1 = f1_score(data.y.cpu().numpy(), preds.numpy(), average='weighted')

        mem_usage = ram_end if ram_end is not None else None
        cuda_usage = cuda_mem_end if cuda_mem_end is not None else None

        if verbose:
            print(f"Prediction Classes: {torch.unique(preds)}"
                  f" | True Classes: {torch.unique(data.y)}")
            print("Classification Report:")
            print(classification_report(data.y.cpu().numpy(),
                                        preds.numpy(), zero_division=0))

    return acc, f1, inf_time, preds.numpy(), data.y.cpu().numpy(), mem_usage, cuda_usage

Run the integrated main model

In [None]:
def main():
    seed_value = 42
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)

    data = load_data_from_files()

    # Timestamp creation
    if os.path.exists("timestamps.pt"):
        os.remove("timestamps.pt")

    timestamps = generate_balanced_timestamps(data.y)
    torch.save(timestamps, "timestamps.pt")

    timestamps = torch.load("timestamps.pt").float()

    # Build chunks -> then split (time-aware)
    chunks = shuffle_and_split_data_into_time_chunks(
        data, timestamps, chunk_size=CHUNK_SIZE, overlap_ratio=0.1
    )
    train_chunks, val_chunks, test_chunks = split_chunks_time_aware(
        chunks, timestamps, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1
    )

    device = torch.device("cpu") # you can change with GPU if you have
    memory_module = TGNMemory(hidden_dim=128)

    original_edge_dim = (
        data.edge_attr.size(1)
        if hasattr(data, "edge_attr") and data.edge_attr is not None
        else 0
    )

    model = CAST_GRN(
        in_channels=data.x.size(1),
        hidden_channels=128,
        num_classes=len(torch.unique(data.y)),
        time_embed_dim=32,
        original_edge_dim=original_edge_dim,
        memory_module=memory_module,
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

    criterion = get_classwise_focal_loss(
        num_classes=len(torch.unique(data.y)),
        device=data.x.device,
    )

    buffer = HybridSelectiveReplayBuffer(
        max_size=300,
        sampling_mode="uncertainty-class",  # chose best sampling mode
    )

    fisher = compute_fisher_information(model, val_chunks, optimizer, criterion)
    prev_params = {n: p.clone() for n, p in model.named_parameters()}

    print_label_statistics(train_chunks, "train")
    print_label_statistics(val_chunks, "validation")
    print_label_statistics(test_chunks, "test")

    results = train_model(
        chunks=train_chunks,
        val_chunks=val_chunks,
        model=model,
        optimizer=optimizer,
        criterion=criterion,
        buffer=buffer,
        fisher=fisher,
        prev_params=prev_params,
        num_epochs=10,
        use_meta_update=True,
        meta_lr=META_LR,
        use_tgn_memory=True,
        use_kd=True,
        kd_alpha=0.3,
        kd_temperature=1.0,
    )

    (
        accs, f1s, forgets, times_list,
        best_model_state, val_AC, val_F,
        window_forgets, per_class_f1_history,
        replay_class_counter, all_epoch_acc_dist,
        cl_metrics
    ) = results

    plot_training_progress(
        val_AC, val_F, forgets, times_list, window_forgets, all_epoch_acc_dist
    )

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    # Test & analysis
    all_preds, all_labels = [], []
    test_accs, test_f1s, test_times = [], [], []
    coherence_scores = []
    test_cpu_mem, test_gpu_mem = [], []

    for test_data in test_chunks:
        test_data = test_data.to(device)
        acc, f1_, t, preds, labels, ram, cuda_mem = evaluate_model(
            model, test_data, log_memory=True
        )
        if len(preds) > 1:
            transitions = sum(preds[i] != preds[i + 1] for i in range(len(preds) - 1))
            coherence = 1.0 - transitions / (len(preds) - 1)
        else:
            coherence = 1.0

        test_accs.append(acc)
        test_f1s.append(f1_)
        test_times.append(t)
        coherence_scores.append(coherence)
        all_preds.extend(preds)
        all_labels.extend(labels)
        test_cpu_mem.append(ram)
        test_gpu_mem.append(cuda_mem)

    print("\n Final Combined Classification Report:")
    print(classification_report(all_labels, all_preds, zero_division=0))

    plot_temporal_coherence_trend(coherence_scores)

    print("\n=== Final Test Results ===")
    print(f" Avg Test Accuracy: {np.mean(test_accs):.4f} ± {np.std(test_accs):.4f}")
    print(f" Avg F1 Score: {np.mean(test_f1s):.4f} ± {np.std(test_f1s):.4f}")
    print(f" Avg Temporal Coherence: {np.mean(coherence_scores):.4f} ± {np.std(coherence_scores):.4f}")
    print(f" Avg Inference Time: {np.mean(test_times):.4f}s ± {np.std(test_times):.4f}")

    # Continual metrics
    avg_forget = cl_metrics.average_forgetting() * 100
    std_forget = cl_metrics.average_forgetting_std() * 100
    avg_adapt = cl_metrics.adaptation_speed()
    std_adapt = cl_metrics.adaptation_speed_std()

    print("\n=== CONTINUAL METRICS ===")
    print(f"Average Forgetting (%): {avg_forget:.4f} ± {std_forget:.4f}")
    if avg_adapt is not None:
        print(f"Adaptation Speed (epoch of peak acc): {avg_adapt:.4f} ± {std_adapt:.4f}")
    else:
        print("Adaptation Speed (epoch of peak acc): None")

    lat_ms = np.array(test_times) * 1000.0
    avg_lat = np.mean(lat_ms)
    std_lat = np.std(lat_ms)
    print(f"Average Inference Latency (ms): {avg_lat:.3f} ± {std_lat:.3f}")

    plot_confusion_matrix(all_labels, all_preds, num_classes=len(torch.unique(data.y)))
    plot_forgetting_curve(cl_metrics)
    #plot_memory_usage(test_cpu_mem, test_gpu_mem)

    drift_flags_acc = detect_concept_drift(test_accs, window=5, threshold=2.5)
    print("Z-Score drift (Accuracy):", [i for i, f in enumerate(drift_flags_acc) if f])
    plot_drift_detection(test_accs, drift_flags_acc, "Accuracy")

    drift_flags_adw = detect_drift_adwin(test_accs, delta=0.5)
    print("ADWIN (river) drift (Accuracy):", [i for i, f in enumerate(drift_flags_adw) if f])
    plot_adwin_drift(test_accs, drift_flags_adw)

    drift_flags_ph = detect_drift_pagehinkley(test_accs, threshold=2, alpha=0.9, min_instances=2)
    print("Page-Hinkley drift (Accuracy):", [i for i, f in enumerate(drift_flags_ph) if f])

    x = np.arange(len(test_accs))
    plt.figure(figsize=(12, 5))
    plt.plot(x, test_accs, label="Chunk Accuracy")
    drift_points = [i for i, flag in enumerate(drift_flags_ph) if flag]
    plt.scatter(
        drift_points, [test_accs[i] for i in drift_points],
        label="Page-Hinkley Drift", zorder=5
    )
    plt.xlabel("Test Chunk")
    plt.ylabel("Accuracy")
    plt.title("Page-Hinkley Concept Drift Detection")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    main()
