In [1]:
import shutil
import os
import json
import torch
import numpy as np
import perceval as pcvl
import torchvision.transforms as transforms

from torch.utils.data import RandomSampler
from tqdm.notebook import tqdm

import sys; sys.path.insert(0, '..')
from models.qgan import QGAN
from helpers.data.digits import DigitsDataset

In [2]:
# definitions and constants
image_size = 8
batch_size = 4
lossy = True
write_to_disk=True

# optimization params
spsa_iter_num = 10500
opt_iter_num = 1500
lrD = 0.0015
opt_params={"spsa_iter_num": spsa_iter_num, "opt_iter_num": opt_iter_num}

In [3]:
# define desired run configurations
config = {
    "noise_dim": 1, 
    "arch": ["var", "var", "enc[2]", "var", "var"],
    "input_state": [0, 1, 0, 1, 0],
    "gen_count": 8,
    "pnr": False
}

# {"noise_dim": 2, "arch": ["var", "var", "enc[1, 4]", "var", "var"], "input_state": [0, 1, 0, 0, 1, 0], "gen_count": 4, "pnr": False}


runs = 10

In [4]:
dataset = DigitsDataset(csv_file="../helpers/data/optdigits_csv.csv", transform = transforms.Compose([transforms.ToTensor()]))
sampler = RandomSampler(dataset, replacement=True, num_samples=batch_size * opt_iter_num)
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, drop_last=True, sampler=sampler
)

In [5]:
# clear save path for the results
path = "./noisy/"

if os.path.isdir(path):
    shutil.rmtree(path)
os.makedirs(path)

with open(os.path.join(path, "config.json"), "w") as f:
    f.write(json.dumps(config))

gen_arch = config["arch"]
noise_dim = config["noise_dim"]
input_state = config["input_state"]
pnr = config["pnr"]
gen_count = config["gen_count"]

run_num = 0
# several runs to average over
for i in tqdm(range(1000), desc="run", position=0, leave=False):
    if run_num == runs:
        break
    run_num += 1

    save_path = path + "run_" + str(run_num)
    os.makedirs(save_path)
    try:
        qgan = QGAN(
            image_size,
            gen_count,
            gen_arch,
            pcvl.BasicState(input_state),
            noise_dim,
            batch_size,
            pnr,
            lossy
        )
        (
            D_loss_progress,
            G_loss_progress,
            G_params_progress,
            fake_data_progress,
        ) = qgan.fit(
            tqdm(dataloader, desc="iter", position=2, leave=False),
            lrD,
            opt_params,
            silent=True,
        )

        if write_to_disk:
            np.savetxt(
                os.path.join(save_path, "fake_progress.csv"),
                fake_data_progress,
                delimiter=",",
            )
            np.savetxt(
                os.path.join(save_path, "loss_progress.csv"),
                np.array(np.array([D_loss_progress, G_loss_progress]).transpose()),
                delimiter=",",
                header="D_loss, G_loss",
            )
            np.savetxt(
                os.path.join(save_path, "G_params_progress.csv"),
                np.array(G_params_progress),
                delimiter=",",
            )

    except Exception as exc:
        print(exc)
        shutil.rmtree(save_path)
        run_num -= 1



run:   0%|          | 0/1000 [00:00<?, ?it/s]

iter:   0%|          | 0/1500 [00:00<?, ?it/s]

iter:   0%|          | 0/1500 [00:00<?, ?it/s]

iter:   0%|          | 0/1500 [00:00<?, ?it/s]

iter:   0%|          | 0/1500 [00:00<?, ?it/s]

iter:   0%|          | 0/1500 [00:00<?, ?it/s]

iter:   0%|          | 0/1500 [00:00<?, ?it/s]

iter:   0%|          | 0/1500 [00:00<?, ?it/s]

iter:   0%|          | 0/1500 [00:00<?, ?it/s]

iter:   0%|          | 0/1500 [00:00<?, ?it/s]

iter:   0%|          | 0/1500 [00:00<?, ?it/s]