In [7]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from src.agent import DDPG_Hedger
from src.network import MLP, MLP_debug
from src.env import StockTradingEnv

In [8]:
env = StockTradingEnv(reset_path=True)

actor_lr, critic_lr = 10**-4, 10**-4
nState, nAction = env.observation_space.shape[0], env.action_space.shape[0]  # 3, 1

# we use hidden layer size of 32, 64 as the author used.
actor = MLP_debug(nState, 32, nAction, "Sigmoid")
qnet_1 = MLP(nState + nAction, 32, nAction, "")
qnet_2 = MLP(nState + nAction, 32, nAction, "")
agent = DDPG_Hedger(actor, qnet_1, qnet_2, actor_lr, critic_lr, 1, 32)
epsilon = 1

In [3]:
epsilon = 0.4

all_actions = []
for i in range(20):
    state = env.reset()
    rewards = 0
    done = False

    ep_tot_reward = 0
    actions = []
    
    print(f'\n\n Episode {i}')
    while not done:
        # normalize the state
        normalized_state = env.normalize(state)

        # take action given state
        action = agent.act(normalized_state, epsilon, True)

        # take next step of the environment
        next_state, reward, done = env.step(action)

        # record interaction between environment and the agent
        agent.store(state, action, reward, next_state, done)

        ep_tot_reward += reward
        state = next_state

        actions.append(np.round(action, 2))
        if done:
            break

    for _ in range(5):
        q1_loss, q2_loss, actor_loss = agent.update(env.price_stat, True)
        print(round(q1_loss.item(),2), round(q2_loss.item(),2), actor_loss)

    agent.polyak_update()
    epsilon *= 0.999
    print(f'Episode {i} End. Reward: {ep_tot_reward}')
    all_actions.append(actions)



 Episode 0
Ouputs - FC3: 0.9283, output: 0.5896
Ouputs - FC3: 0.1893, output: 0.5078
Ouputs - FC3: 0.2697, output: 0.5215
Ouputs - FC3: 0.4662, output: 0.5309
Ouputs - FC3: 0.2293, output: 0.5131
Ouputs - FC3: 0.6971, output: 0.566
Ouputs - FC3: 0.5594, output: 0.6347
Ouputs - FC3: 0.5065, output: 0.5643
Ouputs - FC3: 0.1607, output: 0.4865
Ouputs - FC3: 0.3608, output: 0.5277
Ouputs - FC3: 0.1775, output: 0.4849
Ouputs - FC3: 0.4862, output: 0.5331
Ouputs - FC3: 0.8535, output: 0.5689
Ouputs - FC3: 0.0463, output: 0.4059
Ouputs - FC3: 0.2111, output: 0.4704
Ouputs - FC3: 0.2862, output: 0.5146
Ouputs - FC3: 0.0848, output: 0.4238
Ouputs - FC3: 0.1707, output: 0.4532
Ouputs - FC3: 0.11, output: 0.439
Ouputs - FC3: 0.2734, output: 0.5088
Ouputs - FC3: 0.1269, output: 0.4476
Ouputs - FC3: 0.5357, output: 0.5299
Ouputs - FC3: 0.5558, output: 0.5565
Ouputs - FC3: 0.3852, output: 0.6274
Ouputs - FC3: 0.2715, output: 0.5892
Ouputs - FC3: 0.2335, output: 0.5695
Ouputs - FC3: 0.2657, output: