In [1]:
%load_ext autoreload
%autoreload 2

import torch
from alpha_connect import AlphaZeroModelConnect4, Trainer
import os

In [2]:
device = torch.device("mps")

args = {
    "batch_size": 128,
    "numIters": 500,  # Total number of training iterations
    "num_simulations": 500,  # Total number of MCTS simulations to run when deciding on a move to play
    "numEps": 125,  # Number of full games (episodes) to run during each iteration for each thread
    "numThreads": 1,  # Number of threads running simulations
    "numItersForTrainExamplesHistory": 20,
    "epochs": 3,  # Number of epochs of training per iteration
    "checkpoint_path": "../data/latest.pth",  # location to save latest set of weights
    "loss_history_path": "../data/loss_history.csv",  # location to save loss history
    "lr": 0.00017,  # learning rate
    "lr_decay": 0.99,  # learning rate decay
    "temperature": 1,  #
}

if os.path.exists(args["checkpoint_path"]):
    model = AlphaZeroModelConnect4()
    model.load_state_dict(torch.load(args["checkpoint_path"]))
    print("Loaded model from checkpoint")
else:
    model = AlphaZeroModelConnect4()
model.to(device)

trainer = Trainer(model, args)
trainer.learn()

Loaded model from checkpoint
Resuming training from iteration 126 with best loss 0.3676766900932544
126/500 with lr 0.00029699999999999996


100%|██████████| 125/125 [37:04<00:00, 17.80s/it] 



Loss: 0.9006267404465964
Policy Loss 0.4265494190833785
Value Loss 0.4740773213632179
33 batches processed
Examples:
tensor([ 0.0236,  1.0126, -0.0094,  0.0079, -0.0059, -0.0151, -0.0112])
tensor([0., 1., 0., 0., 0., 0., 0.])
tensor([0.2413])
tensor(0.)
127/500 with lr 0.00029403


100%|██████████| 125/125 [35:39<00:00, 17.12s/it] 



Loss: 0.913058896549046
Policy Loss 0.4262174293398857
Value Loss 0.48684146720916033
32 batches processed
Examples:
tensor([-0.0224,  0.0032,  0.9458,  0.0831,  0.0032, -0.0042,  0.0058])
tensor([0.0020, 0.0000, 0.9940, 0.0020, 0.0000, 0.0000, 0.0020])
tensor([0.0786])
tensor(1.)
128/500 with lr 0.0002910897


100%|██████████| 125/125 [38:36<00:00, 18.54s/it] 



Loss: 0.627486817849179
Policy Loss 0.2793603828176856
Value Loss 0.3481264350314935
36 batches processed
Examples:
tensor([ 0.3026,  0.2376,  0.0149,  0.0045,  0.4432, -0.0053, -0.0098])
tensor([0.3948, 0.3226, 0.0240, 0.0000, 0.2545, 0.0020, 0.0020])
tensor([0.0957])
tensor(0.)
129/500 with lr 0.000288178803


100%|██████████| 125/125 [38:27<00:00, 18.46s/it] 



Loss: 0.5619053210624877
Policy Loss 0.24591167632709532
Value Loss 0.3159936447353924
34 batches processed
Examples:
tensor([ 2.8981e-02, -1.0087e-02, -1.0471e-03,  1.0076e+00,  3.7242e-03,
        -1.0625e-04,  8.4600e-03])
tensor([0.0020, 0.0000, 0.0000, 0.9980, 0.0000, 0.0000, 0.0000])
tensor([-0.2708])
tensor(-1.)
130/500 with lr 0.00028529701496999996


100%|██████████| 125/125 [37:01<00:00, 17.77s/it] 



Loss: 0.8569464498397076
Policy Loss 0.34474684568968683
Value Loss 0.5121996041500207
33 batches processed
Examples:
tensor([ 0.0191,  0.0029,  0.2285, -0.0073,  0.0074,  0.3794,  0.3616])
tensor([0.0180, 0.0020, 0.6192, 0.0000, 0.0060, 0.0281, 0.3267])
tensor([-0.2278])
tensor(-1.)
131/500 with lr 0.00028244404482029995


100%|██████████| 125/125 [35:42<00:00, 17.14s/it] 



Loss: 0.9219532429706305
Policy Loss 0.4373978634830564
Value Loss 0.4845553794875741
32 batches processed
Examples:
tensor([ 0.0803,  0.0050, -0.0031,  0.7908,  0.1099,  0.0193,  0.0067])
tensor([0.0240, 0.0000, 0.0080, 0.9519, 0.0120, 0.0000, 0.0040])
tensor([-0.0789])
tensor(0.)
132/500 with lr 0.00027961960437209696


100%|██████████| 125/125 [37:07<00:00, 17.82s/it]  



Loss: 0.8166313956025988
Policy Loss 0.38505795621313155
Value Loss 0.43157343938946724
32 batches processed
Examples:
tensor([ 0.0065, -0.0097,  0.0040,  0.9872,  0.0080,  0.0174, -0.0059])
tensor([0.0040, 0.0020, 0.0000, 0.9900, 0.0020, 0.0000, 0.0020])
tensor([0.1160])
tensor(-1.)
133/500 with lr 0.000276823408328376


100%|██████████| 125/125 [37:41<00:00, 18.09s/it] 



Loss: 0.9772176196908249
Policy Loss 0.4163011993993731
Value Loss 0.5609164202914518
34 batches processed
Examples:
tensor([ 0.8842, -0.0063, -0.0045, -0.0198,  0.0024, -0.0154,  0.1486])
tensor([0.4569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5431])
tensor([0.8640])
tensor(1.)
134/500 with lr 0.0002740551742450922


100%|██████████| 125/125 [34:57<00:00, 16.78s/it] 



Loss: 0.8884922069807848
Policy Loss 0.3833339206874371
Value Loss 0.5051582862933477
30 batches processed
Examples:
tensor([ 0.1045, -0.0276,  0.1966,  0.0054,  0.5890, -0.0081,  0.1267])
tensor([0.0080, 0.0000, 0.0060, 0.0000, 0.9800, 0.0000, 0.0060])
tensor([0.7455])
tensor(1.)
135/500 with lr 0.00027131462250264127


