Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 240 additions & 0 deletions python/src/count_sketch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np
from random import randint, seed
from streaming_heap import StreamingHeap


class CountSketch:
def __init__(self, num_buckets: int, num_levels: int, phi, rng_seed=100):
"""
# TODO: add string update functionality.
:param num_buckets:int -- number of columns in the sketch table
:param num_levels:int -- number of rows in the sketch table
:param phi:float -- The threshold for heavy hitters
:param rng_seed:int -- seed for the randomisation

Attributes:
- self.p :: int - a large prime used to generate the random hashes.
- self.w :: int - the number of hash buckets for the table
- self.d :: int - the number of levels that are repeated
- self.table :: np.ndarray of type float that is the table that is updated on viewing data.


HELPFUL SOURCES:
http://web.stanford.edu/class/archive/cs/cs166/cs166.1166/lectures/11/Small11.pdf
http://web.stanford.edu/class/archive/cs/cs166/cs166.1146/lectures/12/Small12.pdf
https://people.cs.umass.edu/~mcgregor/711S12/sketches1.pdf
https://cs.au.dk/~gerth/advising/thesis/jonas-nicolai-hovmand_morten-houmoeller-nygaard.pdf

Denote F2 = ||f||_2 where f is the frequency vector underlying the data stream.
eg. if stream = (1, 1, 2, 1, 5) then f = (0, 3, 1, 0, 0, 1).
If the stream is weighted i.e. we receive (item, weight) pairs, then the same idea applies:
stream = [(1, 4), (2, 1) (1, -1), (5, 1)] also has f = (0, 3, 1, 0, 0, 1) -- nb using 0-based indexing for
consistency with python.

The a-heavy hitters problem is to identify all items i for which f[i] > a*F2
Solving this problem requires linear space in the worst-case so a relaxed version is solved: to return
b-heavy hitters for b = a - t and t > 0.
Hence, we permit some _false positives_ : items that are flagged as being b-heavy, but may not be a-heavy.

A CountSketch can achieve this aim in small space by maintaining a sketch that is used to estimate f[i].
Specifically, the CountSketch returns an estimate g[i] for which
f[i] - epsilon*F2 <= g[i] <= f[i] + epsilon*F2.
The value epsilon is between 0 and 1 and is the worst-case error for estimating the frequency f[i] using the
value g[i].
epsilon can be explicitly calculated using parameters of the sketch self.w and self.d as seen in the function
def get_epsilon(self).

We return a set i of heavy indices by using the sketch and finding the (relaxed) b-heavy hitters.
The CountSketch guarantee ensures f[i] - epsilon*F2 <= g[i] so we will opt to find all i such that:
g[i] >= (b - epsilon)*F2 or equivalently g[i] >= ( (a - t) - epsilon)*F2.

We *set* phi ( = b - t ) and do not use the (a - t) setup and epsilon is explicitly known.
Note that there are two sources of approximation:
- the ``t'' is the heavy hitter approximate relaxation parameter
- the ``epsilon'' is the frequency estimation parameter.
"""
seed(rng_seed)
self.p = 2 ** 31 - 1
self.w = num_buckets
self.d = num_levels
self.table = np.zeros((self.d, self.w), dtype=float)
self._init_hashes()
self.phi = phi
assert(self.phi <= 1.0), f"Phi={self.phi:.5f} but cannot be larger than 1."
self.total_weight = 0.0 # Sum of all weights seen and is aka the L1 norm of the underlying frequency vector.
self.max_num_items = np.ceil(1./self.phi).astype(np.int64)
self.heavy_hitter_detector = StreamingHeap(self.max_num_items)
self.merged_sketches = [] # A list of all sketches that have been merged with self.

def _init_hashes(self):
"""
Initialises the hash functions for bucket and sign selection by generating (a,b) pairs for the hash family.
A new (a,b) pair is necessary for each of the hash functions
"""
self.bucket_hash_params = {i: self._get_ab_hash() for i in range(self.d)}
self.sign_hash_params = {i: self._get_ab_hash() for i in range(self.d)}

def _get_ab_hash(self):
"""
# TODO - check the nonzero property on a is correct
We need 0 <= a <= self.p - 1 and 0 <= b <= self.p - 1 as discussed on page 18
of https://people.cs.umass.edu/~mcgregor/711S12/sketches1.pdf
:return: the (a,b) pair used to define the hash family
"""
return randint(0, self.p - 1), randint(0, self.p - 1) # random (a,b) pairs from ([p], [p])

