### import packages and functions

In [1]:
import os
import time

import numpy as np
import torch
import torch.optim as optim
import torch.utils.data
from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential
from torch.nn import functional as F
# import octgan.synthesizers as synthesizers


### Define networks

In [18]:


class Discriminator(Module):
    def __init__(self, input_dim, dis_dims, pack=1):
        super(Discriminator, self).__init__()
        dim = input_dim * pack
        self.pack = pack
        self.packdim = dim
        seq = []
        for item in list(dis_dims):
            seq += [
                Linear(dim, item),
                LeakyReLU(0.2),
                Dropout(0.5)
            ]
            dim = item
        seq += [Linear(dim, 1)]
        self.seq = Sequential(*seq)

    def forward(self, input):
        assert input.size()[0] % self.pack == 0
        return self.seq(input.view(-1, self.packdim))


class Residual(Module):
    def __init__(self, i, o):
        super(Residual, self).__init__()
        self.fc = Linear(i, o)
        self.bn = BatchNorm1d(o)
        self.relu = ReLU()

    def forward(self, input):
        out = self.fc(input)
        out = self.bn(out)
        out = self.relu(out)
        return torch.cat([out, input], dim=1)


class Generator(Module):
    def __init__(self, embedding_dim, gen_dims, data_dim):
        super(Generator, self).__init__()
        dim = embedding_dim
        seq = []
        for item in list(gen_dims):
            seq += [
                Residual(dim, item)
            ]
            dim += item
        seq.append(Linear(dim, data_dim))
        self.seq = Sequential(*seq)

    def forward(self, input):
        data = self.seq(input)
        return data



### Define useful functions

In [19]:
def apply_activate(data, output_info):
    data_t = []
    st = 0
    for item in output_info:
        if item[1] == 'tanh':
            ed = st + item[0]
            data_t.append(torch.tanh(data[:, st:ed]))
            st = ed
        elif item[1] == 'softmax':
            ed = st + item[0]
            data_t.append(F.gumbel_softmax(data[:, st:ed], tau=0.2))
            st = ed
        else:
            assert 0
    return torch.cat(data_t, dim=1)


def random_choice_prob_index(a, axis=1):
    r = np.expand_dims(np.random.rand(a.shape[1 - axis]), axis=axis)
    return (a.cumsum(axis=axis) > r).argmax(axis=axis)


class Cond(object):
    def __init__(self, data, output_info):
        self.model = []

        st = 0
        skip = False
        max_interval = 0
        counter = 0
        for item in output_info:
            if item[1] == 'tanh':
                st += item[0]
                skip = True
                continue
            elif item[1] == 'softmax':
                if skip:
                    skip = False
                    st += item[0]
                    continue

                ed = st + item[0]
                max_interval = max(max_interval, ed - st)
                counter += 1
                self.model.append(np.argmax(data[:, st:ed], axis=-1))
                st = ed
            else:
                assert 0
        assert st == data.shape[1]

        self.interval = []
        self.n_col = 0
        self.n_opt = 0
        skip = False
        st = 0
        self.p = np.zeros((counter, max_interval))
        for item in output_info:
            if item[1] == 'tanh':
                skip = True
                st += item[0]
                continue
            elif item[1] == 'softmax':
                if skip:
                    st += item[0]
                    skip = False
                    continue
                ed = st + item[0]
                tmp = np.sum(data[:, st:ed], axis=0)
                tmp = np.log(tmp + 1)
                tmp = tmp / np.sum(tmp)
                self.p[self.n_col, :item[0]] = tmp
                self.interval.append((self.n_opt, item[0]))
                self.n_opt += item[0]
                self.n_col += 1
                st = ed
            else:
                assert 0
        self.interval = np.asarray(self.interval)

    def sample(self, batch):
        if self.n_col == 0:
            return None
        batch = batch
        idx = np.random.choice(np.arange(self.n_col), batch)

        vec1 = np.zeros((batch, self.n_opt), dtype='float32')
        mask1 = np.zeros((batch, self.n_col), dtype='float32')
        mask1[np.arange(batch), idx] = 1
        opt1prime = random_choice_prob_index(self.p[idx])
        opt1 = self.interval[idx, 0] + opt1prime
        vec1[np.arange(batch), opt1] = 1

        return vec1, mask1, idx, opt1prime

    def sample_zero(self, batch):
        if self.n_col == 0:
            return None
        vec = np.zeros((batch, self.n_opt), dtype='float32')
        idx = np.random.choice(np.arange(self.n_col), batch)
        for i in range(batch):
            col = idx[i]
            pick = int(np.random.choice(self.model[col]))
            vec[i, pick + self.interval[col, 0]] = 1
        return vec


