In [19]:
import warnings

import numpy as np
import pandas as pd
import torch
from torch import optim
from torch.nn import BatchNorm1d, Dropout, LeakyReLU, Linear, Module, ReLU, Sequential, functional
from tqdm import tqdm

from data_transformer.data_sampler import DataSampler
from data_transformer.data_transformer import DataTransformer
import wandb
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [57]:
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
    for _ in range(10):
        transformed = functional.gumbel_softmax(logits, tau=tau, hard=hard, eps=eps, dim=dim)
        if not torch.isnan(transformed).any():
            return transformed

class SmallRes(Module):
    def __init__(self, d_input, d_output):
        super().__init__()
        self.layers = Sequential(
            Linear(d_input,d_output),
            BatchNorm1d(d_output),
            LeakyReLU(0.2)
        )
    def forward(self, x):
        out = self.layers(x)
        return torch.cat([x, out], dim = 1)
class Generator(Module):
    def __init__(self, embedding_dim, generator_dim, data_dim):
        super(Generator, self).__init__()
        seq = []
        dim = embedding_dim
        for index in generator_dim:
            seq.append(SmallRes(dim,index))
            dim += index
        seq.append(Dropout(0.3))
        seq.append(Linear(dim, data_dim))
        self.layers = Sequential(*seq)
                
    def forward(self, noise, cond_vec=None):
        if cond_vec is not None:
            x = torch.cat([noise, cond_vec], dim=1)

        else:
            x = noise
            
        for layer in self.layers:
            x = layer(x)
        return x
class Discriminator(Module):
    def __init__(self, input_dim, discriminator_dim, pac=10):
        super().__init__()
        dim = input_dim * pac
        self.pac = pac
        self.pacdim = dim
        seq = []
        for item in list(discriminator_dim):
            seq += [Linear(dim, item), BatchNorm1d(item), LeakyReLU(0.2), Dropout(0.3)]
            dim = item

        seq += [Linear(dim, 1)]
        self.seq = Sequential(*seq)
        
    def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10):

        batch_size = real_data.size(0) // pac
        alpha = torch.rand(batch_size, 1, 1, device=device)
        alpha = alpha.repeat(1, pac, real_data.size(1)).view(-1, real_data.size(1))
        
        interpolates = alpha * real_data + (1 - alpha) * fake_data
        interpolates.requires_grad_(True)
        
        disc_interpolates = self(interpolates)
        
        ones = torch.ones(disc_interpolates.size(), device=device)
        gradients = torch.autograd.grad(
            outputs=disc_interpolates,
            inputs=interpolates,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]
        
        gradients = gradients.view(-1, pac * real_data.size(1))
        gradient_norm = gradients.norm(2, dim=1)
        gradient_penalty = lambda_ * ((gradient_norm - 1) ** 2).mean()
        
        return gradient_penalty
    def forward(self, input_):
        assert input_.size()[0] % self.pac == 0
        return self.seq(input_.view(-1, self.pacdim))
