Skip to content

Commit

Permalink
undo the proposed fix in sdv-dev#169
Browse files Browse the repository at this point in the history
  • Loading branch information
AndresAlgaba committed Jul 25, 2022
1 parent 8ce3cf7 commit b401998
Showing 1 changed file with 1 addition and 7 deletions.
8 changes: 1 addition & 7 deletions ctgan/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,27 +59,21 @@ def is_discrete_column(column_info):

st = 0
current_id = 0
discrete_st = 0
current_cond_st = 0
for column_info in output_info:
if is_discrete_column(column_info):
span_info = column_info[0]
ed = st + span_info.dim
discrete_ed = discrete_st + span_info.dim
category_freq = np.sum(data[:, st:ed], axis=0)
if log_frequency:
category_freq = np.log(category_freq + 1)
category_prob = category_freq / np.sum(category_freq)
self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob
self._discrete_column_cond_st[current_id] = current_cond_st
self._discrete_column_n_category[current_id] = span_info.dim

self._discrete_column_matrix_st[current_id] = discrete_st

current_cond_st += span_info.dim
current_id += 1
st = ed
discrete_st = discrete_ed
else:
st += sum([span_info.dim for span_info in column_info])

Expand Down Expand Up @@ -156,7 +150,7 @@ def dim_cond_vec(self):
def generate_cond_from_condition_column_info(self, condition_info, batch):
"""Generate the condition vector."""
vec = np.zeros((batch, self._n_categories), dtype='float32')
id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']]
id_ = self._discrete_column_cond_st[condition_info['discrete_column_id']]
id_ += condition_info['value_id']
vec[:, id_] = 1
return vec

0 comments on commit b401998

Please sign in to comment.