Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in low memory random sampling #130

Merged
merged 2 commits into from
Feb 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 34 additions & 21 deletions recordlinkage/algorithms/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from recordlinkage.measures import max_pairs
from recordlinkage.measures import full_index_size


def _map_tril_1d_on_2d(indices, dims):
Expand All @@ -17,20 +17,13 @@ def _map_tril_1d_on_2d(indices, dims):
return np.array([r, c], dtype=np.int64)


def _unique_rows_numpy(a):
"""return unique rows"""
a = np.ascontiguousarray(a)
unique_a = np.unique(a.view([('', a.dtype)] * a.shape[1]))
return unique_a.view(a.dtype).reshape((unique_a.shape[0], a.shape[1]))


def random_pairs_with_replacement(n, shape, random_state=None):
"""make random record pairs"""

if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(random_state)

n_max = max_pairs(shape)
n_max = full_index_size(shape)

if n_max <= 0:
raise ValueError('n_max must be larger than 0')
Expand All @@ -41,13 +34,19 @@ def random_pairs_with_replacement(n, shape, random_state=None):
if len(shape) == 1:
return _map_tril_1d_on_2d(indices, shape[0])
else:
return np.unravel_index(indices, shape)
return np.array(np.unravel_index(indices, shape))


def random_pairs_without_replacement_small_frames(
def random_pairs_without_replacement(
n, shape, random_state=None):
"""Return record pairs for dense sample.

Sample random record pairs without replacement bounded by the
maximum number of record pairs (based on shape). This algorithm is
efficient and fast for relative small samples.
"""

n_max = max_pairs(shape)
n_max = full_index_size(shape)

if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(random_state)
Expand All @@ -63,16 +62,27 @@ def random_pairs_without_replacement_small_frames(
if len(shape) == 1:
return _map_tril_1d_on_2d(sample, shape[0])
else:
return np.unravel_index(sample, shape)
return np.array(np.unravel_index(sample, shape))


def random_pairs_without_replacement_large_frames(
def random_pairs_without_replacement_low_memory(
n, shape, random_state=None):
"""Make a sample of random pairs with replacement"""
"""Make a sample of random pairs with replacement.

Sample random record pairs without replacement bounded by the
maximum number of record pairs (based on shape). This algorithm
consumes low memory and is fast for relatively small samples.
"""

n_max = full_index_size(shape)

n_max = max_pairs(shape)
if not isinstance(random_state, np.random.RandomState):
random_state = np.random.RandomState(random_state)

if not isinstance(n, int) or n <= 0 or n > n_max:
raise ValueError("n must be a integer satisfying 0<n<=%s" % n_max)

sample = np.array([])
sample = np.array([], dtype=np.int64)

# Run as long as the number of pairs is less than the requested number
# of pairs n.
Expand All @@ -81,14 +91,17 @@ def random_pairs_without_replacement_large_frames(
# The number of pairs to sample (sample twice as much record pairs
# because the duplicates are dropped).
n_sample_size = (n - len(sample)) * 2
sample = random_state.randint(n_max, size=n_sample_size)
sample_sub = random_state.randint(
n_max,
size=n_sample_size
)

# concatenate pairs and deduplicate
pairs_non_unique = np.append(sample, sample)
sample = _unique_rows_numpy(pairs_non_unique)
pairs_non_unique = np.append(sample, sample_sub)
sample = np.unique(pairs_non_unique)

# return 2d indices
if len(shape) == 1:
return _map_tril_1d_on_2d(sample[0:n], shape[0])
else:
return np.unravel_index(sample[0:n], shape)
return np.array(np.unravel_index(sample[0:n], shape))
17 changes: 10 additions & 7 deletions recordlinkage/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from recordlinkage import rl_logging as logging
from recordlinkage.algorithms.indexing import (
random_pairs_with_replacement,
random_pairs_without_replacement_large_frames,
random_pairs_without_replacement_small_frames)
random_pairs_without_replacement_low_memory,
random_pairs_without_replacement)
from recordlinkage.base import BaseIndexAlgorithm
from recordlinkage.measures import full_index_size
from recordlinkage.utils import DeprecationHelper, listify, construct_multiindex
Expand Down Expand Up @@ -412,13 +412,16 @@ def _link_index(self, df_a, df_b):
raise ValueError(
"n must be a integer satisfying 0<n<=%s" % n_max)

# the fraction of pairs in the sample
frac = self.n / n_max

# large dataframes
if n_max < 1e6:
pairs = random_pairs_without_replacement_small_frames(
if n_max < 1e6 or frac > 0.5:
pairs = random_pairs_without_replacement(
self.n, shape, self.random_state)
# small dataframes
else:
pairs = random_pairs_without_replacement_large_frames(
pairs = random_pairs_without_replacement_low_memory(
self.n, shape, self.random_state)

levels = [df_a.index.values, df_b.index.values]
Expand Down Expand Up @@ -447,11 +450,11 @@ def _dedup_index(self, df_a):

# large dataframes
if n_max < 1e6:
pairs = random_pairs_without_replacement_small_frames(
pairs = random_pairs_without_replacement(
self.n, shape, self.random_state)
# small dataframes
else:
pairs = random_pairs_without_replacement_large_frames(
pairs = random_pairs_without_replacement_low_memory(
self.n, shape, self.random_state)

levels = [df_a.index.values, df_a.index.values]
Expand Down