100%|██████████| 125/125 [32:37<00:00, 15.66s/it] 



Loss: 0.8187352012971352
Policy Loss 0.41989403603405784
Value Loss 0.39884116526307734
29 batches processed
Examples:
tensor([ 0.5415,  0.0778, -0.0103,  0.0675,  0.2726,  0.0397, -0.0050])
tensor([0.3267, 0.1343, 0.0020, 0.0140, 0.3327, 0.1784, 0.0120])
tensor([-0.8583])
tensor(-1.)
136/500 with lr 0.0002686014762776149


100%|██████████| 125/125 [40:29<00:00, 19.43s/it] 



Loss: 0.5361289575489031
Policy Loss 0.25939156394451857
Value Loss 0.27673739360438454
36 batches processed
Examples:
tensor([-8.1093e-04,  1.0000e+00,  2.2496e-03,  3.6439e-03,  1.6963e-03,
         2.2233e-03,  2.3692e-03])
tensor([0.0000, 0.9980, 0.0000, 0.0000, 0.0020, 0.0000, 0.0000])
tensor([-0.0261])
tensor(0.)
137/500 with lr 0.00026591546151483875


100%|██████████| 125/125 [38:31<00:00, 18.49s/it] 



Loss: 0.7485832793309408
Policy Loss 0.3473123103818473
Value Loss 0.40127096894909353
34 batches processed
Examples:
tensor([0.1574, 0.0552, 0.1250, 0.0667, 0.1835, 0.3347, 0.1390])
tensor([0.1743, 0.0261, 0.0140, 0.0060, 0.0421, 0.7094, 0.0281])
tensor([-0.8521])
tensor(-1.)
138/500 with lr 0.00026325630689969036


100%|██████████| 125/125 [35:44<00:00, 17.16s/it] 



Loss: 0.7977274258155376
Policy Loss 0.3781974215526134
Value Loss 0.4195300042629242
32 batches processed
Examples:
tensor([ 1.3987e-02, -6.1406e-04,  1.8779e-01,  3.8464e-03,  1.7399e-02,
         6.2851e-01,  1.9398e-01])
tensor([0.0020, 0.0020, 0.0681, 0.0080, 0.0080, 0.8758, 0.0361])
tensor([-0.3418])
tensor(-1.)
139/500 with lr 0.00026062374383069347


100%|██████████| 125/125 [36:06<00:00, 17.33s/it] 



Loss: 0.9813907277403455
Policy Loss 0.4359981891783801
Value Loss 0.5453925385619655
33 batches processed
Examples:
tensor([0.0250, 0.0005, 0.1235, 0.0578, 0.0014, 0.3313, 0.4486])
tensor([0.0060, 0.0020, 0.0160, 0.0060, 0.0020, 0.5070, 0.4609])
tensor([-0.5782])
tensor(-1.)
140/500 with lr 0.00025801750639238655


100%|██████████| 125/125 [37:56<00:00, 18.22s/it] 



Loss: 0.7299395109362462
Policy Loss 0.3697920803199796
Value Loss 0.3601474306162666
34 batches processed
Examples:
tensor([-0.0199,  1.0525, -0.0164, -0.0060, -0.0219,  0.0106, -0.0014])
tensor([0.0020, 0.9980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])
tensor([-0.0017])
tensor(0.)
141/500 with lr 0.0002554373313284627


100%|██████████| 125/125 [36:04<00:00, 17.32s/it] 



Loss: 0.8684017595369369
Policy Loss 0.38238876429386437
Value Loss 0.4860129952430725
32 batches processed
Examples:
tensor([0.0708, 0.0838, 0.1930, 0.0237, 0.5846, 0.0483, 0.0442])
tensor([0.0080, 0.0040, 0.3006, 0.0060, 0.6814, 0.0000, 0.0000])
tensor([-0.6754])
tensor(-1.)
142/500 with lr 0.00025288295801517805


100%|██████████| 125/125 [31:17<00:00, 15.02s/it]



Loss: 0.8014484625309706
Policy Loss 0.3593368979969195
Value Loss 0.4421115645340511
28 batches processed
Examples:
tensor([ 0.0074, -0.0306,  0.1473,  0.0061,  0.8420,  0.0027, -0.0052])
tensor([0.0000, 0.0000, 0.1283, 0.0000, 0.8717, 0.0000, 0.0000])
tensor([0.1090])
tensor(0.)
143/500 with lr 0.0002503541284350263


100%|██████████| 125/125 [33:56<00:00, 16.29s/it] 



Loss: 0.8275584797064464
Policy Loss 0.3978290766477585
Value Loss 0.42972940305868784
30 batches processed
Examples:
tensor([0.1218, 0.1163, 0.0128, 0.0380, 0.0141, 0.6014, 0.1105])
tensor([0.0180, 0.0421, 0.0000, 0.0461, 0.0000, 0.7355, 0.1583])
tensor([0.3740])
tensor(0.)
144/500 with lr 0.00024785058715067604


100%|██████████| 125/125 [34:57<00:00, 16.78s/it] 



Loss: 0.9008427103680949
Policy Loss 0.380637742338642
Value Loss 0.5202049680294529
31 batches processed
Examples:
tensor([0.0306, 0.1030, 0.0391, 0.0588, 0.6884, 0.0388, 0.0644])
tensor([0.0000, 0.0281, 0.0261, 0.2124, 0.6994, 0.0000, 0.0341])
tensor([0.6331])
tensor(-1.)
145/500 with lr 0.00024537208127916926


100%|██████████| 125/125 [32:22<00:00, 15.54s/it] 



Loss: 0.8515705202839203
Policy Loss 0.4063418797616447
Value Loss 0.4452286405222757
28 batches processed
Examples:
tensor([ 4.2655e-02,  2.9077e-03,  2.5480e-02,  1.0165e+00,  6.0778e-04,
        -3.5615e-02, -3.3788e-02])
tensor([0.0200, 0.0040, 0.0000, 0.9760, 0.0000, 0.0000, 0.0000])
tensor([0.8045])
tensor(1.)
146/500 with lr 0.00024291836046637757