def _bucket_hash(self, x: int, a: int, b: int, buckets: int):
"""
Generic function for generating 2-wise independent hash functions.
We use this function for the bucket locations with buckets <- self.num_buckets.
It is also used to generate the random signs with buckets <- 2 -- see `` def _sign_hash ''

:param x: stream item
:param a:
:param b:
:param buckets:
:return: h:int the hash value (aka bucket index) for item x observed in the stream.
"""
h = (a * x + b) % self.p
h = h % buckets
return h

def _sign_hash(self, x: int, a: int, b: int):
"""
Returns the 2-wise independent sign hash.
This generates the signs when items are put into buckets.
:param x:stream item
:param a:
:param b:
:return:s -- int the sign +1 or -1 used for the hashing.
"""
s = 2.*self._bucket_hash(x, a, b, 2) - 1.
return s

def _insert(self, item: int, weight=1.0):
"""
Inserts the item into the sketch table
:param item:
:param weight:
"""
if not (isinstance(item, np.integer) or isinstance(item, int)):
# this checks for ``int`` and ``np.int*` for any input item
raise TypeError("Input item must be an int.")
for ii in range(self.d):
a_bucket, b_bucket = self.bucket_hash_params[ii] # Gets the (a,b) pair used for *bucket* hashes at level ii
a_sign, b_sign = self.sign_hash_params[ii] # Gets the (a,b) pair used for *sign* hashes at level ii
bucket = self._bucket_hash(item, a_bucket, b_bucket, self.w)
sign = self._sign_hash(item, a_sign, b_sign)
self.table[ii, bucket] += sign*weight

def update(self, item: int, weight=1.0):
"""
Updates the sketch by:
1. Inserting (item, weight) into self.table
2. Estimating the current frequency and pushing it into the heavy_hitter_detector
(nb. item only pushed into the detector if it has large enough frequency.
This logic is dealt with in the detector class itself.)
:param item:int
:param weight:float
"""
self.total_weight += weight
self._insert(item, weight)
self.heavy_hitter_detector.push(item, self.get_estimate(item))

def get_epsilon(self):
"""
:return: an estimate of the worst-case epsilon error achieved by the sketch for frequency estimation.
"""
return np.sqrt(np.e / self.w)

def get_estimate(self, item: int):
"""
:param item:int -- the identifier of the item whose frequency is being queried.
:return:np.median(_estimates) -- The count sketch frequency estimate.
"""
_estimates = self._get_frequency_estimate(item)
for sk in self.merged_sketches:
_estimates += sk._get_frequency_estimate(item)
return np.median(_estimates) # TODO - replace this with np.max(0.0, np.median(_estimate))?

def _get_frequency_estimate(self, item: int):
"""
:param item:
:return: _estimates -- an array of the frequency estimate where each entry corresponds to a level of the sketch.
"""
_buckets = {_: self._bucket_hash(item, self.bucket_hash_params[_][0], self.bucket_hash_params[_][1], self.w)
for _ in range(self.d)}
_signs = {_: self._sign_hash(item, self.sign_hash_params[_][0], self.sign_hash_params[_][1])
for _ in range(self.d)}
_estimates = np.array([self.table[_, _buckets[_]] * _signs[_] for _ in range(self.d)])
return _estimates

def get_frequent_items(self):
"""
:return: a list of the heavy hitters and their frequencies.
"""
for k in self.heavy_hitter_detector.heap.keys():
self.heavy_hitter_detector.heap[k] = self.get_estimate(k)
return list(self.heavy_hitter_detector.heap.items())

def get_total_weight(self):
"""
Returns the total weight inserted on the stream.
Is equivalent to the mass inserted over the stream, or the ell_1 norm of the
underlying frequency vector.
"""
return self.total_weight

def is_empty(self):
"""
Returns True if the sketch self.table is empty and is False otherwise.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a valid philosophical debate over whether adding and then subtracting an item means a sketch is truly empty. It might seem trivial with a single item, but imagine merging 2 sketches that were created with the same items but opposite weights -- if we treat it as empty, we lose the information that it has ingested data.

Certainly open to the idea that we don't care about a distinction, but it's worth thinking about what we want to capture by "empty"

"""
if not np.any(self.table):
# All zeros array
return True
else:
# There is a nonzero in self.table
return False

