# Circular Buffer

In [14]:
class Circular_Buffer:
    def __init__(self, size):
        self.buffer = [None] * size
        self.size = size
        self.start = 0  # Points to the oldest element
        self.count = 0  # Number of elements currently in the buffer

    def append(self, value):
        index = (self.start + self.count) % self.size
        if self.count < self.size:
            self.buffer[index] = value
            self.count += 1
            return None  # No value was overwritten
        else:
            overwritten = self.buffer[self.start]
            self.buffer[self.start] = value
            self.start = (self.start + 1) % self.size
            return overwritten

    def set_at(self, index, value):
        """Set value at a relative index within the buffer (0 = oldest)."""
        if index < 0 or index >= self.count:
            raise IndexError(f"Index out of bounds in circular buffer {index} {value}")
        real_index = (self.start + index) % self.size
        self.buffer[real_index] = value

    def get(self, idx):
      """Return the element at the given relative index (0 is the oldest element)."""
      if idx < 0 or idx >= self.count:
          raise IndexError(f"Index out of bounds in circular buffer. {idx}")
      return self.buffer[(self.start + idx) % self.size]

    def print_array(self):
        """Print the contents of the buffer in order from oldest to newest."""
        elements = [self.buffer[(self.start + i) % self.size] for i in range(self.count)]
        print(elements)

    def get_array(self):
        return [self.buffer[(self.start + i) % self.size] for i in range(self.count)]
    
    def __repr__(self):
        return f"Circular_Buffer({self.get()})"


# Data Window

In [4]:
class Data_Window:
    def __init__(self, data_window_size = 1000):
        self.abs_idx_max = -1 # abs_idx_max is the absolute index most recent point inserted
        self.abs_idx_min = 0
        self.data_window_size = data_window_size
        self.assigned_cluster_id_window = Circular_Buffer(data_window_size)
        self.is_point_labeled_window = Circular_Buffer(data_window_size)
        self.data_in_window = Circular_Buffer(data_window_size)
        self.last_removed_cluster_id = None # Cluster ID of most recently forgotten point, the abs_idx of that point is abs_idx_min - 1

    def insert_data(self, data_point):
        self.data_in_window.append(data_point)
        self.abs_idx_max += 1
        self.abs_idx_min = max(0, self.abs_idx_max - self.data_window_size + 1)

        self.last_removed_cluster_id = self.assigned_cluster_id_window.append(None)

        self.is_point_labeled_window.append(False)

    # def insert_cluster_id(self, cluster_id):
    #     self.last_removed_cluster_id = self.assigned_cluster_id_window.append(cluster_id)

    def get_data_point(self, abs_index):
        if not (self.abs_idx_min <= abs_index <= self.abs_idx_max):
            raise IndexError(f"abs_index {abs_index} is out of the window range "
                             f"[{self.abs_idx_min}, {self.abs_idx_max})")

        dw_index = abs_index - self.abs_idx_min
        return self.data_in_window.get(dw_index)

    def update_cluster_id_at(self, abs_index, new_id):
        if not (self.abs_idx_min <= abs_index <= self.abs_idx_max):
            raise IndexError(f"abs_index {abs_index} is out of the window range "
                             f"[{self.abs_idx_min}, {self.abs_idx_max})")

        dw_index = abs_index - self.abs_idx_min
        self.assigned_cluster_id_window.set_at(dw_index, new_id)

    def updated_labeled_window(self, abs_index):
      if not (self.abs_idx_min <= abs_index <= self.abs_idx_max):
            raise IndexError(f"abs_index {abs_index} is out of the window range "
                             f"[{self.abs_idx_min}, {self.abs_idx_max})")

      dw_index = abs_index - self.abs_idx_min
      self.is_point_labeled_window.set_at(dw_index, True)

# Labeled Data