100%|██████████| 125/125 [35:22<00:00, 16.98s/it] 



Loss: 0.7378984319511801
Policy Loss 0.33781985216774046
Value Loss 0.40007857978343964
32 batches processed
Examples:
tensor([0.0165, 0.2342, 0.2696, 0.0037, 0.0307, 0.3983, 0.0801])
tensor([0.0000, 0.0862, 0.5090, 0.0000, 0.0000, 0.1703, 0.2345])
tensor([-0.4770])
tensor(-1.)
147/500 with lr 0.00024048917686171379


100%|██████████| 125/125 [2:44:09<00:00, 78.79s/it]   



Loss: 0.9658481432124972
Policy Loss 0.3772404631599784
Value Loss 0.5886076800525188
32 batches processed
Examples:
tensor([ 0.2246,  0.0165, -0.0029,  0.0056, -0.0097,  0.0108,  0.7583])
tensor([0.0401, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9599])
tensor([0.9534])
tensor(1.)
148/500 with lr 0.00023808428509309665


100%|██████████| 125/125 [36:05<00:00, 17.33s/it] 



Loss: 1.1158176795128854
Policy Loss 0.589969027426935
Value Loss 0.5258486520859503
31 batches processed
Examples:
tensor([-0.0335, -0.0083, -0.0227,  0.2737,  0.0127,  0.6968,  0.0179])
tensor([0.0000, 0.0000, 0.0000, 0.5451, 0.0000, 0.4549, 0.0000])
tensor([0.9736])
tensor(1.)
149/500 with lr 0.0002357034422421657


100%|██████████| 125/125 [35:36<00:00, 17.09s/it] 



Loss: 0.940770905344717
Policy Loss 0.4904983615682971
Value Loss 0.45027254377641984
31 batches processed
Examples:
tensor([-0.0140,  0.0472,  0.0708,  0.0382,  0.0577,  0.7760,  0.0363])
tensor([0.0020, 0.0060, 0.0120, 0.0020, 0.0120, 0.9559, 0.0100])
tensor([-0.7050])
tensor(1.)
150/500 with lr 0.00023334640781974404


100%|██████████| 125/125 [38:20<00:00, 18.40s/it] 



Loss: 0.8050276345628149
Policy Loss 0.3506932493080111
Value Loss 0.45433438525480385
34 batches processed
Examples:
tensor([0.1796, 0.0149, 0.0685, 0.0767, 0.3824, 0.0705, 0.1828])
tensor([0.3387, 0.0080, 0.0220, 0.0741, 0.3727, 0.0080, 0.1764])
tensor([0.2075])
tensor(1.)
151/500 with lr 0.00023101294374154658


100%|██████████| 125/125 [1:22:40<00:00, 39.68s/it]  



Loss: 0.6977169984020293
Policy Loss 0.30681478371843696
Value Loss 0.3909022146835923
32 batches processed
Examples:
tensor([ 4.8153e-03,  7.5803e-02, -6.6675e-04,  9.3891e-03,  9.2858e-01,
         4.2025e-03, -2.0861e-02])
tensor([0.0120, 0.0822, 0.0000, 0.0020, 0.8998, 0.0000, 0.0040])
tensor([0.3256])
tensor(1.)
152/500 with lr 0.0002287028143041311


100%|██████████| 125/125 [37:30<00:00, 18.01s/it] 



Loss: 1.1106924816514505
Policy Loss 0.48334111273288727
Value Loss 0.6273513689185634
33 batches processed
Examples:
tensor([0.0787, 0.0260, 0.0303, 0.1924, 0.4631, 0.2007, 0.0699])
tensor([0.0361, 0.0000, 0.0020, 0.0000, 0.4008, 0.5511, 0.0100])
tensor([-0.1328])
tensor(-1.)
153/500 with lr 0.00022641578616108978


100%|██████████| 125/125 [52:42<00:00, 25.30s/it]  



Loss: 0.7752167231896344
Policy Loss 0.36590392975246205
Value Loss 0.40931279343717236
34 batches processed
Examples:
tensor([-0.0085,  0.9643,  0.0016,  0.0406, -0.0068, -0.0055, -0.0188])
tensor([0., 1., 0., 0., 0., 0., 0.])
tensor([-0.2230])
tensor(0.)
154/500 with lr 0.00022415162829947887


100%|██████████| 125/125 [35:35<00:00, 17.09s/it] 



Loss: 0.7937353755677901
Policy Loss 0.37348370902961303
Value Loss 0.42025166653817697
31 batches processed
Examples:
tensor([ 0.3646,  0.1462,  0.0385,  0.0518,  0.3695, -0.0230,  0.0179])
tensor([0.4068, 0.0200, 0.0020, 0.0281, 0.5311, 0.0020, 0.0100])
tensor([-0.6131])
tensor(-1.)
155/500 with lr 0.00022191011201648408


100%|██████████| 125/125 [37:11<00:00, 17.85s/it]  



Loss: 1.1153992008079183
Policy Loss 0.5664302219044078
Value Loss 0.5489689789035104
33 batches processed
Examples:
tensor([ 0.0127,  0.0502,  0.8688,  0.0378,  0.0408,  0.0172, -0.0132])
tensor([0.0060, 0.0160, 0.8737, 0.0521, 0.0341, 0.0140, 0.0040])
tensor([0.0768])
tensor(1.)
156/500 with lr 0.00021969101089631924


100%|██████████| 125/125 [30:55<00:00, 14.84s/it] 



Loss: 0.9556329446258368
Policy Loss 0.523815023402373
Value Loss 0.43181792122346385
27 batches processed
Examples:
tensor([ 0.1741,  0.1074,  0.0073,  0.0938,  0.3580, -0.0131,  0.2796])
tensor([0.0982, 0.0060, 0.0000, 0.0261, 0.7635, 0.0060, 0.1002])
tensor([0.8097])
tensor(1.)
157/500 with lr 0.00021749410078735604


100%|██████████| 125/125 [36:07<00:00, 17.34s/it] 



