In [1]:
import shutil
import json
import os
import torch
import numpy as np
import perceval as pcvl
import torchvision.transforms as transforms

from torch.utils.data import RandomSampler
from tqdm.notebook import tqdm

import sys; sys.path.insert(0, '..')
from models.qgan import QGAN
from helpers.data.digits import DigitsDataset

In [2]:
# definitions and constants
image_size = 8
batch_size = 4
lossy = False
write_to_disk = True

# optimization params
spsa_iter_num = 10500
opt_iter_num = 1500
lrD = 0.002
opt_params={"spsa_iter_num": spsa_iter_num, "opt_iter_num": opt_iter_num, }

In [3]:
# define desired run configurqtions
arch_grid_45modes = [
    {
        "noise_dim": 1,
        "arch": ["var", "enc[2]", "var"],
    },
    {
        "noise_dim": 1,
        "arch": ["var", "var", "enc[2]", "var", "var"]
    },
    {
        "noise_dim": 2,
        "arch": ["var", "enc[1]", "var", "enc[3]", "var"],
    },
    {
        "noise_dim": 2,
        "arch": ["var", "var", "enc[1]", "var", "var", "enc[3]", "var", "var"],
    }
]

input_grid_4modes = [
    {
        "input_state": [1, 1, 1, 1],
        "gen_count": 2,
        "pnr": True
    },
    {   
        "input_state": [1, 1, 1, 1],
        "gen_count": 4,
        "pnr": False
    },
    {
        "input_state": [1, 0, 1, 1],
        "gen_count": 4,
        "pnr": True
    },
]

arch_grid_5modes = [
    {
        "noise_dim": 3,
        "arch": ["var", "enc[0]", "var", "enc[2]", "var", "enc[4]", "var"],
    }
]

input_grid_5modes = [
    {
        "input_state": [0, 1, 0, 1, 0],
        "gen_count": 4,
        "pnr": False
    },
    {
        "input_state": [1, 0, 1, 0, 1],
        "gen_count": 2,
        "pnr": True
    }
]


arch_grid_8modes = [
    {
        "noise_dim": 1,
        "arch": ["var", "enc[4]", "var"],
    },
    {
        "noise_dim": 1,
        "arch": ["var", "var", "enc[4]", "var", "var"]
    },
    {
        "noise_dim": 2,
        "arch": ["var", "enc[2]", "var", "enc[5]", "var"],
    },
    {
        "noise_dim": 2,
        "arch": ["var", "var", "enc[2]", "var", "var", "enc[5]", "var", "var"],
    },
    {
        "noise_dim": 3,
        "arch": ["var", "enc[1]", "var", "enc[4]", "var", "enc[6]", "var"],
    },
    {
        "noise_dim": 4,
        "arch": ["var", "enc[1]", "var", "enc[3]", "var", "enc[5]", "var", "enc[7]", "var"],
    },
]

input_grid_8modes = [
    {
        "input_state": [0, 0, 1, 0, 0, 1, 0, 0],
        "gen_count": 2,
        "pnr": False
    }
]

config_grid = []

for inp in input_grid_5modes:
    for arch in (arch_grid_45modes + arch_grid_5modes):
        config = inp.copy()
        config.update(arch)
        config_grid.append(config)

for inp in input_grid_4modes:
    for arch in arch_grid_45modes:
        config = inp.copy()
        config.update(arch)
        config_grid.append(config)
        
for inp in input_grid_8modes:
    for arch in arch_grid_8modes:
        config = inp.copy()
        config.update(arch)
        config_grid.append(config)

# number of runs for each config combination
runs = 10

In [4]:
dataset = DigitsDataset(csv_file="../helpers/data/optdigits_csv.csv", transform = transforms.Compose([transforms.ToTensor()]))
sampler = RandomSampler(dataset, replacement=True, num_samples=batch_size * opt_iter_num)
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, drop_last=True, sampler=sampler
)

In [None]:
path = "./ideal/"

if os.path.isdir(path):
    shutil.rmtree(path)

for config_num, config in enumerate(
    tqdm(config_grid, desc="config", position=0, leave=False)
):
    config_path = path + "config_" + str(config_num)
    if os.path.isdir(config_path):
        continue
    os.makedirs(config_path)

    with open(os.path.join(config_path, "config.json"), "w") as f:
        f.write(json.dumps(config))

    gen_arch = config["arch"]
    noise_dim = config["noise_dim"]
    input_state = config["input_state"]
    pnr = config["pnr"]
    gen_count = config["gen_count"]

    run_num = 0
    # several runs to average over
    for i in tqdm(range(1000), desc="run", position=1, leave=False):
        if run_num == runs:
            break
        run_num += 1

        save_path = config_path + "/run_" + str(run_num)
        os.makedirs(save_path)
        try:
            qgan = QGAN(
                image_size,
                gen_count,
                gen_arch,
                pcvl.BasicState(input_state),
                noise_dim,
                batch_size,
                pnr,
                lossy
            )
            (
                D_loss_progress,
                G_loss_progress,
                G_params_progress,
                fake_data_progress,
            ) = qgan.fit(
                tqdm(dataloader, desc="iter", position=2, leave=False),
                lrD,
                opt_params,
                silent=True,
            )

            if write_to_disk:
                np.savetxt(
                    os.path.join(save_path, "fake_progress.csv"),
                    fake_data_progress,
                    delimiter=",",
                )
                np.savetxt(
                    os.path.join(save_path, "loss_progress.csv"),
                    np.array(np.array([D_loss_progress, G_loss_progress]).transpose()),
                    delimiter=",",
                    header="D_loss, G_loss",
                )
                np.savetxt(
                    os.path.join(save_path, "G_params_progress.csv"),
                    np.array(G_params_progress),
                    delimiter=",",
                )

        except Exception as exc:
            print(exc)
            shutil.rmtree(save_path)
            run_num -= 1
