In [1]:
import pandas as pd
from sklearn.datasets import load_iris

# This Projekt import
from table_gan.Data.dataset import CTGan_data_set
from table_gan.Benchmark.benchmark import Benchmark
from table_gan.Model.Gans.WCTGan import WCTGan
from table_gan.Model.Gans._gan_utils import plot_gan_losses

### Load Data

In [3]:
iris = load_iris(as_frame=True)
df = iris['frame']

## Create Custom Dataset
To use CTGAN, you need to create a custom dataset tailored to your specific needs. 

In [4]:
data_set = CTGan_data_set(
    data=df,
    cond_cols=["target"],
    cat_cols=["target"]  
)

## Create CTGAN

In [5]:
wctgan = WCTGan()

## See the Training Losses

In [None]:
crit_loss, gen_loss = wctgan.fit(
    data_set
)

plot_gan_losses(crit_loss, gen_loss)

##  Create New Conditional Data

To create data given a condition, just create a DataFrame that represents the condition. 
Generating data with a condition doesn’t guarantee that all generated data will strictly 
adhere to that condition, but it will influence the generation process.

If you want to create data with no specific condition, simply pass the number of points 
you want to generate.

In [None]:
cond_df = pd.DataFrame([{"target" : 1}]*160)
syn_df = wctgan.gen(160)#cond_df=cond_df)

print(syn_df)

### Benchmark the Generated Data
There are various ways to benchmark generated data. In the benchmark class, several methods are implemented 
that you can use to validate your generated data. However, 
it’s worth noting that GANs often struggle to accurately replicate even simple data distributions, as demonstrated in [source TODO].

In [None]:
benchmark = Benchmark()
mean_rfc = benchmark.mean_rfc(df, syn_df, plot=True)