Loss: 0.5958392420603382
Policy Loss 0.30480455318766253
Value Loss 0.29103468887267575
31 batches processed
Examples:
tensor([ 0.3199,  0.1927, -0.0102,  0.0616,  0.4921, -0.0241, -0.0047])
tensor([0.2605, 0.3687, 0.0020, 0.0220, 0.3287, 0.0000, 0.0180])
tensor([-0.4767])
tensor(-1.)
158/500 with lr 0.00021531915977948248


100%|██████████| 125/125 [36:48<00:00, 17.67s/it] 



Loss: 0.6404280461138114
Policy Loss 0.2182884671492502
Value Loss 0.4221395789645612
32 batches processed
Examples:
tensor([ 0.0555,  0.0106,  0.1095, -0.0178,  0.0783,  0.1236,  0.5864])
tensor([0.1423, 0.0000, 0.5551, 0.0000, 0.0100, 0.0621, 0.2305])
tensor([-0.0377])
tensor(1.)
159/500 with lr 0.00021316596818168765


100%|██████████| 125/125 [35:58<00:00, 17.27s/it]



Loss: 0.5418711536874373
Policy Loss 0.20881285108625888
Value Loss 0.33305830260117847
30 batches processed
Examples:
tensor([ 0.6574,  0.0104,  0.1251, -0.0192, -0.0209,  0.1950, -0.0175])
tensor([0.6693, 0.0000, 0.1683, 0.0000, 0.0000, 0.1623, 0.0000])
tensor([-0.8231])
tensor(-1.)
160/500 with lr 0.00021103430849987078


100%|██████████| 125/125 [37:59<00:00, 18.23s/it] 



Loss: 0.5009415684035048
Policy Loss 0.23646827589254826
Value Loss 0.2644732925109565
32 batches processed
Examples:
tensor([-0.0484,  0.1573, -0.0456,  0.1306,  0.0269, -0.0204,  0.7252])
tensor([0.0000, 0.0040, 0.0020, 0.2144, 0.0060, 0.0000, 0.7735])
tensor([-0.5968])
tensor(-1.)
161/500 with lr 0.00020892396541487206


100%|██████████| 125/125 [37:30<00:00, 18.00s/it] 



Loss: 0.479586053872481
Policy Loss 0.1947483450639993
Value Loss 0.2848377088084817
32 batches processed
Examples:
tensor([ 0.0157,  0.2015,  0.5162, -0.0054,  0.0106,  0.2179,  0.0627])
tensor([0.0180, 0.2565, 0.5190, 0.0000, 0.0000, 0.1844, 0.0220])
tensor([-0.9795])
tensor(-1.)
162/500 with lr 0.00020683472576072333


100%|██████████| 125/125 [41:03<00:00, 19.71s/it]  



Loss: 0.5471597781138761
Policy Loss 0.23117180796606202
Value Loss 0.31598797014781405
35 batches processed
Examples:
tensor([ 0.1280, -0.0063,  0.0286,  0.0029, -0.0098,  0.0418,  0.7604])
tensor([0.0040, 0.0000, 0.0040, 0.0000, 0.0000, 0.0040, 0.9880])
tensor([0.2797])
tensor(1.)
163/500 with lr 0.0002047663785031161


100%|██████████| 125/125 [40:42<00:00, 19.54s/it]



Loss: 0.6150465775281191
Policy Loss 0.20914667943382964
Value Loss 0.40589989809428945
34 batches processed
Examples:
tensor([ 1.0071e-02,  3.7115e-03,  9.7578e-01,  4.6914e-03, -2.3769e-02,
        -2.5028e-04,  1.6498e-02])
tensor([0.0200, 0.0000, 0.9218, 0.0000, 0.0000, 0.0040, 0.0541])
tensor([-0.9596])
tensor(-1.)
164/500 with lr 0.00020271871471808493


100%|██████████| 125/125 [35:35<00:00, 17.09s/it] 



Loss: 0.4534065086128456
Policy Loss 0.2458048677071929
Value Loss 0.20760164090565272
28 batches processed
Examples:
tensor([-0.0039, -0.0022,  0.0068,  0.9980, -0.0138,  0.0058, -0.0057])
tensor([0., 0., 0., 1., 0., 0., 0.])
tensor([0.9434])
tensor(1.)
165/500 with lr 0.00020069152757090408


100%|██████████| 125/125 [39:51<00:00, 19.13s/it] 



Loss: 0.7064411502444383
Policy Loss 0.27852114967324515
Value Loss 0.4279200005711931
33 batches processed
Examples:
tensor([ 0.0755,  0.1057,  0.0278, -0.0194,  0.1149,  0.6738,  0.0203])
tensor([0.3226, 0.5832, 0.0040, 0.0000, 0.0261, 0.0621, 0.0020])
tensor([-0.6861])
tensor(-1.)
166/500 with lr 0.00019868461229519504


100%|██████████| 125/125 [34:46<00:00, 16.69s/it] 



Loss: 0.7794076796492626
Policy Loss 0.2935038040681132
Value Loss 0.4859038755811494
29 batches processed
Examples:
tensor([ 0.0326,  0.6810,  0.0575, -0.0178,  0.0553,  0.0140,  0.1224])
tensor([0.0000, 0.4589, 0.0842, 0.0000, 0.0501, 0.0000, 0.4068])
tensor([0.9897])
tensor(1.)
167/500 with lr 0.0001966977661722431


100%|██████████| 125/125 [39:17<00:00, 18.86s/it] 



Loss: 0.6603962407717782
Policy Loss 0.26191634648749906
Value Loss 0.3984798942842791
31 batches processed
Examples:
tensor([ 0.0293,  0.0011,  0.0045,  0.9321,  0.0112, -0.0045,  0.0079])
tensor([0.0180, 0.0060, 0.0020, 0.9519, 0.0020, 0.0020, 0.0180])
tensor([-0.2861])
tensor(-1.)
168/500 with lr 0.00019473078851052066


100%|██████████| 125/125 [35:50<00:00, 17.20s/it] 