In [5]:
class Labeled_Data:
  def __init__(self):
    self.abs_idx_array = []
    self.data_array = []
    self.cluster_id_array = []
    self.label_array = []
    self.relevance_array = []

  def add_point(self, abs_idx, data_point, cluster_id, label, relevance):
    self.abs_idx_array.append(abs_idx)
    self.data_array.append(data_point)
    self.cluster_id_array.append(cluster_id)
    self.label_array.append(label)
    self.relevance_array.append(relevance)

  def get_data(self, abs_idx):
    return self.data_array[self.abs_idx_array.index(abs_idx)]

  def get_ld_index(self, abs_idx):
    return self.abs_idx_array.index(abs_idx)

# Subspace Partition

In [6]:
class Subspace_Partition:
    def __init__(self):        #                                                                   (l_pts)          (o_pts)
        self.cluster_list = [] # cluster is expected to be in the format of [label, relevance, [abs_idx_l_pt], [abs_idx_o_pt], diameter]
        self.set_of_known_labels = set()
        # cluster id is the cluster's index in cluster_list

    def create_new_cluster(self, label, relevance, l_pts, o_pts, labeled_data):
      self.set_of_known_labels.add(label)
      self.cluster_list.append(Cluster(label, relevance, l_pts, o_pts, labeled_data))

# Cluster

In [7]:
class Cluster:
  def __init__(self, label, relevance, l_pts, o_pts, labeled_data):
    self.label = label
    self.relevance = relevance
    self.l_pts = l_pts
    self.o_pts = o_pts
    self.diameter = 0
    # cluster id is this cluster's position in Subspace_Partition.cluster_list

    if len(l_pts) > 1:
      self.update_diameter(labeled_data)

  def add_l_pt(self, abs_idx, labeled_data):
    self.l_pts.append(abs_idx)
    self.update_diameter(labeled_data)

  def add_o_pt(self, abs_idx):
    self.o_pts.append(abs_idx)

  def update_diameter(self, labeled_data):
    largest_distance = 0
    for i in range(len(self.l_pts)):
      for j in range(i):
        data_l_pt_i = labeled_data.get_data(self.l_pts[i])
        data_l_pt_j = labeled_data.get_data(self.l_pts[j])
        distance = np.linalg.norm(data_l_pt_i - data_l_pt_j)
        if largest_distance < distance:
            largest_distance = distance
    self.diameter = largest_distance

# Oracle

In [8]:
class Oracle:
  def __init__(self, X, y):
    self.X = X #[[data]]
    self.y = y #[[label, relevant]]

  def answer_query(self, abs_index):
    label = self.y[abs_index][0]
    relevance = self.y[abs_index][1]
    return (label, relevance)


# Data Stream

In [9]:
class Data_Stream:
  def __init__(self, X, y):
    self.X = X #[[data]]
    self.y = y #[[label, relevant]]
    self.stream_counter = 0

  def stream_new_data_point(self):
    data_point = self.X[self.stream_counter]
    self.stream_counter += 1
    return data_point

  def get_remaining_num_points(self):
    return len(self.X) - self.stream_counter

# ARED

In [15]:
## import numpy as np
import matplotlib.pyplot as plt

class ARED:
    def __init__(self, oracle, kappa=1.0, data_window_size=1000, verbose = False):
        self.kappa = kappa
        self.data_window = Data_Window(data_window_size)
        self.labeled_data = Labeled_Data()
        self.subspace_partition = Subspace_Partition()
        self.oracle = oracle
        self.verbose = verbose


    def process_first_point(self, data_point):

      # Insert data point into data_window
      self.data_window.insert_data(data_point)
      data_point_abs_idx = self.data_window.abs_idx_max

      # START QUERY
      label, relevance = self.query(data_point_abs_idx)
      # END QUERY

      cluster_id = 0

      # Update data_window.assigned_cluster_id_window
      self.data_window.update_cluster_id_at(0, 0)

      # Create new cluster
      self.labeled_data.add_point(data_point_abs_idx, data_point, cluster_id, label, relevance) #cluster_id = 0
      self.subspace_partition.create_new_cluster(label, relevance, [data_point_abs_idx], [], self.labeled_data)

      if self.verbose:
        print("new cluster:", 0, [0])


    def determine_comparison_cluster(self, data_point):

      shortest_distance = np.inf
      closest = None

      for i, cluster in enumerate(self.subspace_partition.cluster_list):
        for abs_l_pt_index in cluster.l_pts:
          l_pt_data = self.labeled_data.get_data(abs_l_pt_index)
          distance = np.linalg.norm(l_pt_data - data_point)

          if distance < shortest_distance:
            shortest_distance = distance
            closest = (i, distance)

      return closest # (cluster_id, distance)


    def anomalous(self, data_point, cluster_id, distance):
        cluster = self.subspace_partition.cluster_list[cluster_id]

        if len(cluster.l_pts) <= 1:
          return True  # Can't define a diameter with fewer than 2 points

        # Point is anomalous if its distance is greater than the cluster's diameter
        return distance * self.kappa > cluster.diameter


    def query(self, abs_data_index):
        self.data_window.updated_labeled_window(abs_data_index)
        # return (label, relevance) from oracle
        return self.oracle.answer_query(abs_data_index)


    # ran when we add a new o_pt to a cluster
    def add_o_pt(self, abs_idx, cluster_id):

      if self.verbose:
        print("add_o_pt:", abs_idx, cluster_id)

      cluster = self.subspace_partition.cluster_list[cluster_id]
      cluster.add_o_pt(abs_idx)

      # update data_window.assigned_cluster_id_window
      self.data_window.update_cluster_id_at(abs_idx, cluster_id)


    # ran when we add a new labeled data point to a known cluster
    def add_l_pt(self, abs_idx, data_point, cluster_id):

      if self.verbose:
        print("add_l_pt:", abs_idx, cluster_id)

      # update cluster in subspace parition
      cluster = self.subspace_partition.cluster_list[cluster_id]

      # get label and relevance
      label = cluster.label
      relevance = cluster.relevance

      # update data_window.assigned_cluster_id_window
      self.data_window.update_cluster_id_at(abs_idx, cluster_id)

      # update labeled_data to have the new point
      self.labeled_data.add_point(abs_idx, data_point, cluster_id, label, relevance)

      # add point to cluster, so diameter gets updated properly
      cluster.add_l_pt(abs_idx, self.labeled_data)


    def split(self, data_point, data_point_idx, new_cluster_label, relevance, old_cluster_id):

      new_cluster_id = len(self.subspace_partition.cluster_list)
      self.labeled_data.add_point(data_point_idx, data_point, new_cluster_id, new_cluster_label, relevance)
      self.data_window.update_cluster_id_at(data_point_idx, new_cluster_id)
      self.subspace_partition.create_new_cluster(new_cluster_label, relevance, [data_point_idx], [], self.labeled_data)

      if self.verbose:
        print("new cluster:", new_cluster_id, [data_point_idx])

      # array to hold o_pt indexes during the split process
      new_cluster_o_pts_abs_inds = []
      old_cluster_o_pts_abs_inds = []

      # get o_pt indices
      o_pts_abs_inds_to_split = self.subspace_partition.cluster_list[old_cluster_id].o_pts

      if (len(o_pts_abs_inds_to_split) == 0):
        #print("No o_pts to split")
        return

      # get l_pt indices
      l_pt_inds = self.subspace_partition.cluster_list[old_cluster_id].l_pts

      # o_pt_index is an abs_idx
      for o_pt_index in o_pts_abs_inds_to_split:
          o_pt = self.data_window.get_data_point(o_pt_index)

          # find the closest labeled point in the exisiting cluster
          distance_to_existing = min([
              np.linalg.norm(o_pt - self.labeled_data.get_data(l_pt_index))
              for l_pt_index in l_pt_inds
          ])

          # get the distance to the labeled point in the new cluster
          distance_to_new = np.linalg.norm(o_pt - data_point)

          # put the o_pt in the closest cluster of the two
          if distance_to_existing < distance_to_new:
              old_cluster_o_pts_abs_inds.append(o_pt_index)
          else:
              print(distance_to_new, distance_to_existing, o_pt_index)
              new_cluster_o_pts_abs_inds.append(o_pt_index)

              # update the data window so the assigned_label_id_window is correct for window maintenance later
              self.data_window.update_cluster_id_at(o_pt_index, new_cluster_id)

      if self.verbose:
        print("Split :")
        print("old_cluster_id w/ o_pts:", old_cluster_id, old_cluster_o_pts_abs_inds)
        print("new_cluster_id w/ o_pts:", new_cluster_id, new_cluster_o_pts_abs_inds)

      # put the o_pts in their correct cluster
      self.subspace_partition.cluster_list[new_cluster_id].o_pts = new_cluster_o_pts_abs_inds # update o_pts new_cluster
      self.subspace_partition.cluster_list[old_cluster_id].o_pts = old_cluster_o_pts_abs_inds # update o_pts old_cluster

    def relevance_processing(self, new_cluster_id):
        pass

    # Removing forgotten o_pts from the subspace partition
    def subspace_partition_maintenance(self, forgotten_abs_idx, forgotten_point_cluster_id):

      cluster = self.subspace_partition.cluster_list[forgotten_point_cluster_id]

      if self.verbose:
          print(forgotten_abs_idx, forgotten_point_cluster_id)

      cluster.o_pts.remove(forgotten_abs_idx)

    def showframe(self, abs_index):
      im_data = self.data_window.get_data_point(abs_index)
      im_data = im_data.reshape([128, 128, 3])
      plt.imshow(im_data, cmap='gray')
      plt.title(f"Index: {abs_index}")
      plt.axis('off')

    def process_point(self, data_point):

      if self.verbose:
          print("labeled id array:", self.labeled_data.cluster_id_array)
          print("labeled abs array:", self.labeled_data.abs_idx_array)
          print("data window assigned id:", self.data_window.assigned_cluster_id_window.get_array())
        

      is_forgotten_point_labeled = self.data_window.is_point_labeled_window.get(0)

      self.data_window.insert_data(data_point)
      data_point_abs_idx = self.data_window.abs_idx_max

      forgotten_abs_idx = self.data_window.abs_idx_min - 1
      forgotten_pt_cluster_id = self.data_window.last_removed_cluster_id

      # if forgotten_pt_cluster_id is NOT None (ie a point has been fogotten) do maintenance
      if forgotten_pt_cluster_id != None and not is_forgotten_point_labeled:
        self.subspace_partition_maintenance(forgotten_abs_idx, forgotten_pt_cluster_id)

      # START DETERMINE COMPARISON CLUSTER

      comp_cluster_id, distance = self.determine_comparison_cluster(data_point)

      relevant = self.subspace_partition.cluster_list[comp_cluster_id].relevance

      # END DETERMINE COMPARISON CLUSTER

      # START NOT RELEVANT

      if not relevant:
        # START NOT ANOMALOUS
        if not self.anomalous(data_point, comp_cluster_id, distance):

          self.add_o_pt(data_point_abs_idx, comp_cluster_id)

          return # Data point processed, END Function

        # END NOT ANOMALOUS

      #END NOT RELEVANT

      # START QUERY
      label, relevant = self.query(data_point_abs_idx)
      #plt.figure()
      #self.showframe(data_point_abs_idx)
      # END QUERY

      # START NOT NEW LABEL
      label_is_new = label in self.subspace_partition.set_of_known_labels

      if not label_is_new:
        self.add_l_pt(data_point_abs_idx, data_point, comp_cluster_id)

        return # Data point processed, END Function

      # END NOT NEW LABEL

      # START NEW LABEL
      # create new cluster with the split o_pts

      self.split(data_point, data_point_abs_idx, label, relevant, comp_cluster_id)

      # END NEW LABEL

      # START RELEVANCE PROCESSING
      if relevant:
        self.relevance_processing(len(self.subspace_partition.cluster_list) - 1)
      # END RELEVANCE PROCESSING

      # POINT PROCESSED

# ARED on Parking lot data

In [None]:
import pickle
import cv2
import pandas as pd
import matplotlib.pyplot as plt


features_path = "./features.pkl"
labels_path = "./labels.csv"

with open(features_path, 'rb') as f:
    features = pickle.load(f)  # Expecting a list or array of 128x128 flattened frames

labels_df = pd.read_csv(labels_path)
print("Features loaded successfully.")
print("Labels loaded successfully.")

# Create Skewed MNIST

In [24]:
import pickle
from sklearn.datasets import fetch_openml

# Create skewed subset
def create_skewed_mnist(X, y, sparsity_levels, n_events):
    np.random.seed(42)
    digit_order = np.random.permutation(10)
    indices = []
    for digit, count in zip(digit_order, sparsity_levels):
        digit_indices = np.where(y == str(digit))[0]
        if len(digit_indices) >= count:
            selected = np.random.choice(digit_indices, count, replace=False)
            indices.extend(selected)
        else:
            print(f"Warning: Not enough samples for digit {digit}, using all {len(digit_indices)}")
            indices.extend(digit_indices)
    indices = np.array(indices)
    np.random.shuffle(indices)
    if len(indices) > n_events:
        indices = indices[:n_events]
    X_skewed = X[indices]
    y_skewed = y[indices]
    return X_skewed, y_skewed

def load_and_skew_mnist(sparsity_levels, n_events, save_path="mnist_full.pkl"):
    """
    Loads MNIST, creates a skewed subset using create_skewed_mnist.
    Also saves the full MNIST (X, y) to a pickle file.

    Args:
        sparsity_levels: list of 10 integers, number of samples per digit.
        n_events: total number of samples to include in final skewed dataset.
        save_path: path to store full MNIST data as pickle.

    Returns:
        X_skewed, y_skewed: filtered and shuffled MNIST subset.
        X, y: full MNIST dataset.
    """
    print("Loading MNIST from OpenML...")
    mnist = fetch_openml("mnist_784", version=1, as_frame=False)
    X, y = mnist.data, mnist.target  # y is a string array of digits

    print(f"Full MNIST loaded: {X.shape[0]} samples")

    # Save full dataset
    with open(save_path, "wb") as f:
        pickle.dump((X, y), f)
    print(f"Full MNIST (X, y) saved to {save_path}")

    # Create skewed subset
    print(f"Creating skewed subset with sparsity {sparsity_levels} and max {n_events} events...")
    X_skewed, y_skewed = create_skewed_mnist(X, y, sparsity_levels, n_events)

    print(f"Skewed dataset shape: {X_skewed.shape}")
    return X_skewed, y_skewed, X, y

sparsity_levels = [10000, 5000, 2000, 2000, 1000, 500, 300, 100, 50, 20]
n_events = sum(sparsity_levels)

X_skewed, y_skewed, X_full, y_full = load_and_skew_mnist(sparsity_levels, n_events)

def generate_is_relevant(y_skewed, relevant_digits={"8"}):
    # mark some digits as relevant (e.g., rare or important events)
    return np.array([label in relevant_digits for label in y_skewed], dtype=bool)

Loading MNIST from OpenML...
Full MNIST loaded: 70000 samples
Full MNIST (X, y) saved to mnist_full.pkl
Creating skewed subset with sparsity [10000, 5000, 2000, 2000, 1000, 500, 300, 100, 50, 20] and max 20970 events...
Skewed dataset shape: (17795, 784)


### ARED on Skewed MNIST

In [None]:
from sklearn.datasets import fetch_openml
from collections import Counter
import numpy as np

def generate_is_relevant(label_list, relevant_set):
    return [label in relevant_set for label in label_list]

def main():

    #STEP 1 run last cell

    # Step 2: Identify the 2 least common digits
    digit_counts = Counter(y_skewed)
    least_common_digits = [digit for digit, _ in digit_counts.most_common()[-2:]]

    print(f"Least common digits: {least_common_digits} (marked as relevant)")

    # Step 3: Generate relevance info
    relevance_array = generate_is_relevant(y_skewed, set(least_common_digits))
    y_w_rel = list(zip(y_skewed, relevance_array))

    # Step 4: Initialize Oracle and ARED
    data_stream = Data_Stream(X_skewed, y_w_rel)
    oracle = Oracle(X_skewed, y_w_rel)
    ared = ARED(oracle, 2, 1000, False)

    # Step 5: Stream data
    ared.process_first_point(data_stream.stream_new_data_point())
    for _ in range(data_stream.get_remaining_num_points()): #data_stream.get_remaining_num_points()
        ared.process_point(data_stream.stream_new_data_point())

    print("ARED COMPLETE")

    # === Matching stats ===
    average_o_pt_in_clusters = 0
    for cluster in ared.subspace_partition.cluster_list:
        average_o_pt_in_clusters += len(cluster.o_pts)
    average_o_pt_in_clusters /= len(ared.subspace_partition.cluster_list)
    print(average_o_pt_in_clusters)

    average_o_pt_in_clusters = 0
    clusters_w_o_pts = 0
    for cluster in ared.subspace_partition.cluster_list:
        if len(cluster.o_pts) != 0:
            clusters_w_o_pts += 1
            average_o_pt_in_clusters += len(cluster.o_pts)
    average_o_pt_in_clusters /= clusters_w_o_pts
    print(average_o_pt_in_clusters)

    average_l_pt_in_clusters = 0
    for cluster in ared.subspace_partition.cluster_list:
        average_l_pt_in_clusters += len(cluster.l_pts)
    average_l_pt_in_clusters /= len(ared.subspace_partition.cluster_list)
    print(average_l_pt_in_clusters)

    # total_points = 0
    # l_points = 0
    # for cluster in ared.subspace_partition.cluster_list:
    #     for l_pt in cluster.l_pts:
    #         l_points += 1
    #         if l_pt + 1 > total_points:
    #             total_points = l_pt + 1

    #     for o_pt in cluster.o_pts:
    #         if o_pt + 1 > total_points:
    #             total_points = o_pt + 1

    # print(total_points)
    # print(l_points)

    print(len(ared.labeled_data.data_array))

main()


Least common digits: ['3', '6'] (marked as relevant)


#Performance Evaluator

In [None]:
import numpy as np
from sklearn.metrics import precision_recall_curve, auc
from collections import defaultdict, Counter