class CTGAN():
    def __init__(self,
        embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4,generator_decay=1e-6, discriminator_lr=2e-4, 
        discriminator_decay=1e-6,  discriminator_steps=1, log_frequency=True, pac=10, cuda=True):


        self._embedding_dim = embedding_dim
        self._generator_dim = generator_dim
        self._discriminator_dim = discriminator_dim

        self._generator_lr = generator_lr
        self._generator_decay = generator_decay
        self._discriminator_lr = discriminator_lr
        self._discriminator_decay = discriminator_decay

        self._discriminator_steps = discriminator_steps
        self._log_frequency = log_frequency
        self.pac = pac

        if not cuda or not torch.cuda.is_available():
            device = 'cpu'
        elif isinstance(cuda, str):
            device = cuda
        else:
            device = 'cuda'

        self._device = torch.device(device)

        self._transformer = None
        self._data_sampler = None
        self._generator = None
        self.loss_values = None
    def _apply_activate(self, data):
        data_t = []
        st = 0
        for column_info in self._transformer.output_info_list:
            for span_info in column_info:
                if span_info.activation_fn == 'tanh':
                    ed = st + span_info.dim
                    data_t.append(torch.tanh(data[:, st:ed]))
                    st = ed
                elif span_info.activation_fn == 'softmax':
                    ed = st + span_info.dim
                    transformed = _gumbel_softmax(data[:, st:ed], tau=0.2)
                    data_t.append(transformed)
                    st = ed
                else:
                    raise ValueError(f'Unexpected activation function {span_info.activation_fn}.')

        return torch.cat(data_t, dim=1)
    def _cond_loss(self, data, c, m):
        loss = []
        st = 0
        st_c = 0
        for column_info in self._transformer.output_info_list:
            for span_info in column_info:
                if len(column_info) != 1 or span_info.activation_fn != 'softmax':
                    st += span_info.dim
                else:
                    ed = st + span_info.dim
                    ed_c = st_c + span_info.dim
                    tmp = functional.cross_entropy(
                        data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none'
                    )
                    loss.append(tmp)
                    st = ed
                    st_c = ed_c

        loss = torch.stack(loss, dim=1)  

        return (loss * m).sum() / data.size()[0]
    def _validate_discrete_columns(self, train_data, discrete_columns):
        if isinstance(train_data, pd.DataFrame):
            invalid_columns = set(discrete_columns) - set(train_data.columns)
        elif isinstance(train_data, np.ndarray):
            invalid_columns = []
            for column in discrete_columns:
                if column < 0 or column >= train_data.shape[1]:
                    invalid_columns.append(column)
        else:
            raise TypeError('``train_data`` should be either pd.DataFrame or np.array.')

        if invalid_columns:
            raise ValueError(f'Invalid columns found: {invalid_columns}')
    def _validate_null_data(self, train_data, discrete_columns):
        if isinstance(train_data, pd.DataFrame):
            continuous_cols = list(set(train_data.columns) - set(discrete_columns))
            any_nulls = train_data[continuous_cols].isna().any().any()
        else:
            continuous_cols = [i for i in range(train_data.shape[1]) if i not in discrete_columns]
            any_nulls = pd.DataFrame(train_data)[continuous_cols].isna().any().any()

        if any_nulls:
            raise ValueError(
                'CTGAN does not support null values in the continuous training data. '
                'Please remove all null values from your continuous training data.'
            )
    def fit(self, train_data, discrete_columns=(), epochs=300, batch_size=500, track = False):
        assert batch_size % 2 == 0
        self._validate_discrete_columns(train_data, discrete_columns)
        self._validate_null_data(train_data, discrete_columns)

        self._transformer = DataTransformer()
        self._transformer.fit(train_data, discrete_columns)

        train_data = self._transformer.transform(train_data)

        self._data_sampler = DataSampler(
            train_data, self._transformer.output_info_list, self._log_frequency
        )

        data_dim = self._transformer.output_dimensions

        self._generator = Generator(
            self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim
        ).to(self._device)

        discriminator = Discriminator(
            data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac
        ).to(self._device)

        optimizerG = optim.Adam(
            self._generator.parameters(),
            lr=self._generator_lr,
            betas=(0.5, 0.9),
            weight_decay=self._generator_decay,
        )

        optimizerD = optim.Adam(
            discriminator.parameters(),
            lr=self._discriminator_lr,
            betas=(0.5, 0.9),
            weight_decay=self._discriminator_decay,
        )

        mean = torch.zeros(batch_size, self._embedding_dim, device=self._device)
        std = mean + 1

        self.loss_values = pd.DataFrame(columns=['Epoch', 'Generator Loss', 'Distriminator Loss'])


        steps_per_epoch = max(len(train_data) // batch_size, 1)
        if track:
            run = wandb.init(project="Data Augmentation - CTGAN", name="ctgan", config={
                "epochs": epochs,
                "batch_size": batch_size,
                "generator_lr": self._generator_lr,
                "discriminator_lr": self._discriminator_lr,
                "pac": self.pac,
                "data_dim": data_dim
            })
        for i in range(epochs):
            for id_ in range(steps_per_epoch):
                for n in range(self._discriminator_steps):
                    fakez = torch.normal(mean=mean, std=std)

                    condvec = self._data_sampler.sample_condvec(batch_size)
                    if condvec is None:
                        c1, m1, col, opt = None, None, None, None
                        real = self._data_sampler.sample_data(
                            train_data, batch_size, col, opt
                        )
                    else:
                        c1, m1, col, opt = condvec
                        c1 = torch.from_numpy(c1).to(self._device)
                        m1 = torch.from_numpy(m1).to(self._device)
                        fakez = torch.cat([fakez, c1], dim=1)

                        perm = np.arange(batch_size)
                        np.random.shuffle(perm)
                        real = self._data_sampler.sample_data(
                            train_data, batch_size, col[perm], opt[perm]
                        )
                        c2 = c1[perm]

                    fake = self._generator(fakez)
                    fakeact = self._apply_activate(fake)

                    real = torch.from_numpy(real.astype('float32')).to(self._device)

                    if c1 is not None:
                        fake_cat = torch.cat([fakeact, c1], dim=1)
                        real_cat = torch.cat([real, c2], dim=1)
                    else:
                        real_cat = real
                        fake_cat = fakeact

                    y_fake = discriminator(fake_cat)
                    y_real = discriminator(real_cat)

                    pen = discriminator.calc_gradient_penalty(
                        real_cat, fake_cat, self._device, self.pac
                    )
                    loss_d = -(torch.mean(y_real) - torch.mean(y_fake))

                    optimizerD.zero_grad(set_to_none=False)
                    pen.backward(retain_graph=True)
                    loss_d.backward()
                    optimizerD.step()

                fakez = torch.normal(mean=mean, std=std)
                condvec = self._data_sampler.sample_condvec(batch_size)

                if condvec is None:
                    c1, m1, col, opt = None, None, None, None
                else:
                    c1, m1, col, opt = condvec
                    c1 = torch.from_numpy(c1).to(self._device)
                    m1 = torch.from_numpy(m1).to(self._device)
                    fakez = torch.cat([fakez, c1], dim=1)

                fake = self._generator(fakez)
                fakeact = self._apply_activate(fake)

                if c1 is not None:
                    y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
                else:
                    y_fake = discriminator(fakeact)

                if condvec is None:
                    cross_entropy = 0
                else:
                    cross_entropy = self._cond_loss(fake, c1, m1)

                loss_g = -torch.mean(y_fake) + cross_entropy

                optimizerG.zero_grad(set_to_none=False)
                loss_g.backward()
                optimizerG.step()

            generator_loss = loss_g.detach().cpu().item()
            discriminator_loss = loss_d.detach().cpu().item()
            if track:
                wandb.log({
                "epoch": i,
                "generator_loss": generator_loss,
                "discriminator_loss": discriminator_loss
                })
            if (i + 1) % 10 == 0 or i == 0:
                print(f"Epoch [{i+1}/{epochs}], "
                      f"Train - G Loss: {generator_loss:.4f}, D Loss: {discriminator_loss:.4f}, "
                      f"LR - Generator: {self._generator_lr:.6f}, Discriminator: {self._discriminator_lr:.6f}")
            epoch_loss_df = pd.DataFrame({
                'Epoch': [i],
                'Generator Loss': [generator_loss],
                'Discriminator Loss': [discriminator_loss],
            })
            if not self.loss_values.empty:
                self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index(
                    drop=True
                )
            else:
                self.loss_values = epoch_loss_df
        wandb.finish()
        
    def sample(self, n, condition_column=None, condition_value=None, batch_size= 500):
        if condition_column is not None and condition_value is not None:
            condition_info = self._transformer.convert_column_name_value_to_id(
                condition_column, condition_value
            )
            global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info(
                condition_info, batch_size
            )
        else:
            global_condition_vec = None

        steps = n // batch_size + 1
        data = []
        for i in range(steps):
            mean = torch.zeros(batch_size, self._embedding_dim)
            std = mean + 1
            fakez = torch.normal(mean=mean, std=std).to(self._device)

            if global_condition_vec is not None:
                condvec = global_condition_vec.copy()
            else:
                condvec = self._data_sampler.sample_original_condvec(batch_size)

            if condvec is None:
                pass
            else:
                c1 = condvec
                c1 = torch.from_numpy(c1).to(self._device)
                fakez = torch.cat([fakez, c1], dim=1)

            fake = self._generator(fakez)
            fakeact = self._apply_activate(fake)
            data.append(fakeact.detach().cpu().numpy())

        data = np.concatenate(data, axis=0)
        data = data[:n]

        return self._transformer.inverse_transform(data)

    def set_device(self, device):
        self._device = device
        if self._generator is not None:
            self._generator.to(self._device)
    def save(self, path):
        if self._generator is None or self._discriminator is None:
            raise RuntimeError("Model not trained. Call fit() first.")
        checkpoint = {
            'generator': self._generator.state_dict(),
            'discriminator': self._discriminator.state_dict(),
            'generator_optimizer': self.optimizerG.state_dict(),
            'discriminator_optimizer': self.optimizerD.state_dict(),
            'transformer': self._transformer  
        }
        torch.save(checkpoint, path)


In [59]:
a = pd.DataFrame({
    'A': [1,2,1,2,1,2,1,3,3,2,2,2,1,1,1], 
    'B': ['z','z','zz', 'z', 'z','z','z','zz', 'z', 'z','z','z','zz', 'z', 'z'], 
    'C': [0.99404096, 0.58273721, 0.21701061, 0.1175965,  0.68291119, 0.62865904, 
          0.68754258, 0.51539969, 0.70036077, 0.94512348, 0.13780938, 0.04576671,
          0.0784216, 0.19138225, 0.78545446]
})

# Create a copy of the dataframe to preserve original
df = a.copy()

# Identify categorical columns
categorical_columns = ['A', 'B']
numeric_columns = ['C']
ctgan = CTGAN()
ctgan.fit(a, categorical_columns,epochs= 100)

Epoch [1/100], Train - G Loss: 0.7460, D Loss: 0.0235, LR - Generator: 0.000200, Discriminator: 0.000200
Epoch [10/100], Train - G Loss: 0.6551, D Loss: 0.0164, LR - Generator: 0.000200, Discriminator: 0.000200
Epoch [20/100], Train - G Loss: 0.7027, D Loss: -0.0121, LR - Generator: 0.000200, Discriminator: 0.000200
Epoch [30/100], Train - G Loss: 0.7212, D Loss: 0.0645, LR - Generator: 0.000200, Discriminator: 0.000200
Epoch [40/100], Train - G Loss: 0.4899, D Loss: 0.0557, LR - Generator: 0.000200, Discriminator: 0.000200
Epoch [50/100], Train - G Loss: 0.4857, D Loss: 0.0158, LR - Generator: 0.000200, Discriminator: 0.000200
Epoch [60/100], Train - G Loss: 0.5198, D Loss: 0.0043, LR - Generator: 0.000200, Discriminator: 0.000200
Epoch [70/100], Train - G Loss: 0.4258, D Loss: 0.0398, LR - Generator: 0.000200, Discriminator: 0.000200
Epoch [80/100], Train - G Loss: 0.3130, D Loss: -0.0181, LR - Generator: 0.000200, Discriminator: 0.000200
Epoch [90/100], Train - G Loss: 0.3534, D Los

In [41]:
ctgan.sample(n=16, condition_column='A', condition_value= 1)

Unnamed: 0,A,B,C
0,2,z,0.991171
1,2,z,1.560485
2,2,zz,0.044675
3,3,zz,0.94736
4,1,z,1.32115
5,3,z,0.526378
6,1,zz,0.112976
7,1,z,0.530903
8,3,z,1.207839
9,2,z,1.546595


In [61]:
ctgan.save("a")

AttributeError: 'CTGAN' object has no attribute '_discriminator'