Loss: 0.9208668455481529
Policy Loss 0.46872321516275406
Value Loss 0.45214363038539884
30 batches processed
Examples:
tensor([ 0.0700, -0.0358,  0.0090,  0.0195,  0.8905,  0.0102,  0.0303])
tensor([0.0160, 0.0020, 0.2184, 0.0000, 0.6774, 0.0200, 0.0661])
tensor([-0.8861])
tensor(-1.)
169/500 with lr 0.00019278348062541544


100%|██████████| 125/125 [35:49<00:00, 17.20s/it] 



Loss: 0.5465731475502252
Policy Loss 0.2264657963067293
Value Loss 0.32010735124349593
30 batches processed
Examples:
tensor([ 0.0195,  0.0020, -0.0065,  0.9822,  0.0028, -0.0038,  0.0051])
tensor([0.0080, 0.0020, 0.0020, 0.9860, 0.0020, 0.0000, 0.0000])
tensor([0.5365])
tensor(1.)
170/500 with lr 0.00019085564581916128


100%|██████████| 125/125 [35:38<00:00, 17.11s/it] 



Loss: 1.0118636352320511
Policy Loss 0.41117082759737966
Value Loss 0.6006928076346715
30 batches processed
Examples:
tensor([-1.4216e-02, -1.1828e-02,  7.2078e-04,  9.9375e-01,  7.6293e-04,
         9.8978e-03, -1.2239e-02])
tensor([0.0060, 0.0020, 0.0020, 0.9880, 0.0000, 0.0000, 0.0020])
tensor([0.2870])
tensor(-1.)
171/500 with lr 0.00018894708936096965


100%|██████████| 125/125 [39:17<00:00, 18.86s/it] 



Loss: 0.7881833084604957
Policy Loss 0.3422216055068103
Value Loss 0.4459617029536854
33 batches processed
Examples:
tensor([ 0.0755,  0.1144,  0.0636, -0.0214,  0.4632,  0.0244,  0.2317])
tensor([0.0060, 0.6172, 0.0261, 0.0000, 0.1984, 0.0000, 0.1523])
tensor([0.3144])
tensor(1.)
172/500 with lr 0.00018705761846735995


100%|██████████| 125/125 [38:22<00:00, 18.42s/it] 



Loss: 0.8348881462041069
Policy Loss 0.40952281478573294
Value Loss 0.425365331418374
34 batches processed
Examples:
tensor([0.1033, 0.0171, 0.0518, 0.1563, 0.0257, 0.3575, 0.3136])
tensor([0.1182, 0.0040, 0.0160, 0.0441, 0.0020, 0.6493, 0.1663])
tensor([-0.3253])
tensor(-1.)
173/500 with lr 0.00018518704228268635


100%|██████████| 125/125 [37:33<00:00, 18.03s/it] 



Loss: 1.0347808845566981
Policy Loss 0.5007905303077265
Value Loss 0.5339903542489717
33 batches processed
Examples:
tensor([-0.0092,  0.1523,  0.6253,  0.0205,  0.1217,  0.1190, -0.0237])
tensor([0.0000, 0.3347, 0.3607, 0.0000, 0.1663, 0.1383, 0.0000])
tensor([0.9486])
tensor(1.)
174/500 with lr 0.00018333517185985948


100%|██████████| 125/125 [36:44<00:00, 17.63s/it] 



Loss: 0.7874439459513216
Policy Loss 0.35804856393267126
Value Loss 0.4293953820186503
34 batches processed
Examples:
tensor([ 0.0039, -0.0216,  0.1727,  0.0319,  0.7493,  0.0444, -0.0083])
tensor([0.0020, 0.0000, 0.0341, 0.0000, 0.9619, 0.0020, 0.0000])
tensor([0.1896])
tensor(1.)
175/500 with lr 0.00018150182014126088


100%|██████████| 125/125 [35:09<00:00, 16.88s/it] 



Loss: 0.9542840418796386
Policy Loss 0.3897737694844123
Value Loss 0.5645102723952262
31 batches processed
Examples:
tensor([ 0.4419,  0.1705,  0.0162, -0.0244,  0.2171,  0.0123,  0.1621])
tensor([0.4168, 0.2265, 0.0000, 0.0000, 0.2265, 0.0000, 0.1303])
tensor([-0.8225])
tensor(-1.)
176/500 with lr 0.00017968680193984827


100%|██████████| 125/125 [36:42<00:00, 17.62s/it] 



Loss: 0.8190355421975255
Policy Loss 0.3506971411406994
Value Loss 0.4683384010568261
32 batches processed
Examples:
tensor([ 0.0864,  0.0145, -0.0150,  0.0102, -0.0046,  0.0157,  0.9302])
tensor([0.0080, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9920])
tensor([-0.0342])
tensor(0.)
177/500 with lr 0.0001778899339204498


100%|██████████| 125/125 [38:43<00:00, 18.58s/it]  



Loss: 0.7166386629071306
Policy Loss 0.3329910422291826
Value Loss 0.383647620677948
34 batches processed
Examples:
tensor([ 0.0251,  0.0370,  0.0195, -0.0154, -0.0297,  1.0067, -0.0277])
tensor([0., 0., 0., 0., 0., 1., 0.])
tensor([0.7880])
tensor(1.)
178/500 with lr 0.0001761110345812453


100%|██████████| 125/125 [37:44<00:00, 18.11s/it]  



Loss: 0.878852492480567
Policy Loss 0.41419754109599377
Value Loss 0.46465495138457324
33 batches processed
Examples:
tensor([0.0179, 0.0021, 0.0228, 0.0074, 0.9182, 0.0276, 0.0050])
tensor([0.0140, 0.0381, 0.0020, 0.0100, 0.9279, 0.0040, 0.0040])
tensor([-0.1646])
tensor(0.)
179/500 with lr 0.00017434992423543284


100%|██████████| 125/125 [37:46<00:00, 18.13s/it] 



Loss: 0.7079865957299869
Policy Loss 0.30445321107452566
Value Loss 0.4035333846554612
33 batches processed
Examples:
tensor([ 0.2365,  0.0140, -0.0140,  0.0618,  0.0179,  0.2577,  0.4819])
tensor([0.2465, 0.0100, 0.0000, 0.0401, 0.0020, 0.2705, 0.4309])
tensor([0.0973])
tensor(0.)
180/500 with lr 0.0001726064249930785


