From 6fabe09de40c9ae8666bc93887efabc4bfee6d25 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Wed, 12 Oct 2016 10:49:36 +0100 Subject: [PATCH 01/19] Factor out weighted Chi Squared calculation for cleaner code --- CHAID/tree.py | 1 - 1 file changed, 1 deletion(-) diff --git a/CHAID/tree.py b/CHAID/tree.py index f876b63..5172a90 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -207,7 +207,6 @@ def generate_best_split(self, ind, dep, wt=None): sufficient_split = highest_p_join < self.alpha_merge and all( sum(node_v.values()) >= self.min_child_node_size for node_v in freq.values() ) - if sufficient_split and len(freq.values()) > 1: n_ij = np.array([ [f[dep_val] for dep_val in all_dep] for f in freq.values() From 9450af3b41a5a100e93d716c08df9cd081f6e3d3 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Wed, 12 Oct 2016 13:35:00 +0100 Subject: [PATCH 02/19] =?UTF-8?q?Ordinal=20Variables=3F!=3F!=3F!=3F=20?= =?UTF-8?q?=F0=9F=98=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHAID/column.py | 36 ++++++++++++++++++++++++++++++++++++ CHAID/tree.py | 15 ++++++++++++--- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/CHAID/column.py b/CHAID/column.py index 8ec720b..1672250 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -135,3 +135,39 @@ def group(self, x, y): self._groupings[x] += self._groupings[y] del self._groupings[y] self._arr[self._arr == y] = x + +class OrdinalColumn(Column): + def __init__(self, arr=None, metadata=None, + missing_id='', substitute=True): + super(self.__class__, self).__init__(arr, metadata, missing_id) + + for x in np.unique(self._arr): + self._groupings[x] = [x] + + def deep_copy(self): + """ + Returns a deep copy. + """ + return OrdinalColumn(self._arr, metadata=self.metadata, + missing_id=self._missing_id, substitute=False) + + def __getitem__(self, key): + return OrdinalColumn(self._arr[key], metadata=self.metadata, substitute=False) + + def __setitem__(self, key, value): + self._arr[key] = value + return OrdinalColumn(np.array(self._arr), metadata=self.metadata, substitute=False) + + def groups(self): + return list(self._groupings.values()) + + def possible_groupings(self): + range_labels = sorted(list(self._groupings.keys())) + canditates = zip(range_labels[0:], range_labels[1:]) + adjacent = lambda x, y: (max(self._groupings[x]) + 1) == min(self._groupings[y]) + return enumerate((x, y) for x,y in canditates if adjacent(x, y)) + + def group(self, x, y): + self._groupings[x] += self._groupings[y] + del self._groupings[y] + self._arr[self._arr == y] = x diff --git a/CHAID/tree.py b/CHAID/tree.py index 5172a90..87d4b9d 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -61,15 +61,24 @@ class Tree(object): array of names for the independent variables in the data """ def __init__(self, ndarr, arr, alpha_merge=0.05, max_depth=2, min_parent_node_size=30, - min_child_node_size=0, split_titles=None, split_threshold=0, weights=None): + min_child_node_size=0, split_titles=None, split_threshold=0, weights=None, + variable_types=None): self.alpha_merge = alpha_merge self.max_depth = max_depth self.min_parent_node_size = min_parent_node_size self.min_child_node_size = min_child_node_size self.split_titles = split_titles or [] self.vectorised_array = [] - for ind in range(0, ndarr.shape[1]): - self.vectorised_array.append(NominalColumn(ndarr[:, ind])) + variable_types = variable_types or ['nominal'] * ndarr.shape[1] + for ind, col_type in enumerate(variable_types): + if col_type == 'ordinal': + col = OrdinalColumn(ndarr[:, ind]) + elif col_type == 'nominal': + col = NominalColumn(ndarr[:, ind]) + else: + raise NotImplementedError('Unknown type ' + col_type) + self.vectorised_array.append(col) + self.data_size = ndarr.shape[0] self.node_count = 0 self.tree_store = None From 451f55acc9c4ad75346835c0e163d3da02078367 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Wed, 12 Oct 2016 14:26:23 +0100 Subject: [PATCH 03/19] Rework command line to accept ordinal variables --- CHAID/__init__.py | 2 +- CHAID/__main__.py | 41 +++++++++++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/CHAID/__init__.py b/CHAID/__init__.py index 49027d3..2d076a0 100644 --- a/CHAID/__init__.py +++ b/CHAID/__init__.py @@ -1,6 +1,6 @@ from .split import Split from .tree import Tree from .node import Node -from .column import NominalColumn +from .column import NominalColumn, OrdinalColumn __version__ = "2.1.0" diff --git a/CHAID/__main__.py b/CHAID/__main__.py index 2bf8ac4..f56f960 100644 --- a/CHAID/__main__.py +++ b/CHAID/__main__.py @@ -6,6 +6,7 @@ import savReaderWriter as spss from .tree import Tree import pandas as pd +import numpy as np def main(): """Entry point when module is run from command line""" @@ -14,7 +15,14 @@ def main(): ' csv/sav file.') parser.add_argument('file') parser.add_argument('dependent_variable', nargs=1) - parser.add_argument('independent_variables', nargs='+') + + var = parser.add_argument_group('Independent Variable Specification') + var.add_argument('nominal_variables', nargs='*', help='The names of ' + 'independent variables to use that have no intrinsic ' + 'order to them') + var.add_argument('--ordinal-variables', type=str, nargs='*', + help='The names of independent variables to use that ' + 'have an intrinsic order but a finite amount of states') parser.add_argument('--weights', type=str, help='Name of weight column') parser.add_argument('--max-depth', type=int, help='Max depth of generated ' @@ -29,16 +37,20 @@ def main(): ' input with the node id of the node that that ' 'respondent has been placed into') group.add_argument('--predict', action='store_true', help='Add column to ' - 'input with the value of the dependent varaible that ' + 'input with the value of the dependent variable that ' 'the majority of respondents in that node selected') nspace = parser.parse_args() - data = pd.read_csv(nspace.file) - - # raw_data = spss.SavReader(nspace.file, returnHeader = True, rawMode=True) - # raw_data_list = list(raw_data) - # data = pd.DataFrame(raw_data_list) - # data = data.rename(columns=data.loc[0]).iloc[1:] + if nspace.file[-4:] == '.csv': + data = pd.read_csv(nspace.file) + elif nspace.file[-4:] == '.sav': + raw_data = spss.SavReader(nspace.file, returnHeader=True) + raw_data_list = list(raw_data) + data = pd.DataFrame(raw_data_list) + data = data.rename(columns=data.loc[0]).iloc[1:] + else: + print('Uknown file type') + exit(1) config = {} if nspace.max_depth: @@ -51,8 +63,17 @@ def main(): config['min_child_node_size'] = nspace.min_child_node_size if nspace.weights: config['weight'] = nspace.weights - tree = Tree.from_pandas_df(data, nspace.independent_variables, - nspace.dependent_variable[0], **config) + + ordinal = nspace.ordinal_variables or [] + nominal = nspace.nominal_variables or [] + independent_variables = nominal + ordinal + types = ['nominal'] * len(nominal) + ['ordinal'] * len(ordinal) + if len(independent_variables) == 0: + print('Need to provide at least one independent variable') + exit(1) + tree = Tree.from_pandas_df(data, independent_variables, + nspace.dependent_variable[0], + variable_types=types, **config) if nspace.classify: predictions = pd.Series(tree.node_predictions()) From 092b00a36c7bdf5de01185460cc3e48b15d2afb7 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Wed, 12 Oct 2016 16:58:11 +0100 Subject: [PATCH 04/19] Remove an enumerate --- CHAID/column.py | 4 ++-- CHAID/tree.py | 37 +++++++++++++++++++------------------ 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/CHAID/column.py b/CHAID/column.py index 1672250..b23b963 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -129,7 +129,7 @@ def groups(self): return list(self._groupings.values()) def possible_groupings(self): - return enumerate(combinations(self._groupings.keys(), 2)) + return combinations(self._groupings.keys(), 2) def group(self, x, y): self._groupings[x] += self._groupings[y] @@ -165,7 +165,7 @@ def possible_groupings(self): range_labels = sorted(list(self._groupings.keys())) canditates = zip(range_labels[0:], range_labels[1:]) adjacent = lambda x, y: (max(self._groupings[x]) + 1) == min(self._groupings[y]) - return enumerate((x, y) for x,y in canditates if adjacent(x, y)) + return ((x, y) for x,y in canditates if adjacent(x, y)) def group(self, x, y): self._groupings[x] += self._groupings[y] diff --git a/CHAID/tree.py b/CHAID/tree.py index 87d4b9d..c1b4845 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -4,7 +4,7 @@ from treelib import Tree as TreeLibTree from .node import Node from .split import Split -from .column import NominalColumn +from .column import NominalColumn, OrdinalColumn def chisquare(n_ij, weighted): @@ -95,7 +95,7 @@ def build_tree(self): @staticmethod def from_pandas_df(df, i_variables, d_variable, alpha_merge=0.05, max_depth=2, min_parent_node_size=30, min_child_node_size=0, split_threshold=0, - weight=None): + weight=None, variable_types=None): """ Helper method to pre-process a pandas data frame in order to run CHAID analysis @@ -119,12 +119,12 @@ def from_pandas_df(df, i_variables, d_variable, alpha_merge=0.05, max_depth=2, contain (default 30) """ ind_df = df[i_variables] - ind_df = ind_df ind_values = ind_df.values dep_values = df[d_variable].values weights = df[weight] if weight is not None else None return Tree(ind_values, dep_values, alpha_merge, max_depth, min_parent_node_size, - min_child_node_size, list(ind_df.columns.values), split_threshold, weights) + min_child_node_size, list(ind_df.columns.values), split_threshold, weights, + variable_types) def node(self, rows, ind, dep, wt=None, depth=0, parent=None, parent_decisions=None): """ internal method to create a node in the tree """ @@ -175,28 +175,30 @@ def generate_best_split(self, ind, dep, wt=None): split = Split(None, None, None, None, 0) relative_split_threshold = 1 - self.split_threshold all_dep = set(dep.arr) - for i, index in enumerate(ind): - index = index.deep_copy() - unique = set(index.arr) + for i, dep_var in enumerate(ind): + dep_var = dep_var.deep_copy() + unique = set(dep_var.arr) freq = {} for col in unique: - counts = np.unique(dep.arr[index.arr == col], return_counts=True) + counts = np.unique(dep.arr[dep_var.arr == col], return_counts=True) if wt is None: freq[col] = cl.defaultdict(int) freq[col].update(np.transpose(counts)) else: freq[col] = cl.defaultdict(int) for dep_v in set(dep.arr): - freq[col][dep_v] = wt[(index.arr == col) * (dep.arr == dep_v)].sum() + freq[col][dep_v] = wt[(dep_var.arr == col) * (dep.arr == dep_v)].sum() - while next(index.possible_groupings(), None) is not None: - groupings = list(index.possible_groupings()) + while next(dep_var.possible_groupings(), None) is not None: + groupings = list(dep_var.possible_groupings()) size = len(groupings) sub_data_columns = [('combinations', object), ('p', float), ('chi', float)] - sub_data = np.array([(None, 0, 1)]*size, dtype=sub_data_columns, order='F') - for j, comb in groupings: + choice = None + highest_p_join = None + split_chi = None + for comb in groupings: col1_freq = freq[comb[0]] col2_freq = freq[comb[1]] @@ -209,9 +211,8 @@ def generate_best_split(self, ind, dep, wt=None): chi, p_split, dof = chisquare(n_ij, wt is not None) - sub_data[j] = (comb, p_split, chi) - - choice, highest_p_join, chi_join = max(sub_data, key=lambda x: (x[1], x[2])) + if choice is None or p_split > highest_p_join or (p_split == highest_p_join and chi > split_chi): + choice, highest_p_join, split_chi = comb, p_split, chi sufficient_split = highest_p_join < self.alpha_merge and all( sum(node_v.values()) >= self.min_child_node_size for node_v in freq.values() @@ -224,7 +225,7 @@ def generate_best_split(self, ind, dep, wt=None): dof = (n_ij.shape[0] - 1) * (n_ij.shape[1] - 1) chi, p_split, dof = chisquare(n_ij, wt is not None) - temp_split = Split(i, index.groups(), chi, p_split, dof) + temp_split = Split(i, dep_var.groups(), chi, p_split, dof) better_split = not split.valid() or p_split < split.p or (p_split == split.p and chi > split.chi) @@ -243,7 +244,7 @@ def generate_best_split(self, ind, dep, wt=None): break - index.group(choice[0], choice[1]) + dep_var.group(choice[0], choice[1]) for val, count in freq[choice[1]].items(): freq[choice[0]][val] += count From 84069f3bc63a082b769f18112d3f7dc0d4057f85 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Wed, 12 Oct 2016 17:56:30 +0100 Subject: [PATCH 05/19] Fix performance issue --- CHAID/tree.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/CHAID/tree.py b/CHAID/tree.py index c1b4845..4f10e5a 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -72,7 +72,7 @@ def __init__(self, ndarr, arr, alpha_merge=0.05, max_depth=2, min_parent_node_si variable_types = variable_types or ['nominal'] * ndarr.shape[1] for ind, col_type in enumerate(variable_types): if col_type == 'ordinal': - col = OrdinalColumn(ndarr[:, ind]) + col = OrdinalColumn(ndarr[:, ind].astype(np.dtype(int))) elif col_type == 'nominal': col = NominalColumn(ndarr[:, ind]) else: @@ -174,20 +174,22 @@ def generate_best_split(self, ind, dep, wt=None): """ internal method to generate the best split """ split = Split(None, None, None, None, 0) relative_split_threshold = 1 - self.split_threshold - all_dep = set(dep.arr) + all_dep = np.unique(dep.arr) for i, dep_var in enumerate(ind): dep_var = dep_var.deep_copy() - unique = set(dep_var.arr) + unique = np.unique(dep_var.arr) freq = {} - for col in unique: - counts = np.unique(dep.arr[dep_var.arr == col], return_counts=True) - if wt is None: + if wt is None: + for col in unique: + counts = np.unique(np.compress(dep_var.arr == col, dep.arr), return_counts=True) freq[col] = cl.defaultdict(int) freq[col].update(np.transpose(counts)) - else: + else: + for col in unique: + counts = np.unique(np.compress(dep_var.arr == col, dep.arr), return_counts=True) freq[col] = cl.defaultdict(int) - for dep_v in set(dep.arr): + for dep_v in all_dep: freq[col][dep_v] = wt[(dep_var.arr == col) * (dep.arr == dep_v)].sum() while next(dep_var.possible_groupings(), None) is not None: From c102b173a10650d60277a1beac13cb0d75e175e5 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Thu, 13 Oct 2016 10:07:07 +0100 Subject: [PATCH 06/19] Fix misguided refactor --- CHAID/tree.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/CHAID/tree.py b/CHAID/tree.py index 4f10e5a..b1c3ce7 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -175,32 +175,29 @@ def generate_best_split(self, ind, dep, wt=None): split = Split(None, None, None, None, 0) relative_split_threshold = 1 - self.split_threshold all_dep = np.unique(dep.arr) - for i, dep_var in enumerate(ind): - dep_var = dep_var.deep_copy() - unique = np.unique(dep_var.arr) + for i, ind_var in enumerate(ind): + ind_var = ind_var.deep_copy() + unique = np.unique(ind_var.arr) freq = {} if wt is None: for col in unique: - counts = np.unique(np.compress(dep_var.arr == col, dep.arr), return_counts=True) + counts = np.unique(np.compress(ind_var.arr == col, dep.arr), return_counts=True) freq[col] = cl.defaultdict(int) freq[col].update(np.transpose(counts)) else: for col in unique: - counts = np.unique(np.compress(dep_var.arr == col, dep.arr), return_counts=True) + counts = np.unique(np.compress(ind_var.arr == col, dep.arr), return_counts=True) freq[col] = cl.defaultdict(int) for dep_v in all_dep: - freq[col][dep_v] = wt[(dep_var.arr == col) * (dep.arr == dep_v)].sum() - - while next(dep_var.possible_groupings(), None) is not None: - groupings = list(dep_var.possible_groupings()) - size = len(groupings) + freq[col][dep_v] = wt[(ind_var.arr == col) * (dep.arr == dep_v)].sum() + while next(ind_var.possible_groupings(), None) is not None: sub_data_columns = [('combinations', object), ('p', float), ('chi', float)] choice = None highest_p_join = None split_chi = None - for comb in groupings: + for comb in ind_var.possible_groupings(): col1_freq = freq[comb[0]] col2_freq = freq[comb[1]] @@ -227,7 +224,7 @@ def generate_best_split(self, ind, dep, wt=None): dof = (n_ij.shape[0] - 1) * (n_ij.shape[1] - 1) chi, p_split, dof = chisquare(n_ij, wt is not None) - temp_split = Split(i, dep_var.groups(), chi, p_split, dof) + temp_split = Split(i, ind_var.groups(), chi, p_split, dof) better_split = not split.valid() or p_split < split.p or (p_split == split.p and chi > split.chi) @@ -246,7 +243,7 @@ def generate_best_split(self, ind, dep, wt=None): break - dep_var.group(choice[0], choice[1]) + ind_var.group(choice[0], choice[1]) for val, count in freq[choice[1]].items(): freq[choice[0]][val] += count From 08371e648d198ed61abb8f1365adcc980e9461cd Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Thu, 13 Oct 2016 11:32:06 +0100 Subject: [PATCH 07/19] Add tests for ordinal column --- CHAID/column.py | 8 +-- ...{test_vector.py => test_nominal_column.py} | 0 tests/test_ordinal_column.py | 53 +++++++++++++++++++ 3 files changed, 57 insertions(+), 4 deletions(-) rename tests/{test_vector.py => test_nominal_column.py} (100%) create mode 100644 tests/test_ordinal_column.py diff --git a/CHAID/column.py b/CHAID/column.py index b23b963..f4ad1d3 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -138,7 +138,7 @@ def group(self, x, y): class OrdinalColumn(Column): def __init__(self, arr=None, metadata=None, - missing_id='', substitute=True): + missing_id=''): super(self.__class__, self).__init__(arr, metadata, missing_id) for x in np.unique(self._arr): @@ -149,14 +149,14 @@ def deep_copy(self): Returns a deep copy. """ return OrdinalColumn(self._arr, metadata=self.metadata, - missing_id=self._missing_id, substitute=False) + missing_id=self._missing_id) def __getitem__(self, key): - return OrdinalColumn(self._arr[key], metadata=self.metadata, substitute=False) + return OrdinalColumn(self._arr[key], metadata=self.metadata) def __setitem__(self, key, value): self._arr[key] = value - return OrdinalColumn(np.array(self._arr), metadata=self.metadata, substitute=False) + return OrdinalColumn(np.array(self._arr), metadata=self.metadata) def groups(self): return list(self._groupings.values()) diff --git a/tests/test_vector.py b/tests/test_nominal_column.py similarity index 100% rename from tests/test_vector.py rename to tests/test_nominal_column.py diff --git a/tests/test_ordinal_column.py b/tests/test_ordinal_column.py new file mode 100644 index 0000000..e20a406 --- /dev/null +++ b/tests/test_ordinal_column.py @@ -0,0 +1,53 @@ +""" +Testing module for the class NominalColumn +""" +from unittest import TestCase +import numpy as np +from numpy import nan +from setup_tests import list_ordered_equal,list_unordered_equal, CHAID + +NAN = float('nan') + +class TestOrdinalDeepCopy(TestCase): + """ Test fixture class for deep copy method """ + def setUp(self): + """ Setup for copy tests""" + # Use string so numpy array dtype is object and may store references + arr = np.array([1, 2, 3, 3, 3, 3]) + self.orig = CHAID.OrdinalColumn(arr) + self.copy = self.orig.deep_copy() + + def test_deep_copy_does_copy(self): + """ Ensure a copy actually happens when deep_copy is called """ + assert id(self.orig) != id(self.copy), 'The vector objects must be different' + assert list_ordered_equal(self.copy, self.orig), 'Vector contents must be the same' + + def test_changing_copy(self): + """ Test that altering the copy doesn't alter the original """ + self.copy.arr[0] = 55.0 + assert not list_ordered_equal(self.copy, self.orig), 'Altering one vector should not affect the other' + + def test_metadata(self): + """ Ensure metadata is copied correctly or deep_copy """ + assert self.copy.metadata == self.orig.metadata, 'Copied metadata should be equivilent' + +class TestOrdinalGrouping(TestCase): + """ Test fixture class for deep copy method """ + def setUp(self): + """ Setup for grouping tests """ + arr = np.array([1, 2, 3, 3, 3, 3, 4, 5, 10]) + self.col = CHAID.OrdinalColumn(arr) + + def test_possible_groups(self): + """ Ensure a groupings are only adjacent numbers """ + groupings = list(self.col.possible_groupings()) + expected_groupings = [(1, 2), (2, 3), (3, 4), (4, 5)] + assert list_unordered_equal(expected_groupings, groupings) + + def test_groups_after_grouping(self): + """ Ensure a copy actually happens when deep_copy is called """ + self.col.group(3, 4) + self.col.group(3, 2) + groupings = list(self.col.possible_groupings()) + expected_groupings = [(1, 3), (3, 5)] + assert list_unordered_equal(expected_groupings, groupings) From 4f83f2a5a664ec98cf480ce755244baae77c7990 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Thu, 13 Oct 2016 13:13:38 +0100 Subject: [PATCH 08/19] Add Nan test case --- CHAID/column.py | 28 ++++++++++++++++++++-------- tests/test_ordinal_column.py | 26 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/CHAID/column.py b/CHAID/column.py index f4ad1d3..332172a 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -123,7 +123,7 @@ def __getitem__(self, key): def __setitem__(self, key, value): self._arr[key] = value - return NominalColumn(np.array(self._arr), metadata=self.metadata, substitute=False) + return self def groups(self): return list(self._groupings.values()) @@ -142,7 +142,8 @@ def __init__(self, arr=None, metadata=None, super(self.__class__, self).__init__(arr, metadata, missing_id) for x in np.unique(self._arr): - self._groupings[x] = [x] + self._groupings[x] = [x, x] + self._possible_groups = None def deep_copy(self): """ @@ -156,18 +157,29 @@ def __getitem__(self, key): def __setitem__(self, key, value): self._arr[key] = value - return OrdinalColumn(np.array(self._arr), metadata=self.metadata) + return self def groups(self): return list(self._groupings.values()) def possible_groupings(self): - range_labels = sorted(list(self._groupings.keys())) - canditates = zip(range_labels[0:], range_labels[1:]) - adjacent = lambda x, y: (max(self._groupings[x]) + 1) == min(self._groupings[y]) - return ((x, y) for x,y in canditates if adjacent(x, y)) + if self._possible_groups is None: + ranges = sorted(self._groupings.items()) + candidates = zip(ranges[0:], ranges[1:]) + self._possible_groups = [ + (k1, k2) for (k1, minmax1), (k2, minmax2) in candidates + if (minmax1[1] + 1) == minmax2[0] + ] + return self._possible_groups.__iter__() def group(self, x, y): - self._groupings[x] += self._groupings[y] + self._possible_groups = None + x_max = self._groupings[x][1] + y_min = self._groupings[y][0] + if y_min > x_max: + self._groupings[x][1] = self._groupings[y][1] + else: + self._groupings[x][0] = y_min + del self._groupings[y] self._arr[self._arr == y] = x diff --git a/tests/test_ordinal_column.py b/tests/test_ordinal_column.py index e20a406..8f9744c 100644 --- a/tests/test_ordinal_column.py +++ b/tests/test_ordinal_column.py @@ -51,3 +51,29 @@ def test_groups_after_grouping(self): groupings = list(self.col.possible_groupings()) expected_groupings = [(1, 3), (3, 5)] assert list_unordered_equal(expected_groupings, groupings) + +class TestOrdinalGroupingWithNAN(TestCase): + """ Test fixture class for deep copy method """ + def setUp(self): + """ Setup for grouping tests """ + arr = np.array([1, 2, NAN, 3, 3, NAN, 3, 3, NAN, 4, 5, 10]) + self.col = CHAID.OrdinalColumn(arr) + + def test_possible_groups(self): + """ Ensure a groupings are only adjacent numbers """ + groupings = list(self.col.possible_groupings()) + expected_groupings = [ + (1, 2), (2, 3), (3, 4), (4, 5), (1, NAN), (2, NAN), (3, NAN), + (4, NAN), (5, NAN), (10, NAN) + ] + assert list_unordered_equal(expected_groupings, groupings) + + def test_groups_after_grouping(self): + """ Ensure a copy actually happens when deep_copy is called """ + self.col.group(3, 4) + self.col.group(3, 2) + groupings = list(self.col.possible_groupings()) + expected_groupings = [ + (1, 3), (3, 5), (1, NAN), (3, NAN), (5, NAN), (10, NAN) + ] + assert list_unordered_equal(expected_groupings, groupings) From fdff416e067f0ae2c040729dd0dfd77bccec06db Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Thu, 13 Oct 2016 17:46:18 +0100 Subject: [PATCH 09/19] Allow variable types to be passed in as a Dict. --- CHAID/tree.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHAID/tree.py b/CHAID/tree.py index b1c3ce7..0f49ddb 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -122,6 +122,8 @@ def from_pandas_df(df, i_variables, d_variable, alpha_merge=0.05, max_depth=2, ind_values = ind_df.values dep_values = df[d_variable].values weights = df[weight] if weight is not None else None + if isinstance(variable_types, dict): + variable_types = [variable_types[col] for col in i_variables] return Tree(ind_values, dep_values, alpha_merge, max_depth, min_parent_node_size, min_child_node_size, list(ind_df.columns.values), split_threshold, weights, variable_types) From 0f807a56a8ca0b35765510301feb5f8465711fa1 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Mon, 17 Oct 2016 10:56:22 +0100 Subject: [PATCH 10/19] Fix incorrect groups reported bug and add specs for said bug --- CHAID/column.py | 26 ++++++++++++----- CHAID/tree.py | 2 +- tests/setup_tests.py | 6 ++-- tests/test_ordinal_column.py | 56 ++++++++++++++++++++++++------------ 4 files changed, 60 insertions(+), 30 deletions(-) diff --git a/CHAID/column.py b/CHAID/column.py index 332172a..a48d71a 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -138,11 +138,17 @@ def group(self, x, y): class OrdinalColumn(Column): def __init__(self, arr=None, metadata=None, - missing_id=''): + missing_id='', orig_type=None): super(self.__class__, self).__init__(arr, metadata, missing_id) + if orig_type is None: + self.orig_type = self._arr.dtype.type + self._arr = self._arr.astype(np.dtype(int)) + else: + self.orig_type = orig_type + for x in np.unique(self._arr): - self._groupings[x] = [x, x] + self._groupings[x] = [x, x + 1] self._possible_groups = None def deep_copy(self): @@ -150,30 +156,36 @@ def deep_copy(self): Returns a deep copy. """ return OrdinalColumn(self._arr, metadata=self.metadata, - missing_id=self._missing_id) + missing_id=self._missing_id, orig_type=self.orig_type) def __getitem__(self, key): - return OrdinalColumn(self._arr[key], metadata=self.metadata) + return OrdinalColumn(self._arr[key], metadata=self.metadata, + missing_id=self._missing_id, orig_type=self.orig_type) def __setitem__(self, key, value): self._arr[key] = value return self def groups(self): - return list(self._groupings.values()) + groups = [ + [self.orig_type(x) for x in range(min, max)] for min, max in self._groupings.values() + ] + return groups def possible_groupings(self): if self._possible_groups is None: ranges = sorted(self._groupings.items()) candidates = zip(ranges[0:], ranges[1:]) self._possible_groups = [ - (k1, k2) for (k1, minmax1), (k2, minmax2) in candidates - if (minmax1[1] + 1) == minmax2[0] + (self.orig_type(k1), self.orig_type(k2)) for (k1, minmax1), (k2, minmax2) in candidates + if (minmax1[1]) == minmax2[0] ] return self._possible_groups.__iter__() def group(self, x, y): self._possible_groups = None + x = int(x) + y = int(y) x_max = self._groupings[x][1] y_min = self._groupings[y][0] if y_min > x_max: diff --git a/CHAID/tree.py b/CHAID/tree.py index 0f49ddb..5458dd1 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -72,7 +72,7 @@ def __init__(self, ndarr, arr, alpha_merge=0.05, max_depth=2, min_parent_node_si variable_types = variable_types or ['nominal'] * ndarr.shape[1] for ind, col_type in enumerate(variable_types): if col_type == 'ordinal': - col = OrdinalColumn(ndarr[:, ind].astype(np.dtype(int))) + col = OrdinalColumn(ndarr[:, ind]) elif col_type == 'nominal': col = NominalColumn(ndarr[:, ind]) else: diff --git a/tests/setup_tests.py b/tests/setup_tests.py index 5f59a7f..9012fa8 100644 --- a/tests/setup_tests.py +++ b/tests/setup_tests.py @@ -5,6 +5,7 @@ from collections import Iterable import os import sys +from math import isnan ROOT_FOLDER = os.path.realpath(os.path.dirname(os.path.realpath(__file__)) + '/../') @@ -12,7 +13,6 @@ import CHAID - def list_unordered_equal(list_a, list_b): """ Compares the unordered contents of two nd lists""" if isinstance(list_a, Iterable) and isinstance(list_b, Iterable): @@ -20,7 +20,7 @@ def list_unordered_equal(list_a, list_b): list_b = sorted(list_b) return all(list_unordered_equal(*item) for item in zip(list_a, list_b)) else: - return list_a == list_b + return list_a == list_b or (isnan(list_a) and isnan(list_b)) def list_ordered_equal(list_a, list_b): @@ -28,4 +28,4 @@ def list_ordered_equal(list_a, list_b): if isinstance(list_a, Iterable) and isinstance(list_b, Iterable): return all(list_ordered_equal(*item) for item in zip(list_a, list_b)) else: - return list_a == list_b + return list_a == list_b or (isnan(list_a) and isnan(list_b)) diff --git a/tests/test_ordinal_column.py b/tests/test_ordinal_column.py index 8f9744c..571ccb4 100644 --- a/tests/test_ordinal_column.py +++ b/tests/test_ordinal_column.py @@ -35,45 +35,63 @@ class TestOrdinalGrouping(TestCase): """ Test fixture class for deep copy method """ def setUp(self): """ Setup for grouping tests """ - arr = np.array([1, 2, 3, 3, 3, 3, 4, 5, 10]) + arr = np.array([1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 10.0]) self.col = CHAID.OrdinalColumn(arr) def test_possible_groups(self): - """ Ensure a groupings are only adjacent numbers """ + """ Ensure possible groups are only adjacent numbers """ groupings = list(self.col.possible_groupings()) - expected_groupings = [(1, 2), (2, 3), (3, 4), (4, 5)] - assert list_unordered_equal(expected_groupings, groupings) + possible_groupings = [(1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, 5.0)] + assert list_unordered_equal(possible_groupings, groupings), 'Without NaNs, with groups are identified, possible grouping are incorrectly identified.' + + groups = list(self.col.groups()) + actual_groups = [[1.0], [2.0], [3.0], [4.0], [5.0], [10.0]] + assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported' def test_groups_after_grouping(self): """ Ensure a copy actually happens when deep_copy is called """ - self.col.group(3, 4) - self.col.group(3, 2) + self.col.group(3.0, 4.0) + self.col.group(3.0, 2.0) + groupings = list(self.col.possible_groupings()) - expected_groupings = [(1, 3), (3, 5)] - assert list_unordered_equal(expected_groupings, groupings) + possible_groupings = [(1.0, 3.0), (3.0, 5.0)] + assert list_unordered_equal(possible_groupings, groupings), 'Without NaNs, with groups are identified, possible grouping are incorrectly identified.' + + groups = list(self.col.groups()) + actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0]] + assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported' class TestOrdinalGroupingWithNAN(TestCase): """ Test fixture class for deep copy method """ def setUp(self): """ Setup for grouping tests """ - arr = np.array([1, 2, NAN, 3, 3, NAN, 3, 3, NAN, 4, 5, 10]) + arr = np.array([1.0, 2.0, NAN, 3.0, 3.0, NAN, 3.0, 3.0, NAN, 4.0, 5.0, 10.0]) self.col = CHAID.OrdinalColumn(arr) def test_possible_groups(self): - """ Ensure a groupings are only adjacent numbers """ + """ Ensure possible groups are only adjacent numbers """ groupings = list(self.col.possible_groupings()) - expected_groupings = [ - (1, 2), (2, 3), (3, 4), (4, 5), (1, NAN), (2, NAN), (3, NAN), - (4, NAN), (5, NAN), (10, NAN) + possible_groupings = [ + (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, 5.0), (1.0, NAN), (2.0, NAN), (3.0, NAN), + (4.0, NAN), (5.0, NAN), (10.0, NAN) ] - assert list_unordered_equal(expected_groupings, groupings) + assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, before any groups are identified, possible grouping are incorrectly calculated.' + + groups = list(self.col.groups()) + actual_groups = [[1.0], [2.0], [3.0], [4.0], [5.0], [NAN], [10.0]] + assert list_unordered_equal(actual_groups, groups), 'With NaNs, before any groups are identified, actual groupings are incorrectly reported' def test_groups_after_grouping(self): """ Ensure a copy actually happens when deep_copy is called """ - self.col.group(3, 4) - self.col.group(3, 2) + self.col.group(3.0, 4.0) + self.col.group(3.0, 2.0) groupings = list(self.col.possible_groupings()) - expected_groupings = [ - (1, 3), (3, 5), (1, NAN), (3, NAN), (5, NAN), (10, NAN) + + possible_groupings = [ + (1.0, 3.0), (3.0, 5.0), (1.0, NAN), (3.0, NAN), (5.0, NAN), (10.0, NAN) ] - assert list_unordered_equal(expected_groupings, groupings) + assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups are identified, possible grouping incorrectly identified.' + + groups = list(self.col.groups()) + actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0], [NAN]] + assert list_unordered_equal(actual_groups, groups), 'With NaNs, before any groups are identified, actual groupings are incorrectly reported' From 1b0ccb7ea941d39be9618e4329daa5e5fd0f296e Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Mon, 17 Oct 2016 16:15:49 +0100 Subject: [PATCH 11/19] Get NaN behavour correct. --- CHAID/column.py | 46 +++++++++++++++++++++++------------- CHAID/tree.py | 3 +-- tests/setup_tests.py | 7 ++++-- tests/test_ordinal_column.py | 17 ++++++------- 4 files changed, 45 insertions(+), 28 deletions(-) diff --git a/CHAID/column.py b/CHAID/column.py index a48d71a..2e51d76 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -1,5 +1,5 @@ import numpy as np -import math +from math import isnan from itertools import combinations from .mapping_dict import MappingDict @@ -106,7 +106,7 @@ def substitute_values(self, vect): """ unique = np.unique(vect) unique = [ - x for x in unique if not isinstance(x, float) or not math.isnan(x) + x for x in unique if not isinstance(x, float) or not isnan(x) ] arr = np.zeros(len(vect), dtype=int) - 1 @@ -138,48 +138,62 @@ def group(self, x, y): class OrdinalColumn(Column): def __init__(self, arr=None, metadata=None, - missing_id='', orig_type=None): + missing_id='', substitute=True): super(self.__class__, self).__init__(arr, metadata, missing_id) - if orig_type is None: - self.orig_type = self._arr.dtype.type - self._arr = self._arr.astype(np.dtype(int)) - else: - self.orig_type = orig_type + if substitute: + self._arr, self.orig_type = self.substitute_values(self._arr) + for x in np.unique(self._arr): self._groupings[x] = [x, x + 1] + self._nan = np.array([np.nan]).astype(int)[0] self._possible_groups = None + def substitute_values(self, vect): + if not np.issubdtype(vect.dtype, np.integer): + uniq = np.unique(vect) + uniq_as_int = uniq.astype(float).astype(int) + nan = self._missing_id + self._metadata = { + new : old if not isnan(float(old)) else nan for old, new in zip(uniq, uniq_as_int) + } + self._arr = self._arr.astype(float) + return self._arr.astype(int), self._arr.dtype.type + def deep_copy(self): """ Returns a deep copy. """ return OrdinalColumn(self._arr, metadata=self.metadata, - missing_id=self._missing_id, orig_type=self.orig_type) + missing_id=self._missing_id, substitute=True) def __getitem__(self, key): return OrdinalColumn(self._arr[key], metadata=self.metadata, - missing_id=self._missing_id, orig_type=self.orig_type) + missing_id=self._missing_id, substitute=True) def __setitem__(self, key, value): self._arr[key] = value return self def groups(self): - groups = [ - [self.orig_type(x) for x in range(min, max)] for min, max in self._groupings.values() + vals = self._groupings.values() + return [ + [x for x in range(minmax[0], minmax[1])] for minmax in vals ] - return groups def possible_groupings(self): if self._possible_groups is None: ranges = sorted(self._groupings.items()) candidates = zip(ranges[0:], ranges[1:]) self._possible_groups = [ - (self.orig_type(k1), self.orig_type(k2)) for (k1, minmax1), (k2, minmax2) in candidates - if (minmax1[1]) == minmax2[0] + (k1, k2) for (k1, minmax1), (k2, minmax2) in candidates + if minmax1[1] == minmax2[0] ] + if self._metadata.has_key(self._nan): + self._possible_groups += list( + (key, self._nan) for key in self._groupings.keys() if key != self._nan + ) return self._possible_groups.__iter__() def group(self, x, y): @@ -188,7 +202,7 @@ def group(self, x, y): y = int(y) x_max = self._groupings[x][1] y_min = self._groupings[y][0] - if y_min > x_max: + if y_min >= x_max: self._groupings[x][1] = self._groupings[y][1] else: self._groupings[x][0] = y_min diff --git a/CHAID/tree.py b/CHAID/tree.py index 5458dd1..eaff76a 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -195,7 +195,6 @@ def generate_best_split(self, ind, dep, wt=None): freq[col][dep_v] = wt[(ind_var.arr == col) * (dep.arr == dep_v)].sum() while next(ind_var.possible_groupings(), None) is not None: - sub_data_columns = [('combinations', object), ('p', float), ('chi', float)] choice = None highest_p_join = None split_chi = None @@ -230,7 +229,7 @@ def generate_best_split(self, ind, dep, wt=None): better_split = not split.valid() or p_split < split.p or (p_split == split.p and chi > split.chi) - if not split.valid() or better_split: + if better_split: split, temp_split = temp_split, split chi_threshold = relative_split_threshold * split.chi diff --git a/tests/setup_tests.py b/tests/setup_tests.py index 9012fa8..fd9c8ae 100644 --- a/tests/setup_tests.py +++ b/tests/setup_tests.py @@ -13,9 +13,12 @@ import CHAID +def islist(a): + return isinstance(a, Iterable) and not isinstance(a, str) + def list_unordered_equal(list_a, list_b): """ Compares the unordered contents of two nd lists""" - if isinstance(list_a, Iterable) and isinstance(list_b, Iterable): + if islist(list_a) and islist(list_b): list_a = sorted(list_a) list_b = sorted(list_b) return all(list_unordered_equal(*item) for item in zip(list_a, list_b)) @@ -25,7 +28,7 @@ def list_unordered_equal(list_a, list_b): def list_ordered_equal(list_a, list_b): """ Compares the unordered contents of two nd lists""" - if isinstance(list_a, Iterable) and isinstance(list_b, Iterable): + if islist(list_a) and islist(list_b): return all(list_ordered_equal(*item) for item in zip(list_a, list_b)) else: return list_a == list_b or (isnan(list_a) and isnan(list_b)) diff --git a/tests/test_ordinal_column.py b/tests/test_ordinal_column.py index 571ccb4..b4977b2 100644 --- a/tests/test_ordinal_column.py +++ b/tests/test_ordinal_column.py @@ -70,28 +70,29 @@ def setUp(self): def test_possible_groups(self): """ Ensure possible groups are only adjacent numbers """ - groupings = list(self.col.possible_groupings()) + groupings = [ (self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] possible_groupings = [ - (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, 5.0), (1.0, NAN), (2.0, NAN), (3.0, NAN), - (4.0, NAN), (5.0, NAN), (10.0, NAN) + (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, 5.0), (1.0, ''), (2.0, ''), (3.0, ''), + (4.0, ''), (5.0, ''), (10.0, '') ] assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, before any groups are identified, possible grouping are incorrectly calculated.' groups = list(self.col.groups()) - actual_groups = [[1.0], [2.0], [3.0], [4.0], [5.0], [NAN], [10.0]] + groups = [ [self.col.metadata[i] for i in group] for group in self.col.groups()] + actual_groups = [[1.0], [2.0], [3.0], [4.0], [5.0], [''], [10.0]] assert list_unordered_equal(actual_groups, groups), 'With NaNs, before any groups are identified, actual groupings are incorrectly reported' def test_groups_after_grouping(self): """ Ensure a copy actually happens when deep_copy is called """ self.col.group(3.0, 4.0) self.col.group(3.0, 2.0) - groupings = list(self.col.possible_groupings()) + groupings = [ (self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] possible_groupings = [ - (1.0, 3.0), (3.0, 5.0), (1.0, NAN), (3.0, NAN), (5.0, NAN), (10.0, NAN) + (1.0, 3.0), (3.0, 5.0), (1.0, ''), (3.0, ''), (5.0, ''), (10.0, '') ] assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups are identified, possible grouping incorrectly identified.' - groups = list(self.col.groups()) - actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0], [NAN]] + groups = [ [self.col.metadata[i] for i in group] for group in self.col.groups()] + actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0], ['']] assert list_unordered_equal(actual_groups, groups), 'With NaNs, before any groups are identified, actual groupings are incorrectly reported' From 4ca04a96c93659fec656692d158962e741169427 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Mon, 17 Oct 2016 17:21:38 +0100 Subject: [PATCH 12/19] Fix empty node issue --- CHAID/column.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHAID/column.py b/CHAID/column.py index 2e51d76..215f995 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -144,7 +144,6 @@ def __init__(self, arr=None, metadata=None, if substitute: self._arr, self.orig_type = self.substitute_values(self._arr) - for x in np.unique(self._arr): self._groupings[x] = [x, x + 1] self._nan = np.array([np.nan]).astype(int)[0] @@ -190,7 +189,7 @@ def possible_groupings(self): (k1, k2) for (k1, minmax1), (k2, minmax2) in candidates if minmax1[1] == minmax2[0] ] - if self._metadata.has_key(self._nan): + if self._nan in self._arr: self._possible_groups += list( (key, self._nan) for key in self._groupings.keys() if key != self._nan ) From 0d52cbbf57dd59c54d1447c997a6335db89ed4ac Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Mon, 17 Oct 2016 17:45:37 +0100 Subject: [PATCH 13/19] Deal with NaNs correctly when returning groups containing NaN --- CHAID/column.py | 22 +++++++++++++--------- tests/test_ordinal_column.py | 20 ++++++++++++++++++-- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/CHAID/column.py b/CHAID/column.py index 215f995..484c66c 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -145,7 +145,7 @@ def __init__(self, arr=None, metadata=None, self._arr, self.orig_type = self.substitute_values(self._arr) for x in np.unique(self._arr): - self._groupings[x] = [x, x + 1] + self._groupings[x] = [x, x + 1, False] self._nan = np.array([np.nan]).astype(int)[0] self._possible_groups = None @@ -178,7 +178,8 @@ def __setitem__(self, key, value): def groups(self): vals = self._groupings.values() return [ - [x for x in range(minmax[0], minmax[1])] for minmax in vals + [x for x in range(minmax[0], minmax[1])] + ([self._nan] if minmax[2] else []) + for minmax in vals ] def possible_groupings(self): @@ -197,14 +198,17 @@ def possible_groupings(self): def group(self, x, y): self._possible_groups = None - x = int(x) - y = int(y) - x_max = self._groupings[x][1] - y_min = self._groupings[y][0] - if y_min >= x_max: - self._groupings[x][1] = self._groupings[y][1] + if y != self._nan: + x = int(x) + y = int(y) + x_max = self._groupings[x][1] + y_min = self._groupings[y][0] + if y_min >= x_max: + self._groupings[x][1] = self._groupings[y][1] + else: + self._groupings[x][0] = y_min else: - self._groupings[x][0] = y_min + self._groupings[x][2] = True del self._groupings[y] self._arr[self._arr == y] = x diff --git a/tests/test_ordinal_column.py b/tests/test_ordinal_column.py index b4977b2..182264e 100644 --- a/tests/test_ordinal_column.py +++ b/tests/test_ordinal_column.py @@ -83,7 +83,7 @@ def test_possible_groups(self): assert list_unordered_equal(actual_groups, groups), 'With NaNs, before any groups are identified, actual groupings are incorrectly reported' def test_groups_after_grouping(self): - """ Ensure a copy actually happens when deep_copy is called """ + """ Ensure possible groups are only adjacent numbers after identifing some groups """ self.col.group(3.0, 4.0) self.col.group(3.0, 2.0) @@ -95,4 +95,20 @@ def test_groups_after_grouping(self): groups = [ [self.col.metadata[i] for i in group] for group in self.col.groups()] actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0], ['']] - assert list_unordered_equal(actual_groups, groups), 'With NaNs, before any groups are identified, actual groupings are incorrectly reported' + assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups identified, actual groupings are incorrectly reported' + + def test_groups_after_grouping_with_nan(self): + """ Ensure possible groups are only adjacent numbers after identifing some groups containing nans""" + self.col.group(3.0, 4.0) + self.col.group(3.0, 2.0) + self.col.group(3.0, self.col._nan) + + groupings = [ (self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] + possible_groupings = [ + (1.0, 3.0), (3.0, 5.0) + ] + assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups containing nan identified, possible grouping incorrectly identified.' + + groups = [ [self.col.metadata[i] for i in group] for group in self.col.groups()] + actual_groups = [[1.0], [2.0, 3.0, 4.0, ''], [5.0], [10.0]] + assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups containing nan identified, actual groupings are incorrectly reported' From 1d8984a07801db656e7833375bb91ab359110b3f Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Tue, 18 Oct 2016 14:02:11 +0100 Subject: [PATCH 14/19] Fix for disappering group items --- CHAID/column.py | 29 ++++++++++++++------- CHAID/tree.py | 3 ++- tests/setup_tests.py | 12 ++++++--- tests/test_ordinal_column.py | 49 +++++++++++++++++++++++++++++++++++- 4 files changed, 78 insertions(+), 15 deletions(-) diff --git a/CHAID/column.py b/CHAID/column.py index 484c66c..d36e21e 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -24,7 +24,6 @@ def __init__(self, arr=None, metadata=None, self._metadata = dict(metadata or {}) self._arr = np.array(arr) self._missing_id = missing_id - self._groupings = MappingDict() def __iter__(self): return iter(self._arr) @@ -80,6 +79,8 @@ def __init__(self, arr=None, metadata=None, super(self.__class__, self).__init__(arr, metadata, missing_id) if substitute: self.substitute_values(arr) + + self._groupings = MappingDict() for x in np.unique(self._arr): self._groupings[x] = [x] @@ -138,24 +139,31 @@ def group(self, x, y): class OrdinalColumn(Column): def __init__(self, arr=None, metadata=None, - missing_id='', substitute=True): + missing_id='', groupings=None, substitute=True): super(self.__class__, self).__init__(arr, metadata, missing_id) if substitute: self._arr, self.orig_type = self.substitute_values(self._arr) - for x in np.unique(self._arr): - self._groupings[x] = [x, x + 1, False] + self._groupings = {} + if groupings is None: + for x in np.unique(self._arr): + self._groupings[x] = [x, x + 1, False] + else: + for x in np.unique(self._arr): + self._groupings[x] = list(groupings[x]) self._nan = np.array([np.nan]).astype(int)[0] self._possible_groups = None def substitute_values(self, vect): if not np.issubdtype(vect.dtype, np.integer): uniq = np.unique(vect) - uniq_as_int = uniq.astype(float).astype(int) + uniq_floats = uniq.astype(float) + uniq_ints = uniq_floats.astype(int) nan = self._missing_id self._metadata = { - new : old if not isnan(float(old)) else nan for old, new in zip(uniq, uniq_as_int) + new : nan if isnan(as_float) else old + for old, as_float, new in zip(uniq, uniq_floats, uniq_ints) } self._arr = self._arr.astype(float) return self._arr.astype(int), self._arr.dtype.type @@ -165,11 +173,13 @@ def deep_copy(self): Returns a deep copy. """ return OrdinalColumn(self._arr, metadata=self.metadata, - missing_id=self._missing_id, substitute=True) + missing_id=self._missing_id, substitute=True, + groupings=self._groupings) def __getitem__(self, key): return OrdinalColumn(self._arr[key], metadata=self.metadata, - missing_id=self._missing_id, substitute=True) + missing_id=self._missing_id, substitute=True, + groupings=self._groupings) def __setitem__(self, key, value): self._arr[key] = value @@ -190,7 +200,7 @@ def possible_groupings(self): (k1, k2) for (k1, minmax1), (k2, minmax2) in candidates if minmax1[1] == minmax2[0] ] - if self._nan in self._arr: + if self._nan in self._arr: self._possible_groups += list( (key, self._nan) for key in self._groupings.keys() if key != self._nan ) @@ -207,6 +217,7 @@ def group(self, x, y): self._groupings[x][1] = self._groupings[y][1] else: self._groupings[x][0] = y_min + self._groupings[x][2] = self._groupings[x][2] or self._groupings[y][2] else: self._groupings[x][2] = True diff --git a/CHAID/tree.py b/CHAID/tree.py index eaff76a..5458dd1 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -195,6 +195,7 @@ def generate_best_split(self, ind, dep, wt=None): freq[col][dep_v] = wt[(ind_var.arr == col) * (dep.arr == dep_v)].sum() while next(ind_var.possible_groupings(), None) is not None: + sub_data_columns = [('combinations', object), ('p', float), ('chi', float)] choice = None highest_p_join = None split_chi = None @@ -229,7 +230,7 @@ def generate_best_split(self, ind, dep, wt=None): better_split = not split.valid() or p_split < split.p or (p_split == split.p and chi > split.chi) - if better_split: + if not split.valid() or better_split: split, temp_split = temp_split, split chi_threshold = relative_split_threshold * split.chi diff --git a/tests/setup_tests.py b/tests/setup_tests.py index fd9c8ae..a748cbb 100644 --- a/tests/setup_tests.py +++ b/tests/setup_tests.py @@ -19,9 +19,11 @@ def islist(a): def list_unordered_equal(list_a, list_b): """ Compares the unordered contents of two nd lists""" if islist(list_a) and islist(list_b): - list_a = sorted(list_a) - list_b = sorted(list_b) - return all(list_unordered_equal(*item) for item in zip(list_a, list_b)) + list_a = [item_a for item_a in list_a] + list_b = [item_b for item_b in list_b] + list_a.sort() + list_b.sort() + return len(list_a) == len(list_b) and all(list_unordered_equal(*item) for item in zip(list_a, list_b)) else: return list_a == list_b or (isnan(list_a) and isnan(list_b)) @@ -29,6 +31,8 @@ def list_unordered_equal(list_a, list_b): def list_ordered_equal(list_a, list_b): """ Compares the unordered contents of two nd lists""" if islist(list_a) and islist(list_b): - return all(list_ordered_equal(*item) for item in zip(list_a, list_b)) + list_a = [item_a for item_a in list_a] + list_b = [item_b for item_b in list_b] + return len(list_a) == len(list_b) and all(list_ordered_equal(*item) for item in zip(list_a, list_b)) else: return list_a == list_b or (isnan(list_a) and isnan(list_b)) diff --git a/tests/test_ordinal_column.py b/tests/test_ordinal_column.py index 182264e..1dba19b 100644 --- a/tests/test_ordinal_column.py +++ b/tests/test_ordinal_column.py @@ -61,6 +61,20 @@ def test_groups_after_grouping(self): actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0]] assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported' + def test_groups_after_copy(self): + """ Ensure a copy actually happens when deep_copy is called """ + self.col.group(3.0, 4.0) + self.col.group(3.0, 2.0) + col = self.col.deep_copy() + + groupings = list(col.possible_groupings()) + possible_groupings = [(1.0, 3.0), (3.0, 5.0)] + assert list_unordered_equal(possible_groupings, groupings), 'Without NaNs, with groups are identified, possible grouping are incorrectly identified.' + + groups = list(col.groups()) + actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0]] + assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported' + class TestOrdinalGroupingWithNAN(TestCase): """ Test fixture class for deep copy method """ def setUp(self): @@ -99,9 +113,9 @@ def test_groups_after_grouping(self): def test_groups_after_grouping_with_nan(self): """ Ensure possible groups are only adjacent numbers after identifing some groups containing nans""" + self.col.group(4.0, self.col._nan) self.col.group(3.0, 4.0) self.col.group(3.0, 2.0) - self.col.group(3.0, self.col._nan) groupings = [ (self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] possible_groupings = [ @@ -112,3 +126,36 @@ def test_groups_after_grouping_with_nan(self): groups = [ [self.col.metadata[i] for i in group] for group in self.col.groups()] actual_groups = [[1.0], [2.0, 3.0, 4.0, ''], [5.0], [10.0]] assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups containing nan identified, actual groupings are incorrectly reported' + + def test_groups_after_copy(self): + """ Ensure possible groups are only adjacent numbers after identifing some groups """ + self.col.group(3.0, 4.0) + self.col.group(3.0, 2.0) + col = self.col.deep_copy() + + groupings = [ (col.metadata[x], col.metadata[y]) for x, y in col.possible_groupings()] + possible_groupings = [ + (1.0, 3.0), (3.0, 5.0), (1.0, ''), (3.0, ''), (5.0, ''), (10.0, '') + ] + assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups are identified, possible grouping incorrectly identified.' + + groups = [ [col.metadata[i] for i in group] for group in col.groups()] + actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0], ['']] + assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups identified, actual groupings are incorrectly reported' + + def test_groups_after_copy_with_nan(self): + """ Ensure possible groups are only adjacent numbers after identifing some groups containing nans""" + self.col.group(3.0, 4.0) + self.col.group(3.0, self.col._nan) + self.col.group(3.0, 2.0) + col = self.col.deep_copy() + + groupings = [ (col.metadata[x], col.metadata[y]) for x, y in col.possible_groupings()] + possible_groupings = [ + (1.0, 3.0), (3.0, 5.0) + ] + assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups containing nan identified, possible grouping incorrectly identified.' + + groups = [ [col.metadata[i] for i in group] for group in col.groups()] + actual_groups = [[1.0], [2.0, 3.0, 4.0, ''], [5.0], [10.0]] + assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups containing nan identified, actual groupings are incorrectly reported' From 4ac893a3c6037af2bba6aa813ce79fa8d1a2de55 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Tue, 18 Oct 2016 14:25:19 +0100 Subject: [PATCH 15/19] Fix broken spec (that was wrong in the first place) --- tests/test_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tree.py b/tests/test_tree.py index f2be8b8..d7ef27b 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -68,7 +68,7 @@ def test_best_split_with_combination(): assert list_ordered_equal(ndarr, orig_ndarr), 'Calling chaid should have no side affects for original numpy arrays' assert list_ordered_equal(arr, orig_arr), 'Calling chaid should have no side affects for original numpy arrays' assert split.column_id == 0, 'Identifies correct column to split on' - assert list_unordered_equal(split.split_map, [[1], [2, 3]]), 'Correctly identifies catagories' + assert list_unordered_equal(split.split_map, [[1], [2], [3]]), 'Correctly identifies catagories' assert list_unordered_equal(split.surrogates, []), 'No surrogates should be generated' assert split.p < 0.015 From 15c88fd5a21863408c0251c3e92b90f0b8cdca50 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Tue, 18 Oct 2016 15:50:12 +0100 Subject: [PATCH 16/19] PR Comments --- CHAID/__main__.py | 2 +- CHAID/column.py | 15 ++++++++------- CHAID/tree.py | 15 ++++++++++----- tests/test_ordinal_column.py | 35 ++++++++++++++++++----------------- 4 files changed, 37 insertions(+), 30 deletions(-) diff --git a/CHAID/__main__.py b/CHAID/__main__.py index f56f960..4e17bea 100644 --- a/CHAID/__main__.py +++ b/CHAID/__main__.py @@ -49,7 +49,7 @@ def main(): data = pd.DataFrame(raw_data_list) data = data.rename(columns=data.loc[0]).iloc[1:] else: - print('Uknown file type') + print('Unknown file type') exit(1) config = {} diff --git a/CHAID/column.py b/CHAID/column.py index d36e21e..ac8ccbc 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -37,9 +37,6 @@ def __setitem__(self, key, value): def possible_groupings(self): raise NotImplementedError - def combine(self, x, y): - raise NotImplementedError - def deep_copy(self): """ Returns a deep copy. @@ -137,13 +134,17 @@ def group(self, x, y): del self._groupings[y] self._arr[self._arr == y] = x + class OrdinalColumn(Column): + """ + A column containing integer values that have an order + """ def __init__(self, arr=None, metadata=None, missing_id='', groupings=None, substitute=True): super(self.__class__, self).__init__(arr, metadata, missing_id) if substitute: - self._arr, self.orig_type = self.substitute_values(self._arr) + self._arr, self.orig_type = self.substitute_values(self._arr) self._groupings = {} if groupings is None: @@ -162,7 +163,7 @@ def substitute_values(self, vect): uniq_ints = uniq_floats.astype(int) nan = self._missing_id self._metadata = { - new : nan if isnan(as_float) else old + new: nan if isnan(as_float) else old for old, as_float, new in zip(uniq, uniq_floats, uniq_ints) } self._arr = self._arr.astype(float) @@ -201,9 +202,9 @@ def possible_groupings(self): if minmax1[1] == minmax2[0] ] if self._nan in self._arr: - self._possible_groups += list( + self._possible_groups += [ (key, self._nan) for key in self._groupings.keys() if key != self._nan - ) + ] return self._possible_groups.__iter__() def group(self, x, y): diff --git a/CHAID/tree.py b/CHAID/tree.py index 5458dd1..c7506a4 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -59,6 +59,10 @@ class Tree(object): contain (default 30) split_titles : array-like array of names for the independent variables in the data + variable_types : array-like or dict + array of variable types, or dict of column names to variable types. + Supported variable types are the strings 'nominal' or 'ordinal' in + lower case """ def __init__(self, ndarr, arr, alpha_merge=0.05, max_depth=2, min_parent_node_size=30, min_child_node_size=0, split_titles=None, split_threshold=0, weights=None, @@ -117,6 +121,10 @@ def from_pandas_df(df, i_variables, d_variable, alpha_merge=0.05, max_depth=2, min_parent_node_size : float the threshold value of the number of respondents that the node must contain (default 30) + variable_types : array-like or dict + array of variable types, or dict of column names to variable types. + Supported variable types are the strings 'nominal' or 'ordinal' in + lower case """ ind_df = df[i_variables] ind_values = ind_df.values @@ -195,10 +203,7 @@ def generate_best_split(self, ind, dep, wt=None): freq[col][dep_v] = wt[(ind_var.arr == col) * (dep.arr == dep_v)].sum() while next(ind_var.possible_groupings(), None) is not None: - sub_data_columns = [('combinations', object), ('p', float), ('chi', float)] - choice = None - highest_p_join = None - split_chi = None + choice, highest_p_join, split_chi = None, None, None for comb in ind_var.possible_groupings(): col1_freq = freq[comb[0]] col2_freq = freq[comb[1]] @@ -230,7 +235,7 @@ def generate_best_split(self, ind, dep, wt=None): better_split = not split.valid() or p_split < split.p or (p_split == split.p and chi > split.chi) - if not split.valid() or better_split: + if better_split: split, temp_split = temp_split, split chi_threshold = relative_split_threshold * split.chi diff --git a/tests/test_ordinal_column.py b/tests/test_ordinal_column.py index 1dba19b..295f719 100644 --- a/tests/test_ordinal_column.py +++ b/tests/test_ordinal_column.py @@ -1,18 +1,16 @@ """ -Testing module for the class NominalColumn +Testing module for the class OrdinalColumn """ from unittest import TestCase import numpy as np from numpy import nan -from setup_tests import list_ordered_equal,list_unordered_equal, CHAID +from setup_tests import list_ordered_equal, list_unordered_equal, CHAID -NAN = float('nan') class TestOrdinalDeepCopy(TestCase): """ Test fixture class for deep copy method """ def setUp(self): """ Setup for copy tests""" - # Use string so numpy array dtype is object and may store references arr = np.array([1, 2, 3, 3, 3, 3]) self.orig = CHAID.OrdinalColumn(arr) self.copy = self.orig.deep_copy() @@ -31,6 +29,7 @@ def test_metadata(self): """ Ensure metadata is copied correctly or deep_copy """ assert self.copy.metadata == self.orig.metadata, 'Copied metadata should be equivilent' + class TestOrdinalGrouping(TestCase): """ Test fixture class for deep copy method """ def setUp(self): @@ -75,16 +74,18 @@ def test_groups_after_copy(self): actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0]] assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported' -class TestOrdinalGroupingWithNAN(TestCase): + +class TestOrdinalGroupingWithnan(TestCase): """ Test fixture class for deep copy method """ def setUp(self): """ Setup for grouping tests """ - arr = np.array([1.0, 2.0, NAN, 3.0, 3.0, NAN, 3.0, 3.0, NAN, 4.0, 5.0, 10.0]) + arr = np.array([1.0, 2.0, nan, 3.0, 3.0, nan, 3.0, 3.0, nan, 4.0, 5.0, 10.0]) self.col = CHAID.OrdinalColumn(arr) def test_possible_groups(self): """ Ensure possible groups are only adjacent numbers """ - groupings = [ (self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] + metadata = self.col.metadata + groupings = [(metadata[x], metadata[y]) for x, y in self.col.possible_groupings()] possible_groupings = [ (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, 5.0), (1.0, ''), (2.0, ''), (3.0, ''), (4.0, ''), (5.0, ''), (10.0, '') @@ -92,7 +93,7 @@ def test_possible_groups(self): assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, before any groups are identified, possible grouping are incorrectly calculated.' groups = list(self.col.groups()) - groups = [ [self.col.metadata[i] for i in group] for group in self.col.groups()] + groups = [[self.col.metadata[i] for i in group] for group in self.col.groups()] actual_groups = [[1.0], [2.0], [3.0], [4.0], [5.0], [''], [10.0]] assert list_unordered_equal(actual_groups, groups), 'With NaNs, before any groups are identified, actual groupings are incorrectly reported' @@ -101,29 +102,29 @@ def test_groups_after_grouping(self): self.col.group(3.0, 4.0) self.col.group(3.0, 2.0) - groupings = [ (self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] + groupings = [(self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] possible_groupings = [ (1.0, 3.0), (3.0, 5.0), (1.0, ''), (3.0, ''), (5.0, ''), (10.0, '') ] assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups are identified, possible grouping incorrectly identified.' - groups = [ [self.col.metadata[i] for i in group] for group in self.col.groups()] + groups = [[self.col.metadata[i] for i in group] for group in self.col.groups()] actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0], ['']] assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups identified, actual groupings are incorrectly reported' - def test_groups_after_grouping_with_nan(self): + def test_groups_grouping_with_nan(self): """ Ensure possible groups are only adjacent numbers after identifing some groups containing nans""" self.col.group(4.0, self.col._nan) self.col.group(3.0, 4.0) self.col.group(3.0, 2.0) - groupings = [ (self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] + groupings = [(self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] possible_groupings = [ (1.0, 3.0), (3.0, 5.0) ] assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups containing nan identified, possible grouping incorrectly identified.' - groups = [ [self.col.metadata[i] for i in group] for group in self.col.groups()] + groups = [[self.col.metadata[i] for i in group] for group in self.col.groups()] actual_groups = [[1.0], [2.0, 3.0, 4.0, ''], [5.0], [10.0]] assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups containing nan identified, actual groupings are incorrectly reported' @@ -133,13 +134,13 @@ def test_groups_after_copy(self): self.col.group(3.0, 2.0) col = self.col.deep_copy() - groupings = [ (col.metadata[x], col.metadata[y]) for x, y in col.possible_groupings()] + groupings = [(col.metadata[x], col.metadata[y]) for x, y in col.possible_groupings()] possible_groupings = [ (1.0, 3.0), (3.0, 5.0), (1.0, ''), (3.0, ''), (5.0, ''), (10.0, '') ] assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups are identified, possible grouping incorrectly identified.' - groups = [ [col.metadata[i] for i in group] for group in col.groups()] + groups = [[col.metadata[i] for i in group] for group in col.groups()] actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0], ['']] assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups identified, actual groupings are incorrectly reported' @@ -150,12 +151,12 @@ def test_groups_after_copy_with_nan(self): self.col.group(3.0, 2.0) col = self.col.deep_copy() - groupings = [ (col.metadata[x], col.metadata[y]) for x, y in col.possible_groupings()] + groupings = [(col.metadata[x], col.metadata[y]) for x, y in col.possible_groupings()] possible_groupings = [ (1.0, 3.0), (3.0, 5.0) ] assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups containing nan identified, possible grouping incorrectly identified.' - groups = [ [col.metadata[i] for i in group] for group in col.groups()] + groups = [[col.metadata[i] for i in group] for group in col.groups()] actual_groups = [[1.0], [2.0, 3.0, 4.0, ''], [5.0], [10.0]] assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups containing nan identified, actual groupings are incorrectly reported' From 5902d0ed522d4fccf521be1b5d435c1932bc3156 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Tue, 18 Oct 2016 16:19:04 +0100 Subject: [PATCH 17/19] PEP8 fixes --- CHAID/__main__.py | 1 + CHAID/column.py | 2 ++ CHAID/node.py | 1 + CHAID/tree.py | 4 ++-- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHAID/__main__.py b/CHAID/__main__.py index 4e17bea..b071ea9 100644 --- a/CHAID/__main__.py +++ b/CHAID/__main__.py @@ -8,6 +8,7 @@ import pandas as pd import numpy as np + def main(): """Entry point when module is run from command line""" diff --git a/CHAID/column.py b/CHAID/column.py index ac8ccbc..a571c57 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -3,6 +3,7 @@ from itertools import combinations from .mapping_dict import MappingDict + class Column(object): """ A numpy array with metadata @@ -66,6 +67,7 @@ def metadata(self): """ return self._metadata + class NominalColumn(Column): """ A column containing numerical values that are unrelated to diff --git a/CHAID/node.py b/CHAID/node.py index ccbaf4e..3655124 100644 --- a/CHAID/node.py +++ b/CHAID/node.py @@ -1,6 +1,7 @@ from .split import Split import numpy as np + class Node(object): """ A node in the CHAID tree diff --git a/CHAID/tree.py b/CHAID/tree.py index c7506a4..66b1a92 100644 --- a/CHAID/tree.py +++ b/CHAID/tree.py @@ -16,7 +16,7 @@ def chisquare(n_ij, weighted): m_ij = n_ij / n_ij nan_mask = np.isnan(m_ij) - m_ij[nan_mask] = 0.000001 # otherwise it breaks the chi-squared test + m_ij[nan_mask] = 0.000001 # otherwise it breaks the chi-squared test w_ij = m_ij n_ij_col_sum = n_ij.sum(axis=1) @@ -36,6 +36,7 @@ def chisquare(n_ij, weighted): return (chi, p_val, dof) + class Tree(object): """ Create a CHAID object which contains all the information of the tree @@ -260,7 +261,6 @@ def generate_best_split(self, ind, dep, wt=None): split.sub_split_values(ind[split.column_id].metadata) return split - def to_tree(self): """ returns a TreeLib tree """ tree = TreeLibTree() From c03df768937ab50a3497e78d6ab7907846675127 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Tue, 18 Oct 2016 16:21:41 +0100 Subject: [PATCH 18/19] Add dtype=oject tests --- tests/test_ordinal_column.py | 90 ++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/tests/test_ordinal_column.py b/tests/test_ordinal_column.py index 295f719..b5ce2ec 100644 --- a/tests/test_ordinal_column.py +++ b/tests/test_ordinal_column.py @@ -75,6 +75,96 @@ def test_groups_after_copy(self): assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported' +class TestOrdinalWithObjects(TestCase): + """ Test fixture class for deep copy method """ + def setUp(self): + """ Setup for grouping tests """ + arr = np.array( + [1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 10.0, None], + dtype=object + ) + self.col = CHAID.OrdinalColumn(arr) + + def test_possible_groups(self): + """ Ensure possible groups are only adjacent numbers """ + metadata = self.col.metadata + groupings = [(metadata[x], metadata[y]) for x, y in self.col.possible_groupings()] + possible_groupings = [ + (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, 5.0), (1.0, ''), (2.0, ''), (3.0, ''), + (4.0, ''), (5.0, ''), (10.0, '') + ] + assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, before any groups are identified, possible grouping are incorrectly calculated.' + + groups = list(self.col.groups()) + groups = [[self.col.metadata[i] for i in group] for group in self.col.groups()] + actual_groups = [[1.0], [2.0], [3.0], [4.0], [5.0], [''], [10.0]] + assert list_unordered_equal(actual_groups, groups), 'With NaNs, before any groups are identified, actual groupings are incorrectly reported' + + def test_groups_after_grouping(self): + """ Ensure possible groups are only adjacent numbers after identifing some groups """ + self.col.group(3.0, 4.0) + self.col.group(3.0, 2.0) + + groupings = [(self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] + possible_groupings = [ + (1.0, 3.0), (3.0, 5.0), (1.0, ''), (3.0, ''), (5.0, ''), (10.0, '') + ] + assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups are identified, possible grouping incorrectly identified.' + + groups = [[self.col.metadata[i] for i in group] for group in self.col.groups()] + actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0], ['']] + assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups identified, actual groupings are incorrectly reported' + + def test_groups_grouping_with_nan(self): + """ Ensure possible groups are only adjacent numbers after identifing some groups containing nans""" + self.col.group(4.0, self.col._nan) + self.col.group(3.0, 4.0) + self.col.group(3.0, 2.0) + + groupings = [(self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()] + possible_groupings = [ + (1.0, 3.0), (3.0, 5.0) + ] + assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups containing nan identified, possible grouping incorrectly identified.' + + groups = [[self.col.metadata[i] for i in group] for group in self.col.groups()] + actual_groups = [[1.0], [2.0, 3.0, 4.0, ''], [5.0], [10.0]] + assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups containing nan identified, actual groupings are incorrectly reported' + + def test_groups_after_copy(self): + """ Ensure possible groups are only adjacent numbers after identifing some groups """ + self.col.group(3.0, 4.0) + self.col.group(3.0, 2.0) + col = self.col.deep_copy() + + groupings = [(col.metadata[x], col.metadata[y]) for x, y in col.possible_groupings()] + possible_groupings = [ + (1.0, 3.0), (3.0, 5.0), (1.0, ''), (3.0, ''), (5.0, ''), (10.0, '') + ] + assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups are identified, possible grouping incorrectly identified.' + + groups = [[col.metadata[i] for i in group] for group in col.groups()] + actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0], ['']] + assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups identified, actual groupings are incorrectly reported' + + def test_groups_after_copy_with_nan(self): + """ Ensure possible groups are only adjacent numbers after identifing some groups containing nans""" + self.col.group(3.0, 4.0) + self.col.group(3.0, self.col._nan) + self.col.group(3.0, 2.0) + col = self.col.deep_copy() + + groupings = [(col.metadata[x], col.metadata[y]) for x, y in col.possible_groupings()] + possible_groupings = [ + (1.0, 3.0), (3.0, 5.0) + ] + assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups containing nan identified, possible grouping incorrectly identified.' + + groups = [[col.metadata[i] for i in group] for group in col.groups()] + actual_groups = [[1.0], [2.0, 3.0, 4.0, ''], [5.0], [10.0]] + assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups containing nan identified, actual groupings are incorrectly reported' + + class TestOrdinalGroupingWithnan(TestCase): """ Test fixture class for deep copy method """ def setUp(self): From cf4a388f50b1c884c8b9136eb757b4a22ac61358 Mon Sep 17 00:00:00 2001 From: Richard Fitzgerald Date: Tue, 25 Oct 2016 11:11:19 +0100 Subject: [PATCH 19/19] Python 3 fixes --- CHAID/column.py | 4 ++-- tests/setup_tests.py | 14 ++++++++++---- tests/test_ordinal_column.py | 20 ++++++++++---------- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/CHAID/column.py b/CHAID/column.py index a571c57..4e6680e 100644 --- a/CHAID/column.py +++ b/CHAID/column.py @@ -160,8 +160,8 @@ def __init__(self, arr=None, metadata=None, def substitute_values(self, vect): if not np.issubdtype(vect.dtype, np.integer): - uniq = np.unique(vect) - uniq_floats = uniq.astype(float) + uniq = set(vect) + uniq_floats = np.array(list(uniq), dtype=float) uniq_ints = uniq_floats.astype(int) nan = self._missing_id self._metadata = { diff --git a/tests/setup_tests.py b/tests/setup_tests.py index a748cbb..de72e47 100644 --- a/tests/setup_tests.py +++ b/tests/setup_tests.py @@ -9,23 +9,29 @@ ROOT_FOLDER = os.path.realpath(os.path.dirname(os.path.realpath(__file__)) + '/../') -sys.path.append(ROOT_FOLDER) +sys.path = [ROOT_FOLDER] + sys.path import CHAID + def islist(a): return isinstance(a, Iterable) and not isinstance(a, str) + +def str_ndlist(a): + return [str_ndlist(i) for i in a] if islist(a) else str(a) + + def list_unordered_equal(list_a, list_b): """ Compares the unordered contents of two nd lists""" if islist(list_a) and islist(list_b): - list_a = [item_a for item_a in list_a] - list_b = [item_b for item_b in list_b] + list_a = [str_ndlist(item_a) for item_a in list_a] + list_b = [str_ndlist(item_b) for item_b in list_b] list_a.sort() list_b.sort() return len(list_a) == len(list_b) and all(list_unordered_equal(*item) for item in zip(list_a, list_b)) else: - return list_a == list_b or (isnan(list_a) and isnan(list_b)) + return list_a == list_b or (isinstance(float, str) and isnan(list_a) and isnan(list_b)) def list_ordered_equal(list_a, list_b): diff --git a/tests/test_ordinal_column.py b/tests/test_ordinal_column.py index b5ce2ec..51267cb 100644 --- a/tests/test_ordinal_column.py +++ b/tests/test_ordinal_column.py @@ -40,38 +40,38 @@ def setUp(self): def test_possible_groups(self): """ Ensure possible groups are only adjacent numbers """ groupings = list(self.col.possible_groupings()) - possible_groupings = [(1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, 5.0)] + possible_groupings = [(1, 2), (2, 3), (3, 4), (4, 5)] assert list_unordered_equal(possible_groupings, groupings), 'Without NaNs, with groups are identified, possible grouping are incorrectly identified.' groups = list(self.col.groups()) - actual_groups = [[1.0], [2.0], [3.0], [4.0], [5.0], [10.0]] + actual_groups = [[1], [2], [3], [4], [5], [10]] assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported' def test_groups_after_grouping(self): """ Ensure a copy actually happens when deep_copy is called """ - self.col.group(3.0, 4.0) - self.col.group(3.0, 2.0) + self.col.group(3, 4) + self.col.group(3, 2) groupings = list(self.col.possible_groupings()) - possible_groupings = [(1.0, 3.0), (3.0, 5.0)] + possible_groupings = [(1, 3), (3, 5)] assert list_unordered_equal(possible_groupings, groupings), 'Without NaNs, with groups are identified, possible grouping are incorrectly identified.' groups = list(self.col.groups()) - actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0]] + actual_groups = [[1], [2, 3, 4], [5], [10]] assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported' def test_groups_after_copy(self): """ Ensure a copy actually happens when deep_copy is called """ - self.col.group(3.0, 4.0) - self.col.group(3.0, 2.0) + self.col.group(3, 4) + self.col.group(3, 2) col = self.col.deep_copy() groupings = list(col.possible_groupings()) - possible_groupings = [(1.0, 3.0), (3.0, 5.0)] + possible_groupings = [(1, 3), (3, 5)] assert list_unordered_equal(possible_groupings, groupings), 'Without NaNs, with groups are identified, possible grouping are incorrectly identified.' groups = list(col.groups()) - actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0]] + actual_groups = [[1], [2, 3, 4], [5], [10]] assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported'