<center>
  
# TABSYN: Tabular Data Synthesis with Diffusion Models

</center>

Two challenges regarding the extention of diffusion models to tabular data are:
1. **Diverse data types:** a single table can have different columns each containing data of different types, including numerical, categorical, text, etc.
2. **Varied distributions:** the distribution of data under different columns in a single table varry widely from column to column.

**TabSyn** addresses these challenges by introducing a latent space where tabular data of all columns are jointly represented. It then proceedes to train a diffusion model on the latent representations.
This tactic allows TabSyn to:
1. Train a single diffusion model for all data types in the dataset (i.e. Generality).
2. Optimize the distribution of latent embeddings to facilitate training of the subsequent diffusion model, thus generating higher quality synthetic data (i.e. Quality).
3. Require much fewer reverse steps during training of the diffusion model, and synthesize data faster (i.e. Speed).

In this notebook, we review and implement the TabSyn model. The notebook is organized as follows:

1. [Imports and Setup]()


2. [Berka Dataset]()
    
    
3. [TabSyn Algorithm]()
    
    3.1. [Load Config]()
    
    3.2. [Make Dataset]()
    
    3.3. [Instantiate Model]()
    
    3.4. [Train Model]()
        
    3.5. [Load Pretrained Model]()
    
    3.6. [Sample Data]()
    
    3.7. [Review Synthetic Data]()


# Imports and Setup

In this section, we import all necessary libraries and modules required for setting up the environment.

In [34]:
import os
import json
import pandas as pd
from pprint import pprint

import torch
from torch.utils.data import DataLoader

from midst_models.single_table_TabSyn.scripts.process_dataset import process_data

from midst_models.single_table_TabSyn.src.data import preprocess, TabularDataset
from midst_models.single_table_TabSyn.src.tabsyn.pipeline import TabSyn
from midst_models.single_table_TabSyn.src import load_config

# Berka Dataset