100%|██████████| 125/125 [39:01<00:00, 18.73s/it] 



Loss: 0.8290507543612928
Policy Loss 0.32001866619376573
Value Loss 0.5090320881675271
34 batches processed
Examples:
tensor([ 0.0808,  0.2775,  0.0333, -0.0062,  0.0828,  0.4777,  0.0815])
tensor([0.0120, 0.0501, 0.0000, 0.0000, 0.0060, 0.9138, 0.0180])
tensor([-0.0435])
tensor(0.)
181/500 with lr 0.00017088036074314772


100%|██████████| 125/125 [1:24:11<00:00, 40.42s/it] 



Loss: 0.9953403577208519
Policy Loss 0.4721254274249077
Value Loss 0.5232149302959442
30 batches processed
Examples:
tensor([-0.0056, -0.0148,  0.9031,  0.0242,  0.0064,  0.0368,  0.0051])
tensor([0.0040, 0.0020, 0.9820, 0.0000, 0.0040, 0.0000, 0.0080])
tensor([0.4247])
tensor(1.)
182/500 with lr 0.00016917155713571625


100%|██████████| 125/125 [50:44<00:00, 24.36s/it] 



Loss: 0.7706564924482143
Policy Loss 0.3026463667100126
Value Loss 0.4680101257382017
33 batches processed
Examples:
tensor([-0.0088,  0.0091, -0.0017,  0.9824,  0.0108,  0.0220,  0.0019])
tensor([0.0000, 0.0000, 0.0000, 0.9960, 0.0000, 0.0020, 0.0020])
tensor([0.2099])
tensor(1.)
183/500 with lr 0.00016747984156435908


100%|██████████| 125/125 [37:24<00:00, 17.95s/it] 



Loss: 0.9359361620679979
Policy Loss 0.3400351400336912
Value Loss 0.5959010220343067
31 batches processed
Examples:
tensor([-0.0313,  0.7867, -0.0250, -0.0244, -0.0353, -0.0853,  0.3085])
tensor([0.0000, 0.9780, 0.0000, 0.0000, 0.0000, 0.0000, 0.0220])
tensor([0.2014])
tensor(0.)
184/500 with lr 0.0001658050431487155


100%|██████████| 125/125 [38:27<00:00, 18.46s/it] 



Loss: 0.8713162017657476
Policy Loss 0.3807738828746711
Value Loss 0.49054231889107647
34 batches processed
Examples:
tensor([ 0.0046, -0.0193,  0.0117,  0.0161,  0.9581,  0.0094, -0.0198])
tensor([0., 0., 0., 0., 1., 0., 0.])
tensor([0.7939])
tensor(1.)
185/500 with lr 0.00016414699271722832


100%|██████████| 125/125 [37:49<00:00, 18.15s/it] 



Loss: 0.7011125192922705
Policy Loss 0.3193339863244225
Value Loss 0.3817785329678479
34 batches processed
Examples:
tensor([ 0.0148,  0.0138,  0.0172,  0.8833, -0.0055, -0.0152,  0.0624])
tensor([0.0040, 0.0020, 0.0381, 0.9419, 0.0020, 0.0000, 0.0120])
tensor([0.0628])
tensor(0.)
186/500 with lr 0.00016250552279005603


100%|██████████| 125/125 [38:20<00:00, 18.41s/it] 



Loss: 0.8467643899899541
Policy Loss 0.38232120126485825
Value Loss 0.4644431887250958
33 batches processed
Examples:
tensor([ 0.0587,  0.0976,  0.0740, -0.0195,  0.7587, -0.0134,  0.0123])
tensor([0.1543, 0.0922, 0.0741, 0.0000, 0.6794, 0.0000, 0.0000])
tensor([-0.4382])
tensor(-1.)
187/500 with lr 0.00016088046756215546


100%|██████████| 125/125 [36:34<00:00, 17.56s/it] 



Loss: 0.9783551267558529
Policy Loss 0.40536884506863935
Value Loss 0.5729862816872136
31 batches processed
Examples:
tensor([ 0.0015,  0.0115,  0.0045,  0.9847,  0.0121, -0.0141,  0.0168])
tensor([0.0040, 0.0020, 0.0040, 0.9840, 0.0000, 0.0000, 0.0060])
tensor([0.0293])
tensor(1.)
188/500 with lr 0.0001592716628865339


100%|██████████| 125/125 [37:24<00:00, 17.96s/it] 



Loss: 0.9620668860999021
Policy Loss 0.4445571411739696
Value Loss 0.5175097449259325
33 batches processed
Examples:
tensor([0.4036, 0.4017, 0.1049, 0.0102, 0.0099, 0.0935, 0.0096])
tensor([0.5150, 0.3547, 0.0842, 0.0000, 0.0000, 0.0461, 0.0000])
tensor([-0.9615])
tensor(-1.)
189/500 with lr 0.00015767894625766857


100%|██████████| 125/125 [37:32<00:00, 18.02s/it] 



Loss: 0.9324842668843991
Policy Loss 0.39239941808310425
Value Loss 0.5400848488012949
33 batches processed
Examples:
tensor([ 0.3513,  0.0602,  0.0191, -0.0106,  0.0567,  0.2912,  0.3201])
tensor([0.1764, 0.0000, 0.0000, 0.0000, 0.0601, 0.4128, 0.3507])
tensor([-0.9387])
tensor(-1.)
190/500 with lr 0.0001561021567950919


100%|██████████| 125/125 [38:25<00:00, 18.45s/it] 



Loss: 0.7760026273043716
Policy Loss 0.3076202634941129
Value Loss 0.4683823638102588
34 batches processed
Examples:
tensor([ 0.0114,  0.0103,  0.9547,  0.0283, -0.0025, -0.0179,  0.0277])
tensor([0.0080, 0.0020, 0.9399, 0.0160, 0.0020, 0.0020, 0.0301])
tensor([0.0776])
tensor(-1.)
191/500 with lr 0.00015454113522714096


100%|██████████| 125/125 [41:53<00:00, 20.11s/it] 



