From 8ce3cf78ae5c6563b8026a45e3fa5280863fc72b Mon Sep 17 00:00:00 2001 From: Andres Algaba Date: Thu, 21 Jul 2022 16:45:38 +0200 Subject: [PATCH] proposed fix in #169 --- ctgan/data_sampler.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ctgan/data_sampler.py b/ctgan/data_sampler.py index 8a0a956b..d9c9f4e5 100644 --- a/ctgan/data_sampler.py +++ b/ctgan/data_sampler.py @@ -59,11 +59,13 @@ def is_discrete_column(column_info): st = 0 current_id = 0 + discrete_st = 0 current_cond_st = 0 for column_info in output_info: if is_discrete_column(column_info): span_info = column_info[0] ed = st + span_info.dim + discrete_ed = discrete_st + span_info.dim category_freq = np.sum(data[:, st:ed], axis=0) if log_frequency: category_freq = np.log(category_freq + 1) @@ -71,9 +73,13 @@ def is_discrete_column(column_info): self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob self._discrete_column_cond_st[current_id] = current_cond_st self._discrete_column_n_category[current_id] = span_info.dim + + self._discrete_column_matrix_st[current_id] = discrete_st + current_cond_st += span_info.dim current_id += 1 st = ed + discrete_st = discrete_ed else: st += sum([span_info.dim for span_info in column_info]) @@ -150,7 +156,7 @@ def dim_cond_vec(self): def generate_cond_from_condition_column_info(self, condition_info, batch): """Generate the condition vector.""" vec = np.zeros((batch, self._n_categories), dtype='float32') - id_ = self._discrete_column_cond_st[condition_info['discrete_column_id']] + id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']] id_ += condition_info['value_id'] vec[:, id_] = 1 return vec