In [None]:
import os
import time
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from itertools import combinations

# 1) Load the full population dataset
x_population = pd.read_csv("data/h_population.csv")

# 2) Encode each categorical column as integer labels
label_encoders = {}
for col in x_population.columns:
for col in x_population.columns:
    le = LabelEncoder()
    x_population[col] = le.fit_transform(x_population[col])
    label_encoders[col] = le

# 3) Convert columns to pandas 'category' dtype (expected by OneHotEncoder)
x_population = x_population.apply(lambda col: col.astype("category"))

# 4) Create and fit the OneHotEncoder
#    Note: use sparse_output=False in scikit-learn >=1.2
enc = OneHotEncoder(handle_unknown='ignore', sparse_output=False)
enc.fit(x_population)

# 5) Load your actual sample data
# # Example of Spliting 5% Training data
# x_train_split = train_test_split(
#     x_population, test_size=0.05, shuffle=True, random_state=1004
# )[1]
x_train_split = pd.read_csv("data/h_sample.csv")

# 6) Apply the same LabelEncoder mappings to the sample data
for col in x_train_split.columns:
    x_train_split[col] = label_encoders[col].transform(x_train_split[col])

# 7) (Optional) Convert sample data to 'category' dtype as well
x_train_split = x_train_split.apply(lambda col: col.astype("category"))

# 8) Perform one-hot encoding on the sample
x_train_encoded = enc.transform(x_train_split)

# 9) Convert the encoded array into a PyTorch tensor
x_train_tensor = torch.tensor(x_train_encoded, dtype=torch.float32)

print("Training data shape:", x_train_tensor.shape)


Training data shape: torch.Size([53315, 70])


In [2]:

##########################
# 2. MODEL DEFINITIONS
##########################
class Generator(nn.Module):
    def __init__(self, latent_dim, hidden_dims, cat_dims):
        super().__init__()
        layers = []
        in_dim = latent_dim
        for dim in hidden_dims:
            layers.append(nn.Linear(in_dim, dim))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(dim))
            in_dim = dim
        self.hidden = nn.Sequential(*layers)

        # One output "head" per categorical column
        self.heads = nn.ModuleList([
            nn.Linear(hidden_dims[-1], cd) for cd in cat_dims
        ])

    def forward(self, noise):
        h = self.hidden(noise)
        outputs = []
        for head in self.heads:
            logits = head(h)
            probs = F.softmax(logits, dim=1)
            outputs.append(probs)
        return torch.cat(outputs, dim=1)

class Critic(nn.Module):
    def __init__(self, input_dim, hidden_dims):
        super().__init__()
        layers = []
        in_dim = input_dim
        for dim in hidden_dims:
            layers.append(nn.Linear(in_dim, dim))
            layers.append(nn.LeakyReLU(0.2))
            in_dim = dim
        layers.append(nn.Linear(hidden_dims[-1], 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

def gradient_penalty(critic, real_data, fake_data, device):
    batch_size = real_data.size(0)
    epsilon = torch.rand(batch_size, 1, device=device)
    epsilon = epsilon.expand_as(real_data)

    interpolates = epsilon * real_data + (1 - epsilon) * fake_data
    interpolates.requires_grad_(True)

    critic_interpolates = critic(interpolates)
    grad_outputs = torch.ones_like(critic_interpolates)

    gradients = torch.autograd.grad(
        outputs=critic_interpolates,
        inputs=interpolates,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True
    )[0]

    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    return ((gradient_norm - 1) ** 2).mean()



In [3]:
##########################
# 3. Utils
##########################

def wide_to_long(samples_pop, cat_boundaries, enc):
    """
    Converts the wide one-hot output to discrete categories
    by sampling from each column's predicted probabilities.
    """
    resamples = []
    for row_idx in range(samples_pop.shape[0]):
        row_probs = samples_pop[row_idx]
        resample_row = []
        for i in range(len(cat_boundaries) - 1):
            start = cat_boundaries[i]
            end   = cat_boundaries[i+1]
            probs = row_probs[start:end]
            sum_p = probs.sum()
            if sum_p > 0:
                probs /= sum_p
            else:
                probs = np.ones_like(probs) / len(probs)

            chosen_idx = np.random.choice(len(probs), p=probs)
            original_value = enc.categories_[i][chosen_idx]
            resample_row.append(original_value)
        resamples.append(resample_row)

    df = pd.DataFrame(resamples, columns=enc.feature_names_in_)
    return df.apply(lambda x: x.astype('category'))

def SRMSE(x_population: pd.DataFrame, resamples: pd.DataFrame):
    """Caculate Marginal·Bivariate SRMSE"""
    sam_marg_cnt, resam_marg_cnt = [], []
    for col in x_population.columns:
        pop_series = x_population[col].dropna()
        syn_series = resamples[col].dropna()

        tab = (pd
               .merge(syn_series.value_counts().rename('syn'),
                      pop_series.value_counts().rename('pop'),
                      left_index=True, right_index=True, how='outer')
               .fillna(0))
        sam_marg_cnt.append(tab['pop'].values / (len(pop_series) or 1))
        resam_marg_cnt.append(tab['syn'].values / (len(syn_series) or 1))

    sam_marg_cnt = np.concatenate(sam_marg_cnt) if sam_marg_cnt else np.array([])
    resam_marg_cnt = np.concatenate(resam_marg_cnt) if resam_marg_cnt else np.array([])

    if sam_marg_cnt.size:
        rmse_mar  = np.linalg.norm(sam_marg_cnt - resam_marg_cnt) / np.sqrt(len(sam_marg_cnt))
        ybar_mar  = sam_marg_cnt.mean()
        srmse_mar = rmse_mar / ybar_mar if ybar_mar else np.nan
    else:
        srmse_mar = np.nan

    # ── Bivariate ──────────────────────────────────────────
    sam_bi_cnt, resam_bi_cnt = [], []
    for c1, c2 in combinations(x_population.columns, 2):
        pop_pair = x_population[[c1, c2]].dropna()
        syn_pair = resamples[[c1, c2]].dropna()
        if pop_pair.empty or syn_pair.empty:
            continue  

        tab = (pd
               .merge(pd.crosstab(syn_pair[c1], syn_pair[c2]).stack().rename('syn'),
                      pd.crosstab(pop_pair[c1], pop_pair[c2]).stack().rename('pop'),
                      left_index=True, right_index=True, how='outer')
               .fillna(0))
        sam_bi_cnt.append(tab['pop'].values / (len(pop_pair) or 1))
        resam_bi_cnt.append(tab['syn'].values / (len(syn_pair) or 1))

    sam_bi_cnt  = np.concatenate(sam_bi_cnt)  if sam_bi_cnt  else np.array([])
    resam_bi_cnt = np.concatenate(resam_bi_cnt) if resam_bi_cnt else np.array([])

    if sam_bi_cnt.size:
        rmse_bi  = np.linalg.norm(sam_bi_cnt - resam_bi_cnt) / np.sqrt(len(sam_bi_cnt))
        ybar_bi  = sam_bi_cnt.mean()
        srmse_bi = rmse_bi / ybar_bi if ybar_bi else np.nan
    else:
        srmse_bi = np.nan

    return srmse_mar, srmse_bi


def calculate_precision_recall(population_df: pd.DataFrame,
                               generated_df: pd.DataFrame):
    """Precision·Recall·F1 and combinations"""
    pop_df = population_df.dropna()
    gen_df = generated_df.dropna()
    if pop_df.empty or gen_df.empty:
        return dict(precision=0, recall=0, f1_score=0,
                    unique_combinations={'population': 0, 'generated': 0},
                    matching_combinations={'unique_types': 0, 'total_count': 0})

    pop_tuples = pop_df.apply(tuple, axis=1)
    gen_tuples = gen_df.apply(tuple, axis=1)

    pop_set, gen_set = set(pop_tuples), set(gen_tuples)

    precision = (gen_tuples.isin(pop_set)).mean()
    recall    = (pop_tuples.isin(gen_set)).mean()
    f1        = 2*precision*recall/(precision+recall) if precision+recall else 0

    matching_unique = pop_set & gen_set
    gen_counts      = gen_tuples.value_counts()
    match_row_cnt   = sum(gen_counts[t] for t in matching_unique)

    return {
        'precision' : round(float(precision), 4),
        'recall'    : round(float(recall),    4),
        'f1_score'  : round(float(f1),        4),
        'unique_combinations': {
            'population': int(pop_tuples.nunique()),
            'generated' : int(gen_tuples.nunique())
        },
        'matching_combinations': {
            'unique_types': len(matching_unique),
            'total_count' : int(match_row_cnt)
        }
    }

def evaluate_coverage(
    x_population: pd.DataFrame,
    x_sample: pd.DataFrame,
    generator,
    cat_boundaries,
    enc,
    device,
    latent_dim: int,
    batch_size: int = 256,
):
    """
    1) Generate the same number of synthetic samples as the population
    2) Compute STZ/SAZ ratios, SRMSE, and Precision/Recall/F1
    3) Return a dictionary containing all metrics
    """
    # ── 1. Synthetic data generation ───────────────────────────────────────
    generator.eval().to(device)
    n_gen = len(x_population)
    results, done = [], 0
    t0 = time.time()

    with tqdm(total=n_gen, desc="Generating") as bar:
        while done < n_gen:
            cur = min(batch_size, n_gen - done)
            noise = torch.randn(cur, latent_dim, device=device)
            with torch.no_grad():
                fake = generator(noise).cpu().numpy()
            df_cat = wide_to_long(fake, cat_boundaries, enc)
            results.append(df_cat)
            done += cur
            bar.update(cur)

    resamples = pd.concat(results, ignore_index=True)
    infer_sec = time.time() - t0

    # ── 2. Common preprocessing (string conversion & column alignment) ─────
    pop_df = x_population.astype(str).reset_index(drop=True)
    sam_df = x_sample.astype(str).reset_index(drop=True)
    gen_df = resamples.astype(str).reindex(columns=pop_df.columns)

    # ── 3-A. STZ / SAZ computation ─────────────────────────────────────────
    pop_key = pop_df.agg("".join, axis=1)
    sam_key = sam_df.agg("".join, axis=1)
    gen_key = gen_df.agg("".join, axis=1)

    STZ = gen_df[~gen_key.isin(sam_key) & ~gen_key.isin(pop_key)]
    SAZ = gen_df[~gen_key.isin(sam_key) & gen_key.isin(pop_key)]
    stz_ratio = round(len(STZ) / len(gen_df), 4)
    saz_ratio = round(len(SAZ) / len(gen_df), 4)

    # ── 3-B. SRMSE computation ─────────────────────────────────────────────
    srmse_mar, srmse_bi = SRMSE(pop_df, gen_df)

    # ── 3-C. Precision, Recall, F1 computation ─────────────────────────────
    pr = calculate_precision_recall(pop_df, gen_df)

    # ── 4. Packaging results ───────────────────────────────────────────────
    metrics = {
        "stz_ratio": stz_ratio,
        "saz_ratio": saz_ratio,
        "srmse_marginal": srmse_mar,
        "srmse_bivariate": srmse_bi,
        "precision": pr["precision"],
        "recall": pr["recall"],
        "f1_score": pr["f1_score"],
        "unique_combinations": pr["unique_combinations"],
        "matching_combinations": pr["matching_combinations"],
        "infer_time_sec": round(infer_sec, 2),
    }
    return metrics

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

cat_dims = [x_population[col].nunique() for col in x_population.columns]
total_output_dim = sum(cat_dims)

latent_dim = 256
hidden_dims = [256, 256, 256]
batch_size = 256
n_critic = 3
lambda_gp = 10
epochs = 250

generator = Generator(latent_dim, hidden_dims, cat_dims).to(device)
critic    = Critic(total_output_dim, hidden_dims).to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.9))
optimizer_C = optim.Adam(critic.parameters(),   lr=0.0002, betas=(0.5, 0.9))