def cond_loss(data, output_info, c, m):
    loss = []
    st = 0
    st_c = 0
    skip = False
    for item in output_info:
        if item[1] == 'tanh':
            st += item[0]
            skip = True

        elif item[1] == 'softmax':
            if skip:
                skip = False
                st += item[0]
                continue

            ed = st + item[0]
            ed_c = st_c + item[0]
            tmp = F.cross_entropy(
                data[:, st:ed],
                torch.argmax(c[:, st_c:ed_c], dim=1),
                reduction='none'
            )
            loss.append(tmp)
            st = ed
            st_c = ed_c

        else:
            assert 0
    loss = torch.stack(loss, dim=1)

    return (loss * m).sum() / data.size()[0]


class Sampler(object):
    """docstring for Sampler."""

    def __init__(self, data, output_info):
        super(Sampler, self).__init__()
        self.data = data
        self.model = []
        self.n = len(data)

        st = 0
        skip = False
        for item in output_info:
            if item[1] == 'tanh':
                st += item[0]
                skip = True
            elif item[1] == 'softmax':
                if skip:
                    skip = False
                    st += item[0]
                    continue
                ed = st + item[0]
                tmp = []
                for j in range(item[0]):
                    tmp.append(np.nonzero(data[:, st + j])[0])
                self.model.append(tmp)
                st = ed
            else:
                assert 0
        assert st == data.shape[1]

    def sample(self, n, col, opt):
        if col is None:
            idx = np.random.choice(np.arange(self.n), n)
            return self.data[idx]
        idx = []
        for c, o in zip(col, opt):
            idx.append(np.random.choice(self.model[c][o]))
        return self.data[idx]


def calc_gradient_penalty(netD, real_data, fake_data, device='cpu', pac=10, lambda_=10):
    alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
    alpha = alpha.repeat(1, pac, real_data.size(1))
    alpha = alpha.view(-1, real_data.size(1))

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    disc_interpolates = netD(interpolates)

    gradients = torch.autograd.grad(
        outputs=disc_interpolates, inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size(), device=device),
        create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = (
        (gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1) ** 2).mean() * lambda_
    return gradient_penalty


In [24]:
from data import load_dataset
from transformers import BGMTransformer
train, test, meta, categorical_columns, ordinal_columns = load_dataset("adult")

transformer = BGMTransformer()
transformer.fit(train, categorical_columns, ordinal_columns)
train_data = transformer.transform(train)





In [25]:

if meta['problem_type'] == 'binary_classification':
    metric = 'binary_f1'
elif meta['problem_type'] == 'multiclass_classification':
    metric = 'macro_f1'
else:
    metric = 'r2'

data_sampler = Sampler(train_data, transformer.output_info)
data_dim = transformer.output_dim
cond_generator = Cond(train_data, transformer.output_info)

embedding_dim = 64
gen_dim = [256, 256]
dis_dim = [256, 256]
lr = 2e-03
batch_size = 1024

generator = Generator(
    embedding_dim + cond_generator.n_opt,
    gen_dim,
    data_dim).to('cuda')

discriminator = Discriminator(
    data_dim + cond_generator.n_opt,
    dis_dim).to('cuda')

optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.9))