Loss: 0.39715458195958586
Policy Loss 0.18282995982145941
Value Loss 0.21432462213812647
37 batches processed
Examples:
tensor([-0.0108,  0.0080, -0.0033,  1.0185, -0.0041, -0.0026, -0.0148])
tensor([0.0000, 0.0000, 0.0000, 0.9920, 0.0000, 0.0060, 0.0020])
tensor([0.2373])
tensor(1.)
192/500 with lr 0.00015299572387486956


100%|██████████| 125/125 [42:09<00:00, 20.24s/it] 



Loss: 0.41452628716423706
Policy Loss 0.18295252786294836
Value Loss 0.2315737593012887
37 batches processed
Examples:
tensor([ 4.1247e-03, -4.3554e-03,  1.7225e-01,  1.6899e-02,  7.1674e-04,
        -2.4525e-02,  8.5963e-01])
tensor([0.0020, 0.0020, 0.1824, 0.0080, 0.0000, 0.0000, 0.8056])
tensor([-0.1597])
tensor(1.)
193/500 with lr 0.00015146576663612087


100%|██████████| 125/125 [42:21<00:00, 20.33s/it] 



Loss: 0.47637545420891714
Policy Loss 0.1766877652456363
Value Loss 0.2996876889632808
36 batches processed
Examples:
tensor([ 0.9180, -0.0013,  0.0076,  0.0157, -0.0016, -0.0020,  0.0609])
tensor([0.9880, 0.0020, 0.0000, 0.0040, 0.0020, 0.0020, 0.0020])
tensor([-0.6840])
tensor(-1.)
194/500 with lr 0.00014995110896975965


100%|██████████| 125/125 [41:02<00:00, 19.70s/it] 



Loss: 0.6292671571884836
Policy Loss 0.24275403640099935
Value Loss 0.3865131207874843
35 batches processed
Examples:
tensor([0.1270, 0.0378, 0.0364, 0.0225, 0.2430, 0.5187, 0.0257])
tensor([0.0822, 0.0180, 0.0040, 0.0000, 0.0541, 0.8337, 0.0080])
tensor([-0.5195])
tensor(-1.)
195/500 with lr 0.00014845159788006205


100%|██████████| 125/125 [3:29:43<00:00, 100.67s/it]



Loss: 0.42570430583187513
Policy Loss 0.2072879539004394
Value Loss 0.21841635193143571
35 batches processed
Examples:
tensor([-8.0466e-07,  8.7560e-01,  6.9753e-03,  7.4277e-03,  6.8685e-02,
         4.7617e-02,  7.7740e-03])
tensor([0.0000, 0.9238, 0.0000, 0.0000, 0.0441, 0.0321, 0.0000])
tensor([-0.9851])
tensor(-1.)
196/500 with lr 0.00014696708190126143


100%|██████████| 125/125 [43:27<00:00, 20.86s/it] 



Loss: 0.32044981471780276
Policy Loss 0.13857269628594318
Value Loss 0.18187711843185955
36 batches processed
Examples:
tensor([ 0.0100,  1.0014,  0.0057, -0.0082, -0.0060,  0.0121,  0.0208])
tensor([0.0000, 0.9920, 0.0000, 0.0000, 0.0000, 0.0080, 0.0000])
tensor([0.2195])
tensor(0.)
New best loss: 0.32044981471780276
197/500 with lr 0.0001454974110822488


100%|██████████| 125/125 [45:27<00:00, 21.82s/it] 



Loss: 0.24289413624884265
Policy Loss 0.11448108978373439
Value Loss 0.12841304646510826
38 batches processed
Examples:
tensor([ 0.0740,  0.0039, -0.0024, -0.0044,  0.0239, -0.0014,  0.8893])
tensor([0.0501, 0.0040, 0.0000, 0.0000, 0.0180, 0.0000, 0.9279])
tensor([-0.0168])
tensor(0.)
New best loss: 0.24289413624884265
198/500 with lr 0.00014404243697142632


100%|██████████| 125/125 [41:55<00:00, 20.13s/it] 



Loss: 0.5524264054638999
Policy Loss 0.18540834954806737
Value Loss 0.36701805591583253
35 batches processed
Examples:
tensor([ 3.2704e-03,  5.4107e-03, -1.6289e-04, -6.9659e-03,  9.0404e-01,
         7.9390e-03,  2.4324e-02])
tensor([0.0080, 0.0040, 0.0000, 0.0000, 0.7094, 0.0000, 0.2786])
tensor([0.4599])
tensor(1.)
199/500 with lr 0.00014260201260171205


100%|██████████| 125/125 [42:27<00:00, 20.38s/it]



Loss: 0.4843612664068739
Policy Loss 0.19428721349686384
Value Loss 0.29007405291001004
36 batches processed
Examples:
tensor([-0.0079,  0.0056, -0.0048,  0.0103,  0.0337, -0.0127,  0.9814])
tensor([0.0000, 0.0220, 0.0000, 0.0000, 0.0020, 0.0000, 0.9760])
tensor([0.2798])
tensor(0.)
200/500 with lr 0.00014117599247569492


100%|██████████| 125/125 [42:23<00:00, 20.35s/it] 



Loss: 0.38296355441626573
Policy Loss 0.1721976458405455
Value Loss 0.21076590857572025
36 batches processed
Examples:
tensor([0.1089, 0.0490, 0.0229, 0.1298, 0.0349, 0.0304, 0.6479])
tensor([0.0180, 0.0060, 0.0000, 0.0421, 0.0020, 0.0120, 0.9198])
tensor([-0.6185])
tensor(0.)
201/500 with lr 0.00013976423255093796


100%|██████████| 125/125 [39:56<00:00, 19.17s/it] 



Loss: 0.6046629934845602
Policy Loss 0.27368602693519173
Value Loss 0.3309769665493685
34 batches processed
Examples:
tensor([ 0.0033,  0.6632,  0.0097, -0.0192,  0.0419, -0.0039,  0.3291])
tensor([0.0000, 0.0401, 0.0000, 0.0000, 0.0000, 0.0000, 0.9599])
tensor([0.8714])
tensor(1.)
202/500 with lr 0.00013836659022542857


