Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes #2462: Categorical hashing #2487

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ServerModules.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ RandMsg
IndexingMsg
UniqueMsg
In1dMsg
HashMsg
HistogramMsg
SequenceMsg
SortMsg
Expand Down
24 changes: 24 additions & 0 deletions arkouda/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,30 @@ def unique(self) -> Categorical:
arange(self._categories_used.size), self._categories_used, NAvalue=self.NAvalue
)

def hash(self) -> Tuple[pdarray, pdarray]:
"""
Compute a 128-bit hash of each element of the Categorical.

Returns
-------
Tuple[pdarray,pdarray]
A tuple of two int64 pdarrays. The ith hash value is the concatenation
of the ith values from each array.

Notes
-----
The implementation uses SipHash128, a fast and balanced hash function (used
by Python for dictionaries and sets). For realistic numbers of strings (up
to about 10**15), the probability of a collision between two 128-bit hash
values is negligible.
"""
rep_msg = generic_msg(
cmd="categoricalHash",
args={"objType": self.objType, "categories": self.categories, "codes": self.codes},
)
hashes = json.loads(rep_msg)
return create_pdarray(hashes["upperHash"]), create_pdarray(hashes["lowerHash"])

def group(self) -> pdarray:
"""
Return the permutation that groups the array, placing equivalent
Expand Down
50 changes: 30 additions & 20 deletions arkouda/numeric.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import json
from enum import Enum
from typing import TYPE_CHECKING, ForwardRef, List, Optional, Tuple, Union
from typing import ForwardRef, List, Optional, Tuple, Union
from typing import cast as type_cast
from typing import no_type_check

import numpy as np # type: ignore
from typeguard import typechecked

if TYPE_CHECKING:
from arkouda.segarray import SegArray

from arkouda.client import generic_msg
from arkouda.dtypes import (
BigInt,
Expand All @@ -28,6 +25,7 @@
from arkouda.strings import Strings

Categorical = ForwardRef("Categorical")
SegArray = ForwardRef("SegArray")

__all__ = [
"cast",
Expand All @@ -46,8 +44,6 @@
"ErrorMode",
]

hashable = Union[pdarray, Strings, "SegArray"]


class ErrorMode(Enum):
strict = "strict"
Expand Down Expand Up @@ -428,26 +424,37 @@ def cos(pda: pdarray) -> pdarray:
return create_pdarray(repMsg)


def _hash_helper(a: hashable):
from arkouda import SegArray as Segarray_
def _hash_helper(a):
from arkouda import Categorical as Categorical_
from arkouda import SegArray as SegArray_

if isinstance(a, Segarray_):
if isinstance(a, SegArray_):
return json.dumps(
{"segments": a.segments.name, "values": a.values.name, "valObjType": a.values.objType}
)
elif isinstance(a, Categorical_):
return json.dumps({"categories": a.categories.name, "codes": a.codes.name})
else:
return a.name


# this is # type: ignored and doesn't actually do any type checking
# the type hints are there as a reference to show which types are expected
# type validation is done within the function
def hash(
pda: Union[hashable, List[hashable]], full: bool = True
pda: Union[ # type: ignore
Union[pdarray, Strings, SegArray, Categorical],
List[Union[pdarray, Strings, SegArray, Categorical]],
],
full: bool = True,
) -> Union[Tuple[pdarray, pdarray], pdarray]:
"""
Return an element-wise hash of the array or list of arrays.

Parameters
----------
pda : Union[pdarray, Strings, Segarray], List[Union[pdarray, Strings, Segarray]]]
pda : Union[pdarray, Strings, Segarray, Categorical],
List[Union[pdarray, Strings, Segarray, Categorical]]]

