In [1]:
import h5py
import numpy as np
from dlchess.agents.zero import ZeroAgent
from dlchess.encoders.zero import ZeroEncoder
from dlchess.rl.experience import load_experience
from tensorflow.keras.models import load_model

weights = [0.25, 0.5, 0.75, 1, 1.25, 1.5, 2, 3, 4, 5, 10]
num_exp_files = 16

file_indices = np.random.permutation(range(num_exp_files))
results = {}

time: 1.28 s (started: 2021-05-27 16:27:15 -06:00)


In [2]:
for policy_weight in weights:
    for value_weight in weights:
        model = load_model("dlchess/models/zero_0.h5", compile=False)
        encoder = ZeroEncoder()
        agent = ZeroAgent(model, encoder)

        loss_weights = (policy_weight, value_weight)
        print(loss_weights)

        histories = [None] * num_exp_files
        for i in file_indices:
            experience = load_experience(
                h5py.File(f"data/zero_exp_0.{i}.h5", mode="r"), zero=True
            )
            history = agent.train(
                experience,
                epochs=1,
                learning_rate=0.01,
                batch_size=1024,
                loss_weights=list(loss_weights),
            )
            histories[i] = history.history

        results[loss_weights] = histories
        print("")
        print(histories)
        print("")

2.36170196533203], 'policy_output_loss': [5.196845054626465], 'value_output_loss': [0.1213928610086441]}, {'loss': [54.70576095581055], 'policy_output_loss': [5.436028957366943], 'value_output_loss': [0.08994394540786743]}]

(10, 2)

[{'loss': [57.286869049072266], 'policy_output_loss': [5.6813273429870605], 'value_output_loss': [0.13168664276599884]}, {'loss': [62.651344299316406], 'policy_output_loss': [6.220754623413086], 'value_output_loss': [0.11728399246931076]}, {'loss': [73.00775909423828], 'policy_output_loss': [7.263417720794678], 'value_output_loss': [0.08284103125333786]}, {'loss': [68.07415008544922], 'policy_output_loss': [6.773642539978027], 'value_output_loss': [0.06469705700874329]}, {'loss': [69.59111022949219], 'policy_output_loss': [6.921422004699707], 'value_output_loss': [0.08436447381973267]}, {'loss': [53.68414306640625], 'policy_output_loss': [5.330533504486084], 'value_output_loss': [0.0839906632900238]}, {'loss': [58.12631607055664], 'policy_output_loss': [5.

In [11]:
final_losses = {w: l[file_indices[-1]] for w, l in results.items()}

best_losses = {}
for weights, histories in results.items():
    lows = (
        min([x["loss"][0] for x in histories]),
        min([x["policy_output_loss"][0] for x in histories]),
        min([x["value_output_loss"][0] for x in histories]),
    )
    best_losses[weights] = lows

best_policy_weight = None
best_value_weight = None
best_total_weight = None
best_policy_loss = np.inf
best_value_loss = np.inf
best_total_loss = np.inf  # policy + value, not the calculated loss value
for weights, histories in results.items():
    for x in histories:
        p = x["policy_output_loss"][0]
        v = x["value_output_loss"][0]
        t = p + v
        if p < best_policy_loss:
            best_policy_loss = p
            best_policy_weight = weights
        if v < best_value_loss:
            best_value_loss = v
            best_value_weight = weights
        if t < best_total_loss:
            best_total_loss = t
            best_total_weight = weights

time: 4.81 ms (started: 2021-05-27 21:32:03 -06:00)


In [13]:
print("\nFinal Losses")
for k, v in final_losses.items():
    print(k, v, end="\n\n")


Final Losses
(0.25, 0.25) {'loss': [2.086679458618164], 'policy_output_loss': [7.463397979736328], 'value_output_loss': [0.05400645732879639]}

(0.25, 0.5) {'loss': [2.101181983947754], 'policy_output_loss': [7.467381000518799], 'value_output_loss': [0.054032377898693085]}

(0.25, 0.75) {'loss': [2.117943048477173], 'policy_output_loss': [7.470255374908447], 'value_output_loss': [0.05741892382502556]}

(0.25, 1) {'loss': [2.1303634643554688], 'policy_output_loss': [7.470057964324951], 'value_output_loss': [0.05553631857037544]}

(0.25, 1.25) {'loss': [2.1443867683410645], 'policy_output_loss': [7.469871997833252], 'value_output_loss': [0.05568673461675644]}

(0.25, 1.5) {'loss': [2.1588857173919678], 'policy_output_loss': [7.473797798156738], 'value_output_loss': [0.05541880428791046]}

(0.25, 2) {'loss': [2.188763380050659], 'policy_output_loss': [7.480268478393555], 'value_output_loss': [0.055697258561849594]}

(0.25, 3) {'loss': [2.2327218055725098], 'policy_output_loss': [7.483279

In [15]:
print("\nBest Losses")
for k, v in best_losses.items():
    print(k, v, end="\n\n")
print("")


Best Losses
(0.25, 0.25) (2.086679458618164, 7.463397979736328, 0.05400645732879639)

(0.25, 0.5) (2.101181983947754, 7.467381000518799, 0.054032377898693085)

(0.25, 0.75) (2.117943048477173, 7.470255374908447, 0.05741892382502556)

(0.25, 1) (2.1303634643554688, 7.470057964324951, 0.05553631857037544)

(0.25, 1.25) (2.1443867683410645, 7.469871997833252, 0.05568673461675644)

(0.25, 1.5) (2.1588857173919678, 7.473797798156738, 0.05541880428791046)

(0.25, 2) (2.188763380050659, 7.480268478393555, 0.055697258561849594)

(0.25, 3) (2.2327218055725098, 7.483279228210449, 0.05153488740324974)

(0.25, 4) (2.2889010906219482, 7.4825005531311035, 0.05273868516087532)

(0.25, 5) (2.332249641418457, 7.4874348640441895, 0.050612643361091614)

(0.25, 10) (2.947026491165161, 7.51134729385376, 0.0861768126487732)

(0.5, 0.25) (3.9079127311706543, 7.373918056488037, 0.054405033588409424)

(0.5, 0.5) (3.922377109527588, 7.375911712646484, 0.05414971709251404)

(0.5, 0.75) (3.9391419887542725, 7.37

In [16]:
print("Best Policy Weights", best_policy_weight, best_policy_loss)
print("Best Value Weights", best_value_weight, best_value_loss)
print("Best Total Weights", best_total_weight, best_total_loss)

Best Policy Weights (10, 0.5) 5.009276866912842
Best Value Weights (2, 10) 0.04264706000685692
Best Total Weights (10, 1) 5.0701508186757565
time: 548 µs (started: 2021-05-27 21:35:53 -06:00)


In [19]:
for i in range(num_exp_files):
    print(f"\nStep {i}")
    result = [(w, l[i]) for w, l in results.items()]
    result.sort(
        key=lambda x: x[1]["policy_output_loss"][0] + x[1]["value_output_loss"][0]
    )
    for r in result:
        print(r[0], r[1])



Step 0
(10, 0.75) {'loss': [56.78391647338867], 'policy_output_loss': [5.647197246551514], 'value_output_loss': [0.13552957773208618]}
(10, 0.5) {'loss': [56.73786544799805], 'policy_output_loss': [5.645760536193848], 'value_output_loss': [0.13987991213798523]}
(10, 1) {'loss': [56.89402770996094], 'policy_output_loss': [5.655165672302246], 'value_output_loss': [0.13210433721542358]}
(10, 1.25) {'loss': [57.058433532714844], 'policy_output_loss': [5.668367385864258], 'value_output_loss': [0.131608247756958]}
(10, 2) {'loss': [57.286869049072266], 'policy_output_loss': [5.6813273429870605], 'value_output_loss': [0.13168664276599884]}
(10, 1.5) {'loss': [57.24113464355469], 'policy_output_loss': [5.683352470397949], 'value_output_loss': [0.13158363103866577]}
(10, 0.25) {'loss': [56.84132766723633], 'policy_output_loss': [5.659201622009277], 'value_output_loss': [0.15591876208782196]}
(10, 3) {'loss': [57.70501708984375], 'policy_output_loss': [5.709357261657715], 'value_output_loss': [