100%|██████████| 125/125 [37:14<00:00, 17.87s/it] 



Loss: 0.5433185514993966
Policy Loss 0.2549389712512493
Value Loss 0.28837958024814725
32 batches processed
Examples:
tensor([-0.0079,  0.0082, -0.0049,  0.0868, -0.0108,  0.0102,  0.9385])
tensor([0.0020, 0.0020, 0.0020, 0.0461, 0.0000, 0.0020, 0.9459])
tensor([-0.2484])
tensor(0.)
203/500 with lr 0.0001369829243231743


100%|██████████| 125/125 [35:45<00:00, 17.16s/it] 



Loss: 0.5088942290594181
Policy Loss 0.19579045139253143
Value Loss 0.31310377766688663
30 batches processed
Examples:
tensor([-1.4223e-02,  9.7958e-01, -1.8009e-03,  9.9033e-03, -6.2670e-03,
         1.5069e-02,  6.5067e-04])
tensor([0.0000, 0.9940, 0.0000, 0.0000, 0.0000, 0.0020, 0.0040])
tensor([0.8097])
tensor(1.)
204/500 with lr 0.00013561309507994257


100%|██████████| 125/125 [34:55<00:00, 16.76s/it]



Loss: 0.7002425896624724
Policy Loss 0.33114714697003367
Value Loss 0.36909544269243877
30 batches processed
Examples:
tensor([ 0.0454,  0.0136, -0.0112,  0.0195,  0.2858,  0.1841,  0.5544])
tensor([0.0401, 0.0040, 0.0000, 0.0200, 0.0361, 0.7415, 0.1583])
tensor([0.9281])
tensor(1.)
205/500 with lr 0.00013425696412914314


100%|██████████| 125/125 [32:48<00:00, 15.75s/it] 



Loss: 1.1454863045364618
Policy Loss 0.5451264102011919
Value Loss 0.6003598943352699
28 batches processed
Examples:
tensor([ 0.0023,  0.0571,  0.1140,  0.6440,  0.0511,  0.1008, -0.0010])
tensor([0.0020, 0.0561, 0.1383, 0.7535, 0.0040, 0.0441, 0.0020])
tensor([0.2271])
tensor(-1.)
206/500 with lr 0.0001329143944878517


100%|██████████| 125/125 [37:28<00:00, 17.99s/it] 



Loss: 0.8935280607053728
Policy Loss 0.39985598420554946
Value Loss 0.49367207649982336
33 batches processed
Examples:
tensor([ 0.0059, -0.0058, -0.0118,  0.0110,  0.0173, -0.0124,  0.9725])
tensor([0.0000, 0.0000, 0.0000, 0.0401, 0.0000, 0.0000, 0.9599])
tensor([-0.1599])
tensor(-1.)
207/500 with lr 0.0001315852505429732


100%|██████████| 125/125 [38:29<00:00, 18.48s/it] 



Loss: 0.7919122302795158
Policy Loss 0.35375577189466534
Value Loss 0.43815645838485046
34 batches processed
Examples:
tensor([0.0425, 0.0051, 0.7863, 0.0161, 0.0242, 0.0976, 0.0043])
tensor([0.0060, 0.0080, 0.9118, 0.0000, 0.0180, 0.0521, 0.0040])
tensor([0.2684])
tensor(0.)
208/500 with lr 0.00013026939803754346


100%|██████████| 125/125 [39:00<00:00, 18.72s/it] 



Loss: 0.8040905305567909
Policy Loss 0.31400536526651945
Value Loss 0.4900851652902715
34 batches processed
Examples:
tensor([ 0.1193, -0.0061,  0.4444,  0.2817,  0.0171,  0.1007,  0.0848])
tensor([0.9158, 0.0000, 0.0140, 0.0521, 0.0040, 0.0040, 0.0100])
tensor([0.5155])
tensor(1.)
209/500 with lr 0.00012896670405716803


100%|██████████| 125/125 [38:26<00:00, 18.45s/it] 



Loss: 0.6721164963262922
Policy Loss 0.3014457026386962
Value Loss 0.37067079368759603
34 batches processed
Examples:
tensor([0.0198, 0.5141, 0.0417, 0.0166, 0.3664, 0.0404, 0.0157])
tensor([0.0000, 0.7194, 0.0000, 0.0000, 0.2806, 0.0000, 0.0000])
tensor([-0.9995])
tensor(-1.)
210/500 with lr 0.00012767703701659636


100%|██████████| 125/125 [36:28<00:00, 17.51s/it]



Loss: 0.8012091871351004
Policy Loss 0.2606321658939123
Value Loss 0.540577021241188
30 batches processed
Examples:
tensor([ 0.0097,  0.9738,  0.0207,  0.0171, -0.0185, -0.0021,  0.0294])
tensor([0.0000, 0.9960, 0.0000, 0.0020, 0.0000, 0.0000, 0.0020])
tensor([0.6514])
tensor(0.)
211/500 with lr 0.0001264002666464304


100%|██████████| 125/125 [37:56<00:00, 18.21s/it] 



Loss: 0.7339884439716116
Policy Loss 0.26304340513888747
Value Loss 0.4709450388327241
32 batches processed
Examples:
tensor([-0.0062,  0.0012,  0.9930, -0.0078,  0.0114,  0.0116,  0.0053])
tensor([0.0020, 0.0000, 0.9900, 0.0060, 0.0000, 0.0000, 0.0020])
tensor([0.4609])
tensor(1.)
212/500 with lr 0.0001251362639799661


100%|██████████| 125/125 [35:34<00:00, 17.07s/it] 



Loss: 0.8371990899244945
Policy Loss 0.37520954310894017
Value Loss 0.4619895468155543
30 batches processed
Examples:
tensor([ 0.0048, -0.0082,  0.7808,  0.0111,  0.0032,  0.2197,  0.0115])
tensor([0.0020, 0.0020, 0.7234, 0.0100, 0.0020, 0.2585, 0.0020])
tensor([-0.4422])
tensor(-1.)
Early stopping at iteration 212
