In [118]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import pandas as pd  # Assuming .svc is tabular data



In [119]:
class GAN_Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(GAN_Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

In [120]:
class GAN_Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(GAN_Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
            nn.Sigmoid()
        )
    
    def forward(self, z):
        return self.model(z)

In [121]:
import numpy as np

def load_svc_file(filepath):
    data = []
    expected_num_features = 6  # Set this based on your expected number of features

    with open(filepath, 'r') as file:
        # Skip the first line (contains the number of rows)
        next(file)

        for line in file:
            # Skip empty lines
            if not line.strip():
                continue
            
            try:
                # Convert the line into a list of floats, skipping the first value
                values = list(map(float, line.strip().split()))[1:]

                # Check if the number of values matches the expected count
                if len(values) == expected_num_features:
                    data.append(values)
                else:
                    print(f"Warning: Skipped line due to unexpected number of features: '{line.strip()}'")
            except ValueError as e:
                print(f"Error processing line: '{line.strip()}': {e}")

    if not data:
        raise ValueError(f"No valid data found in file: {filepath}")

    # Convert the list to a NumPy array
    return np.array(data)

In [122]:
def save_synthetic_data(data, filepath):
    # Convert to integers (if necessary) and ensure the data is in the right format
    data = data.astype(int)

    # Prepare the output with the number of rows at the top
    num_rows = data.shape[0]
    output_data = [str(num_rows)]  # First line with the number of rows

    # Append each row of data to the output
    for row in data:
        output_data.append(" ".join(map(str, row)))

    # Write to the output file
    with open(filepath, 'w') as f:
        f.write("\n".join(output_data) + "\n")  # Join rows with newline and write to file


In [123]:
# Main function to handle folder processing
def process_svc_folder(input_folder, output_folder, vae, generator, latent_dim, epochs):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for filename in os.listdir(input_folder):
        if filename.endswith(".svc"):
            filepath = os.path.join(input_folder, filename)
            print(f"Processing file: {filename}")

            # Load the .svc file
            svc_data = load_svc_file(filepath)

            # Normalize the data
            normalized_svc = normalize(svc_data)
            original_min, original_max = np.min(svc_data), np.max(svc_data)

            # Convert to tensor
            real_data = Variable(torch.Tensor(normalized_svc))

            # Ensure real_data is shaped correctly
            if real_data.size(0) < 7:
                print(f"Not enough data in {filename}. Skipping.")
                continue
            
            # Reshape to ensure it has the shape (batch_size, 7)
            try:
                real_data = real_data.view(-1, 7)  # Reshape to (batch_size, input_dim)
            except RuntimeError as e:
                print(f"Error reshaping real_data: {e}")
                print(f"real_data contents: {real_data}")
                continue  # Skip this iteration if there's an error

            # Train VAE-GAN (for simplicity, we train per file)
            for epoch in range(epochs):
                # Step 1: VAE training
                vae_optimizer.zero_grad()

                # Check the shape of real_data after reshaping
                print(f"Epoch {epoch + 1}/{epochs} - real_data shape: {real_data.shape}")

                recon_data, mean, log_var = vae(real_data)
                reconstruction_loss = reconstruction_loss_fn(recon_data, real_data)
                kl_divergence = -0.5 * torch.mean(1 + log_var - mean.pow(2) - log_var.exp())

                vae_loss = reconstruction_loss + kl_divergence
                vae_loss.backward()
                vae_optimizer.step()

                # Step 2: GAN training
                discriminator_optimizer.zero_grad()
                real_labels = Variable(torch.ones(real_data.size(0), 1))
                fake_labels = Variable(torch.zeros(real_data.size(0), 1))

                # Real data loss
                real_output = discriminator(real_data)
                d_loss_real = gan_loss_fn(real_output, real_labels)

                # Fake data generation and loss
                latent_samples = torch.randn(real_data.size(0), latent_dim)
                fake_data = generator(latent_samples)
                fake_output = discriminator(fake_data.detach())
                d_loss_fake = gan_loss_fn(fake_output, fake_labels)

                d_loss = d_loss_real + d_loss_fake
                d_loss.backward()
                discriminator_optimizer.step()

                # Train generator
                generator_optimizer.zero_grad()
                fake_output = discriminator(fake_data)
                g_loss = gan_loss_fn(fake_output, real_labels)
                g_loss.backward()
                generator_optimizer.step()

            # Generate synthetic data using the generator
            latent_samples = torch.randn(real_data.size(0), latent_dim)
            synthetic_data = generator(latent_samples).detach().numpy()

            # Denormalize the generated synthetic data
            synthetic_data_denorm = denormalize(synthetic_data, original_min, original_max)

            # Save synthetic data to the output folder
            output_filepath = os.path.join(output_folder, filename)
            save_synthetic_data(synthetic_data_denorm, output_filepath)
            print(f"Synthetic data saved to: {output_filepath}")

In [124]:
# Hyperparameters
input_dim = 7  # Assuming there are 8 columns in your .svc file
latent_dim = 7
lr = 1e-3
epochs = 10

# Initialize models
vae = VAE(input_dim, latent_dim)
discriminator = GAN_Discriminator(input_dim)
generator = GAN_Generator(latent_dim, input_dim)

# Loss and Optimizer
vae_optimizer = optim.Adam(vae.parameters(), lr=lr)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
generator_optimizer = optim.Adam(generator.parameters(), lr=lr)

reconstruction_loss_fn = nn.MSELoss()
gan_loss_fn = nn.BCELoss()

# Specify input and output folders
input_folder = "../test/samplefew"
output_folder = "../test/output/vaegan"

# Process the folder
process_svc_folder(input_folder, output_folder, vae, generator, latent_dim, epochs=epochs)

Processing file: 0_hw00001(2).svc
Epoch 1/10 - real_data shape: torch.Size([2952, 7])
Epoch 2/10 - real_data shape: torch.Size([2952, 7])
Epoch 3/10 - real_data shape: torch.Size([2952, 7])
Epoch 4/10 - real_data shape: torch.Size([2952, 7])
Epoch 5/10 - real_data shape: torch.Size([2952, 7])
Epoch 6/10 - real_data shape: torch.Size([2952, 7])
Epoch 7/10 - real_data shape: torch.Size([2952, 7])
Epoch 8/10 - real_data shape: torch.Size([2952, 7])
Epoch 9/10 - real_data shape: torch.Size([2952, 7])
Epoch 10/10 - real_data shape: torch.Size([2952, 7])
Synthetic data saved to: ../test/output/vaegan\0_hw00001(2).svc
Processing file: 0_hw00001(2)21.svc
Epoch 1/10 - real_data shape: torch.Size([1788, 7])
Epoch 2/10 - real_data shape: torch.Size([1788, 7])
Epoch 3/10 - real_data shape: torch.Size([1788, 7])
Epoch 4/10 - real_data shape: torch.Size([1788, 7])
Epoch 5/10 - real_data shape: torch.Size([1788, 7])
Epoch 6/10 - real_data shape: torch.Size([1788, 7])
Epoch 7/10 - real_data shape: tor