In [51]:
import torch
import DreamingNets
import matplotlib.pyplot as plt
import numpy as np
import hockey.hockey_env as h_env
import gymnasium as gym
import config

%matplotlib inline

In [None]:
RNET = DreamingNets.reward_net()

In [53]:
env = h_env.HockeyEnv()

In [54]:
o, info = env.reset()
player1 = h_env.BasicOpponent(weak=True)
player2 = h_env.BasicOpponent(weak=True)

In [55]:

def cut_data(data):
    to_remove = []
    for i in range(len(data)):
        if data[i][1] < 1 and data[i][1] > -1 and np.random.rand() < 0.5:
            to_remove.append(i)
    for i in reversed(to_remove):
        data.pop(i)
    return data

In [56]:
def densify_rewards(data):
    count = 0
    for j in range(len(data)): 
        if data[j][1] > 5 or data[j][1] < -5:
            for i in range(1, 10):
                data[j - i] = (data[j-1][0], data[j-1][1] * 0.5 + data[j][1] * 0.5)
            count += 1
    return data

In [None]:
def run_episode():
    obs_buffer = []
    reward_buffer=[]
    for _ in range(config.TRAIN_SEQUENCES):
        obs, info = env.reset()
        obs_agent2 = env.obs_agent_two()
        for _ in range(config.TRAIN_STEPS):
            a1 = player1.act(obs)
            a2 = player2.act(obs_agent2)
            obs, r, d, t, info = env.step(np.hstack([a1,a2]))    
            obs_buffer.append(obs)
            reward_buffer.append(r)
            obs_agent2 = env.obs_agent_two()
            if d or t: break
    obs_buffer = np.asarray(obs_buffer)
    reward_buffer = np.asarray(reward_buffer)
    batch = [(obs, reward) for obs, reward in zip(obs_buffer, reward_buffer)]
    #print(len(batch))
    batch = densify_rewards(batch)
    batch = cut_data(batch)
    #print(len(batch))
    RNET.train_episode(batch)

In [58]:
def test_run():
    obs_buffer = []
    reward_buffer=[]
    obs, info = env.reset()
    obs_agent2 = env.obs_agent_two()
    for _ in range(config.TRAIN_STEPS):
        a1 = player1.act(obs)
        a2 = player2.act(obs_agent2)
        obs, r, d, t, info = env.step(np.hstack([a1,a2]))    
        obs_buffer.append(obs)
        reward_buffer.append(r)
        obs_agent2 = env.obs_agent_two()
        if d or t: break
    obs_buffer = np.asarray(obs_buffer)
    reward_buffer = np.asarray(reward_buffer)
    for i in range(obs_buffer.shape[0]):
        print("Expected result ", reward_buffer[i], "Prediciton " , RNET.predict(obs_buffer[i]))

In [59]:
#env.close()

In [151]:
for e in range(200):
    print("Epoch  ", e)
    run_episode()

Epoch   0
11215
5693
RNET Validation loss:  0.5514425881693956
Epoch   1
10544
5447
RNET Validation loss:  0.45234290549039097
Epoch   2
12618
6492
RNET Validation loss:  0.435703523241929
Epoch   3
13131
6595
RNET Validation loss:  0.43940833970615445
Epoch   4
12827
6533
RNET Validation loss:  0.4658886717892836
Epoch   5
11817
6024
RNET Validation loss:  0.4756178184467599
Epoch   6
12225
6374
RNET Validation loss:  0.4957582495421124
Epoch   7
11924
6101
RNET Validation loss:  0.7840765631104116
Epoch   8
12593
6291
RNET Validation loss:  0.522057706442193
Epoch   9
12246
6330
RNET Validation loss:  0.42232789028198936
Epoch   10
12208
6264
RNET Validation loss:  0.3364595840310363
Epoch   11
12485
6379
RNET Validation loss:  0.3700274477891651
Epoch   12
11852
6028
RNET Validation loss:  0.28749912738358957
Epoch   13
12327
6281
RNET Validation loss:  0.21187660405873715
Epoch   14
12739
6436
RNET Validation loss:  0.22951595943929554
Epoch   15
12964
6600
RNET Validation loss:  0

In [150]:
test_run()

Expected result  -0.07272180438944181 Prediciton  -0.05637522041797638
Expected result  -0.06983958716740218 Prediciton  -0.05261506140232086
Expected result  -0.06580087452743087 Prediciton  -0.05059942603111267
Expected result  -0.060805165433238106 Prediciton  -0.04669293761253357
Expected result  -0.055052322182679934 Prediciton  -0.044605761766433716
Expected result  -0.048742001196187414 Prediciton  -0.041895076632499695
Expected result  -0.04211471018388952 Prediciton  -0.03859470784664154
Expected result  -0.035349845991578394 Prediciton  -0.037399739027023315
Expected result  -0.02868208550003534 Prediciton  -0.020895421504974365
Expected result  -0.02473975438895847 Prediciton  -0.03718556463718414
Expected result  -0.020830850581192224 Prediciton  -0.034801751375198364
Expected result  0.0 Prediciton  -0.025088131427764893
Expected result  0.0 Prediciton  -0.024539023637771606
Expected result  0.0 Prediciton  -0.01717446744441986
Expected result  0.0 Prediciton  -0.013843789