full : bool
This is only used when a single pdarray is passed into hash
Expand Down Expand Up @@ -484,25 +491,28 @@ def hash(
fixed key for the hash, which makes it possible for an
adversary with control over input to engineer collisions.

In the case of a list of pdrrays, Strings, or Segarrays
In the case of a list of pdrrays, Strings, Categoricals, or Segarrays
being passed, a non-linear function must be applied to each
array since hashes of subsequent arrays cannot be simply XORed
because equivalent values will cancel each other out, hence we
do a rotation by the ordinal of the array.
"""
from arkouda import SegArray as Segarray_
from arkouda import Categorical as Categorical_
from arkouda import SegArray as SegArray_

if isinstance(pda, (pdarray, Strings, Segarray_)):
if isinstance(pda, (pdarray, Strings, SegArray_, Categorical_)):
return _hash_single(pda, full) if isinstance(pda, pdarray) else pda.hash()
elif isinstance(pda, List):
if any(wrong_type := [not isinstance(a, (pdarray, Strings, Segarray_)) for a in pda]):
if any(
wrong_type := [not isinstance(a, (pdarray, Strings, SegArray_, Categorical_)) for a in pda]
):
raise TypeError(
f"Unsupported type {type(pda[np.argmin(wrong_type)])}. Supported types are pdarray,"
f" SegArray, Strings, and Lists of these types."
f" SegArray, Strings, Categoricals, and Lists of these types."
)
types_list = [a.objType for a in pda]
names_list = [_hash_helper(a) for a in pda]
repMsg = type_cast(
rep_msg = type_cast(
str,
generic_msg(
cmd="hashList",
Expand All @@ -514,12 +524,12 @@ def hash(
},
),
)
a, b = repMsg.split("+")
return create_pdarray(a), create_pdarray(b)
hashes = json.loads(rep_msg)
return create_pdarray(hashes["upperHash"]), create_pdarray(hashes["lowerHash"])
else:
raise TypeError(
f"Unsupported type {type(pda)}. Supported types are pdarray,"
f" SegArray, Strings, and Lists of these types."
f" SegArray, Strings, Categoricals, and Lists of these types."
)


Expand Down
14 changes: 11 additions & 3 deletions src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,17 @@ module AryUtil
hasStr = true;
}
when ObjType.CATEGORICAL {
// passed only Categorical.codes.name to be sorted on
var g = getGenericTypedArrayEntry(name, st);
thisSize = g.size;
if st.contains(name) {
// passed only Categorical.codes.name to be sorted on
var g = getGenericTypedArrayEntry(name, st);
thisSize = g.size;
}
else {
var catComps = jsonToMap(name);
var codesName = catComps["codes"];
var codes = getGenericTypedArrayEntry(codesName, st);
thisSize = codes.size;
}
}
when ObjType.SEGARRAY {
var segComps = jsonToMap(name);
Expand Down
29 changes: 0 additions & 29 deletions src/EfuncMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -745,34 +745,6 @@ module EfuncMsg
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc hashArraysMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
var n = msgArgs.get("length").getIntValue();
var s = msgArgs.get("size").getIntValue();
var namesList = msgArgs.get("nameslist").getList(n);
var typesList = msgArgs.get("typeslist").getList(n);
var (size, hasStr, names, types) = validateArraysSameLength(n, namesList, typesList, st);

// Call hashArrays on list of given array names
var hashes = hashArrays(size, names, types, st);
var upper = makeDistArray(s, uint);
var lower = makeDistArray(s, uint);

// Assign upper and lower bit values to their respective entries
forall (up, low, h) in zip(upper, lower, hashes) {
(up, low) = h;
}

var upperName = st.nextName();
st.addEntry(upperName, new shared SymEntry(upper));
var lowerName = st.nextName();
st.addEntry(lowerName, new shared SymEntry(lower));

var repMsg = "created %s+created %s".format(st.attrib(upperName), st.attrib(lowerName));
eLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}


/* The 'where' function takes a boolean array and two other arguments A and B, and
returns an array with A where the boolean is true and B where it is false. A and B
can be vectors or scalars.
Expand Down Expand Up @@ -872,5 +844,4 @@ module EfuncMsg
registerFunction("efunc3vs", efunc3vsMsg, getModuleName());
registerFunction("efunc3sv", efunc3svMsg, getModuleName());
registerFunction("efunc3ss", efunc3ssMsg, getModuleName());
registerFunction("hashList", hashArraysMsg, getModuleName());
}
99 changes: 99 additions & 0 deletions src/HashMsg.chpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
module HashMsg {
use Reflection;
use ServerErrors;
use ServerErrorStrings;
use Logging;
use Message;

use MultiTypeSymbolTable;
use MultiTypeSymEntry;
use CommAggregation;
use ServerConfig;
use SegmentedString;
use AryUtil;
use UniqueMsg;
use ArkoudaMapCompat;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
const hmLogger = new Logger(logLevel, logChannel);

proc categoricalHash(categoriesName: string, codesName: string, st: borrowed SymTab) throws {
var categories = getSegString(categoriesName, st);
var codes = toSymEntry(getGenericTypedArrayEntry(codesName, st), int);
// hash categories first
var hashes = categories.siphash();
// then do expansion indexing at codes
ref ca = codes.a;
var expandedHashes: [ca.domain] (uint, uint);
forall (eh, c) in zip(expandedHashes, ca) with (var agg = newSrcAggregator((uint, uint))) {
agg.copy(eh, hashes[c]);
}
var hash1 = makeDistArray(ca.size, uint);
var hash2 = makeDistArray(ca.size, uint);
forall (h, h1, h2) in zip(expandedHashes, hash1, hash2) {
(h1,h2) = h:(uint,uint);
}
return (hash1, hash2);
}

proc categoricalHashMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
var pn = Reflection.getRoutineName();
var repMsg: string;
const objtype = msgArgs.getValueOf("objType").toUpper(): ObjType;
if objtype != ObjType.CATEGORICAL {
var errorMsg = notImplementedError(pn, objtype: string);
hmLogger.error(getModuleName(),getRoutineName(),getLineNumber(),errorMsg);
return new MsgTuple(errorMsg, MsgType.ERROR);
}
const categoriesName = msgArgs.getValueOf("categories");
const codesName = msgArgs.getValueOf("codes");
st.checkTable(categoriesName);
st.checkTable(codesName);
var (upper, lower) = categoricalHash(categoriesName, codesName, st);
var upperName = st.nextName();
st.addEntry(upperName, new shared SymEntry(upper));
var lowerName = st.nextName();
st.addEntry(lowerName, new shared SymEntry(lower));
var createdMap = new map(keyType=string,valType=string);
createdMap.add("upperHash", "created %s".format(st.attrib(upperName)));
createdMap.add("lowerHash", "created %s".format(st.attrib(lowerName)));
repMsg = "%jt".format(createdMap);
hmLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

proc hashArraysMsg(cmd: string, msgArgs: borrowed MessageArgs, st: borrowed SymTab): MsgTuple throws {
var n = msgArgs.get("length").getIntValue();
var s = msgArgs.get("size").getIntValue();
var namesList = msgArgs.get("nameslist").getList(n);
var typesList = msgArgs.get("typeslist").getList(n);
var (size, hasStr, names, types) = validateArraysSameLength(n, namesList, typesList, st);

// Call hashArrays on list of given array names
var hashes = hashArrays(size, names, types, st);
var upper = makeDistArray(s, uint);
var lower = makeDistArray(s, uint);

// Assign upper and lower bit values to their respective entries
forall (up, low, h) in zip(upper, lower, hashes) {
(up, low) = h;
}

var upperName = st.nextName();
st.addEntry(upperName, new shared SymEntry(upper));
var lowerName = st.nextName();
st.addEntry(lowerName, new shared SymEntry(lower));

var createdMap = new map(keyType=string,valType=string);
createdMap.add("upperHash", "created %s".format(st.attrib(upperName)));
createdMap.add("lowerHash", "created %s".format(st.attrib(lowerName)));
var repMsg = "%jt".format(createdMap);
hmLogger.debug(getModuleName(),getRoutineName(),getLineNumber(),repMsg);
return new MsgTuple(repMsg, MsgType.NORMAL);
}

use CommandMap;
registerFunction("hashList", hashArraysMsg, getModuleName());
registerFunction("categoricalHash", categoricalHashMsg, getModuleName());
}
10 changes: 9 additions & 1 deletion src/UniqueMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ module UniqueMsg
use SipHash;
use CommAggregation;
use SegmentedArray;
use HashMsg;

private config const logLevel = ServerConfig.logLevel;
private config const logChannel = ServerConfig.logChannel;
Expand Down Expand Up @@ -250,7 +251,7 @@ module UniqueMsg
}
for (name, objtype, i) in zip(names, types, 0..) {
select objtype.toUpper(): ObjType {
when ObjType.PDARRAY, ObjType.CATEGORICAL {
when ObjType.PDARRAY {
var g = getGenericTypedArrayEntry(name, st);
select g.dtype {
when DType.Int64 {
Expand Down Expand Up @@ -295,6 +296,13 @@ module UniqueMsg
h ^= rotl((u,l), i);
}
}
when ObjType.CATEGORICAL {
var catComps = jsonToMap(name);
var (upper, lower) = categoricalHash(catComps["categories"], catComps["codes"], st);
forall (h, u, l) in zip(hashes, upper, lower) {
h ^= rotl((u,l), i);
}
}
}
}
return hashes;
Expand Down
Loading