/
synthesizer.py
261 lines (210 loc) · 9.16 KB
/
synthesizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import numpy as np
import torch
from torch import optim
from torch.nn import functional
from ctgan.conditional import ConditionalGenerator
from ctgan.models import Discriminator, Generator
from ctgan.sampler import Sampler
from ctgan.transformer import DataTransformer
class CTGANSynthesizer(object):
"""Conditional Table GAN Synthesizer.
This is the core class of the CTGAN project, where the different components
are orchestrated together.
For more details about the process, please check the [Modeling Tabular data using
Conditional GAN](https://arxiv.org/abs/1907.00503) paper.
Args:
embedding_dim (int):
Size of the random sample passed to the Generator. Defaults to 128.
gen_dim (tuple or list of ints):
Size of the output samples for each one of the Residuals. A Resiudal Layer
will be created for each one of the values provided. Defaults to (256, 256).
dis_dim (tuple or list of ints):
Size of the output samples for each one of the Discriminator Layers. A Linear Layer
will be created for each one of the values provided. Defaults to (256, 256).
l2scale (float):
Wheight Decay for the Adam Optimizer. Defaults to 1e-6.
batch_size (int):
Number of data samples to process in each step.
"""
def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256),
l2scale=1e-6, batch_size=500):
self.embedding_dim = embedding_dim
self.gen_dim = gen_dim
self.dis_dim = dis_dim
self.l2scale = l2scale
self.batch_size = batch_size
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def _apply_activate(self, data):
data_t = []
st = 0
for item in self.transformer.output_info:
if item[1] == 'tanh':
ed = st + item[0]
data_t.append(torch.tanh(data[:, st:ed]))
st = ed
elif item[1] == 'softmax':
ed = st + item[0]
data_t.append(functional.gumbel_softmax(data[:, st:ed], tau=0.2))
st = ed
else:
assert 0
return torch.cat(data_t, dim=1)
def _cond_loss(self, data, c, m):
loss = []
st = 0
st_c = 0
skip = False
for item in self.transformer.output_info:
if item[1] == 'tanh':
st += item[0]
skip = True
elif item[1] == 'softmax':
if skip:
skip = False
st += item[0]
continue
ed = st + item[0]
ed_c = st_c + item[0]
tmp = functional.cross_entropy(
data[:, st:ed],
torch.argmax(c[:, st_c:ed_c], dim=1),
reduction='none'
)
loss.append(tmp)
st = ed
st_c = ed_c
else:
assert 0
loss = torch.stack(loss, dim=1)
return (loss * m).sum() / data.size()[0]
def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=True):
"""Fit the CTGAN Synthesizer models to the training data.
Args:
train_data (numpy.ndarray or pandas.DataFrame):
Training Data. It must be a 2-dimensional numpy array or a
pandas.DataFrame.
discrete_columns (list-like):
List of discrete columns to be used to generate the Conditional
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
epochs (int):
Number of training epochs. Defaults to 300.
log_frequency (boolean):
Whether to use log frequency of categorical levels in conditional
sampling. Defaults to ``True``.
"""
self.transformer = DataTransformer()
self.transformer.fit(train_data, discrete_columns)
train_data = self.transformer.transform(train_data)
data_sampler = Sampler(train_data, self.transformer.output_info)
data_dim = self.transformer.output_dimensions
self.cond_generator = ConditionalGenerator(
train_data,
self.transformer.output_info,
log_frequency
)
self.generator = Generator(
self.embedding_dim + self.cond_generator.n_opt,
self.gen_dim,
data_dim
).to(self.device)
discriminator = Discriminator(
data_dim + self.cond_generator.n_opt,
self.dis_dim
).to(self.device)
optimizerG = optim.Adam(
self.generator.parameters(), lr=2e-4, betas=(0.5, 0.9),
weight_decay=self.l2scale
)
optimizerD = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.9))
assert self.batch_size % 2 == 0
mean = torch.zeros(self.batch_size, self.embedding_dim, device=self.device)
std = mean + 1
steps_per_epoch = max(len(train_data) // self.batch_size, 1)
for i in range(epochs):
for id_ in range(steps_per_epoch):
fakez = torch.normal(mean=mean, std=std)
condvec = self.cond_generator.sample(self.batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
real = data_sampler.sample(self.batch_size, col, opt)
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self.device)
m1 = torch.from_numpy(m1).to(self.device)
fakez = torch.cat([fakez, c1], dim=1)
perm = np.arange(self.batch_size)
np.random.shuffle(perm)
real = data_sampler.sample(self.batch_size, col[perm], opt[perm])
c2 = c1[perm]
fake = self.generator(fakez)
fakeact = self._apply_activate(fake)
real = torch.from_numpy(real.astype('float32')).to(self.device)
if c1 is not None:
fake_cat = torch.cat([fakeact, c1], dim=1)
real_cat = torch.cat([real, c2], dim=1)
else:
real_cat = real
fake_cat = fake
y_fake = discriminator(fake_cat)
y_real = discriminator(real_cat)
pen = discriminator.calc_gradient_penalty(real_cat, fake_cat, self.device)
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
optimizerD.zero_grad()
pen.backward(retain_graph=True)
loss_d.backward()
optimizerD.step()
fakez = torch.normal(mean=mean, std=std)
condvec = self.cond_generator.sample(self.batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self.device)
m1 = torch.from_numpy(m1).to(self.device)
fakez = torch.cat([fakez, c1], dim=1)
fake = self.generator(fakez)
fakeact = self._apply_activate(fake)
if c1 is not None:
y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
else:
y_fake = discriminator(fakeact)
if condvec is None:
cross_entropy = 0
else:
cross_entropy = self._cond_loss(fake, c1, m1)
loss_g = -torch.mean(y_fake) + cross_entropy
optimizerG.zero_grad()
loss_g.backward()
optimizerG.step()
print("Epoch %d, Loss G: %.4f, Loss D: %.4f" %
(i + 1, loss_g.detach().cpu(), loss_d.detach().cpu()),
flush=True)
def sample(self, n):
"""Sample data similar to the training data.
Args:
n (int):
Number of rows to sample.
Returns:
numpy.ndarray or pandas.DataFrame
"""
steps = n // self.batch_size + 1
data = []
for i in range(steps):
mean = torch.zeros(self.batch_size, self.embedding_dim)
std = mean + 1
fakez = torch.normal(mean=mean, std=std).to(self.device)
condvec = self.cond_generator.sample_zero(self.batch_size)
if condvec is None:
pass
else:
c1 = condvec
c1 = torch.from_numpy(c1).to(self.device)
fakez = torch.cat([fakez, c1], dim=1)
fake = self.generator(fakez)
fakeact = self._apply_activate(fake)
data.append(fakeact.detach().cpu().numpy())
data = np.concatenate(data, axis=0)
data = data[:n]
return self.transformer.inverse_transform(data, None)