class PerformanceEvaluator:
    def __init__(self, oracle):
        self.oracle = oracle
        self.predictions = []  # (abs_idx, pred_label, pred_relevance, confidence)
        self.ground_truth = {}  # abs_idx -> (true_label, true_relevance)
        self.class_discovery = {}  # class -> events_before_discovery
        self.total_events = 0
        self.total_queries = 0

    def record_prediction(self, abs_idx, pred_label, pred_relevance, confidence=1.0):
        """Record prediction (fast - O(1))"""
        self.predictions.append((abs_idx, pred_label, pred_relevance, confidence))

    def record_query(self, abs_idx, true_label, true_relevance):
        """Record oracle query (fast - O(1))"""
        self.ground_truth[abs_idx] = (true_label, true_relevance)
        self.total_queries += 1

        # Track class discovery for n_missed metric
        if true_label not in self.class_discovery:
            self.class_discovery[true_label] = self.total_events

    def record_point_processed(self):
        """Record point processed (fast - O(1))"""
        self.total_events += 1

    def get_ground_truth(self, abs_idx):
        """Get ground truth, query oracle if needed"""
        if abs_idx not in self.ground_truth:
            true_label, true_relevance = self.oracle.answer_query(abs_idx)
            self.ground_truth[abs_idx] = (true_label, true_relevance)
        return self.ground_truth[abs_idx]

    def compute_metrics(self):
        """Compute all metrics at once"""
        if not self.predictions:
            return {}

        # Get all ground truth
        y_true_labels = []
        y_pred_labels = []
        y_true_relevance = []
        y_pred_relevance = []

        for abs_idx, pred_label, pred_relevance, _ in self.predictions:
            true_label, true_relevance = self.get_ground_truth(abs_idx)
            y_true_labels.append(true_label)
            y_pred_labels.append(pred_label)
            y_true_relevance.append(true_relevance)
            y_pred_relevance.append(pred_relevance)

        # 1. BALANCED ACCURACY: Average of per-class recalls
        class_correct = defaultdict(int)
        class_total = defaultdict(int)
        for true_label, pred_label in zip(y_true_labels, y_pred_labels):
            class_total[true_label] += 1
            if true_label == pred_label:
                class_correct[true_label] += 1
        per_class_recall = {cls: class_correct[cls]/class_total[cls] for cls in class_total}
        balanced_accuracy = np.mean(list(per_class_recall.values()))

        # 2. CLASS DISCOVERY: Discovery rate + n_missed
        all_true_classes = set(y_true_labels)
        discovered_classes = set(self.class_discovery.keys())
        discovery_rate = len(discovered_classes) / len(all_true_classes)
        n_missed = sum(self.class_discovery.values())  # Total events before all discoveries

        # 3. RELEVANCE PREDICTION: TP, FP, FN counts
        tp = sum(1 for true_rel, pred_rel in zip(y_true_relevance, y_pred_relevance)
                if true_rel and pred_rel)
        fp = sum(1 for true_rel, pred_rel in zip(y_true_relevance, y_pred_relevance)
                if not true_rel and pred_rel)
        fn = sum(1 for true_rel, pred_rel in zip(y_true_relevance, y_pred_relevance)
                if true_rel and not pred_rel)

        rel_precision = tp / max(1, tp + fp)
        rel_recall = tp / max(1, tp + fn)  # This is rel_acc from paper
        rel_f1 = 2 * rel_precision * rel_recall / max(1, rel_precision + rel_recall)

        # 4. RELEVANT CLASS PERFORMANCE: Per-class for relevant classes only
        relevant_classes = set()
        for i, (true_label, pred_label) in enumerate(zip(y_true_labels, y_pred_labels)):
            if y_true_relevance[i]:
                relevant_classes.add(true_label)

        relevant_class_recalls = []
        for cls in relevant_classes:
            cls_correct = sum(1 for true_label, pred_label in zip(y_true_labels, y_pred_labels)
                            if true_label == cls and pred_label == cls)
            cls_total = sum(1 for true_label in y_true_labels if true_label == cls)
            cls_recall = cls_correct / max(1, cls_total)
            relevant_class_recalls.append(cls_recall)

        avg_relevant_recall = np.mean(relevant_class_recalls) if relevant_class_recalls else 0

        # 5. QUERY EFFICIENCY
        query_rate = self.total_queries / max(1, self.total_events)
        relevant_queries = sum(1 for abs_idx in self.ground_truth
                             if self.ground_truth[abs_idx][1])
        relevant_query_precision = relevant_queries / max(1, self.total_queries)

        return {
            # GOAL METRICS (most important)
            'discovery_rate': discovery_rate,           # Goal 1: Find all classes
            'avg_relevant_recall': avg_relevant_recall, # Goal 2: Find all points in relevant classes

            # PAPER METRICS (for literature comparison)
            'balanced_accuracy': balanced_accuracy,     # Classification across all classes
            'n_missed': n_missed,                      # Speed of class discovery
            'rel_acc': rel_recall,                     # Relevance prediction recall
            'query_rate': query_rate,                  # Query efficiency

            # DETAILED METRICS
            'rel_precision': rel_precision,
            'rel_f1': rel_f1,
            'tp': tp, 'fp': fp, 'fn': fn,
            'num_relevant_classes': len(relevant_classes),
            'relevant_query_precision': relevant_query_precision,
            'total_predictions': len(self.predictions),
            'total_queries': self.total_queries,
            'missed_classes': all_true_classes - discovered_classes
        }

    def compute_auprc_per_class(self, target_classes=None):
        """Compute AUPRC for specific classes (optional)"""
        if not self.predictions or not target_classes:
            return {}

        auprc_scores = {}
        for target_class in target_classes:
            y_true = []
            y_scores = []

            for abs_idx, pred_label, _, confidence in self.predictions:
                true_label, _ = self.get_ground_truth(abs_idx)
                y_true.append(true_label == target_class)
                y_scores.append(confidence if pred_label == target_class else 0.0)

            if sum(y_true) > 0:
                precision, recall, _ = precision_recall_curve(y_true, y_scores)
                auprc_scores[target_class] = auc(recall, precision)
            else:
                auprc_scores[target_class] = 0.0

        return auprc_scores

    def print_report(self, target_classes=None):
        """Print concise performance report"""
        metrics = self.compute_metrics()

        print("A/RED PERFORMANCE REPORT")
        print("=" * 50)

        # Goal metrics first
        print(f"GOAL 1 - Class Discovery: {metrics['discovery_rate']:.3f}")
        print(f"GOAL 2 - Relevant Recall: {metrics['avg_relevant_recall']:.3f}")

        # Paper metrics
        print(f"\nPAPER METRICS:")
        print(f"  Balanced Accuracy: {metrics['balanced_accuracy']:.3f}")
        print(f"  n_missed: {metrics['n_missed']}")
        print(f"  rel_acc: {metrics['rel_acc']:.3f}")
        print(f"  query_rate: {metrics['query_rate']:.3f}")

        # Key details
        print(f"\nDETAILS:")
        print(f"  Relevance - P:{metrics['rel_precision']:.3f} R:{metrics['rel_recall']:.3f} F1:{metrics['rel_f1']:.3f}")
        print(f"  TP:{metrics['tp']} FP:{metrics['fp']} FN:{metrics['fn']}")
        print(f"  Relevant classes: {metrics['num_relevant_classes']}")
        print(f"  Total predictions: {metrics['total_predictions']}")

        if metrics['missed_classes']:
            print(f"  Missed classes: {sorted(metrics['missed_classes'])}")

        # Optional AUPRC
        if target_classes:
            auprc = self.compute_auprc_per_class(target_classes)
            if auprc:
                print(f"\nAUPRC per class:")
                for cls, score in sorted(auprc.items()):
                    print(f"  Class {cls}: {score:.3f}")

        print("=" * 50)


# Simple integration example
def simple_integration_example():
    """
    Simple example of integrating with A/RED.

    During streaming:
    1. Call evaluator.record_prediction() after each prediction
    2. Call evaluator.record_query() when querying oracle
    3. Call evaluator.record_point_processed() for each point

    After streaming:
    4. Call evaluator.print_report() to see results
    """

    # Example usage:
    # evaluator = PerformanceEvaluator(oracle)

    # During streaming loop:
    # evaluator.record_point_processed()
    # evaluator.record_prediction(abs_idx, predicted_label, predicted_relevance)
    # if query_made:
    #     evaluator.record_query(abs_idx, true_label, true_relevance)

    # After streaming:
    # evaluator.print_report(rare_classes=['8', '9'])

    pass