if len(train_data) <= batch_size:
    batch_size = (len(train_data) // 10)*10

assert batch_size % 2 == 0
mean = torch.zeros(batch_size, embedding_dim, device='cuda')
std = mean + 1


In [32]:
epochs = 300
device = 'cuda'

steps_per_epoch = len(train_data) // batch_size
print(len(train_data))
for i in range(epochs):
    print(i)
    for id_ in range(steps_per_epoch):
        fakez = torch.normal(mean=mean, std=std)

        condvec = cond_generator.sample(batch_size)
        if condvec is None:
            c1, m1, col, opt = None, None, None, None
            real = data_sampler.sample(batch_size, col, opt)
        else:
            c1, m1, col, opt = condvec
            c1 = torch.from_numpy(c1).to(device)
            m1 = torch.from_numpy(m1).to(device)
            fakez = torch.cat([fakez, c1], dim=1)

            perm = np.arange(batch_size)
            np.random.shuffle(perm)
            real = data_sampler.sample(batch_size, col[perm], opt[perm])
            c2 = c1[perm]

        fake = generator(fakez)
        fakeact = apply_activate(fake, transformer.output_info)

        real = torch.from_numpy(real.astype('float32')).to(device)

        if c1 is not None:
            fake_cat = torch.cat([fakeact, c1], dim=1).to(device)
            real_cat = torch.cat([real, c2], dim=1).to(device)
        else:
            real_cat = real.to(device)
            fake_cat = fake.to(device)

        y_fake = discriminator(fake_cat)
        y_real = discriminator(real_cat)

        loss_d = -torch.mean(y_real) + torch.mean(y_fake)
        pen = calc_gradient_penalty(discriminator, real_cat, fake_cat, device)

        optimizerD.zero_grad()
        pen.backward(retain_graph=True)
        loss_d.backward()
        optimizerD.step()

        fakez = torch.normal(mean=mean, std=std)
        condvec = cond_generator.sample(batch_size)

        if condvec is None:
            c1, m1, col, opt = None, None, None, None
        else:
            c1, m1, col, opt = condvec
            c1 = torch.from_numpy(c1).to(device)
            m1 = torch.from_numpy(m1).to(device)
            fakez = torch.cat([fakez, c1], dim=1)

        fake = generator(fakez)
        fakeact = apply_activate(fake, transformer.output_info)

        if c1 is not None:
            y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
        else:
            y_fake = discriminator(fakeact)

        if condvec is None:
            cross_entropy = 0
        else:
            cross_entropy = cond_loss(fake, transformer.output_info, c1, m1)

        loss_g = -torch.mean(y_fake) + cross_entropy

        optimizerG.zero_grad()
        loss_g.backward()
        optimizerG.step()


36178
0


RuntimeError: The size of tensor a (1020) must match the size of tensor b (1024) at non-singleton dimension 0

In [31]:
real_cat.shape

torch.Size([1024, 252])

In [None]:


def sample(self, n):
# n=18000
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)

generator.eval()

output_info = transformer.output_info
steps = n // batch_size + 1
# start = time.time()
torch.cuda.synchronize()
starter.record()

data = []
for i in range(steps):
    mean = torch.zeros(batch_size, embedding_dim)
    std = mean + 1
    fakez = torch.normal(mean=mean, std=std).to(device)

    condvec = cond_generator.sample_zero(batch_size)
    if condvec is None:
        pass
    else:
        c1 = condvec
        c1 = torch.from_numpy(c1).to(device)
        fakez = torch.cat([fakez, c1], dim=1)

    fake = generator(fakez)
    fakeact = apply_activate(fake, output_info)
    data.append(fakeact.detach().cpu().numpy())

# end = time.time()
ender.record()
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)

print(f"{curr_time} takes to sampling {n} records.")
# exit()
data = np.concatenate(data, axis=0)
data = data[:n]
data = transformer.inverse_transform(data, None)


# if not os.path.exists(save_dir):
#     os.makedirs(save_dir)
#     np.savetxt(os.path.join(save_dir, f"{sample}.csv"), data, delimiter=',')

return data