def merge(self, other_count_sketch):
"""
:param other_count_sketch: another count_sketch object to be merged into self.

The count sketch is a linear sketch so we can simple add the data structures self.table
to other_count_sketch.table.
Then we need to adjust the dictionaries of heavy hitters.
"""
try:
self.merged_sketches.append(other_count_sketch)
self.table += other_count_sketch.table
self.total_weight += other_count_sketch.total_weight
# Call a naive merge on StreamingHeap (could be improved by using min_heaps rather than just dicts.)
self.heavy_hitter_detector.merge(other_count_sketch.heavy_hitter_detector)
except AttributeError:
AttributeError("Argument must be a count sketch object")

def __str__(self):
return (
'### Count sketch summary:\n'
f' num. buckets : {self.w}\n'
f' num. levels : {self.d}\n'
f' worst-case error : {(self.get_epsilon()):.5f}\n'
f' Heavy hitter threshold : {self.phi:.5f}\n'
'### End sketch summary')
117 changes: 117 additions & 0 deletions python/src/streaming_heap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import numpy as np


class StreamingHeap:
def __init__(self, max_heap_size):
"""
This is a basic implementation of a heap used for maintaining heavy hitters
with the CountSketch.

We keep a dictionary of items in a (key, value) pair as (item, estimate_count)
in the dictionary.

Note that this *** IS NOT *** strictly a heap in the current implementation.
At least one optimisation would be to store the items in a heap data structure.
This would avoid the need for sorting in the third line of _set_heap_minimum.
"""
self.max_heap_size = max_heap_size
self.heap = {} # {-np.inf: -np.inf}
self.minimum_count = - np.inf # Set baseline for comparisons that will *always* be less than any inserted value

def _set_heap_minimum(self):
"""
Sets the minimum value of the heap for quicker threshold checking over the stream.
"""
k = np.array(list(self.heap.keys()))
v = np.array(list(self.heap.values()))
v_min_id = np.argmin(v) # !!! Arbitrary tie breaking for the first item !!!
self.minimum_count = v[v_min_id]
self.minimum_count_id = k[v_min_id]

def _is_space_in_heap(self):
"""
Return True if the heap is not full (i.e. has fewer keys than ``self.max_heap_size'') and False otherwise
"""
if len(self.heap) < self.max_heap_size:
return True
return False

def _item_enter_sketch(self, estimate_count):
"""
Checks if item with value estimate_count is larger than the current minimum.
If YES then return True so that we can permit the item into the sketch.
"""
if estimate_count > self.minimum_count:
return True
return False

def push(self, item, estimate_count):
"""
Inserts key ``item'' into the heap with value ``estimate_count''
"""
if item in self.heap.keys() or self._is_space_in_heap():
# If the item is in the current heap then simply adjust its counter
# Alternatively, if there is room in the heap, then just add a new counter.
self.heap[item] = estimate_count
elif not self._item_enter_sketch(estimate_count):
# Does nothing and returns if the item does not have large enough count to enter the sketch.
# print(f'Item {item}: count {estimate_count} too small')
return
else:
# The heap is full and the item is not in the heap so we remove (key, value) pair at
# heap root [i.e. the node root = (min_key, min_count) ] and
# insert a new (key, value) = (item, estimate_count) to the heap in place of (min_key, min_count)
# Then reset the minimum count and value id for the heap.
self.heap.pop(self.minimum_count_id)
self.heap[item] = estimate_count
self._set_heap_minimum()

def merge(self, other_streaming_heap):
"""
The merge step on heaps is done naively and could be improved by using min_heaps rather than just dicts.
:param other_streaming_heap:
:return:
"""
try:
all_k_v = self.heap.copy()
all_k_v.update(other_streaming_heap.heap)
sorted_values = sorted(all_k_v, key=all_k_v.get, reverse=True)
# Now we will delete all keys in self.heap and repopulate until it is large enough.
self.heap = {}
for i in range(self.max_heap_size):
k = sorted_values[i]
self.heap[k] = all_k_v[k]
all_k_v.clear() # Deletes the dictionary.
self._set_heap_minimum()
except TypeError:
raise TypeError("Argument must be a StreamingHeap object")

def keys(self):
"""
Overload the keys() functionality from dictionary
"""
return self.heap.keys()

def __str__(self):
"""
Returns the string version of the dictionary to print
Overloads the print() function on a dict
"""
return '{}'.format(self.heap)
Loading