In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys, os, pickle, time, datetime
from copy import deepcopy

import numpy as np
import matplotlib.pyplot as plt

from functools import partial
from p_tqdm import p_uimap

from scipy.ndimage.filters import uniform_filter1d

from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit
from qiskit.circuit import Parameter
from qiskit.tools.visualization import circuit_drawer
from qiskit.providers.aer import QasmSimulator

from circuit_builder import CircuitBuilder

from checkers import Checkers
from agents import QRQLAgent

In [3]:
backend = QasmSimulator(method='statevector', precision='single')

In [3]:
from stats import EpisodeStats

globals()['EpisodeStats'] = EpisodeStats

def load_stats(stats_name):
    stats = None
    
    if os.path.isfile(stats_name + '.pkl'):
        with open(stats_name + '.pkl', 'rb') as f:
            stats = pickle.load(f)
            f.close()

    return stats

def save_stats(stats_name, stats):
    with open(stats_name + '.pkl', 'wb') as f:
        pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL)
        f.close()

In [4]:
def base_run(env, num_episodes, agent):
    from copy import deepcopy
    stats = []
    n, a = agent
    for i in range(5):
        stats.append(deepcopy(a).train(deepcopy(env), num_episodes))
    return n, stats

In [None]:
# Classical version
#'stats/absolute/6x6_greedy/rql_01': QRQLAgent(0, 0.20, 1.00, 5, 'classical'),
#'stats/absolute/6x6_greedy/rql_02': QRQLAgent(0, 0.20, 1.00, 10, 'classical'),
#'stats/absolute/6x6_greedy/rql_03': QRQLAgent(0, 0.40, 1.00, 5, 'classical'),
#'stats/absolute/6x6_greedy/rql_04': QRQLAgent(0, 0.40, 1.00, 10, 'classical'),
#'stats/absolute/6x6_greedy/rql_05': QRQLAgent(0, 0.60, 1.00, 5, 'classical'),
#'stats/absolute/6x6_greedy/rql_06': QRQLAgent(0, 0.60, 1.00, 10, 'classical'),
#'stats/absolute/6x6_greedy/rql_07': QRQLAgent(0, 0.80, 1.00, 5, 'classical'),
#'stats/absolute/6x6_greedy/rql_08': QRQLAgent(0, 0.80, 1.00, 10, 'classical'),
#'stats/absolute/8x8_greedy/rql_09': QRQLAgent(0, 0.40, 1.00, 100, 'classical'),
#'stats/absolute/8x8_greedy/rql_10': QRQLAgent(3, 0.40, 1.00, 1000, 'classical'),

# Quantum version
#'stats/absolute/6x6_greedy/qrql_01': QRQLAgent(1, 0.20, 1.00, 5, 'quantum'),
#'stats/absolute/6x6_greedy/qrql_02': QRQLAgent(1, 0.20, 1.00, 10, 'quantum'),
#'stats/absolute/6x6_greedy/qrql_03': QRQLAgent(1, 0.40, 1.00, 5, 'quantum'),
#'stats/absolute/6x6_greedy/qrql_04': QRQLAgent(1, 0.40, 1.00, 10, 'quantum')
#'stats/absolute/6x6_greedy/qrql_05': QRQLAgent(0, 0.60, 1.00, 5, 'quantum'),
#'stats/absolute/6x6_greedy/qrql_06': QRQLAgent(0, 0.60, 1.00, 10, 'quantum'),
#'stats/absolute/6x6_greedy/qrql_07': QRQLAgent(0, 0.80, 1.00, 5, 'quantum'),
#'stats/absolute/6x6_greedy/qrql_08': QRQLAgent(0, 0.80, 1.00, 10, 'quantum'),
#'stats/absolute/8x8_greedy/qrql_09': QRQLAgent(0, 0.40, 1.00, 100, 'quantum'),
#'stats/absolute/8x8_greedy/qrql_10': QRQLAgent(7, 0.40, 1.00, 1000, 'quantum'),

In [None]:
env = Checkers(shape=(6,6), opponent='optimal3', absolute=True)

# tau = 0.2 + (20 - 0.2) / (1 + math.e**(0.5*(i_episode / 1000)))

agents = {
    # Classical version
    'stats/absolute/6x6_optimal3/rql_03': QRQLAgent(0, 0.40, 1.00, 5, '6x6', 'classical'),
    'stats/absolute/6x6_optimal3/rql_04': QRQLAgent(0, 0.40, 1.00, 10, '6x6', 'classical'),

    # Quantum version
    'stats/absolute/6x6_optimal3/qrql_03': QRQLAgent(1, 0.40, 1.00, 5, '6x6', 'quantum'),
    'stats/absolute/6x6_optimal3/qrql_04': QRQLAgent(1, 0.40, 1.00, 10, '6x6', 'quantum'),
}

start = time.time()
        
for n, stats in p_uimap(partial(base_run, env, 20000), list(agents.items()), num_cpus=8):
    for i, stat in enumerate(stats):
        save_stats(n + f'_{i}', stat)
    
end = time.time()

print("Training time:", str(datetime.timedelta(seconds=end - start)))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))

In [None]:
env = Checkers(shape=(6,6), opponent='optimal5', absolute=True)

# tau = 0.2 + (20 - 0.2) / (1 + math.e**(0.5*(i_episode / 1000)))

agents = {
    # Classical version
    'stats/absolute/6x6_optimal5/rql_03': QRQLAgent(0, 0.40, 1.00, 5, '6x6', 'classical'),
    'stats/absolute/6x6_optimal5/rql_04': QRQLAgent(0, 0.40, 1.00, 10, '6x6', 'classical'),

    # Quantum version
    'stats/absolute/6x6_optimal5/qrql_03': QRQLAgent(1, 0.40, 1.00, 5, '6x6', 'quantum'),
    'stats/absolute/6x6_optimal5/qrql_04': QRQLAgent(1, 0.40, 1.00, 10, '6x6', 'quantum'),
}

start = time.time()
        
for n, stats in p_uimap(partial(base_run, env, 20000), list(agents.items()), num_cpus=8):
    for i, stat in enumerate(stats):
        save_stats(n + f'_{i}', stat)
    
end = time.time()

print("Training time:", str(datetime.timedelta(seconds=end - start)))

In [None]:
env = Checkers(shape=(8,8), opponent='greedy_to_optimal3', absolute=True)

agents = {
    # Classical version
    #'stats/absolute/8x8_greedy/rql_03': QRQLAgent(0, 0.40, 1.00, 5, '8x8', classical'),
    'stats/absolute/8x8_changing3/rql_04': QRQLAgent(1, 0.40, 1.00, 10, '8x8_changing', 'classical'),
    'stats/absolute/8x8_changing3/rql_09': QRQLAgent(2, 0.40, 1.00, 100, '8x8_changing', 'classical'),
    #'stats/absolute/8x8_greedy/rql_10': QRQLAgent(3, 0.40, 1.00, 1000, '8x8', 'classical'),
    
    # Quantum version
    #'stats/absolute/8x8_greedy/qrql_03': QRQLAgent(4, 0.40, 1.00, 5, 'quantum'),
    'stats/absolute/8x8_changing3/qrql_04': QRQLAgent(3, 0.40, 1.00, 10, '8x8_changing', 'quantum'),
    'stats/absolute/8x8_changing3/qrql_09': QRQLAgent(4, 0.40, 1.00, 100, '8x8_changing', 'quantum'),
    #'stats/absolute/8x8_greedy/qrql_10': QRQLAgent(7, 0.40, 1.00, 1000, 'quantum'),
}

start = time.time()
        
for n, stats in p_uimap(partial(base_run, env, 37500), list(agents.items()), num_cpus=4, disable=None):
    for i, stat in enumerate(stats):
        save_stats(n + f'_{5 + i}', stat)
    
end = time.time()

print("Training time:", str(datetime.timedelta(seconds=end - start)))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4.0), HTML(value='')))