In [1]:
import os
import time

from functools import partial

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc

import tensorcircuit as tc
from opt_einsum import contract

import jax
from jax import numpy as jnp
from jax import random

import optax

from src.QDDPM_jax import QDDPM, HaarSampleGeneration
from src.distance_jax import naturalDistance, WassDistance, sinkhornDistance

rc('text', usetex=True)
rc('axes', linewidth=3)

Please first ``pip install -U qiskit`` to enable related functionality in translation module
Please first ``pip install -U cirq`` to enable related functionality in translation module


In [7]:
# check device
print(jax.lib.xla_bridge.get_backend().platform)

cpu


In [3]:
def Training_t(model, t, inputs_T, params_tot, epochs, dis_measure='wd'):
    '''
    training for the backward PQC at step t using whole dataset
    Args:
    model: QDDPM model
    t: diffusion step
    params_tot: collection of PQC parameters for steps > t 
    epochs: number of iterations
    dis_measure: the distance measure to compare two distributions of quantum states
    dis_params: potential hyper-parameters for distance measure
    '''
    Ndata = inputs_T.shape[0]

    input_tplus1 = model.prepareInput_t(
        inputs_T, params_tot, t, Ndata)  # prepare input
    states_diff = model.states_diff
    loss_hist = []  # record of training history

    # initialize parameters
    key = random.PRNGKey(np.random.randint(low=0, high=10000))
    param_shape = 2 * model.n_tot * model.L
    params_t = random.normal(key, shape=(param_shape, ))

    # set optimizer and learning rate decay
    optimizer = optax.adam(learning_rate=5e-4)
    opt_state = optimizer.init(params_t)

    if dis_measure == 'nat':
        def loss_func(params_t, input_tplus1, true_data):
            output_t = model.backwardOutput_t(input_tplus1, params_t)
            loss = naturalDistance(output_t, true_data)

            return loss

    elif dis_measure == 'wd':
        def loss_func(params_t, input_tplus1, true_data):
            output_t = model.backwardOutput_t(input_tplus1, params_t)
            loss = sinkhornDistance(output_t, true_data, reg=0.01)

            return loss

    loss_func_vg = jax.jit(jax.value_and_grad(loss_func))
    # @partial(jax.jit, static_argnums=(2, ))

    def update(params_t, input_tplus1, true_data, opt_state):
        loss_value, grads = loss_func_vg(params_t, input_tplus1, true_data)
        updates, new_opt_state = optimizer.update(grads, opt_state, params_t)
        new_params_t = optax.apply_updates(params_t, updates)

        return new_params_t, new_opt_state, loss_value

    t0 = time.time()
    for step in range(epochs):
        np.random.seed()
        indices = np.random.choice(
            states_diff.shape[1], size=Ndata, replace=False)
        true_data = states_diff[t, indices]

        params_t, opt_state, loss_value = update(
            params_t, input_tplus1, true_data, opt_state)

        if step % 100 == 0:
            print("Step {}, loss: {:.7f}, time elapsed: {:.4f} seconds".format(
                step, loss_value, time.time() - t0))

        loss_hist.append(loss_value)  # record the current loss

    return params_t, loss_hist

In [4]:
n, na = 6, 4
T = 20
L = 12

Ndata = 10000
epochs = 1201

inputs_T = HaarSampleGeneration(Ndata, 2 ** n, seed=42)
states_diff = jnp.load('data/cluster/cluster0Diff_n6T%d_N10000.npy' % T)

In [5]:
model = QDDPM(n=n, na=na, T=T, L=L)
model.set_diffusionSet(states_diff)
params_tot = jnp.zeros((T, 2 * (n + na) * L))

In [6]:
params, loss = Training_t(model, T - 1, inputs_T, params_tot, epochs, 'nat')

Step 0, loss: 0.0015049, time elapsed: 25.7046 seconds


KeyboardInterrupt: 