In this section, we will process the Transactions table from the Berka dataset. You can access the Berka dataset files for TabSyn [here](https://drive.google.com/drive/folders/18KHv3VQuRphMHqZQsQc-x2ALoIiAggA0?usp=drive_link).
The BERKA dataset is a comprehensive banking dataset originally released by the Czech bank ČSOB for the Financial Modeling and Analysis (FMA) competition in 1999. It provides detailed financial data on transactions, accounts, loans, credit cards, and demographic information for thousands of customers over multiple years.

Download the data files from the link above and place the train set in the `RAW_DATA_DIR` directory.
Note that the id columns (columns ending in "_id") should be removed from the training and test data.

Data info files are required for running the scripts. Sample info file for the transaction data is available in `data_info/trans.json`. The paths for the training and test data in the file can be modified as needed.

In [35]:
from midst_models.single_table_TabSyn.scripts.process_dataset import get_column_name_mapping, train_val_test_split
import numpy as np


INFO_DIR = "data_info"

# DATA_DIR = "data/"
SOURCE_DATA_DIR = "../../starter_kits/tabsyn_white_box/train_overfit/tabsyn_1/"
DATA_DIR = os.path.join(SOURCE_DATA_DIR, 'data')
# RAW_DATA_DIR = os.path.join(DATA_DIR, "raw_data")
RAW_DATA_DIR=DATA_DIR
PROCESSED_DATA_DIR = os.path.join(DATA_DIR, "processed_data")
SYNTH_DATA_DIR = os.path.join(DATA_DIR, "synthetic_data")
DATA_NAME = "trans"

MODEL_PATH = "models/tabsyn"

def process_data(name, info_path, data_dir, data_path=None, test_path=None):
    processed_data_dir = os.path.join(data_dir, "processed_data")
    
    with open(f"{info_path}/{name}.json", "r") as f:
        info = json.load(f)
    if data_path:
        info["data_path"] = data_path
    if test_path:
        info["test_path"] = test_path
    data_path = info["data_path"]
    print(data_path)

    if info["file_type"] == "csv":
        data_df = pd.read_csv(data_path, header=info["header"])
        print('reading', data_path)

    elif info["file_type"] == "xls":
        data_df = pd.read_excel(data_path, sheet_name="Data", header=1)
        data_df = data_df.drop("ID", axis=1)

    expected_columns = ['trans_id', 'account_id']
    if data_df.columns[:2].tolist() == expected_columns:
        data_df = data_df.iloc[:, 2:]
    else:
        raise ValueError(f"Not Matching {expected_columns}, Please Check {data_path}")
    
    num_data = data_df.shape[0]

    column_names = (
        info["column_names"] if info["column_names"] else data_df.columns.tolist()
    )

    num_col_idx = info["num_col_idx"]
    cat_col_idx = info["cat_col_idx"]
    target_col_idx = info["target_col_idx"]

    idx_mapping, inverse_idx_mapping, idx_name_mapping = get_column_name_mapping(
        data_df, num_col_idx, cat_col_idx, target_col_idx, column_names
    )

    num_columns = [column_names[i] for i in num_col_idx]
    cat_columns = [column_names[i] for i in cat_col_idx]
    target_columns = [column_names[i] for i in target_col_idx]

    if ("test_path" in info.keys()) and info["test_path"]:
        # if testing data is given

        test_path = info["test_path"]
        test_df = pd.read_csv(test_path)
        print('reading', test_path)
        if test_df.columns[:2].tolist() == expected_columns:
            test_df = test_df.iloc[:, 2:]
        else:
            raise ValueError(f"Not Matching {expected_columns}, Please Check {test_path}")
        train_df = data_df
    else:
        num_train = int(num_data * 0.99)
        num_test = num_data - num_train

        train_df, test_df, seed = train_val_test_split(
            data_df, cat_columns, num_train, num_test
        )

    train_df.columns = range(len(train_df.columns))
    test_df.columns = range(len(test_df.columns))

    col_info = {}

    for col_idx in num_col_idx:
        col_info[col_idx] = {}
        col_info["type"] = "numerical"
        col_info["max"] = float(train_df[col_idx].max())
        col_info["min"] = float(train_df[col_idx].min())

    for col_idx in cat_col_idx:
        col_info[col_idx] = {}
        col_info["type"] = "categorical"
        col_info["categorizes"] = list(set(train_df[col_idx]))

    for col_idx in target_col_idx:
        if info["task_type"] == "regression":
            col_info[col_idx] = {}
            col_info["type"] = "numerical"
            col_info["max"] = float(train_df[col_idx].max())
            col_info["min"] = float(train_df[col_idx].min())
        else:
            col_info[col_idx] = {}
            col_info["type"] = "categorical"
            col_info["categorizes"] = list(set(train_df[col_idx]))

    info["column_info"] = col_info

    train_df.rename(columns=idx_name_mapping, inplace=True)
    test_df.rename(columns=idx_name_mapping, inplace=True)

    for col in num_columns:
        train_df.loc[train_df[col] == "?", col] = np.nan
    for col in cat_columns:
        train_df.loc[train_df[col] == "?", col] = "nan"
    for col in num_columns:
        test_df.loc[test_df[col] == "?", col] = np.nan
    for col in cat_columns:
        test_df.loc[test_df[col] == "?", col] = "nan"

    X_num_train = train_df[num_columns].to_numpy().astype(np.float32)
    X_cat_train = train_df[cat_columns].to_numpy().astype(np.int64)
    y_train = train_df[target_columns].to_numpy()

    X_num_test = test_df[num_columns].to_numpy().astype(np.float32)
    X_cat_test = test_df[cat_columns].to_numpy().astype(np.int32)
    y_test = test_df[target_columns].to_numpy()

    if not os.path.exists(f"{processed_data_dir}/{name}"):
        os.makedirs(f"{processed_data_dir}/{name}")

    np.save(f"{processed_data_dir}/{name}/X_num_train.npy", X_num_train)
    np.save(f"{processed_data_dir}/{name}/X_cat_train.npy", X_cat_train)
    np.save(f"{processed_data_dir}/{name}/y_train.npy", y_train)

    np.save(f"{processed_data_dir}/{name}/X_num_test.npy", X_num_test)
    np.save(f"{processed_data_dir}/{name}/X_cat_test.npy", X_cat_test)
    np.save(f"{processed_data_dir}/{name}/y_test.npy", y_test)

    train_df[num_columns] = train_df[num_columns].astype(np.float32)
    test_df[num_columns] = test_df[num_columns].astype(np.float32)

    train_df.to_csv(f"{processed_data_dir}/{name}/train.csv", index=False)
    test_df.to_csv(f"{processed_data_dir}/{name}/test.csv", index=False)

    info["column_names"] = column_names
    info["train_num"] = train_df.shape[0]
    info["test_num"] = test_df.shape[0]

    info["idx_mapping"] = idx_mapping
    info["inverse_idx_mapping"] = inverse_idx_mapping
    info["idx_name_mapping"] = idx_name_mapping

    metadata = {"columns": {}}
    task_type = info["task_type"]
    num_col_idx = info["num_col_idx"]
    cat_col_idx = info["cat_col_idx"]
    target_col_idx = info["target_col_idx"]

    for i in num_col_idx:
        metadata["columns"][i] = {}
        metadata["columns"][i]["sdtype"] = "numerical"
        metadata["columns"][i]["computer_representation"] = "Float"

    for i in cat_col_idx:
        metadata["columns"][i] = {}
        metadata["columns"][i]["sdtype"] = "categorical"

    if task_type == "regression":
        for i in target_col_idx:
            metadata["columns"][i] = {}
            metadata["columns"][i]["sdtype"] = "numerical"
            metadata["columns"][i]["computer_representation"] = "Float"

    else:
        for i in target_col_idx:
            metadata["columns"][i] = {}
            metadata["columns"][i]["sdtype"] = "categorical"

    info["metadata"] = metadata

    with open(f"{processed_data_dir}/{name}/info.json", "w") as file:
        json.dump(info, file, indent=4)

    print(f"Processing and Saving {name} Successfully!")

    print("Dataset Name:", name)
    print("Total Size:", info["train_num"] + info["test_num"])
    print("Train Size:", info["train_num"])
    print("Test Size:", info["test_num"])
    if info["task_type"] == "regression":
        num = len(info["num_col_idx"] + info["target_col_idx"])
        cat = len(info["cat_col_idx"])
    else:
        cat = len(info["cat_col_idx"] + info["target_col_idx"])
        num = len(info["num_col_idx"])
    print("Number of Numerical Columns:", num)
    print("Number of Categorical Columns:", cat)


In [36]:
# process data
data_path = os.path.join(SOURCE_DATA_DIR, 'train_with_id.csv')
test_path = os.path.join(SOURCE_DATA_DIR, 'challenge_with_id.csv')
process_data(DATA_NAME, INFO_DIR, DATA_DIR, data_path=data_path, test_path=test_path)

# review data
df = pd.read_csv(os.path.join(PROCESSED_DATA_DIR, DATA_NAME, "train.csv"))
df.head(10)

../../starter_kits/tabsyn_white_box/train_overfit/tabsyn_1/train_with_id.csv
reading ../../starter_kits/tabsyn_white_box/train_overfit/tabsyn_1/train_with_id.csv
reading ../../starter_kits/tabsyn_white_box/train_overfit/tabsyn_1/challenge_with_id.csv
Processing and Saving trans Successfully!
Dataset Name: trans
Total Size: 20200
Train Size: 20000
Test Size: 200
Number of Numerical Columns: 4
Number of Categorical Columns: 4


Unnamed: 0,trans_date,trans_type,operation,amount,balance,k_symbol,bank,account
0,336.0,0,3,2400.0,20515.0,1,0,0.0
1,2129.0,2,4,14.6,65847.0,6,0,0.0
2,1641.0,2,4,14.6,13507.4,6,0,0.0
3,515.0,2,4,14.6,36742.7,6,0,0.0
4,1984.0,0,2,3650.0,16299.2,1,8,78194776.0
5,1856.0,2,4,14.6,25490.5,6,0,0.0
6,1833.0,2,4,1700.0,54860.6,1,0,0.0
7,2006.0,2,4,30.0,29900.0,6,0,0.0
8,1093.0,2,4,15900.0,58053.2,1,0,0.0
9,801.0,0,3,12970.0,32283.0,1,0,0.0


In [37]:
# review json file and its contents
with open(f"{PROCESSED_DATA_DIR}/{DATA_NAME}/info.json", "r") as file:
    data_info = json.load(file)
data_info

{'name': 'trans',
 'task_type': 'regression',
 'header': 'infer',
 'column_names': ['trans_date',
  'trans_type',
  'operation',
  'amount',
  'balance',
  'k_symbol',
  'bank',
  'account'],
 'num_col_idx': [0, 4, 7],
 'cat_col_idx': [1, 2, 5, 6],
 'target_col_idx': [3],
 'file_type': 'csv',
 'data_path': '../../starter_kits/tabsyn_white_box/train_overfit/tabsyn_1/train_with_id.csv',
 'test_path': '../../starter_kits/tabsyn_white_box/train_overfit/tabsyn_1/challenge_with_id.csv',
 'column_info': {'0': {},
  'type': 'numerical',
  'max': 78600.0,
  'min': 0.2,
  '4': {},
  '7': {},
  '1': {},
  'categorizes': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
  '2': {},
  '5': {},
  '6': {},
  '3': {}},
 'train_num': 20000,
 'test_num': 200,
 'idx_mapping': {'0': 0,
  '1': 3,
  '2': 4,
  '3': 7,
  '4': 1,
  '5': 5,
  '6': 6,
  '7': 2},
 'inverse_idx_mapping': {'0': 0,
  '3': 1,
  '4': 2,
  '7': 3,
  '1': 4,
  '5': 5,
  '6': 6,
  '2': 7},
 'idx_name_mapping': {'0': 'trans_date',
  '1': 'tr


Note that if you want to use a subset of the entire transaction table, you must still preprocess the full table, retain the main table, and pass it as the reference data to `preprocess` later. This is because the model should have access to all the categories for categorical columns in the data.

The sample data info files is available in `data_info/trans_all.json`. The paths for the training and test data in the file can be modified as needed.

In [38]:
# DATA_DIR_ALL = "all_data/"
# RAW_DATA_DIR_ALL = os.path.join(DATA_DIR_ALL, "raw_data")
# PROCESSED_DATA_DIR_ALL = os.path.join(DATA_DIR_ALL, "processed_data")
# DATA_NAME_ALL = "trans_all"
# PROCESSED_DATA_DIR_ALL = os.path.join(DATA_DIR_ALL, "processed_data")
DATA_DIR_ALL = "../../starter_kits/tabsyn_white_box_all_data/"
DATA_NAME_ALL = "trans_all_demo"
PROCESSED_DATA_DIR_ALL = os.path.join(DATA_DIR_ALL, "processed_data")
process_data(DATA_NAME_ALL, INFO_DIR, DATA_DIR_ALL)

REF_DATA_PATH = os.path.join(PROCESSED_DATA_DIR_ALL, DATA_NAME_ALL)

/work4/xiaoyuwu/MIDSTModelsMIA/starter_kits/tab_syn_train_merge.csv
reading /work4/xiaoyuwu/MIDSTModelsMIA/starter_kits/tab_syn_train_merge.csv
Processing and Saving trans_all_demo Successfully!
Dataset Name: trans_all_demo
Total Size: 460972
Train Size: 456362
Test Size: 4610
Number of Numerical Columns: 4
Number of Categorical Columns: 4


# TabSyn Algorithm

In this section, we will describe the design of TabSyn as well as its main hyperparameters loaded through config, which affect the model’s effectiveness. 

**TabSyn** consists of two parts:
1. A *variational auto-encoder (VAE)* which learns a joint representation space for the given tabular data.
2. A *Diffusion model* which learns the distribution of data in the joint representation space.

The figure below shows a diagram of the TabSyn model.

<p align="center">
<img src="https://github.com/user-attachments/assets/a7e6a218-dd8e-4ae8-a8e5-6fc3974b2e9b" width="1000"/>
</p>

**VAE**

The left-side of the figure shows the VAE which operates in the original data space. The VAE itself consists of two parts: an encoder and a decoder. It also contains the corresponding tokenizer and detokenizer.
Each row of the input tabular data ($\pmb{x}$) is tokenized, then embedded by a transformer. Another transformer decodes the embeddings and a detokenizer reconstructs the table ($\pmb{\tilde{x}}$). The VAE is trained by minimizing the reconstruction loss between $\pmb{x}$ and $\pmb{\tilde{x}}$.

After the VAE is fully trained, the whole data ($\pmb{x}$) is tokenized and embedded. The embedding of each row is flattened to form a 1-dimensional vector $\pmb{z}$.
These 1-dimensional embeddings for all rows are stored on disk, and will later be used to train the diffusion model.

**Diffusion**

The right-side of the figure shows the diffusion model which operates in the latent representation space; in other words, it only *sees* the embeddings obtained by the VAE, not the original tabular data.
The diffusion model can be similarly divided into two parts: a forward process, and a reverse process.

The forward process receives the embedded data points. A single data point is denoted by $\pmb{z_0}$ in the figure. Gaussian noise is incrementally added to the embeddings in numerous incremental steps during the forward process. The number of the steps is denoted by $T$ in the figure. $T$ should be high enough that the distribution of embeddings at step $t=T$ is essentially a standard Gaussian distribution; in other words, the signal-to-noise ratio is practically zero.

The reverse process, on the other hand, learns to *predict* an earlier-step embedding (e.g. $\pmb{z_{t-\Delta t}}$) from a later-step embedding (e.g. $\pmb{z_t}$) via a neural network.

After the diffusion model is fully trained, the reverse process can estimate the data distribution at step $t=0$ if it receives a standard Gaussian distribution at step $t=T$. New data points can be synthesized by sampling from this estimated distribution.


## Load Config

In this section, we will load the configuration file that contains the hyperparameters for the TabSyn model. 

In [39]:
config_path = os.path.join("src/configs", f"{DATA_NAME}.toml")
raw_config = load_config(config_path)

pprint(raw_config)

{'loss_params': {'lambd': 0.7, 'max_beta': 0.01, 'min_beta': 1e-05},
 'model_params': {'d_token': 4, 'factor': 32, 'n_head': 1, 'num_layers': 2},
 'task_type': 'regression',
 'train': {'diffusion': {'batch_size': 4096,
                         'num_dataset_workers': 4,
                         'num_epochs': 10001},
           'optim': {'diffusion': {'factor': 0.9,
                                   'lr': 0.001,
                                   'patience': 50,
                                   'weight_decay': 0},
                     'vae': {'factor': 0.95,
                             'lr': 0.001,
                             'patience': 10,
                             'weight_decay': 0}},
           'vae': {'batch_size': 4096,
                   'num_dataset_workers': 4,
                   'num_epochs': 4000}},
 'transforms': {'cat_encoding': None,
                'cat_min_frequency': None,
                'cat_nan_policy': None,
                'normalization': 'quantile',
      

The configuration file is a TOML file that contains the following hyperparameters:

1. **model_params:** specifies the structure of the transformers (both encoder and decoder) in the VAE model, including number of transformer layers, number of self-attnetion heads and token dimension.

2. **transforms:** specifies the transformations and preprocessing of the data before tokenization, such as cleaning, normalization, and encoding.
    - For preprocessing numerical features, we use the gaussian quantile transformation and replace the NaN values with mean of each row.
    - For categorical features, we use the one-hot encoding method. NaN values are left unchanged, but we have the option to replace them. We have the option to drop the values that appear with less than a given minimum frequency under each column. Furthermore, we have the option to add an extra encoding step for categorical features during tokenization.

3. **train.vae:** specifies training parameters of the VAE, including batch size, number of epochs, and number of dataset workers.

4. **train.diffusion:** specifies the same training parameters as above for the diffusion model.

5. **train.optim.vae:** specifies the parameters of the *Adam* optimizer and the `ReduceLROnPlateau` learning rate scheduler used to train the VAE. Optimizer parameters include initial learning rate and weight decay. LR scheduler parameters includer `factor` and `patience`.

6. **train.optim.diffusion:** specifies the same parameters as above for the diffusion model.

7. **loss_params:** specifies parameters of the loss function used to train the VAE including `max_beta`, `min_beta` and `lambd`.

$\beta$ is the coefficient of the KL divergence term in the VAE loss formula,

$\mathcal{L}_{vae} = \mathcal{L}_{mse} + \mathcal{L}_{ce} + \beta \mathcal{L}_{kl}$
.

Parameters `max_beta` and `min_beta` determine the range of $\beta$. $\beta$ is first set to `max_beta`. If the loss stops decreasing for a certain number of epochs (e.g. $10$ epochs), then at the end of each epoch after that (e.g. epoch $11$, $12$, etc.) $\beta$ is decreased by a factor of `lambd`,
$\beta_{new} = \lambda \beta_{curr}$,
until it reaches `beta_min`.


## Make Dataset

In this section, we pre-process the data and make a dataset object.

First, we determine transformations needed for the dataset, such as normalization and cleaning, in `transforms`. Next, using `preprocess` function we load the data from disk in arrays that contain both training and test data (`X_num` and `X_cat`), as well as the number of categories for each categorical feature (`categories`) and the number of numerical features (`d_numerical`).

We then separate the train and test data in different arrays and convert them to Pytorch tensors.
We create a dataset object (`TabularDataset`) with the train data. `TabularDataset` is a simple module which returns the tokens of a single row at a time. Each row constiutes a single data sample in TabSyn. Afterwards, we create a Dataloader for the train data using the `batch_size` and `num_workers` specified in config.

In contrast, we keep the test data as tensors (`X_test_num` and `X_test_cat`). If a GPU is available, we move these tensors to GPU so that they can be accessed by the model later on.

In [40]:
# preprocess data
X_num, X_cat, categories, d_numerical = preprocess(
    os.path.join(PROCESSED_DATA_DIR, DATA_NAME),
    ref_dataset_path=REF_DATA_PATH,
    transforms=raw_config["transforms"],
    task_type=raw_config["task_type"],
)

# separate train and test data
X_train_num, X_test_num = X_num
X_train_cat, X_test_cat = X_cat

# convert to float tensor
X_train_num, X_test_num = (
    torch.tensor(X_train_num).float(),
    torch.tensor(X_test_num).float(),
)
X_train_cat, X_test_cat = torch.tensor(X_train_cat), torch.tensor(X_test_cat)

# create dataset module
train_data = TabularDataset(X_train_num.float(), X_train_cat)

# move test data to gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_test_num = X_test_num.float().to(device)
X_test_cat = X_test_cat.to(device)

# create train dataloader
train_loader = DataLoader(
    train_data,
    batch_size=raw_config["train"]["vae"]["batch_size"],
    shuffle=True,
    num_workers=raw_config["train"]["vae"]["num_dataset_workers"],
)

No NaNs in numerical features, skipping
No NaNs in numerical features, skipping


## Instantiate Model

Next, we instantiate the model using the `TabSyn` class. `TabSyn` class takes the following arguments:

1. `train_loader`: dataloader for train data.
2. `X_test_num`: numerical features of the test data.
3. `X_test_cat`: categorical features of the train data.
4. `num_numerical_features`: number of numerical features in the dataset.
5. `num_classes`: number of classes (i.e. categories) of each categorical feature in the dataset.
6. `device`: the device on which the model and data exist, either "cpu" or "cuda".

In [41]:
tabsyn = TabSyn(
    train_loader,
    X_test_num,
    X_test_cat,
    num_numerical_features=d_numerical,
    num_classes=categories,
    device=device,
)

`TabSyn` class has the tools to instantiate VAE and diffusion models, train both, and sample from the trained diffusion model.
We will demonstrate how to use these tools in the following sections.

## Train Model


The VAE and the diffusion model are trained independently. The following subsections explain each training process.


### A. Train VAE

First, we need to instantiate the VAE using the `instantiate_vae` method. This method takes the VAE model hyperparameters, optimizer and lr scheduler parameters from config, and instantiates them.

In [42]:
# instantiate VAE model for training
tabsyn.instantiate_vae(
    **raw_config["model_params"], optim_params=raw_config["train"]["optim"]["vae"]
)

Successfully instantiated VAE model.


Now that we have instantiated the VAE, we can train it using the `train_vae` function.
This function receives the loss hyperparameters and number of epochs from the config.
Moreover, it recieves `save_path` which is the directory where trained model checkpoints will be saved.

In [43]:
os.makedirs(f"{MODEL_PATH}/{DATA_NAME}/vae", exist_ok=True)
tabsyn.train_vae(
    **raw_config["loss_params"],
    num_epochs=raw_config["train"]["vae"]["num_epochs"],
    save_path=os.path.join(MODEL_PATH, DATA_NAME, "vae"),
)

Epoch 1/4000: 100%|██████████| 5/5 [00:00<00:00,  8.46it/s]


epoch: 0, beta = 0.010000, Train MSE: 4.147132, Train CE:2.108088, Train KL:0.450256, Val MSE:3.007195, Val CE:2.016140, Train ACC:0.204231, Val ACC:0.237500


Epoch 2/4000: 100%|██████████| 5/5 [00:00<00:00,  7.11it/s]
Epoch 3/4000: 100%|██████████| 5/5 [00:00<00:00,  7.99it/s]
Epoch 4/4000: 100%|██████████| 5/5 [00:00<00:00,  7.74it/s]
Epoch 5/4000: 100%|██████████| 5/5 [00:00<00:00,  7.77it/s]
Epoch 6/4000:  40%|████      | 2/5 [00:00<00:00,  3.41it/s]


KeyboardInterrupt: 

After training the VAE, we embed the training data with the trained encoder and store the embeddings in a direcotry specified by `vae_ckpt_dir`.

In [None]:
# embed all inputs in the latent space 
tabsyn.save_vae_embeddings(
    X_train_num, X_train_cat, vae_ckpt_dir=os.path.join(MODEL_PATH, DATA_NAME, "vae")
)

### B. Train Diffusion Model

Now that we have stored the training data embeddings, we need to load and prepare them for the diffusion model.
We load the embeddings using `load_vae_embeddings`. We normalize the embeddings by subtracting the mean and dividing by the standard deviation. Then, we create a Dataloader with the specified `batch_size` and `num_workers` from the config.

In [44]:
# load latent space embeddings
train_z, _ = tabsyn.load_latent_embeddings(
    os.path.join(MODEL_PATH, DATA_NAME, "vae")
)  # train_z dim: B x in_dim

# normalize embeddings
mean, std = train_z.mean(0), train_z.std(0)
latent_train_data = (train_z - mean) / 2

# create data loader
latent_train_loader = DataLoader(
    latent_train_data,
    batch_size=raw_config["train"]["diffusion"]["batch_size"],
    shuffle=True,
    num_workers=raw_config["train"]["diffusion"]["num_dataset_workers"],
)

Now that the data is ready, we instantiate the diffusion model with `instantiate_diffusion`. The input dimension and hidden dimention of the diffusion model is determined by the dimension of the embeddings. 
Moreover, we instantiate the optimizer and lr scheduler using hyperparameters from config.

In [45]:
# instantiate diffusion model for training
tabsyn.instantiate_diffusion(
    in_dim=train_z.shape[1],
    hid_dim=train_z.shape[1],
    optim_params=raw_config["train"]["optim"]["diffusion"],
)

MLPDiffusion(
  (proj): Linear(in_features=32, out_features=1024, bias=True)
  (mlp): Sequential(
    (0): Linear(in_features=1024, out_features=2048, bias=True)
    (1): SiLU()
    (2): Linear(in_features=2048, out_features=2048, bias=True)
    (3): SiLU()
    (4): Linear(in_features=2048, out_features=1024, bias=True)
    (5): SiLU()
    (6): Linear(in_features=1024, out_features=32, bias=True)
  )
  (map_noise): PositionalEmbedding()
  (time_embed): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): SiLU()
    (2): Linear(in_features=1024, out_features=1024, bias=True)
  )
)
The number of parameters: 10559520
Successfully instantiated diffusion model.


We train the diffusion model with `train_diffusion` function.
This function takes the following arguements:
1. `latent_train_loader`: dataloader for the latent representations which are used to train the diffusion model.
2. `num_epochs`: number of training epochs.
3. `ckpt_path`: directory where the model checkpoints will be stored.

In [46]:
os.makedirs(f"{MODEL_PATH}/{DATA_NAME}", exist_ok=True)
# train diffusion model

tabsyn.train_diffusion(
    latent_train_loader,
    num_epochs=raw_config["train"]["diffusion"]["num_epochs"]*10,
    ckpt_path=os.path.join(MODEL_PATH, DATA_NAME),
)

Epoch 1/100010: 100%|██████████| 5/5 [00:00<00:00,  8.52it/s, Loss=0.785]
Epoch 2/100010: 100%|██████████| 5/5 [00:00<00:00,  8.59it/s, Loss=0.714]
Epoch 3/100010: 100%|██████████| 5/5 [00:00<00:00,  8.38it/s, Loss=0.666]
Epoch 4/100010: 100%|██████████| 5/5 [00:00<00:00,  7.31it/s, Loss=0.667]
Epoch 5/100010: 100%|██████████| 5/5 [00:00<00:00,  8.47it/s, Loss=0.627]
Epoch 6/100010: 100%|██████████| 5/5 [00:00<00:00,  7.98it/s, Loss=0.592]
Epoch 7/100010: 100%|██████████| 5/5 [00:00<00:00,  7.82it/s, Loss=0.562]
Epoch 8/100010: 100%|██████████| 5/5 [00:00<00:00,  7.97it/s, Loss=0.536]
Epoch 9/100010: 100%|██████████| 5/5 [00:00<00:00,  8.46it/s, Loss=0.537]
Epoch 10/100010: 100%|██████████| 5/5 [00:00<00:00,  7.80it/s, Loss=0.474]
Epoch 11/100010: 100%|██████████| 5/5 [00:00<00:00,  8.20it/s, Loss=0.427]
Epoch 12/100010: 100%|██████████| 5/5 [00:00<00:00,  8.24it/s, Loss=0.393]
Epoch 13/100010: 100%|██████████| 5/5 [00:00<00:00,  7.86it/s, Loss=0.365]
Epoch 14/100010: 100%|██████████| 

Epoch 00412: reducing learning rate of group 0 to 9.0000e-04.


Epoch 413/100010: 100%|██████████| 5/5 [00:00<00:00, 11.21it/s, Loss=0.17] 
Epoch 414/100010: 100%|██████████| 5/5 [00:00<00:00, 10.38it/s, Loss=0.169]
Epoch 415/100010: 100%|██████████| 5/5 [00:00<00:00, 11.42it/s, Loss=0.17] 
Epoch 416/100010: 100%|██████████| 5/5 [00:00<00:00, 11.29it/s, Loss=0.17] 
Epoch 417/100010: 100%|██████████| 5/5 [00:00<00:00, 10.86it/s, Loss=0.176]
Epoch 418/100010: 100%|██████████| 5/5 [00:00<00:00, 11.13it/s, Loss=0.169]
Epoch 419/100010: 100%|██████████| 5/5 [00:00<00:00, 12.14it/s, Loss=0.161]
Epoch 420/100010: 100%|██████████| 5/5 [00:00<00:00, 11.23it/s, Loss=0.161]
Epoch 421/100010: 100%|██████████| 5/5 [00:00<00:00, 11.76it/s, Loss=0.168]
Epoch 422/100010: 100%|██████████| 5/5 [00:00<00:00, 11.31it/s, Loss=0.169]
Epoch 423/100010: 100%|██████████| 5/5 [00:00<00:00, 10.99it/s, Loss=0.165]
Epoch 424/100010: 100%|██████████| 5/5 [00:00<00:00, 11.88it/s, Loss=0.162]
Epoch 425/100010: 100%|██████████| 5/5 [00:00<00:00, 11.24it/s, Loss=0.169]
Epoch 426/10

Epoch 00493: reducing learning rate of group 0 to 8.1000e-04.


Epoch 494/100010: 100%|██████████| 5/5 [00:00<00:00, 11.64it/s, Loss=0.16] 
Epoch 495/100010: 100%|██████████| 5/5 [00:00<00:00, 11.09it/s, Loss=0.164]
Epoch 496/100010: 100%|██████████| 5/5 [00:00<00:00, 11.52it/s, Loss=0.156]
Epoch 497/100010: 100%|██████████| 5/5 [00:00<00:00, 11.94it/s, Loss=0.162]
Epoch 498/100010: 100%|██████████| 5/5 [00:00<00:00, 11.16it/s, Loss=0.158]
Epoch 499/100010: 100%|██████████| 5/5 [00:00<00:00, 11.64it/s, Loss=0.164]
Epoch 500/100010: 100%|██████████| 5/5 [00:00<00:00, 12.09it/s, Loss=0.16] 
Epoch 501/100010: 100%|██████████| 5/5 [00:00<00:00, 11.30it/s, Loss=0.17] 
Epoch 502/100010: 100%|██████████| 5/5 [00:00<00:00, 11.25it/s, Loss=0.173]
Epoch 503/100010: 100%|██████████| 5/5 [00:00<00:00, 11.76it/s, Loss=0.159]
Epoch 504/100010: 100%|██████████| 5/5 [00:00<00:00, 11.28it/s, Loss=0.173]
Epoch 505/100010: 100%|██████████| 5/5 [00:00<00:00, 11.66it/s, Loss=0.162]
Epoch 506/100010: 100%|██████████| 5/5 [00:00<00:00, 11.43it/s, Loss=0.167]
Epoch 507/10

Epoch 00593: reducing learning rate of group 0 to 7.2900e-04.


Epoch 594/100010: 100%|██████████| 5/5 [00:00<00:00, 11.92it/s, Loss=0.167]
Epoch 595/100010: 100%|██████████| 5/5 [00:00<00:00, 11.25it/s, Loss=0.152]
Epoch 596/100010: 100%|██████████| 5/5 [00:00<00:00, 11.71it/s, Loss=0.166]
Epoch 597/100010: 100%|██████████| 5/5 [00:00<00:00, 11.51it/s, Loss=0.164]
Epoch 598/100010: 100%|██████████| 5/5 [00:00<00:00, 12.08it/s, Loss=0.16] 
Epoch 599/100010: 100%|██████████| 5/5 [00:00<00:00, 11.55it/s, Loss=0.162]
Epoch 600/100010: 100%|██████████| 5/5 [00:00<00:00, 11.43it/s, Loss=0.161]
Epoch 601/100010: 100%|██████████| 5/5 [00:00<00:00, 11.92it/s, Loss=0.163]
Epoch 602/100010: 100%|██████████| 5/5 [00:00<00:00, 11.55it/s, Loss=0.162]
Epoch 603/100010: 100%|██████████| 5/5 [00:00<00:00, 11.31it/s, Loss=0.17] 
Epoch 604/100010: 100%|██████████| 5/5 [00:00<00:00, 11.62it/s, Loss=0.167]
Epoch 605/100010: 100%|██████████| 5/5 [00:00<00:00, 10.92it/s, Loss=0.161]
Epoch 606/100010: 100%|██████████| 5/5 [00:00<00:00, 11.41it/s, Loss=0.153]
Epoch 607/10

Epoch 00719: reducing learning rate of group 0 to 6.5610e-04.


Epoch 720/100010: 100%|██████████| 5/5 [00:00<00:00, 11.67it/s, Loss=0.16] 
Epoch 721/100010: 100%|██████████| 5/5 [00:00<00:00, 10.81it/s, Loss=0.17] 
Epoch 722/100010: 100%|██████████| 5/5 [00:00<00:00, 11.47it/s, Loss=0.159]
Epoch 723/100010: 100%|██████████| 5/5 [00:00<00:00, 11.62it/s, Loss=0.162]
Epoch 724/100010: 100%|██████████| 5/5 [00:00<00:00, 11.46it/s, Loss=0.162]
Epoch 725/100010: 100%|██████████| 5/5 [00:00<00:00, 11.68it/s, Loss=0.168]
Epoch 726/100010: 100%|██████████| 5/5 [00:00<00:00, 11.65it/s, Loss=0.167]
Epoch 727/100010: 100%|██████████| 5/5 [00:00<00:00, 11.83it/s, Loss=0.159]
Epoch 728/100010: 100%|██████████| 5/5 [00:00<00:00, 11.63it/s, Loss=0.166]
Epoch 729/100010: 100%|██████████| 5/5 [00:00<00:00, 11.08it/s, Loss=0.162]
Epoch 730/100010: 100%|██████████| 5/5 [00:00<00:00, 11.26it/s, Loss=0.167]
Epoch 731/100010: 100%|██████████| 5/5 [00:00<00:00, 11.57it/s, Loss=0.16] 
Epoch 732/100010: 100%|██████████| 5/5 [00:00<00:00, 10.97it/s, Loss=0.159]
Epoch 733/10

Epoch 00770: reducing learning rate of group 0 to 5.9049e-04.


Epoch 771/100010: 100%|██████████| 5/5 [00:00<00:00, 11.04it/s, Loss=0.158]
Epoch 772/100010: 100%|██████████| 5/5 [00:00<00:00, 11.67it/s, Loss=0.162]
Epoch 773/100010: 100%|██████████| 5/5 [00:00<00:00, 11.27it/s, Loss=0.161]
Epoch 774/100010: 100%|██████████| 5/5 [00:00<00:00, 11.46it/s, Loss=0.159]
Epoch 775/100010: 100%|██████████| 5/5 [00:00<00:00, 13.25it/s, Loss=0.162]
Epoch 776/100010: 100%|██████████| 5/5 [00:00<00:00, 11.84it/s, Loss=0.158]
Epoch 777/100010: 100%|██████████| 5/5 [00:00<00:00, 11.17it/s, Loss=0.163]
Epoch 778/100010: 100%|██████████| 5/5 [00:00<00:00, 11.41it/s, Loss=0.156]
Epoch 779/100010: 100%|██████████| 5/5 [00:00<00:00, 11.60it/s, Loss=0.16] 
Epoch 780/100010: 100%|██████████| 5/5 [00:00<00:00, 11.57it/s, Loss=0.165]
Epoch 781/100010: 100%|██████████| 5/5 [00:00<00:00, 11.57it/s, Loss=0.16] 
Epoch 782/100010: 100%|██████████| 5/5 [00:00<00:00, 11.19it/s, Loss=0.161]
Epoch 783/100010: 100%|██████████| 5/5 [00:00<00:00, 11.57it/s, Loss=0.164]
Epoch 784/10

Epoch 00821: reducing learning rate of group 0 to 5.3144e-04.


Epoch 822/100010: 100%|██████████| 5/5 [00:00<00:00, 10.29it/s, Loss=0.157]
Epoch 823/100010: 100%|██████████| 5/5 [00:00<00:00, 11.07it/s, Loss=0.165]
Epoch 824/100010: 100%|██████████| 5/5 [00:00<00:00, 11.83it/s, Loss=0.164]
Epoch 825/100010: 100%|██████████| 5/5 [00:00<00:00, 11.28it/s, Loss=0.164]
Epoch 826/100010: 100%|██████████| 5/5 [00:00<00:00, 11.92it/s, Loss=0.165]
Epoch 827/100010: 100%|██████████| 5/5 [00:00<00:00, 11.14it/s, Loss=0.161]
Epoch 828/100010: 100%|██████████| 5/5 [00:00<00:00, 11.72it/s, Loss=0.151]
Epoch 829/100010: 100%|██████████| 5/5 [00:00<00:00, 11.86it/s, Loss=0.162]
Epoch 830/100010: 100%|██████████| 5/5 [00:00<00:00, 10.96it/s, Loss=0.16] 
Epoch 831/100010: 100%|██████████| 5/5 [00:00<00:00, 12.24it/s, Loss=0.159]
Epoch 832/100010: 100%|██████████| 5/5 [00:00<00:00, 11.27it/s, Loss=0.16] 
Epoch 833/100010: 100%|██████████| 5/5 [00:00<00:00, 11.22it/s, Loss=0.16] 
Epoch 834/100010: 100%|██████████| 5/5 [00:00<00:00, 10.92it/s, Loss=0.162]
Epoch 835/10

Epoch 00961: reducing learning rate of group 0 to 4.7830e-04.


Epoch 962/100010: 100%|██████████| 5/5 [00:00<00:00, 11.17it/s, Loss=0.16] 
Epoch 963/100010: 100%|██████████| 5/5 [00:00<00:00, 11.77it/s, Loss=0.153]
Epoch 964/100010: 100%|██████████| 5/5 [00:00<00:00, 11.41it/s, Loss=0.16] 
Epoch 965/100010: 100%|██████████| 5/5 [00:00<00:00, 11.62it/s, Loss=0.162]
Epoch 966/100010: 100%|██████████| 5/5 [00:00<00:00, 11.38it/s, Loss=0.157]
Epoch 967/100010: 100%|██████████| 5/5 [00:00<00:00, 11.83it/s, Loss=0.156]
Epoch 968/100010: 100%|██████████| 5/5 [00:00<00:00, 11.43it/s, Loss=0.15] 
Epoch 969/100010: 100%|██████████| 5/5 [00:00<00:00, 10.64it/s, Loss=0.161]
Epoch 970/100010: 100%|██████████| 5/5 [00:00<00:00, 10.83it/s, Loss=0.16] 
Epoch 971/100010: 100%|██████████| 5/5 [00:00<00:00, 11.08it/s, Loss=0.155]
Epoch 972/100010: 100%|██████████| 5/5 [00:00<00:00, 10.73it/s, Loss=0.158]
Epoch 973/100010: 100%|██████████| 5/5 [00:00<00:00, 11.02it/s, Loss=0.157]
Epoch 974/100010: 100%|██████████| 5/5 [00:00<00:00, 11.51it/s, Loss=0.157]
Epoch 975/10

Epoch 01057: reducing learning rate of group 0 to 4.3047e-04.


Epoch 1058/100010: 100%|██████████| 5/5 [00:00<00:00, 11.06it/s, Loss=0.169]
Epoch 1059/100010: 100%|██████████| 5/5 [00:00<00:00, 11.17it/s, Loss=0.153]
Epoch 1060/100010: 100%|██████████| 5/5 [00:00<00:00, 11.23it/s, Loss=0.162]
Epoch 1061/100010: 100%|██████████| 5/5 [00:00<00:00, 11.54it/s, Loss=0.162]
Epoch 1062/100010: 100%|██████████| 5/5 [00:00<00:00, 11.62it/s, Loss=0.165]
Epoch 1063/100010: 100%|██████████| 5/5 [00:00<00:00, 11.54it/s, Loss=0.16] 
Epoch 1064/100010: 100%|██████████| 5/5 [00:00<00:00, 11.42it/s, Loss=0.159]
Epoch 1065/100010: 100%|██████████| 5/5 [00:00<00:00, 11.05it/s, Loss=0.162]
Epoch 1066/100010: 100%|██████████| 5/5 [00:00<00:00, 11.91it/s, Loss=0.164]
Epoch 1067/100010: 100%|██████████| 5/5 [00:00<00:00, 11.19it/s, Loss=0.158]
Epoch 1068/100010: 100%|██████████| 5/5 [00:00<00:00, 11.49it/s, Loss=0.16] 
Epoch 1069/100010: 100%|██████████| 5/5 [00:00<00:00, 11.21it/s, Loss=0.159]
Epoch 1070/100010: 100%|██████████| 5/5 [00:00<00:00, 11.47it/s, Loss=0.162]

Epoch 01124: reducing learning rate of group 0 to 3.8742e-04.


Epoch 1125/100010: 100%|██████████| 5/5 [00:00<00:00, 11.16it/s, Loss=0.157]
Epoch 1126/100010: 100%|██████████| 5/5 [00:00<00:00, 11.09it/s, Loss=0.167]
Epoch 1127/100010: 100%|██████████| 5/5 [00:00<00:00, 11.13it/s, Loss=0.159]
Epoch 1128/100010: 100%|██████████| 5/5 [00:00<00:00, 11.16it/s, Loss=0.16] 
Epoch 1129/100010: 100%|██████████| 5/5 [00:00<00:00, 10.53it/s, Loss=0.153]
Epoch 1130/100010: 100%|██████████| 5/5 [00:00<00:00, 12.21it/s, Loss=0.161]
Epoch 1131/100010: 100%|██████████| 5/5 [00:00<00:00, 11.00it/s, Loss=0.158]
Epoch 1132/100010: 100%|██████████| 5/5 [00:00<00:00, 11.36it/s, Loss=0.151]
Epoch 1133/100010: 100%|██████████| 5/5 [00:00<00:00, 11.94it/s, Loss=0.161]
Epoch 1134/100010: 100%|██████████| 5/5 [00:00<00:00, 11.42it/s, Loss=0.163]
Epoch 1135/100010: 100%|██████████| 5/5 [00:00<00:00, 10.83it/s, Loss=0.149]
Epoch 1136/100010: 100%|██████████| 5/5 [00:00<00:00, 11.65it/s, Loss=0.166]
Epoch 1137/100010: 100%|██████████| 5/5 [00:00<00:00, 10.78it/s, Loss=0.167]

Epoch 01175: reducing learning rate of group 0 to 3.4868e-04.


Epoch 1176/100010: 100%|██████████| 5/5 [00:00<00:00, 10.77it/s, Loss=0.155]
Epoch 1177/100010: 100%|██████████| 5/5 [00:00<00:00, 11.50it/s, Loss=0.159]
Epoch 1178/100010: 100%|██████████| 5/5 [00:00<00:00, 10.78it/s, Loss=0.155]
Epoch 1179/100010: 100%|██████████| 5/5 [00:00<00:00, 10.86it/s, Loss=0.156]
Epoch 1180/100010: 100%|██████████| 5/5 [00:00<00:00, 11.00it/s, Loss=0.168]
Epoch 1181/100010: 100%|██████████| 5/5 [00:00<00:00, 11.91it/s, Loss=0.156]
Epoch 1182/100010: 100%|██████████| 5/5 [00:00<00:00, 12.26it/s, Loss=0.156]
Epoch 1183/100010: 100%|██████████| 5/5 [00:00<00:00, 11.83it/s, Loss=0.161]
Epoch 1184/100010: 100%|██████████| 5/5 [00:00<00:00, 11.45it/s, Loss=0.157]
Epoch 1185/100010: 100%|██████████| 5/5 [00:00<00:00, 11.21it/s, Loss=0.155]
Epoch 1186/100010: 100%|██████████| 5/5 [00:00<00:00, 11.59it/s, Loss=0.164]
Epoch 1187/100010: 100%|██████████| 5/5 [00:00<00:00, 11.72it/s, Loss=0.16] 
Epoch 1188/100010: 100%|██████████| 5/5 [00:00<00:00, 12.30it/s, Loss=0.153]

Epoch 01226: reducing learning rate of group 0 to 3.1381e-04.


Epoch 1227/100010: 100%|██████████| 5/5 [00:00<00:00, 11.18it/s, Loss=0.156]
Epoch 1228/100010: 100%|██████████| 5/5 [00:00<00:00, 11.27it/s, Loss=0.147]
Epoch 1229/100010: 100%|██████████| 5/5 [00:00<00:00, 10.95it/s, Loss=0.151]
Epoch 1230/100010: 100%|██████████| 5/5 [00:00<00:00, 11.55it/s, Loss=0.156]
Epoch 1231/100010: 100%|██████████| 5/5 [00:00<00:00, 11.66it/s, Loss=0.159]
Epoch 1232/100010: 100%|██████████| 5/5 [00:00<00:00, 11.32it/s, Loss=0.157]
Epoch 1233/100010: 100%|██████████| 5/5 [00:00<00:00, 11.45it/s, Loss=0.154]
Epoch 1234/100010: 100%|██████████| 5/5 [00:00<00:00, 10.96it/s, Loss=0.159]
Epoch 1235/100010: 100%|██████████| 5/5 [00:00<00:00, 11.30it/s, Loss=0.155]
Epoch 1236/100010: 100%|██████████| 5/5 [00:00<00:00, 11.59it/s, Loss=0.155]
Epoch 1237/100010: 100%|██████████| 5/5 [00:00<00:00, 11.77it/s, Loss=0.157]
Epoch 1238/100010: 100%|██████████| 5/5 [00:00<00:00, 11.34it/s, Loss=0.158]
Epoch 1239/100010: 100%|██████████| 5/5 [00:00<00:00, 12.46it/s, Loss=0.16] 

Epoch 01277: reducing learning rate of group 0 to 2.8243e-04.


Epoch 1278/100010: 100%|██████████| 5/5 [00:00<00:00, 11.14it/s, Loss=0.158]
Epoch 1279/100010: 100%|██████████| 5/5 [00:00<00:00, 11.49it/s, Loss=0.159]
Epoch 1280/100010: 100%|██████████| 5/5 [00:00<00:00, 11.73it/s, Loss=0.162]
Epoch 1281/100010: 100%|██████████| 5/5 [00:00<00:00, 11.20it/s, Loss=0.154]
Epoch 1282/100010: 100%|██████████| 5/5 [00:00<00:00, 11.62it/s, Loss=0.162]
Epoch 1283/100010: 100%|██████████| 5/5 [00:00<00:00, 11.77it/s, Loss=0.155]
Epoch 1284/100010: 100%|██████████| 5/5 [00:00<00:00, 10.96it/s, Loss=0.152]
Epoch 1285/100010: 100%|██████████| 5/5 [00:00<00:00, 11.23it/s, Loss=0.158]
Epoch 1286/100010: 100%|██████████| 5/5 [00:00<00:00, 11.14it/s, Loss=0.157]
Epoch 1287/100010: 100%|██████████| 5/5 [00:00<00:00, 11.71it/s, Loss=0.159]
Epoch 1288/100010: 100%|██████████| 5/5 [00:00<00:00, 12.69it/s, Loss=0.159]
Epoch 1289/100010: 100%|██████████| 5/5 [00:00<00:00, 11.16it/s, Loss=0.162]
Epoch 1290/100010: 100%|██████████| 5/5 [00:00<00:00, 11.88it/s, Loss=0.154]

Epoch 01328: reducing learning rate of group 0 to 2.5419e-04.


Epoch 1329/100010: 100%|██████████| 5/5 [00:00<00:00, 11.22it/s, Loss=0.153]
Epoch 1330/100010: 100%|██████████| 5/5 [00:00<00:00, 11.12it/s, Loss=0.162]
Epoch 1331/100010: 100%|██████████| 5/5 [00:00<00:00, 10.45it/s, Loss=0.164]
Epoch 1332/100010: 100%|██████████| 5/5 [00:00<00:00, 11.14it/s, Loss=0.161]
Epoch 1333/100010: 100%|██████████| 5/5 [00:00<00:00, 11.20it/s, Loss=0.163]
Epoch 1334/100010: 100%|██████████| 5/5 [00:00<00:00, 11.57it/s, Loss=0.154]
Epoch 1335/100010: 100%|██████████| 5/5 [00:00<00:00, 11.62it/s, Loss=0.157]
Epoch 1336/100010: 100%|██████████| 5/5 [00:00<00:00, 10.52it/s, Loss=0.16] 
Epoch 1337/100010: 100%|██████████| 5/5 [00:00<00:00, 11.70it/s, Loss=0.157]
Epoch 1338/100010: 100%|██████████| 5/5 [00:00<00:00, 11.29it/s, Loss=0.159]
Epoch 1339/100010: 100%|██████████| 5/5 [00:00<00:00, 11.64it/s, Loss=0.161]
Epoch 1340/100010: 100%|██████████| 5/5 [00:00<00:00, 11.41it/s, Loss=0.159]
Epoch 1341/100010: 100%|██████████| 5/5 [00:00<00:00, 10.54it/s, Loss=0.156]

Epoch 01417: reducing learning rate of group 0 to 2.2877e-04.


Epoch 1418/100010: 100%|██████████| 5/5 [00:00<00:00, 10.69it/s, Loss=0.155]
Epoch 1419/100010: 100%|██████████| 5/5 [00:00<00:00, 11.30it/s, Loss=0.158]
Epoch 1420/100010: 100%|██████████| 5/5 [00:00<00:00, 11.00it/s, Loss=0.153]
Epoch 1421/100010: 100%|██████████| 5/5 [00:00<00:00, 10.98it/s, Loss=0.157]
Epoch 1422/100010: 100%|██████████| 5/5 [00:00<00:00, 11.83it/s, Loss=0.158]
Epoch 1423/100010: 100%|██████████| 5/5 [00:00<00:00, 11.34it/s, Loss=0.154]
Epoch 1424/100010: 100%|██████████| 5/5 [00:00<00:00, 12.07it/s, Loss=0.161]
Epoch 1425/100010: 100%|██████████| 5/5 [00:00<00:00, 11.78it/s, Loss=0.159]
Epoch 1426/100010: 100%|██████████| 5/5 [00:00<00:00, 11.65it/s, Loss=0.159]
Epoch 1427/100010: 100%|██████████| 5/5 [00:00<00:00, 11.51it/s, Loss=0.158]
Epoch 1428/100010: 100%|██████████| 5/5 [00:00<00:00, 12.08it/s, Loss=0.16] 
Epoch 1429/100010: 100%|██████████| 5/5 [00:00<00:00, 11.39it/s, Loss=0.158]
Epoch 1430/100010: 100%|██████████| 5/5 [00:00<00:00, 11.26it/s, Loss=0.16] 

Epoch 01468: reducing learning rate of group 0 to 2.0589e-04.


Epoch 1469/100010: 100%|██████████| 5/5 [00:00<00:00, 11.94it/s, Loss=0.163]
Epoch 1470/100010: 100%|██████████| 5/5 [00:00<00:00,  9.29it/s, Loss=0.158]
Epoch 1471/100010: 100%|██████████| 5/5 [00:00<00:00,  9.08it/s, Loss=0.151]
Epoch 1472/100010: 100%|██████████| 5/5 [00:00<00:00,  8.78it/s, Loss=0.152]
Epoch 1473/100010: 100%|██████████| 5/5 [00:00<00:00,  8.93it/s, Loss=0.163]
Epoch 1474/100010: 100%|██████████| 5/5 [00:00<00:00,  9.05it/s, Loss=0.154]
Epoch 1475/100010: 100%|██████████| 5/5 [00:00<00:00,  8.88it/s, Loss=0.159]
Epoch 1476/100010: 100%|██████████| 5/5 [00:00<00:00, 10.66it/s, Loss=0.156]
Epoch 1477/100010: 100%|██████████| 5/5 [00:00<00:00, 10.95it/s, Loss=0.15] 
Epoch 1478/100010: 100%|██████████| 5/5 [00:00<00:00, 10.34it/s, Loss=0.159]
Epoch 1479/100010: 100%|██████████| 5/5 [00:00<00:00, 11.46it/s, Loss=0.152]
Epoch 1480/100010: 100%|██████████| 5/5 [00:00<00:00, 11.82it/s, Loss=0.16] 
Epoch 1481/100010: 100%|██████████| 5/5 [00:00<00:00, 10.81it/s, Loss=0.161]

Epoch 01523: reducing learning rate of group 0 to 1.8530e-04.


Epoch 1524/100010: 100%|██████████| 5/5 [00:00<00:00, 11.66it/s, Loss=0.157]
Epoch 1525/100010: 100%|██████████| 5/5 [00:00<00:00, 12.50it/s, Loss=0.15] 
Epoch 1526/100010: 100%|██████████| 5/5 [00:00<00:00, 11.32it/s, Loss=0.158]
Epoch 1527/100010: 100%|██████████| 5/5 [00:00<00:00, 11.85it/s, Loss=0.156]
Epoch 1528/100010: 100%|██████████| 5/5 [00:00<00:00, 11.49it/s, Loss=0.154]
Epoch 1529/100010: 100%|██████████| 5/5 [00:00<00:00, 10.81it/s, Loss=0.158]
Epoch 1530/100010: 100%|██████████| 5/5 [00:00<00:00, 11.69it/s, Loss=0.16] 
Epoch 1531/100010: 100%|██████████| 5/5 [00:00<00:00, 11.05it/s, Loss=0.16] 
Epoch 1532/100010: 100%|██████████| 5/5 [00:00<00:00, 11.03it/s, Loss=0.157]
Epoch 1533/100010: 100%|██████████| 5/5 [00:00<00:00, 11.30it/s, Loss=0.158]
Epoch 1534/100010: 100%|██████████| 5/5 [00:00<00:00, 11.91it/s, Loss=0.154]
Epoch 1535/100010: 100%|██████████| 5/5 [00:00<00:00, 11.23it/s, Loss=0.155]
Epoch 1536/100010: 100%|██████████| 5/5 [00:00<00:00, 11.52it/s, Loss=0.157]

Epoch 01574: reducing learning rate of group 0 to 1.6677e-04.


Epoch 1575/100010: 100%|██████████| 5/5 [00:00<00:00, 11.21it/s, Loss=0.15] 
Epoch 1576/100010: 100%|██████████| 5/5 [00:00<00:00, 11.11it/s, Loss=0.158]
Epoch 1577/100010: 100%|██████████| 5/5 [00:00<00:00, 11.29it/s, Loss=0.157]
Epoch 1578/100010: 100%|██████████| 5/5 [00:00<00:00, 10.64it/s, Loss=0.156]
Epoch 1579/100010: 100%|██████████| 5/5 [00:00<00:00, 11.27it/s, Loss=0.161]
Epoch 1580/100010: 100%|██████████| 5/5 [00:00<00:00, 12.14it/s, Loss=0.156]
Epoch 1581/100010: 100%|██████████| 5/5 [00:00<00:00, 12.49it/s, Loss=0.154]
Epoch 1582/100010: 100%|██████████| 5/5 [00:00<00:00, 11.22it/s, Loss=0.158]
Epoch 1583/100010: 100%|██████████| 5/5 [00:00<00:00, 11.57it/s, Loss=0.152]
Epoch 1584/100010: 100%|██████████| 5/5 [00:00<00:00, 11.64it/s, Loss=0.163]
Epoch 1585/100010: 100%|██████████| 5/5 [00:00<00:00, 11.48it/s, Loss=0.153]
Epoch 1586/100010: 100%|██████████| 5/5 [00:00<00:00, 10.95it/s, Loss=0.15] 
Epoch 1587/100010: 100%|██████████| 5/5 [00:00<00:00, 11.93it/s, Loss=0.161]

Epoch 01667: reducing learning rate of group 0 to 1.5009e-04.


Epoch 1668/100010: 100%|██████████| 5/5 [00:00<00:00, 11.19it/s, Loss=0.159]
Epoch 1669/100010: 100%|██████████| 5/5 [00:00<00:00, 11.48it/s, Loss=0.158]
Epoch 1670/100010: 100%|██████████| 5/5 [00:00<00:00, 11.02it/s, Loss=0.153]
Epoch 1671/100010: 100%|██████████| 5/5 [00:00<00:00, 10.93it/s, Loss=0.152]
Epoch 1672/100010: 100%|██████████| 5/5 [00:00<00:00, 11.53it/s, Loss=0.154]
Epoch 1673/100010: 100%|██████████| 5/5 [00:00<00:00, 11.61it/s, Loss=0.163]
Epoch 1674/100010: 100%|██████████| 5/5 [00:00<00:00, 11.23it/s, Loss=0.15] 
Epoch 1675/100010: 100%|██████████| 5/5 [00:00<00:00, 10.96it/s, Loss=0.157]
Epoch 1676/100010: 100%|██████████| 5/5 [00:00<00:00, 10.82it/s, Loss=0.152]
Epoch 1677/100010: 100%|██████████| 5/5 [00:00<00:00, 11.69it/s, Loss=0.16] 
Epoch 1678/100010: 100%|██████████| 5/5 [00:00<00:00, 10.84it/s, Loss=0.162]
Epoch 1679/100010: 100%|██████████| 5/5 [00:00<00:00, 11.38it/s, Loss=0.156]
Epoch 1680/100010: 100%|██████████| 5/5 [00:00<00:00, 11.59it/s, Loss=0.152]

Epoch 01718: reducing learning rate of group 0 to 1.3509e-04.


Epoch 1719/100010: 100%|██████████| 5/5 [00:00<00:00, 10.58it/s, Loss=0.158]
Epoch 1720/100010: 100%|██████████| 5/5 [00:00<00:00, 11.76it/s, Loss=0.157]
Epoch 1721/100010: 100%|██████████| 5/5 [00:00<00:00, 11.57it/s, Loss=0.156]
Epoch 1722/100010: 100%|██████████| 5/5 [00:00<00:00, 11.05it/s, Loss=0.16] 
Epoch 1723/100010: 100%|██████████| 5/5 [00:00<00:00, 11.33it/s, Loss=0.157]
Epoch 1724/100010: 100%|██████████| 5/5 [00:00<00:00, 11.83it/s, Loss=0.162]
Epoch 1725/100010: 100%|██████████| 5/5 [00:00<00:00, 11.12it/s, Loss=0.155]
Epoch 1726/100010: 100%|██████████| 5/5 [00:00<00:00, 11.65it/s, Loss=0.159]
Epoch 1727/100010: 100%|██████████| 5/5 [00:00<00:00, 10.09it/s, Loss=0.147]
Epoch 1728/100010: 100%|██████████| 5/5 [00:00<00:00, 12.19it/s, Loss=0.154]
Epoch 1729/100010: 100%|██████████| 5/5 [00:00<00:00, 11.65it/s, Loss=0.149]
Epoch 1730/100010: 100%|██████████| 5/5 [00:00<00:00, 11.32it/s, Loss=0.157]
Epoch 1731/100010: 100%|██████████| 5/5 [00:00<00:00, 10.10it/s, Loss=0.16] 

Epoch 01769: reducing learning rate of group 0 to 1.2158e-04.


Epoch 1770/100010: 100%|██████████| 5/5 [00:00<00:00, 10.84it/s, Loss=0.156]
Epoch 1771/100010: 100%|██████████| 5/5 [00:00<00:00, 12.68it/s, Loss=0.159]
Epoch 1772/100010: 100%|██████████| 5/5 [00:00<00:00, 11.41it/s, Loss=0.158]
Epoch 1773/100010: 100%|██████████| 5/5 [00:00<00:00, 10.44it/s, Loss=0.159]
Epoch 1774/100010: 100%|██████████| 5/5 [00:00<00:00, 11.74it/s, Loss=0.152]
Epoch 1775/100010: 100%|██████████| 5/5 [00:00<00:00, 11.18it/s, Loss=0.161]
Epoch 1776/100010: 100%|██████████| 5/5 [00:00<00:00, 11.40it/s, Loss=0.161]
Epoch 1777/100010: 100%|██████████| 5/5 [00:00<00:00, 11.37it/s, Loss=0.162]
Epoch 1778/100010: 100%|██████████| 5/5 [00:00<00:00, 11.30it/s, Loss=0.152]
Epoch 1779/100010: 100%|██████████| 5/5 [00:00<00:00, 10.83it/s, Loss=0.162]
Epoch 1780/100010: 100%|██████████| 5/5 [00:00<00:00, 11.83it/s, Loss=0.155]
Epoch 1781/100010: 100%|██████████| 5/5 [00:00<00:00, 11.55it/s, Loss=0.153]
Epoch 1782/100010: 100%|██████████| 5/5 [00:00<00:00, 11.09it/s, Loss=0.158]

Epoch 01820: reducing learning rate of group 0 to 1.0942e-04.


Epoch 1821/100010: 100%|██████████| 5/5 [00:00<00:00, 11.55it/s, Loss=0.161]
Epoch 1822/100010: 100%|██████████| 5/5 [00:00<00:00, 12.14it/s, Loss=0.154]
Epoch 1823/100010: 100%|██████████| 5/5 [00:00<00:00, 11.26it/s, Loss=0.159]
Epoch 1824/100010: 100%|██████████| 5/5 [00:00<00:00, 10.60it/s, Loss=0.165]
Epoch 1825/100010: 100%|██████████| 5/5 [00:00<00:00, 12.08it/s, Loss=0.157]
Epoch 1826/100010: 100%|██████████| 5/5 [00:00<00:00, 11.53it/s, Loss=0.166]
Epoch 1827/100010: 100%|██████████| 5/5 [00:00<00:00, 11.77it/s, Loss=0.152]
Epoch 1828/100010: 100%|██████████| 5/5 [00:00<00:00, 11.39it/s, Loss=0.157]
Epoch 1829/100010: 100%|██████████| 5/5 [00:00<00:00, 11.50it/s, Loss=0.158]
Epoch 1830/100010: 100%|██████████| 5/5 [00:00<00:00, 11.08it/s, Loss=0.153]
Epoch 1831/100010: 100%|██████████| 5/5 [00:00<00:00, 11.40it/s, Loss=0.153]
Epoch 1832/100010: 100%|██████████| 5/5 [00:00<00:00, 11.74it/s, Loss=0.148]
Epoch 1833/100010: 100%|██████████| 5/5 [00:00<00:00, 11.10it/s, Loss=0.152]

Epoch 01871: reducing learning rate of group 0 to 9.8477e-05.


Epoch 1872/100010: 100%|██████████| 5/5 [00:00<00:00, 10.93it/s, Loss=0.158]
Epoch 1873/100010: 100%|██████████| 5/5 [00:00<00:00, 11.45it/s, Loss=0.155]
Epoch 1874/100010: 100%|██████████| 5/5 [00:00<00:00, 11.59it/s, Loss=0.153]
Epoch 1875/100010: 100%|██████████| 5/5 [00:00<00:00, 11.48it/s, Loss=0.159]
Epoch 1876/100010: 100%|██████████| 5/5 [00:00<00:00, 11.07it/s, Loss=0.162]
Epoch 1877/100010: 100%|██████████| 5/5 [00:00<00:00, 11.29it/s, Loss=0.157]
Epoch 1878/100010: 100%|██████████| 5/5 [00:00<00:00, 11.62it/s, Loss=0.156]
Epoch 1879/100010: 100%|██████████| 5/5 [00:00<00:00, 11.56it/s, Loss=0.157]
Epoch 1880/100010: 100%|██████████| 5/5 [00:00<00:00, 12.14it/s, Loss=0.155]
Epoch 1881/100010: 100%|██████████| 5/5 [00:00<00:00, 10.55it/s, Loss=0.153]
Epoch 1882/100010: 100%|██████████| 5/5 [00:00<00:00, 10.71it/s, Loss=0.158]
Epoch 1883/100010: 100%|██████████| 5/5 [00:00<00:00, 11.98it/s, Loss=0.156]
Epoch 1884/100010: 100%|██████████| 5/5 [00:00<00:00, 11.00it/s, Loss=0.152]

Epoch 01974: reducing learning rate of group 0 to 8.8629e-05.


Epoch 1975/100010: 100%|██████████| 5/5 [00:00<00:00, 11.13it/s, Loss=0.159]
Epoch 1976/100010: 100%|██████████| 5/5 [00:00<00:00, 11.43it/s, Loss=0.159]
Epoch 1977/100010: 100%|██████████| 5/5 [00:00<00:00, 10.74it/s, Loss=0.153]
Epoch 1978/100010: 100%|██████████| 5/5 [00:00<00:00, 11.53it/s, Loss=0.147]
Epoch 1979/100010: 100%|██████████| 5/5 [00:00<00:00, 11.61it/s, Loss=0.149]
Epoch 1980/100010: 100%|██████████| 5/5 [00:00<00:00, 11.71it/s, Loss=0.156]
Epoch 1981/100010: 100%|██████████| 5/5 [00:00<00:00, 11.28it/s, Loss=0.164]
Epoch 1982/100010: 100%|██████████| 5/5 [00:00<00:00, 11.44it/s, Loss=0.158]
Epoch 1983/100010: 100%|██████████| 5/5 [00:00<00:00, 11.79it/s, Loss=0.155]
Epoch 1984/100010: 100%|██████████| 5/5 [00:00<00:00, 11.46it/s, Loss=0.152]
Epoch 1985/100010: 100%|██████████| 5/5 [00:00<00:00, 11.45it/s, Loss=0.157]
Epoch 1986/100010: 100%|██████████| 5/5 [00:00<00:00, 11.70it/s, Loss=0.153]
Epoch 1987/100010: 100%|██████████| 5/5 [00:00<00:00, 11.28it/s, Loss=0.158]

Epoch 02025: reducing learning rate of group 0 to 7.9766e-05.


Epoch 2026/100010: 100%|██████████| 5/5 [00:00<00:00, 11.24it/s, Loss=0.157]
Epoch 2027/100010: 100%|██████████| 5/5 [00:00<00:00, 11.14it/s, Loss=0.155]
Epoch 2028/100010: 100%|██████████| 5/5 [00:00<00:00, 11.15it/s, Loss=0.16] 
Epoch 2029/100010: 100%|██████████| 5/5 [00:00<00:00, 11.71it/s, Loss=0.157]
Epoch 2030/100010: 100%|██████████| 5/5 [00:00<00:00, 11.76it/s, Loss=0.151]
Epoch 2031/100010: 100%|██████████| 5/5 [00:00<00:00, 11.70it/s, Loss=0.161]
Epoch 2032/100010: 100%|██████████| 5/5 [00:00<00:00, 10.96it/s, Loss=0.157]
Epoch 2033/100010: 100%|██████████| 5/5 [00:00<00:00, 11.99it/s, Loss=0.155]
Epoch 2034/100010: 100%|██████████| 5/5 [00:00<00:00, 11.16it/s, Loss=0.156]
Epoch 2035/100010: 100%|██████████| 5/5 [00:00<00:00, 12.00it/s, Loss=0.151]
Epoch 2036/100010: 100%|██████████| 5/5 [00:00<00:00, 10.89it/s, Loss=0.154]
Epoch 2037/100010: 100%|██████████| 5/5 [00:00<00:00, 11.60it/s, Loss=0.15] 
Epoch 2038/100010: 100%|██████████| 5/5 [00:00<00:00, 10.70it/s, Loss=0.153]

Epoch 02076: reducing learning rate of group 0 to 7.1790e-05.


Epoch 2077/100010: 100%|██████████| 5/5 [00:00<00:00, 10.85it/s, Loss=0.162]
Epoch 2078/100010: 100%|██████████| 5/5 [00:00<00:00, 11.75it/s, Loss=0.155]
Epoch 2079/100010: 100%|██████████| 5/5 [00:00<00:00, 11.37it/s, Loss=0.146]
Epoch 2080/100010: 100%|██████████| 5/5 [00:00<00:00, 11.64it/s, Loss=0.157]
Epoch 2081/100010: 100%|██████████| 5/5 [00:00<00:00, 11.90it/s, Loss=0.154]
Epoch 2082/100010: 100%|██████████| 5/5 [00:00<00:00, 11.19it/s, Loss=0.162]
Epoch 2083/100010: 100%|██████████| 5/5 [00:00<00:00, 11.04it/s, Loss=0.159]
Epoch 2084/100010: 100%|██████████| 5/5 [00:00<00:00, 10.89it/s, Loss=0.154]
Epoch 2085/100010: 100%|██████████| 5/5 [00:00<00:00, 11.35it/s, Loss=0.155]
Epoch 2086/100010: 100%|██████████| 5/5 [00:00<00:00, 10.69it/s, Loss=0.159]
Epoch 2087/100010: 100%|██████████| 5/5 [00:00<00:00, 11.48it/s, Loss=0.165]
Epoch 2088/100010: 100%|██████████| 5/5 [00:00<00:00, 11.08it/s, Loss=0.155]
Epoch 2089/100010: 100%|██████████| 5/5 [00:00<00:00, 11.90it/s, Loss=0.156]

Epoch 02127: reducing learning rate of group 0 to 6.4611e-05.


Epoch 2128/100010: 100%|██████████| 5/5 [00:00<00:00, 10.57it/s, Loss=0.162]
Epoch 2129/100010: 100%|██████████| 5/5 [00:00<00:00, 11.51it/s, Loss=0.16] 
Epoch 2130/100010: 100%|██████████| 5/5 [00:00<00:00, 11.10it/s, Loss=0.152]
Epoch 2131/100010: 100%|██████████| 5/5 [00:00<00:00, 11.14it/s, Loss=0.158]
Epoch 2132/100010: 100%|██████████| 5/5 [00:00<00:00, 12.43it/s, Loss=0.15] 
Epoch 2133/100010: 100%|██████████| 5/5 [00:00<00:00, 11.22it/s, Loss=0.152]
Epoch 2134/100010: 100%|██████████| 5/5 [00:00<00:00, 11.41it/s, Loss=0.153]
Epoch 2135/100010: 100%|██████████| 5/5 [00:00<00:00, 11.30it/s, Loss=0.156]
Epoch 2136/100010: 100%|██████████| 5/5 [00:00<00:00, 11.01it/s, Loss=0.162]
Epoch 2137/100010: 100%|██████████| 5/5 [00:00<00:00, 11.37it/s, Loss=0.159]
Epoch 2138/100010: 100%|██████████| 5/5 [00:00<00:00, 11.42it/s, Loss=0.159]
Epoch 2139/100010: 100%|██████████| 5/5 [00:00<00:00, 11.56it/s, Loss=0.155]
Epoch 2140/100010: 100%|██████████| 5/5 [00:00<00:00, 11.59it/s, Loss=0.153]

Epoch 02201: reducing learning rate of group 0 to 5.8150e-05.


Epoch 2202/100010: 100%|██████████| 5/5 [00:00<00:00, 11.81it/s, Loss=0.157]
Epoch 2203/100010: 100%|██████████| 5/5 [00:00<00:00, 10.68it/s, Loss=0.158]
Epoch 2204/100010: 100%|██████████| 5/5 [00:00<00:00, 12.79it/s, Loss=0.15] 
Epoch 2205/100010: 100%|██████████| 5/5 [00:00<00:00, 11.23it/s, Loss=0.152]
Epoch 2206/100010: 100%|██████████| 5/5 [00:00<00:00, 11.12it/s, Loss=0.156]
Epoch 2207/100010: 100%|██████████| 5/5 [00:00<00:00, 11.98it/s, Loss=0.16] 
Epoch 2208/100010: 100%|██████████| 5/5 [00:00<00:00, 11.32it/s, Loss=0.15] 
Epoch 2209/100010: 100%|██████████| 5/5 [00:00<00:00, 11.40it/s, Loss=0.156]
Epoch 2210/100010: 100%|██████████| 5/5 [00:00<00:00, 12.12it/s, Loss=0.153]
Epoch 2211/100010: 100%|██████████| 5/5 [00:00<00:00, 11.41it/s, Loss=0.159]
Epoch 2212/100010: 100%|██████████| 5/5 [00:00<00:00, 11.44it/s, Loss=0.151]
Epoch 2213/100010: 100%|██████████| 5/5 [00:00<00:00, 11.85it/s, Loss=0.152]
Epoch 2214/100010: 100%|██████████| 5/5 [00:00<00:00, 11.77it/s, Loss=0.153]

Epoch 02252: reducing learning rate of group 0 to 5.2335e-05.


Epoch 2253/100010: 100%|██████████| 5/5 [00:00<00:00, 11.46it/s, Loss=0.158]
Epoch 2254/100010: 100%|██████████| 5/5 [00:00<00:00, 11.65it/s, Loss=0.157]
Epoch 2255/100010: 100%|██████████| 5/5 [00:00<00:00, 11.90it/s, Loss=0.152]
Epoch 2256/100010: 100%|██████████| 5/5 [00:00<00:00, 11.24it/s, Loss=0.154]
Epoch 2257/100010: 100%|██████████| 5/5 [00:00<00:00, 11.59it/s, Loss=0.156]
Epoch 2258/100010: 100%|██████████| 5/5 [00:00<00:00, 11.93it/s, Loss=0.153]
Epoch 2259/100010: 100%|██████████| 5/5 [00:00<00:00,  9.10it/s, Loss=0.159]
Epoch 2260/100010: 100%|██████████| 5/5 [00:00<00:00, 11.18it/s, Loss=0.158]
Epoch 2261/100010: 100%|██████████| 5/5 [00:00<00:00,  9.13it/s, Loss=0.162]
Epoch 2262/100010: 100%|██████████| 5/5 [00:00<00:00, 11.09it/s, Loss=0.155]
Epoch 2263/100010: 100%|██████████| 5/5 [00:00<00:00, 11.29it/s, Loss=0.152]
Epoch 2264/100010: 100%|██████████| 5/5 [00:00<00:00, 10.76it/s, Loss=0.162]
Epoch 2265/100010: 100%|██████████| 5/5 [00:00<00:00, 11.32it/s, Loss=0.159]

Epoch 02303: reducing learning rate of group 0 to 4.7101e-05.


Epoch 2304/100010: 100%|██████████| 5/5 [00:00<00:00, 12.27it/s, Loss=0.159]
Epoch 2305/100010: 100%|██████████| 5/5 [00:00<00:00, 11.08it/s, Loss=0.156]
Epoch 2306/100010: 100%|██████████| 5/5 [00:00<00:00, 11.98it/s, Loss=0.158]
Epoch 2307/100010: 100%|██████████| 5/5 [00:00<00:00, 11.97it/s, Loss=0.151]
Epoch 2308/100010: 100%|██████████| 5/5 [00:00<00:00, 11.32it/s, Loss=0.162]
Epoch 2309/100010: 100%|██████████| 5/5 [00:00<00:00, 12.31it/s, Loss=0.161]
Epoch 2310/100010: 100%|██████████| 5/5 [00:00<00:00, 12.38it/s, Loss=0.149]
Epoch 2311/100010: 100%|██████████| 5/5 [00:00<00:00, 11.58it/s, Loss=0.159]
Epoch 2312/100010: 100%|██████████| 5/5 [00:00<00:00, 12.06it/s, Loss=0.157]
Epoch 2313/100010: 100%|██████████| 5/5 [00:00<00:00, 11.67it/s, Loss=0.148]
Epoch 2314/100010: 100%|██████████| 5/5 [00:00<00:00, 11.55it/s, Loss=0.157]
Epoch 2315/100010: 100%|██████████| 5/5 [00:00<00:00, 11.30it/s, Loss=0.159]
Epoch 2316/100010: 100%|██████████| 5/5 [00:00<00:00, 11.60it/s, Loss=0.16] 

Epoch 02354: reducing learning rate of group 0 to 4.2391e-05.


Epoch 2355/100010: 100%|██████████| 5/5 [00:00<00:00, 11.29it/s, Loss=0.157]
Epoch 2356/100010: 100%|██████████| 5/5 [00:00<00:00, 11.11it/s, Loss=0.161]
Epoch 2357/100010: 100%|██████████| 5/5 [00:00<00:00, 11.41it/s, Loss=0.154]
Epoch 2358/100010: 100%|██████████| 5/5 [00:00<00:00, 11.22it/s, Loss=0.155]
Epoch 2359/100010: 100%|██████████| 5/5 [00:00<00:00, 12.09it/s, Loss=0.158]
Epoch 2360/100010: 100%|██████████| 5/5 [00:00<00:00, 11.95it/s, Loss=0.148]
Epoch 2361/100010: 100%|██████████| 5/5 [00:00<00:00, 10.68it/s, Loss=0.152]
Epoch 2362/100010: 100%|██████████| 5/5 [00:00<00:00, 11.38it/s, Loss=0.155]
Epoch 2363/100010: 100%|██████████| 5/5 [00:00<00:00, 10.59it/s, Loss=0.159]
Epoch 2364/100010: 100%|██████████| 5/5 [00:00<00:00, 11.57it/s, Loss=0.162]
Epoch 2365/100010: 100%|██████████| 5/5 [00:00<00:00, 11.40it/s, Loss=0.156]
Epoch 2366/100010: 100%|██████████| 5/5 [00:00<00:00, 10.92it/s, Loss=0.146]
Epoch 2367/100010: 100%|██████████| 5/5 [00:00<00:00, 11.58it/s, Loss=0.157]

Epoch 02405: reducing learning rate of group 0 to 3.8152e-05.


Epoch 2406/100010: 100%|██████████| 5/5 [00:00<00:00, 11.35it/s, Loss=0.154]
Epoch 2407/100010: 100%|██████████| 5/5 [00:00<00:00, 10.72it/s, Loss=0.157]
Epoch 2408/100010: 100%|██████████| 5/5 [00:00<00:00, 12.41it/s, Loss=0.155]
Epoch 2409/100010: 100%|██████████| 5/5 [00:00<00:00, 11.55it/s, Loss=0.154]
Epoch 2410/100010: 100%|██████████| 5/5 [00:00<00:00, 11.32it/s, Loss=0.165]
Epoch 2411/100010: 100%|██████████| 5/5 [00:00<00:00, 11.05it/s, Loss=0.153]
Epoch 2412/100010: 100%|██████████| 5/5 [00:00<00:00, 11.20it/s, Loss=0.147]
Epoch 2413/100010: 100%|██████████| 5/5 [00:00<00:00, 11.83it/s, Loss=0.158]
Epoch 2414/100010: 100%|██████████| 5/5 [00:00<00:00, 10.50it/s, Loss=0.161]
Epoch 2415/100010: 100%|██████████| 5/5 [00:00<00:00, 11.52it/s, Loss=0.16] 
Epoch 2416/100010: 100%|██████████| 5/5 [00:00<00:00, 10.88it/s, Loss=0.153]
Epoch 2417/100010: 100%|██████████| 5/5 [00:00<00:00, 12.43it/s, Loss=0.145]
Epoch 2418/100010: 100%|██████████| 5/5 [00:00<00:00, 11.34it/s, Loss=0.159]

Epoch 02456: reducing learning rate of group 0 to 3.4337e-05.


Epoch 2457/100010: 100%|██████████| 5/5 [00:00<00:00, 11.41it/s, Loss=0.153]
Epoch 2458/100010: 100%|██████████| 5/5 [00:00<00:00, 11.42it/s, Loss=0.156]
Epoch 2459/100010: 100%|██████████| 5/5 [00:00<00:00, 11.54it/s, Loss=0.154]
Epoch 2460/100010: 100%|██████████| 5/5 [00:00<00:00, 11.89it/s, Loss=0.149]
Epoch 2461/100010: 100%|██████████| 5/5 [00:00<00:00, 11.38it/s, Loss=0.163]
Epoch 2462/100010: 100%|██████████| 5/5 [00:00<00:00, 11.67it/s, Loss=0.15] 
Epoch 2463/100010: 100%|██████████| 5/5 [00:00<00:00, 11.03it/s, Loss=0.154]
Epoch 2464/100010: 100%|██████████| 5/5 [00:00<00:00, 11.89it/s, Loss=0.153]
Epoch 2465/100010: 100%|██████████| 5/5 [00:00<00:00, 10.60it/s, Loss=0.156]
Epoch 2466/100010: 100%|██████████| 5/5 [00:00<00:00, 11.43it/s, Loss=0.162]
Epoch 2467/100010: 100%|██████████| 5/5 [00:00<00:00, 10.98it/s, Loss=0.152]
Epoch 2468/100010: 100%|██████████| 5/5 [00:00<00:00, 10.46it/s, Loss=0.159]
Epoch 2469/100010: 100%|██████████| 5/5 [00:00<00:00, 13.14it/s, Loss=0.153]

Epoch 02507: reducing learning rate of group 0 to 3.0903e-05.


Epoch 2508/100010: 100%|██████████| 5/5 [00:00<00:00, 11.33it/s, Loss=0.151]
Epoch 2509/100010: 100%|██████████| 5/5 [00:00<00:00, 10.93it/s, Loss=0.15] 
Epoch 2510/100010: 100%|██████████| 5/5 [00:00<00:00, 11.07it/s, Loss=0.151]
Epoch 2511/100010: 100%|██████████| 5/5 [00:00<00:00, 11.38it/s, Loss=0.154]
Epoch 2512/100010: 100%|██████████| 5/5 [00:00<00:00, 10.50it/s, Loss=0.154]
Epoch 2513/100010: 100%|██████████| 5/5 [00:00<00:00,  9.29it/s, Loss=0.152]
Epoch 2514/100010: 100%|██████████| 5/5 [00:00<00:00, 11.62it/s, Loss=0.16] 
Epoch 2515/100010: 100%|██████████| 5/5 [00:00<00:00, 11.54it/s, Loss=0.158]
Epoch 2516/100010: 100%|██████████| 5/5 [00:00<00:00, 11.06it/s, Loss=0.162]
Epoch 2517/100010: 100%|██████████| 5/5 [00:00<00:00, 10.84it/s, Loss=0.164]
Epoch 2518/100010: 100%|██████████| 5/5 [00:00<00:00, 10.34it/s, Loss=0.152]
Epoch 2519/100010: 100%|██████████| 5/5 [00:00<00:00, 10.70it/s, Loss=0.16] 
Epoch 2520/100010: 100%|██████████| 5/5 [00:00<00:00, 11.91it/s, Loss=0.153]

Epoch 02558: reducing learning rate of group 0 to 2.7813e-05.


Epoch 2559/100010: 100%|██████████| 5/5 [00:00<00:00, 11.55it/s, Loss=0.153]
Epoch 2560/100010: 100%|██████████| 5/5 [00:00<00:00, 11.86it/s, Loss=0.155]
Epoch 2561/100010: 100%|██████████| 5/5 [00:00<00:00, 10.97it/s, Loss=0.158]
Epoch 2562/100010: 100%|██████████| 5/5 [00:00<00:00, 11.06it/s, Loss=0.155]
Epoch 2563/100010: 100%|██████████| 5/5 [00:00<00:00, 11.43it/s, Loss=0.156]
Epoch 2564/100010: 100%|██████████| 5/5 [00:00<00:00, 11.10it/s, Loss=0.15] 
Epoch 2565/100010: 100%|██████████| 5/5 [00:00<00:00, 11.45it/s, Loss=0.161]
Epoch 2566/100010: 100%|██████████| 5/5 [00:00<00:00, 11.64it/s, Loss=0.159]
Epoch 2567/100010: 100%|██████████| 5/5 [00:00<00:00, 11.69it/s, Loss=0.145]
Epoch 2568/100010: 100%|██████████| 5/5 [00:00<00:00, 11.42it/s, Loss=0.158]
Epoch 2569/100010: 100%|██████████| 5/5 [00:00<00:00, 11.71it/s, Loss=0.151]
Epoch 2570/100010: 100%|██████████| 5/5 [00:00<00:00, 11.92it/s, Loss=0.153]
Epoch 2571/100010: 100%|██████████| 5/5 [00:00<00:00, 12.05it/s, Loss=0.155]

Epoch 02609: reducing learning rate of group 0 to 2.5032e-05.


Epoch 2610/100010: 100%|██████████| 5/5 [00:00<00:00, 11.38it/s, Loss=0.162]
Epoch 2611/100010: 100%|██████████| 5/5 [00:00<00:00, 11.63it/s, Loss=0.16] 
Epoch 2612/100010: 100%|██████████| 5/5 [00:00<00:00, 10.94it/s, Loss=0.158]
Epoch 2613/100010: 100%|██████████| 5/5 [00:00<00:00, 11.80it/s, Loss=0.148]
Epoch 2614/100010: 100%|██████████| 5/5 [00:00<00:00, 11.05it/s, Loss=0.151]
Epoch 2615/100010: 100%|██████████| 5/5 [00:00<00:00, 11.47it/s, Loss=0.153]
Epoch 2616/100010: 100%|██████████| 5/5 [00:00<00:00, 11.36it/s, Loss=0.154]
Epoch 2617/100010: 100%|██████████| 5/5 [00:00<00:00, 12.11it/s, Loss=0.153]
Epoch 2618/100010: 100%|██████████| 5/5 [00:00<00:00, 11.27it/s, Loss=0.15] 
Epoch 2619/100010: 100%|██████████| 5/5 [00:00<00:00, 11.01it/s, Loss=0.156]
Epoch 2620/100010: 100%|██████████| 5/5 [00:00<00:00, 11.69it/s, Loss=0.151]
Epoch 2621/100010: 100%|██████████| 5/5 [00:00<00:00, 11.47it/s, Loss=0.16] 
Epoch 2622/100010: 100%|██████████| 5/5 [00:00<00:00, 10.91it/s, Loss=0.158]

Early stopping
Time:  1213.4014220237732





## Load Pretrained Model

Instead of training model from scratch, we can also load weights of a pre-trained model from a given checkpoint with `load_model_state` function.
If we haven't instantiated the VAE and diffusion model beforehand, we need to instantiate them first using `instantiate_vae` and `instantiate_diffusion` methods.

In [None]:
latent_embeddings_path = os.path.join(MODEL_PATH, DATA_NAME, "vae")
pretrained_model_path = os.path.join(MODEL_PATH, DATA_NAME)

# instantiate VAE model
tabsyn.instantiate_vae(**raw_config["model_params"], optim_params=None)

# load latent embeddings of input data
train_z, token_dim = tabsyn.load_latent_embeddings(latent_embeddings_path)

# instantiate diffusion model
tabsyn.instantiate_diffusion(
    in_dim=train_z.shape[1], hid_dim=train_z.shape[1], optim_params=None
)

# load state from checkpoint
tabsyn.load_model_state(ckpt_dir=pretrained_model_path, dif_ckpt_name="model.pt")

## Sample Data

Now that we trained the model effectively, using `sample` function we can generate synthetic data starting from compelete noise. The input of this function is as follows:

1. `train_z`: latent embeddings of the training data.
2. `info`: info about the data from the json file we reviewed at the beginning of this notebook.
3. `num_inverse`: detokenizer for numerical features.
4. `cat_inverse`: detokenizer for categorical features.
5. `save_path`: file-path where the synthetic table will be saved.

In [None]:
# load data info file
with open(os.path.join(PROCESSED_DATA_DIR, DATA_NAME, "info.json"), "r") as file:
    data_info = json.load(file)
data_info["token_dim"] = token_dim

# get inverse tokenizers
_, _, categories, d_numerical, num_inverse, cat_inverse = preprocess(
    os.path.join(PROCESSED_DATA_DIR, DATA_NAME),
    ref_dataset_path=REF_DATA_PATH,
    transforms=raw_config["transforms"],
    task_type=raw_config["task_type"],
    inverse=True,
)

os.makedirs(os.path.join(SYNTH_DATA_DIR, DATA_NAME), exist_ok=True)

# sample data
num_samples = train_z.shape[0]
in_dim = train_z.shape[1]
mean_input_emb = train_z.mean(0)
tabsyn.sample(
    num_samples,
    in_dim,
    mean_input_emb,
    info=data_info,
    num_inverse=num_inverse,
    cat_inverse=cat_inverse,
    save_path=os.path.join(SYNTH_DATA_DIR, DATA_NAME, "tabsyn.csv"),
)

## Review Synthetic Data

Finally here, we review the synthesized data. In the following `evaluate_synthetic_data.ipynb` notebook, we will evaluate this synthesized data with respect to various metrics.

In [None]:
df = pd.read_csv(os.path.join(SYNTH_DATA_DIR, DATA_NAME, "tabsyn.csv"))
df.head(10)

## References

**Zhang, Hengrui, et al.** "Mixed-type tabular data synthesis with score-based diffusion in latent space." *International Conference on Learning Representations (ICLR)* (2023).

**GitHub Repository:** [Amazon Science - Tabsyn](https://github.com/amazon-science/tabsyn)