# Prepare a DataLoader to iterate over the entire training set each epoch
train_dataset = TensorDataset(x_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

os.makedirs("Vanilla-WGAN", exist_ok=True)
best_g_loss = float("inf")
save_interval = 50

train_start_time = time.time()

for epoch in tqdm(range(epochs), desc="Training Progress"):
    epoch_loss_G = 0.0
    epoch_loss_C = 0.0
    batch_count = 0

    for real_data_tuple in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
        real_data = real_data_tuple[0].to(device)
        batch_count += 1

        # (A) Critic updates (n_critic times)
        for _ in range(n_critic):
            noise = torch.randn(real_data.size(0), latent_dim, device=device)
            with torch.no_grad():
                fake_data = generator(noise)

            optimizer_C.zero_grad()
            real_score = critic(real_data).mean()
            fake_score = critic(fake_data).mean()
            gp = gradient_penalty(critic, real_data, fake_data, device)
            d_loss = -real_score + fake_score + lambda_gp * gp
            d_loss.backward()
            optimizer_C.step()
            epoch_loss_C += d_loss.item()

        # (B) Generator update (1 step)
        noise = torch.randn(real_data.size(0), latent_dim, device=device)
        fake_data = generator(noise)
        optimizer_G.zero_grad()
        g_loss = -critic(fake_data).mean()
        g_loss.backward()
        optimizer_G.step()
        epoch_loss_G += g_loss.item()

    # Average losses across mini-batches
    avg_loss_G = epoch_loss_G / batch_count
    avg_loss_C = epoch_loss_C / (batch_count * n_critic)

    if (epoch + 1) % 10 == 0:
        tqdm.write(f"Epoch [{epoch+1}/{epochs}] | D Loss: {avg_loss_C:.4f} | G Loss: {avg_loss_G:.4f}")

    # Check if generator improved => save "best"
    if avg_loss_G < best_g_loss:
        best_g_loss = avg_loss_G
        torch.save(generator.state_dict(), "Vanilla-WGAN/generator_best.pth")
        torch.save(critic.state_dict(),    "Vanilla-WGAN/critic_best.pth")
        tqdm.write(f"  [Epoch {epoch+1}] Best model saved (G loss = {avg_loss_G:.4f})")

    # Periodic checkpoint
    if (epoch + 1) % save_interval == 0:
        torch.save(generator.state_dict(), f"Vanilla-WGAN/generator_epoch_{epoch+1}.pth")
        torch.save(critic.state_dict(),    f"Vanilla-WGAN/critic_epoch_{epoch+1}.pth")
        tqdm.write(f"  [Epoch {epoch+1}] Checkpoint saved.")

train_end_time = time.time()
train_duration = train_end_time - train_start_time
print("\nTraining complete.")
print(f"Total training time: {train_duration:.2f} seconds")


Training Progress:   0%|                                                                       | 0/250 [00:00<?, ?it/s]
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

[Ach 1/250:   0%|▎                                                                    | 1/209 [00:00<00:57,  3.64it/s]
[Ach 1/250:   2%|█▋                                                                   | 5/209 [00:00<00:14, 14.28it/s]
[Ach 1/250:   4%|██▉                                                                  | 9/209 [00:00<00:09, 20.09it/s]
[Ach 1/250:   6%|████▏                                                               | 13/209 [00:00<00:07, 25.33it/s]
[Ach 1/250:   8%|█████▏                                                              | 16/209 [00:00<00:07, 26.07it/s]
[Ach 1/250:  10%|██████▌                                                             | 20/209 [00:00<00:06, 28.17it/s]
[Ach 1/250:  11%|███████▊                             

  [Epoch 1] Best model saved (G loss = -0.6804)



[Ach 2/250:   0%|                                                                             | 0/209 [00:00<?, ?it/s]
[Ach 2/250:   1%|▉                                                                    | 3/209 [00:00<00:07, 28.30it/s]
[Ach 2/250:   3%|█▉                                                                   | 6/209 [00:00<00:07, 26.36it/s]
[Ach 2/250:   4%|██▉                                                                  | 9/209 [00:00<00:07, 27.45it/s]
[Ach 2/250:   6%|███▉                                                                | 12/209 [00:00<00:06, 28.18it/s]
[Ach 2/250:   7%|████▉                                                               | 15/209 [00:00<00:07, 26.85it/s]
[Ach 2/250:   9%|█████▊                                                              | 18/209 [00:00<00:07, 26.79it/s]
[Ach 2/250:  10%|██████▊                                                             | 21/209 [00:00<00:07, 26.57it/s]
[Ach 2/250:  11%|███████▊             

  [Epoch 2] Best model saved (G loss = -0.8964)



[Ach 3/250:   0%|                                                                             | 0/209 [00:00<?, ?it/s]
[Ach 3/250:   1%|▉                                                                    | 3/209 [00:00<00:07, 27.05it/s]
[Ach 3/250:   3%|█▉                                                                   | 6/209 [00:00<00:07, 27.17it/s]
[Ach 3/250:   4%|██▉                                                                  | 9/209 [00:00<00:07, 25.35it/s]
[Ach 3/250:   6%|███▉                                                                | 12/209 [00:00<00:07, 26.22it/s]
[Ach 3/250:   7%|████▉                                                               | 15/209 [00:00<00:07, 27.03it/s]
[Ach 3/250:   9%|█████▊                                                              | 18/209 [00:00<00:06, 27.53it/s]
[Ach 3/250:  10%|██████▊                                                             | 21/209 [00:00<00:07, 25.66it/s]
[Ach 3/250:  11%|███████▊             

  [Epoch 3] Best model saved (G loss = -0.9465)



[Ach 4/250:   0%|                                                                             | 0/209 [00:00<?, ?it/s]
[Ach 4/250:   1%|▉                                                                    | 3/209 [00:00<00:08, 25.53it/s]
[Ach 4/250:   3%|█▉                                                                   | 6/209 [00:00<00:07, 25.83it/s]
[Ach 4/250:   4%|██▉                                                                  | 9/209 [00:00<00:07, 25.96it/s]
[Ach 4/250:   6%|███▉                                                                | 12/209 [00:00<00:07, 25.09it/s]
[Ach 4/250:   8%|█████▏                                                              | 16/209 [00:00<00:07, 26.10it/s]
[Ach 4/250:   9%|██████▏                                                             | 19/209 [00:00<00:07, 26.20it/s]
[Ach 4/250:  11%|███████▍                                                            | 23/209 [00:00<00:06, 27.51it/s]
[Ach 4/250:  13%|████████▊            

  [Epoch 9] Best model saved (G loss = -1.0079)



[Ach 10/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 10/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 24.92it/s]
[Ach 10/250:   3%|█▉                                                                  | 6/209 [00:00<00:08, 25.04it/s]
[Ach 10/250:   5%|███▏                                                               | 10/209 [00:00<00:07, 27.16it/s]
[Ach 10/250:   6%|████▏                                                              | 13/209 [00:00<00:06, 28.10it/s]
[Ach 10/250:   8%|█████▏                                                             | 16/209 [00:00<00:06, 28.22it/s]
[Ach 10/250:   9%|██████                                                             | 19/209 [00:00<00:06, 28.03it/s]
[Ach 10/250:  11%|███████                                                            | 22/209 [00:00<00:06, 27.19it/s]
[Ach 10/250:  12%|████████            

Epoch [10/250] | D Loss: -0.3709 | G Loss: -1.0342
  [Epoch 10] Best model saved (G loss = -1.0342)



[Ach 11/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 11/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 28.79it/s]
[Ach 11/250:   3%|██▎                                                                 | 7/209 [00:00<00:06, 30.88it/s]
[Ach 11/250:   5%|███▌                                                               | 11/209 [00:00<00:06, 30.54it/s]
[Ach 11/250:   7%|████▊                                                              | 15/209 [00:00<00:06, 28.51it/s]
[Ach 11/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 26.96it/s]
[Ach 11/250:  11%|███████                                                            | 22/209 [00:00<00:06, 28.52it/s]
[Ach 11/250:  12%|████████                                                           | 25/209 [00:00<00:06, 28.49it/s]
[Ach 11/250:  13%|████████▉           

  [Epoch 11] Best model saved (G loss = -1.0599)



[Ach 12/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 12/250:   1%|▉                                                                   | 3/209 [00:00<00:06, 29.92it/s]
[Ach 12/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 28.97it/s]
[Ach 12/250:   5%|███▏                                                               | 10/209 [00:00<00:06, 29.78it/s]
[Ach 12/250:   6%|████▏                                                              | 13/209 [00:00<00:06, 29.33it/s]
[Ach 12/250:   8%|█████▏                                                             | 16/209 [00:00<00:06, 28.22it/s]
[Ach 12/250:   9%|██████                                                             | 19/209 [00:00<00:06, 27.50it/s]
[Ach 12/250:  11%|███████                                                            | 22/209 [00:00<00:06, 27.35it/s]
[Ach 12/250:  12%|████████            

  [Epoch 13] Best model saved (G loss = -1.1068)



[Ach 14/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 14/250:   1%|▉                                                                   | 3/209 [00:00<00:09, 20.81it/s]
[Ach 14/250:   3%|█▉                                                                  | 6/209 [00:00<00:08, 25.17it/s]
[Ach 14/250:   4%|██▉                                                                 | 9/209 [00:00<00:08, 24.90it/s]
[Ach 14/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 26.17it/s]
[Ach 14/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 26.05it/s]
[Ach 14/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 26.60it/s]
[Ach 14/250:  10%|██████▋                                                            | 21/209 [00:00<00:06, 27.35it/s]
[Ach 14/250:  11%|███████▋            

  [Epoch 16] Best model saved (G loss = -1.1256)



[Ach 17/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 17/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 23.00it/s]
[Ach 17/250:   3%|█▉                                                                  | 6/209 [00:00<00:08, 24.83it/s]
[Ach 17/250:   5%|███▏                                                               | 10/209 [00:00<00:07, 28.17it/s]
[Ach 17/250:   6%|████▏                                                              | 13/209 [00:00<00:06, 28.10it/s]
[Ach 17/250:   8%|█████▏                                                             | 16/209 [00:00<00:06, 28.67it/s]
[Ach 17/250:   9%|██████                                                             | 19/209 [00:00<00:06, 27.81it/s]
[Ach 17/250:  11%|███████                                                            | 22/209 [00:00<00:06, 27.56it/s]
[Ach 17/250:  12%|████████            

  [Epoch 17] Best model saved (G loss = -1.1474)



[Ach 18/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 18/250:   2%|█▎                                                                  | 4/209 [00:00<00:06, 31.12it/s]
[Ach 18/250:   4%|██▌                                                                 | 8/209 [00:00<00:06, 29.31it/s]
[Ach 18/250:   5%|███▌                                                               | 11/209 [00:00<00:06, 29.53it/s]
[Ach 18/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 27.69it/s]
[Ach 18/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 26.48it/s]
[Ach 18/250:  10%|██████▋                                                            | 21/209 [00:00<00:06, 27.45it/s]
[Ach 18/250:  11%|███████▋                                                           | 24/209 [00:00<00:06, 27.04it/s]
[Ach 18/250:  13%|████████▋           

  [Epoch 19] Best model saved (G loss = -1.1724)



[Ach 20/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 20/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 23.35it/s]
[Ach 20/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 26.52it/s]
[Ach 20/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 25.71it/s]
[Ach 20/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 25.83it/s]
[Ach 20/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 26.95it/s]
[Ach 20/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 25.14it/s]
[Ach 20/250:  11%|███████                                                            | 22/209 [00:00<00:06, 26.96it/s]
[Ach 20/250:  12%|████████            

Epoch [20/250] | D Loss: -0.2704 | G Loss: -1.1826
  [Epoch 20] Best model saved (G loss = -1.1826)



[Ach 21/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 21/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 28.18it/s]
[Ach 21/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 28.38it/s]
[Ach 21/250:   4%|██▉                                                                 | 9/209 [00:00<00:06, 28.97it/s]
[Ach 21/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 26.49it/s]
[Ach 21/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 26.68it/s]
[Ach 21/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 25.98it/s]
[Ach 21/250:  10%|██████▋                                                            | 21/209 [00:00<00:07, 25.99it/s]
[Ach 21/250:  11%|███████▋            

  [Epoch 21] Best model saved (G loss = -1.1830)



[Ach 22/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 22/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 27.20it/s]
[Ach 22/250:   3%|██▎                                                                 | 7/209 [00:00<00:07, 27.08it/s]
[Ach 22/250:   5%|███▏                                                               | 10/209 [00:00<00:07, 27.44it/s]
[Ach 22/250:   6%|████▏                                                              | 13/209 [00:00<00:07, 26.91it/s]
[Ach 22/250:   8%|█████▍                                                             | 17/209 [00:00<00:06, 27.86it/s]
[Ach 22/250:  10%|██████▍                                                            | 20/209 [00:00<00:06, 27.90it/s]
[Ach 22/250:  11%|███████▎                                                           | 23/209 [00:00<00:07, 25.90it/s]
[Ach 22/250:  13%|████████▋           

  [Epoch 29] Best model saved (G loss = -1.1852)



[Ach 30/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 30/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 26.27it/s]
[Ach 30/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 27.76it/s]
[Ach 30/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 28.24it/s]
[Ach 30/250:   6%|████▏                                                              | 13/209 [00:00<00:06, 29.28it/s]
[Ach 30/250:   8%|█████▍                                                             | 17/209 [00:00<00:06, 29.81it/s]
[Ach 30/250:  10%|██████▍                                                            | 20/209 [00:00<00:06, 28.46it/s]
[Ach 30/250:  11%|███████▎                                                           | 23/209 [00:00<00:06, 26.82it/s]
[Ach 30/250:  12%|████████▎           

Epoch [30/250] | D Loss: -0.2347 | G Loss: -1.1738



[Ach 31/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 31/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 25.99it/s]
[Ach 31/250:   3%|█▉                                                                  | 6/209 [00:00<00:08, 24.63it/s]
[Ach 31/250:   4%|██▉                                                                 | 9/209 [00:00<00:08, 23.78it/s]
[Ach 31/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 25.77it/s]
[Ach 31/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 27.08it/s]
[Ach 31/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 26.36it/s]
[Ach 31/250:  10%|██████▋                                                            | 21/209 [00:00<00:07, 25.28it/s]
[Ach 31/250:  11%|███████▋            

  [Epoch 31] Best model saved (G loss = -1.1854)



[Ach 32/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 32/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 25.12it/s]
[Ach 32/250:   3%|█▉                                                                  | 6/209 [00:00<00:08, 23.51it/s]
[Ach 32/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 25.08it/s]
[Ach 32/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 25.36it/s]
[Ach 32/250:   8%|█████▏                                                             | 16/209 [00:00<00:07, 27.33it/s]
[Ach 32/250:  10%|██████▍                                                            | 20/209 [00:00<00:06, 28.09it/s]
[Ach 32/250:  11%|███████▎                                                           | 23/209 [00:00<00:06, 27.14it/s]
[Ach 32/250:  12%|████████▎           

  [Epoch 32] Best model saved (G loss = -1.1884)



[Ach 33/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 33/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 26.61it/s]
[Ach 33/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 27.88it/s]
[Ach 33/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 26.62it/s]
[Ach 33/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 25.38it/s]
[Ach 33/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 25.98it/s]
[Ach 33/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 24.40it/s]
[Ach 33/250:  11%|███████                                                            | 22/209 [00:00<00:06, 27.09it/s]
[Ach 33/250:  12%|████████▎           

  [Epoch 34] Best model saved (G loss = -1.1987)



[Ach 35/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 35/250:   2%|█▎                                                                  | 4/209 [00:00<00:07, 29.15it/s]
[Ach 35/250:   3%|██▎                                                                 | 7/209 [00:00<00:06, 29.57it/s]
[Ach 35/250:   5%|███▌                                                               | 11/209 [00:00<00:06, 28.96it/s]
[Ach 35/250:   7%|████▍                                                              | 14/209 [00:00<00:07, 27.06it/s]
[Ach 35/250:   9%|█████▊                                                             | 18/209 [00:00<00:06, 28.56it/s]
[Ach 35/250:  10%|██████▋                                                            | 21/209 [00:00<00:06, 28.12it/s]
[Ach 35/250:  11%|███████▋                                                           | 24/209 [00:00<00:06, 26.44it/s]
[Ach 35/250:  13%|████████▋           

  [Epoch 37] Best model saved (G loss = -1.2007)



[Ach 38/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 38/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 25.37it/s]
[Ach 38/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 26.58it/s]
[Ach 38/250:   5%|███▏                                                               | 10/209 [00:00<00:07, 27.58it/s]
[Ach 38/250:   6%|████▏                                                              | 13/209 [00:00<00:07, 26.83it/s]
[Ach 38/250:   8%|█████▏                                                             | 16/209 [00:00<00:07, 25.55it/s]
[Ach 38/250:   9%|██████                                                             | 19/209 [00:00<00:07, 25.99it/s]
[Ach 38/250:  11%|███████                                                            | 22/209 [00:00<00:07, 25.46it/s]
[Ach 38/250:  12%|████████▎           

Epoch [40/250] | D Loss: -0.2061 | G Loss: -1.1771



[Ach 41/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 41/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 23.03it/s]
[Ach 41/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 25.82it/s]
[Ach 41/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 25.62it/s]
[Ach 41/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 26.59it/s]
[Ach 41/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 27.45it/s]
[Ach 41/250:   9%|█████▊                                                             | 18/209 [00:00<00:06, 27.95it/s]
[Ach 41/250:  10%|██████▋                                                            | 21/209 [00:00<00:06, 28.31it/s]
[Ach 41/250:  11%|███████▋            

  [Epoch 41] Best model saved (G loss = -1.2402)



[Ach 42/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 42/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 27.70it/s]
[Ach 42/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 27.19it/s]
[Ach 42/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 27.84it/s]
[Ach 42/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 27.79it/s]
[Ach 42/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 25.80it/s]
[Ach 42/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 25.71it/s]
[Ach 42/250:  10%|██████▋                                                            | 21/209 [00:00<00:07, 25.47it/s]
[Ach 42/250:  11%|███████▋            

  [Epoch 49] Best model saved (G loss = -1.2522)



[Ach 50/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 50/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 28.21it/s]
[Ach 50/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 27.64it/s]
[Ach 50/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 26.36it/s]
[Ach 50/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 26.88it/s]
[Ach 50/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 26.47it/s]
[Ach 50/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 26.33it/s]
[Ach 50/250:  10%|██████▋                                                            | 21/209 [00:00<00:06, 27.22it/s]
[Ach 50/250:  11%|███████▋            

Epoch [50/250] | D Loss: -0.1896 | G Loss: -1.2156
  [Epoch 50] Checkpoint saved.



[Ach 51/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 51/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 23.22it/s]
[Ach 51/250:   3%|█▉                                                                  | 6/209 [00:00<00:08, 24.43it/s]
[Ach 51/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 25.59it/s]
[Ach 51/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 26.38it/s]
[Ach 51/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 26.75it/s]
[Ach 51/250:   9%|██████                                                             | 19/209 [00:00<00:06, 28.12it/s]
[Ach 51/250:  11%|███████                                                            | 22/209 [00:00<00:06, 28.35it/s]
[Ach 51/250:  12%|████████            

  [Epoch 53] Best model saved (G loss = -1.2734)



[Ach 54/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 54/250:   1%|▉                                                                   | 3/209 [00:00<00:09, 22.26it/s]
[Ach 54/250:   3%|██▎                                                                 | 7/209 [00:00<00:07, 27.80it/s]
[Ach 54/250:   5%|███▌                                                               | 11/209 [00:00<00:07, 28.25it/s]
[Ach 54/250:   7%|████▍                                                              | 14/209 [00:00<00:07, 27.65it/s]
[Ach 54/250:   8%|█████▍                                                             | 17/209 [00:00<00:06, 27.80it/s]
[Ach 54/250:  10%|██████▍                                                            | 20/209 [00:00<00:06, 27.27it/s]
[Ach 54/250:  11%|███████▎                                                           | 23/209 [00:00<00:06, 26.75it/s]
[Ach 54/250:  12%|████████▎           

Epoch [60/250] | D Loss: -0.1811 | G Loss: -1.2724



[Ach 61/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 61/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 24.38it/s]
[Ach 61/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 27.25it/s]
[Ach 61/250:   4%|██▉                                                                 | 9/209 [00:00<00:08, 24.29it/s]
[Ach 61/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 24.91it/s]
[Ach 61/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 25.12it/s]
[Ach 61/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 24.96it/s]
[Ach 61/250:  11%|███████                                                            | 22/209 [00:00<00:06, 26.97it/s]
[Ach 61/250:  12%|████████            

  [Epoch 63] Best model saved (G loss = -1.2840)



[Ach 64/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 64/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 25.18it/s]
[Ach 64/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 27.60it/s]
[Ach 64/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 26.96it/s]
[Ach 64/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 27.47it/s]
[Ach 64/250:   8%|█████▏                                                             | 16/209 [00:00<00:06, 28.74it/s]
[Ach 64/250:   9%|██████                                                             | 19/209 [00:00<00:06, 28.63it/s]
[Ach 64/250:  11%|███████                                                            | 22/209 [00:00<00:06, 28.85it/s]
[Ach 64/250:  12%|████████            

  [Epoch 64] Best model saved (G loss = -1.2985)



[Ach 65/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 65/250:   2%|█▎                                                                  | 4/209 [00:00<00:06, 29.46it/s]
[Ach 65/250:   3%|██▎                                                                 | 7/209 [00:00<00:07, 26.96it/s]
[Ach 65/250:   5%|███▏                                                               | 10/209 [00:00<00:07, 25.56it/s]
[Ach 65/250:   6%|████▏                                                              | 13/209 [00:00<00:07, 26.08it/s]
[Ach 65/250:   8%|█████▏                                                             | 16/209 [00:00<00:07, 25.50it/s]
[Ach 65/250:   9%|██████                                                             | 19/209 [00:00<00:07, 25.74it/s]
[Ach 65/250:  11%|███████                                                            | 22/209 [00:00<00:07, 26.54it/s]
[Ach 65/250:  12%|████████            

Epoch [70/250] | D Loss: -0.1733 | G Loss: -1.3024
  [Epoch 70] Best model saved (G loss = -1.3024)



[Ach 71/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 71/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 27.40it/s]
[Ach 71/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 27.01it/s]
[Ach 71/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 27.36it/s]
[Ach 71/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 27.53it/s]
[Ach 71/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 26.33it/s]
[Ach 71/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 26.36it/s]
[Ach 71/250:  10%|██████▋                                                            | 21/209 [00:00<00:07, 26.28it/s]
[Ach 71/250:  11%|███████▋            

  [Epoch 71] Best model saved (G loss = -1.3037)



[Ach 72/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 72/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 23.57it/s]
[Ach 72/250:   3%|██▎                                                                 | 7/209 [00:00<00:07, 27.20it/s]
[Ach 72/250:   5%|███▏                                                               | 10/209 [00:00<00:07, 25.93it/s]
[Ach 72/250:   7%|████▍                                                              | 14/209 [00:00<00:06, 28.27it/s]
[Ach 72/250:   9%|█████▊                                                             | 18/209 [00:00<00:06, 28.03it/s]
[Ach 72/250:  10%|██████▋                                                            | 21/209 [00:00<00:06, 28.35it/s]
[Ach 72/250:  11%|███████▋                                                           | 24/209 [00:00<00:06, 28.30it/s]
[Ach 72/250:  13%|████████▋           

  [Epoch 73] Best model saved (G loss = -1.3500)



[Ach 74/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 74/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 24.17it/s]
[Ach 74/250:   3%|██▎                                                                 | 7/209 [00:00<00:07, 27.91it/s]
[Ach 74/250:   5%|███▏                                                               | 10/209 [00:00<00:06, 28.59it/s]
[Ach 74/250:   6%|████▏                                                              | 13/209 [00:00<00:07, 27.39it/s]
[Ach 74/250:   8%|█████▍                                                             | 17/209 [00:00<00:07, 27.16it/s]
[Ach 74/250:  10%|██████▍                                                            | 20/209 [00:00<00:06, 27.40it/s]
[Ach 74/250:  11%|███████▎                                                           | 23/209 [00:00<00:06, 27.78it/s]
[Ach 74/250:  12%|████████▎           

  [Epoch 75] Best model saved (G loss = -1.3949)



[Ach 76/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 76/250:   1%|▉                                                                   | 3/209 [00:00<00:08, 25.16it/s]
[Ach 76/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 26.12it/s]
[Ach 76/250:   5%|███▏                                                               | 10/209 [00:00<00:07, 27.97it/s]
[Ach 76/250:   6%|████▏                                                              | 13/209 [00:00<00:07, 27.25it/s]
[Ach 76/250:   8%|█████▏                                                             | 16/209 [00:00<00:07, 26.22it/s]
[Ach 76/250:   9%|██████                                                             | 19/209 [00:00<00:07, 25.55it/s]
[Ach 76/250:  11%|███████                                                            | 22/209 [00:00<00:07, 26.36it/s]
[Ach 76/250:  12%|████████            

Epoch [80/250] | D Loss: -0.1490 | G Loss: -1.3836



[Ach 81/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 81/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 26.26it/s]
[Ach 81/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 26.21it/s]
[Ach 81/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 26.78it/s]
[Ach 81/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 26.82it/s]
[Ach 81/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 27.58it/s]
[Ach 81/250:   9%|█████▊                                                             | 18/209 [00:00<00:06, 27.90it/s]
[Ach 81/250:  11%|███████                                                            | 22/209 [00:00<00:06, 28.27it/s]
[Ach 81/250:  12%|████████▎           

  [Epoch 81] Best model saved (G loss = -1.4043)



[Ach 82/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 82/250:   1%|▉                                                                   | 3/209 [00:00<00:09, 21.75it/s]
[Ach 82/250:   3%|█▉                                                                  | 6/209 [00:00<00:08, 24.71it/s]
[Ach 82/250:   4%|██▉                                                                 | 9/209 [00:00<00:08, 24.21it/s]
[Ach 82/250:   6%|███▊                                                               | 12/209 [00:00<00:07, 24.99it/s]
[Ach 82/250:   7%|████▊                                                              | 15/209 [00:00<00:07, 25.35it/s]
[Ach 82/250:   9%|█████▊                                                             | 18/209 [00:00<00:07, 26.74it/s]
[Ach 82/250:  10%|██████▋                                                            | 21/209 [00:00<00:06, 27.10it/s]
[Ach 82/250:  11%|███████▋            

  [Epoch 82] Best model saved (G loss = -1.4202)



[Ach 83/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 83/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 27.97it/s]
[Ach 83/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 25.49it/s]
[Ach 83/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 26.22it/s]
[Ach 83/250:   6%|████▏                                                              | 13/209 [00:00<00:06, 28.35it/s]
[Ach 83/250:   8%|█████▏                                                             | 16/209 [00:00<00:07, 27.50it/s]
[Ach 83/250:   9%|██████                                                             | 19/209 [00:00<00:06, 27.59it/s]
[Ach 83/250:  11%|███████▎                                                           | 23/209 [00:00<00:06, 28.72it/s]
[Ach 83/250:  12%|████████▎           

  [Epoch 89] Best model saved (G loss = -1.4283)



[Ach 90/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 90/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 28.41it/s]
[Ach 90/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 26.08it/s]
[Ach 90/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 26.06it/s]
[Ach 90/250:   6%|████▏                                                              | 13/209 [00:00<00:07, 27.81it/s]
[Ach 90/250:   8%|█████▏                                                             | 16/209 [00:00<00:07, 27.28it/s]
[Ach 90/250:   9%|██████                                                             | 19/209 [00:00<00:06, 27.51it/s]
[Ach 90/250:  11%|███████                                                            | 22/209 [00:00<00:07, 25.92it/s]
[Ach 90/250:  12%|████████            

Epoch [90/250] | D Loss: -0.1437 | G Loss: -1.3664



[Ach 91/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 91/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 26.52it/s]
[Ach 91/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 26.29it/s]
[Ach 91/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 27.01it/s]
[Ach 91/250:   6%|████▏                                                              | 13/209 [00:00<00:06, 28.55it/s]
[Ach 91/250:   8%|█████▏                                                             | 16/209 [00:00<00:07, 26.61it/s]
[Ach 91/250:   9%|██████                                                             | 19/209 [00:00<00:07, 26.96it/s]
[Ach 91/250:  11%|███████                                                            | 22/209 [00:00<00:07, 26.09it/s]
[Ach 91/250:  12%|████████            

  [Epoch 93] Best model saved (G loss = -1.4535)



[Ach 94/250:   0%|                                                                            | 0/209 [00:00<?, ?it/s]
[Ach 94/250:   1%|▉                                                                   | 3/209 [00:00<00:07, 26.33it/s]
[Ach 94/250:   3%|█▉                                                                  | 6/209 [00:00<00:07, 27.24it/s]
[Ach 94/250:   4%|██▉                                                                 | 9/209 [00:00<00:07, 28.22it/s]
[Ach 94/250:   6%|███▊                                                               | 12/209 [00:00<00:06, 28.21it/s]
[Ach 94/250:   8%|█████▏                                                             | 16/209 [00:00<00:06, 29.98it/s]
[Ach 94/250:   9%|██████                                                             | 19/209 [00:00<00:07, 25.98it/s]
[Ach 94/250:  11%|███████                                                            | 22/209 [00:00<00:07, 26.66it/s]
[Ach 94/250:  12%|████████            

Epoch [100/250] | D Loss: -0.1401 | G Loss: -1.4083
  [Epoch 100] Checkpoint saved.



[Ach 101/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 101/250:   1%|▉                                                                  | 3/209 [00:00<00:09, 22.06it/s]
[Ach 101/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 26.12it/s]
[Ach 101/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 26.11it/s]
[Ach 101/250:   6%|███▊                                                              | 12/209 [00:00<00:07, 26.47it/s]
[Ach 101/250:   7%|████▋                                                             | 15/209 [00:00<00:07, 26.50it/s]
[Ach 101/250:   9%|█████▋                                                            | 18/209 [00:00<00:07, 26.32it/s]
[Ach 101/250:  10%|██████▋                                                           | 21/209 [00:00<00:07, 26.32it/s]
[Ach 101/250:  11%|███████▌           

Epoch [110/250] | D Loss: -0.1375 | G Loss: -1.3247



[Ach 111/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 111/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 27.34it/s]
[Ach 111/250:   3%|██▏                                                                | 7/209 [00:00<00:07, 28.24it/s]
[Ach 111/250:   5%|███▏                                                              | 10/209 [00:00<00:07, 25.72it/s]
[Ach 111/250:   6%|████                                                              | 13/209 [00:00<00:07, 25.16it/s]
[Ach 111/250:   8%|█████                                                             | 16/209 [00:00<00:07, 26.13it/s]
[Ach 111/250:  10%|██████▎                                                           | 20/209 [00:00<00:06, 27.75it/s]
[Ach 111/250:  11%|███████▎                                                          | 23/209 [00:00<00:06, 26.74it/s]
[Ach 111/250:  12%|████████▏          

Epoch [120/250] | D Loss: -0.1338 | G Loss: -1.3046



[Ach 121/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 121/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 25.86it/s]
[Ach 121/250:   3%|█▉                                                                 | 6/209 [00:00<00:08, 24.24it/s]
[Ach 121/250:   5%|███▏                                                              | 10/209 [00:00<00:07, 27.23it/s]
[Ach 121/250:   6%|████                                                              | 13/209 [00:00<00:07, 26.91it/s]
[Ach 121/250:   8%|█████                                                             | 16/209 [00:00<00:07, 26.57it/s]
[Ach 121/250:   9%|██████                                                            | 19/209 [00:00<00:07, 27.08it/s]
[Ach 121/250:  11%|██████▉                                                           | 22/209 [00:00<00:07, 26.13it/s]
[Ach 121/250:  12%|███████▉           

Epoch [130/250] | D Loss: -0.1309 | G Loss: -1.3206



[Ach 131/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 131/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 26.94it/s]
[Ach 131/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 27.91it/s]
[Ach 131/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 28.02it/s]
[Ach 131/250:   6%|███▊                                                              | 12/209 [00:00<00:07, 28.03it/s]
[Ach 131/250:   7%|████▋                                                             | 15/209 [00:00<00:06, 28.09it/s]
[Ach 131/250:   9%|█████▋                                                            | 18/209 [00:00<00:06, 28.08it/s]
[Ach 131/250:  11%|██████▉                                                           | 22/209 [00:00<00:06, 28.87it/s]
[Ach 131/250:  12%|███████▉           

Epoch [140/250] | D Loss: -0.1292 | G Loss: -1.3726



[Ach 141/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 141/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 27.64it/s]
[Ach 141/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 25.81it/s]
[Ach 141/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 27.28it/s]
[Ach 141/250:   6%|███▊                                                              | 12/209 [00:00<00:06, 28.27it/s]
[Ach 141/250:   7%|████▋                                                             | 15/209 [00:00<00:06, 28.83it/s]
[Ach 141/250:   9%|█████▋                                                            | 18/209 [00:00<00:07, 27.12it/s]
[Ach 141/250:  11%|██████▉                                                           | 22/209 [00:00<00:06, 28.37it/s]
[Ach 141/250:  12%|████████▏          

Epoch [150/250] | D Loss: -0.1277 | G Loss: -1.3743
  [Epoch 150] Checkpoint saved.



[Ach 151/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 151/250:   1%|▉                                                                  | 3/209 [00:00<00:08, 24.80it/s]
[Ach 151/250:   3%|██▏                                                                | 7/209 [00:00<00:07, 27.67it/s]
[Ach 151/250:   5%|███▏                                                              | 10/209 [00:00<00:07, 25.09it/s]
[Ach 151/250:   6%|████                                                              | 13/209 [00:00<00:07, 26.06it/s]
[Ach 151/250:   8%|█████                                                             | 16/209 [00:00<00:07, 26.70it/s]
[Ach 151/250:   9%|██████                                                            | 19/209 [00:00<00:07, 25.85it/s]
[Ach 151/250:  11%|██████▉                                                           | 22/209 [00:00<00:07, 26.55it/s]
[Ach 151/250:  12%|████████▏          

Epoch [160/250] | D Loss: -0.1203 | G Loss: -1.4173



[Ach 161/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 161/250:   1%|▉                                                                  | 3/209 [00:00<00:08, 24.52it/s]
[Ach 161/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 26.03it/s]
[Ach 161/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 25.27it/s]
[Ach 161/250:   6%|███▊                                                              | 12/209 [00:00<00:07, 25.65it/s]
[Ach 161/250:   7%|████▋                                                             | 15/209 [00:00<00:07, 26.61it/s]
[Ach 161/250:   9%|█████▋                                                            | 18/209 [00:00<00:07, 26.42it/s]
[Ach 161/250:  10%|██████▋                                                           | 21/209 [00:00<00:07, 26.31it/s]
[Ach 161/250:  12%|███████▉           

Epoch [170/250] | D Loss: -0.1075 | G Loss: -1.4129



[Ach 171/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 171/250:   2%|█▎                                                                 | 4/209 [00:00<00:07, 28.63it/s]
[Ach 171/250:   4%|██▌                                                                | 8/209 [00:00<00:06, 29.62it/s]
[Ach 171/250:   5%|███▍                                                              | 11/209 [00:00<00:07, 28.26it/s]
[Ach 171/250:   7%|████▍                                                             | 14/209 [00:00<00:06, 28.46it/s]
[Ach 171/250:   8%|█████▎                                                            | 17/209 [00:00<00:06, 28.39it/s]
[Ach 171/250:  10%|██████▎                                                           | 20/209 [00:00<00:06, 27.51it/s]
[Ach 171/250:  11%|███████▌                                                          | 24/209 [00:00<00:06, 28.14it/s]
[Ach 171/250:  13%|████████▌          

  [Epoch 171] Best model saved (G loss = -1.4602)



[Ach 172/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 172/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 26.03it/s]
[Ach 172/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 26.02it/s]
[Ach 172/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 26.17it/s]
[Ach 172/250:   6%|███▊                                                              | 12/209 [00:00<00:07, 26.99it/s]
[Ach 172/250:   7%|████▋                                                             | 15/209 [00:00<00:07, 26.61it/s]
[Ach 172/250:   9%|█████▋                                                            | 18/209 [00:00<00:07, 26.51it/s]
[Ach 172/250:  10%|██████▋                                                           | 21/209 [00:00<00:07, 26.74it/s]
[Ach 172/250:  11%|███████▌           

  [Epoch 176] Best model saved (G loss = -1.4828)



[Ach 177/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 177/250:   1%|▉                                                                  | 3/209 [00:00<00:08, 25.62it/s]
[Ach 177/250:   3%|█▉                                                                 | 6/209 [00:00<00:08, 25.10it/s]
[Ach 177/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 25.29it/s]
[Ach 177/250:   6%|███▊                                                              | 12/209 [00:00<00:07, 26.31it/s]
[Ach 177/250:   7%|████▋                                                             | 15/209 [00:00<00:07, 25.56it/s]
[Ach 177/250:   9%|█████▋                                                            | 18/209 [00:00<00:07, 25.73it/s]
[Ach 177/250:  10%|██████▋                                                           | 21/209 [00:00<00:07, 25.72it/s]
[Ach 177/250:  11%|███████▌           

  [Epoch 177] Best model saved (G loss = -1.5503)



[Ach 178/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 178/250:   1%|▉                                                                  | 3/209 [00:00<00:08, 25.57it/s]
[Ach 178/250:   3%|██▏                                                                | 7/209 [00:00<00:07, 28.76it/s]
[Ach 178/250:   5%|███▏                                                              | 10/209 [00:00<00:07, 28.07it/s]
[Ach 178/250:   6%|████                                                              | 13/209 [00:00<00:07, 27.28it/s]
[Ach 178/250:   8%|█████▎                                                            | 17/209 [00:00<00:06, 27.72it/s]
[Ach 178/250:  10%|██████▋                                                           | 21/209 [00:00<00:06, 28.14it/s]
[Ach 178/250:  11%|███████▌                                                          | 24/209 [00:00<00:06, 28.11it/s]
[Ach 178/250:  13%|████████▌          

Epoch [180/250] | D Loss: -0.0925 | G Loss: -1.5619
  [Epoch 180] Best model saved (G loss = -1.5619)



[Ach 181/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 181/250:   2%|█▎                                                                 | 4/209 [00:00<00:07, 26.97it/s]
[Ach 181/250:   3%|██▏                                                                | 7/209 [00:00<00:07, 26.50it/s]
[Ach 181/250:   5%|███▏                                                              | 10/209 [00:00<00:07, 27.13it/s]
[Ach 181/250:   6%|████                                                              | 13/209 [00:00<00:07, 26.81it/s]
[Ach 181/250:   8%|█████                                                             | 16/209 [00:00<00:07, 27.35it/s]
[Ach 181/250:   9%|██████                                                            | 19/209 [00:00<00:06, 27.45it/s]
[Ach 181/250:  11%|██████▉                                                           | 22/209 [00:00<00:06, 27.93it/s]
[Ach 181/250:  12%|███████▉           

  [Epoch 181] Best model saved (G loss = -1.6009)



[Ach 182/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 182/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 28.28it/s]
[Ach 182/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 28.48it/s]
[Ach 182/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 28.33it/s]
[Ach 182/250:   6%|████                                                              | 13/209 [00:00<00:06, 29.39it/s]
[Ach 182/250:   8%|█████                                                             | 16/209 [00:00<00:06, 28.95it/s]
[Ach 182/250:  10%|██████▎                                                           | 20/209 [00:00<00:06, 29.46it/s]
[Ach 182/250:  11%|███████▌                                                          | 24/209 [00:00<00:06, 30.36it/s]
[Ach 182/250:  13%|████████▊          

Epoch [190/250] | D Loss: -0.0910 | G Loss: -1.5789



[Ach 191/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 191/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 26.36it/s]
[Ach 191/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 27.42it/s]
[Ach 191/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 26.45it/s]
[Ach 191/250:   6%|███▊                                                              | 12/209 [00:00<00:07, 26.64it/s]
[Ach 191/250:   7%|████▋                                                             | 15/209 [00:00<00:07, 27.47it/s]
[Ach 191/250:   9%|█████▋                                                            | 18/209 [00:00<00:06, 27.32it/s]
[Ach 191/250:  10%|██████▋                                                           | 21/209 [00:00<00:06, 27.30it/s]
[Ach 191/250:  11%|███████▌           

Epoch [200/250] | D Loss: -0.0892 | G Loss: -1.5319
  [Epoch 200] Checkpoint saved.



[Ach 201/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 201/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 26.93it/s]
[Ach 201/250:   3%|██▏                                                                | 7/209 [00:00<00:06, 29.13it/s]
[Ach 201/250:   5%|███▍                                                              | 11/209 [00:00<00:06, 30.81it/s]
[Ach 201/250:   7%|████▋                                                             | 15/209 [00:00<00:06, 29.69it/s]
[Ach 201/250:   9%|█████▋                                                            | 18/209 [00:00<00:06, 28.38it/s]
[Ach 201/250:  10%|██████▋                                                           | 21/209 [00:00<00:06, 28.85it/s]
[Ach 201/250:  11%|███████▌                                                          | 24/209 [00:00<00:06, 28.51it/s]
[Ach 201/250:  13%|████████▌          

Epoch [210/250] | D Loss: -0.0887 | G Loss: -1.5473



[Ach 211/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 211/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 29.04it/s]
[Ach 211/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 27.12it/s]
[Ach 211/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 27.28it/s]
[Ach 211/250:   6%|███▊                                                              | 12/209 [00:00<00:07, 27.77it/s]
[Ach 211/250:   8%|█████                                                             | 16/209 [00:00<00:07, 26.82it/s]
[Ach 211/250:   9%|██████                                                            | 19/209 [00:00<00:07, 25.94it/s]
[Ach 211/250:  11%|██████▉                                                           | 22/209 [00:00<00:06, 26.85it/s]
[Ach 211/250:  12%|███████▉           

  [Epoch 211] Best model saved (G loss = -1.6011)



[Ach 212/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 212/250:   2%|█▎                                                                 | 4/209 [00:00<00:06, 30.30it/s]
[Ach 212/250:   4%|██▌                                                                | 8/209 [00:00<00:07, 28.70it/s]
[Ach 212/250:   6%|███▊                                                              | 12/209 [00:00<00:06, 30.09it/s]
[Ach 212/250:   8%|█████                                                             | 16/209 [00:00<00:06, 28.15it/s]
[Ach 212/250:   9%|██████                                                            | 19/209 [00:00<00:06, 28.20it/s]
[Ach 212/250:  11%|██████▉                                                           | 22/209 [00:00<00:06, 27.47it/s]
[Ach 212/250:  12%|███████▉                                                          | 25/209 [00:00<00:06, 28.00it/s]
[Ach 212/250:  14%|█████████▏         

  [Epoch 216] Best model saved (G loss = -1.6026)



[Ach 217/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 217/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 25.88it/s]
[Ach 217/250:   3%|█▉                                                                 | 6/209 [00:00<00:08, 25.07it/s]
[Ach 217/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 26.41it/s]
[Ach 217/250:   6%|███▊                                                              | 12/209 [00:00<00:07, 25.57it/s]
[Ach 217/250:   7%|████▋                                                             | 15/209 [00:00<00:07, 26.02it/s]
[Ach 217/250:   9%|█████▋                                                            | 18/209 [00:00<00:07, 27.13it/s]
[Ach 217/250:  10%|██████▋                                                           | 21/209 [00:00<00:07, 26.78it/s]
[Ach 217/250:  11%|███████▌           

  [Epoch 218] Best model saved (G loss = -1.6299)



[Ach 219/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 219/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 28.40it/s]
[Ach 219/250:   3%|██▏                                                                | 7/209 [00:00<00:06, 30.11it/s]
[Ach 219/250:   5%|███▏                                                              | 10/209 [00:00<00:06, 29.40it/s]
[Ach 219/250:   6%|████                                                              | 13/209 [00:00<00:06, 28.24it/s]
[Ach 219/250:   8%|█████                                                             | 16/209 [00:00<00:07, 27.43it/s]
[Ach 219/250:   9%|██████                                                            | 19/209 [00:00<00:06, 28.19it/s]
[Ach 219/250:  11%|███████▎                                                          | 23/209 [00:00<00:06, 28.45it/s]
[Ach 219/250:  12%|████████▏          

Epoch [220/250] | D Loss: -0.0850 | G Loss: -1.5963



[Ach 221/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 221/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 28.91it/s]
[Ach 221/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 27.22it/s]
[Ach 221/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 27.17it/s]
[Ach 221/250:   6%|███▊                                                              | 12/209 [00:00<00:07, 27.49it/s]
[Ach 221/250:   7%|████▋                                                             | 15/209 [00:00<00:06, 27.78it/s]
[Ach 221/250:   9%|█████▋                                                            | 18/209 [00:00<00:06, 28.00it/s]
[Ach 221/250:  10%|██████▋                                                           | 21/209 [00:00<00:06, 27.32it/s]
[Ach 221/250:  11%|███████▌           

  [Epoch 228] Best model saved (G loss = -1.6316)



[Ach 229/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 229/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 27.33it/s]
[Ach 229/250:   3%|█▉                                                                 | 6/209 [00:00<00:09, 22.40it/s]
[Ach 229/250:   4%|██▉                                                                | 9/209 [00:00<00:08, 24.23it/s]
[Ach 229/250:   6%|████                                                              | 13/209 [00:00<00:07, 26.77it/s]
[Ach 229/250:   8%|█████▎                                                            | 17/209 [00:00<00:06, 28.90it/s]
[Ach 229/250:  10%|██████▎                                                           | 20/209 [00:00<00:07, 26.76it/s]
[Ach 229/250:  11%|███████▎                                                          | 23/209 [00:00<00:07, 26.12it/s]
[Ach 229/250:  12%|████████▏          

Epoch [230/250] | D Loss: -0.0839 | G Loss: -1.6060



[Ach 231/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 231/250:   1%|▉                                                                  | 3/209 [00:00<00:08, 25.31it/s]
[Ach 231/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 26.19it/s]
[Ach 231/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 26.95it/s]
[Ach 231/250:   6%|████                                                              | 13/209 [00:00<00:07, 27.71it/s]
[Ach 231/250:   8%|█████▎                                                            | 17/209 [00:00<00:06, 28.19it/s]
[Ach 231/250:  10%|██████▎                                                           | 20/209 [00:00<00:07, 26.82it/s]
[Ach 231/250:  11%|███████▎                                                          | 23/209 [00:00<00:06, 27.17it/s]
[Ach 231/250:  12%|████████▏          

Epoch [240/250] | D Loss: -0.0845 | G Loss: -1.5939



[Ach 241/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 241/250:   2%|█▎                                                                 | 4/209 [00:00<00:07, 28.83it/s]
[Ach 241/250:   4%|██▌                                                                | 8/209 [00:00<00:06, 30.12it/s]
[Ach 241/250:   6%|███▊                                                              | 12/209 [00:00<00:06, 29.47it/s]
[Ach 241/250:   8%|█████                                                             | 16/209 [00:00<00:06, 29.05it/s]
[Ach 241/250:   9%|██████                                                            | 19/209 [00:00<00:06, 28.32it/s]
[Ach 241/250:  11%|███████▎                                                          | 23/209 [00:00<00:06, 28.47it/s]
[Ach 241/250:  13%|████████▌                                                         | 27/209 [00:00<00:06, 29.09it/s]
[Ach 241/250:  14%|█████████▍         

  [Epoch 245] Best model saved (G loss = -1.6354)



[Ach 246/250:   0%|                                                                           | 0/209 [00:00<?, ?it/s]
[Ach 246/250:   1%|▉                                                                  | 3/209 [00:00<00:07, 28.34it/s]
[Ach 246/250:   3%|█▉                                                                 | 6/209 [00:00<00:07, 25.99it/s]
[Ach 246/250:   4%|██▉                                                                | 9/209 [00:00<00:07, 26.95it/s]
[Ach 246/250:   6%|███▊                                                              | 12/209 [00:00<00:07, 25.37it/s]
[Ach 246/250:   8%|█████                                                             | 16/209 [00:00<00:07, 26.60it/s]
[Ach 246/250:   9%|██████                                                            | 19/209 [00:00<00:07, 26.37it/s]
[Ach 246/250:  11%|██████▉                                                           | 22/209 [00:00<00:07, 25.76it/s]
[Ach 246/250:  12%|███████▉           

Epoch [250/250] | D Loss: -0.0841 | G Loss: -1.6123
  [Epoch 250] Checkpoint saved.

Training complete.
Total training time: 1933.91 seconds





In [5]:
generator = Generator(latent_dim, hidden_dims, cat_dims).to(device)
generator.load_state_dict(torch.load("Vanilla-WGAN/generator_best.pth", map_location=device))
generator.eval()  # Set to evaluation mode

Generator(
  (hidden): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=256, out_features=256, bias=True)
    (4): ReLU()
    (5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): ReLU()
    (8): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (heads): ModuleList(
    (0): Linear(in_features=256, out_features=17, bias=True)
    (1): Linear(in_features=256, out_features=2, bias=True)
    (2): Linear(in_features=256, out_features=5, bias=True)
    (3): Linear(in_features=256, out_features=6, bias=True)
    (4-5): 2 x Linear(in_features=256, out_features=2, bias=True)
    (6): Linear(in_features=256, out_features=4, bias=True)
    (7): Linear(in_features=256, out_features=9, bias=True)
    (8)

In [6]:
# Compute categorical boundaries
cat_boundaries = [0] + list(np.cumsum(cat_dims))

evaluate_coverage(
    x_population=x_population,
    x_sample=x_train_split, 
    generator=generator,
    cat_boundaries=cat_boundaries,
    enc=enc,
    device=device,
    latent_dim=latent_dim,
    batch_size=256
)


Generating: 100%|██████████████████████████████████████████████████████████| 1066319/1066319 [07:00<00:00, 2538.29it/s]


{'stz_ratio': 0.1862,
 'saz_ratio': 0.2942,
 'srmse_marginal': np.float64(0.032534788830581494),
 'srmse_bivariate': np.float64(0.09533603745985905),
 'precision': 0.8138,
 'recall': 0.8089,
 'f1_score': 0.8113,
 'unique_combinations': {'population': 264005, 'generated': 264006},
 'matching_combinations': {'unique_types': 111932, 'total_count': 867786},
 'infer_time_sec': 426.89}

In [None]:
# Set device (cuda or cpu)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define latent_dim, hidden_dims, cat_dims, etc. (should match those used during training)
latent_dim = 256
hidden_dims = [256, 256, 256]
cat_dims = [x_population[col].nunique() for col in x_population.columns]

cat_boundaries = [0] + list(np.cumsum(cat_dims))


# Re-create the generator instance and load the best saved model weights
generator = Generator(latent_dim, hidden_dims, cat_dims).to(device)
generator.load_state_dict(torch.load("Vanilla-WGAN/generator_best.pth", map_location=device))
generator.eval()  # Set to evaluation mode

# Number of synthetic samples to generate (for example, generate as many as in the full population)
num_samples = x_population.shape[0]  # or any desired number

# Create a list to store DataFrame batches
generated_batches = []
batch_size = 256

# Generate synthetic data in batches
with torch.no_grad():
    for i in range(0, num_samples, batch_size):
        current_batch = min(batch_size, num_samples - i)
        # Sample random noise from a standard normal distribution: 
        # mathematically, z ~ N(0, I), where z ∈ ℝ^(latent_dim)
        noise = torch.randn(current_batch, latent_dim, device=device)
        # Pass noise through generator to get softmax probabilities:
        # h = f(noise) and for each categorical head, output probabilities: 
        # \(\hat{y} = \text{softmax}(W h + b)\)
        fake_output = generator(noise).cpu().numpy()
        # Convert wide one-hot outputs to categorical values using the helper function
        df_batch = wide_to_long(fake_output, cat_boundaries, enc)
        generated_batches.append(df_batch)

# Concatenate all batches into a single DataFrame
generated_data = pd.concat(generated_batches, ignore_index=True)


Synthetic data generated and saved to 'generated_synthetic_data.csv'


In [8]:
decoded = generated_data.copy()

for col in decoded.columns:
    le = label_encoders[col]
    decoded[col] = le.inverse_transform( decoded[col].astype(int).values )

print(decoded.head())


       Age  Gender  Homeincome  Hometype  CarOwn  Driver  Workdays  Worktype  \
0  [45,50)       1           3         1       1       1         2         6   
1  [40,45)       1           2         1       1       1         1         8   
2  [35,40)       2           2         4       1       1         4         4   
3  [70,75)       2           1         1       1       2         4         2   
4  [55,60)       1           3         1       1       1         3         2   

   Student  NumHH  KidinHH                ComMode     ComTime  
0        4      3        2                    Car        Peak  
1        4      3        2  Public Transportation        Peak  
2        4      4        1             No commute  No commute  
3        4      3        1             No commute  No commute  
4        4      3        2                Walking     NonPeak  


In [None]:
decoded.to_csv("generated_synthetic_data_WGAN.csv", index=False)