## Setup Data Fetching

In [56]:
import pandas as pd
import tensortrade.env.default as default

from tensortrade.data.cdd import CryptoDataDownload
from tensortrade.feed.core import Stream, DataFeed
from tensortrade.oms.exchanges import Exchange
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.instruments import USD, BTC, ETH
from tensortrade.oms.wallets import Wallet, Portfolio
from tensortrade.agents import DQNAgent
from ta import add_all_ta_features


In [57]:
# gather data
def get_feed(n_events=None):
    cdd = CryptoDataDownload()
    data = cdd.fetch("Bitstamp", "USD", "BTC", "1h")
    data = add_all_ta_features(data, 'open', 'high', 'low', 'close', 'volume')
    
    if n_events is not None:
        data = data.iloc[n_events:]
    print(len(data))
    features = []
    for c in data.columns[2:]:
        s = Stream.source(list(data[c]), dtype="float").rename(data[c].name)
        features += [s]
    feed = DataFeed(features)
    feed.compile()    
    return data, feed

data, feed = get_feed()


invalid value encountered in double_scalars


invalid value encountered in double_scalars



23130


In [82]:
# Create environment
def create_env(config=None):  
    bitstamp = Exchange("bitstamp", service=execute_order)(
        Stream.source(list(data["close"]), dtype="float").rename("USD-BTC")
    )

    portfolio = Portfolio(USD, [
        Wallet(bitstamp, 10000 * USD),
        Wallet(bitstamp, 10 * BTC)
    ])


    renderer_feed = DataFeed([
        Stream.source(list(data["date"])).rename("date"),
        Stream.source(list(data["open"]), dtype="float").rename("open"),
        Stream.source(list(data["high"]), dtype="float").rename("high"),
        Stream.source(list(data["low"]), dtype="float").rename("low"),
        Stream.source(list(data["close"]), dtype="float").rename("close"), 
        Stream.source(list(data["volume"]), dtype="float").rename("volume") 
    ])


    env = default.create(
        portfolio=portfolio,
        action_scheme="simple",
        reward_scheme="risk-adjusted",
        feed=feed,
        renderer_feed=renderer_feed,
        renderer=default.renderers.FileLogger(),
        window_size=20
    )
    return env

env = create_env()

## Setup and Train DQN Agent

In [83]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [84]:
# create agent
def get_agent(env, agent_id=None):
    agent = DQNAgent(env)
    if agent_id is not None:
        agent.id = "TEST_AGENT"
    return agent


agent = get_agent(env=env, agent_id="TEST_AGENT")


In [85]:
# train the agent

mean_reward = agent.train(n_steps=len(data) / 100,
                          n_episodes=1,
                          save_every=1
                         )

agent.save("./")

print(mean_reward)

====      AGENT ID: TEST_AGENT      ====
-1120081.1682655017


In [86]:
# remove the agent
del agent

In [87]:
# we restore the agent

agent = get_agent(env=env, agent_id="TEST_AGENT")

agent.restore("./policy_network__TEST_AGENT.hdf5")




In [88]:
# now we have restored our agent, we can save our model
agent.save("./")


In [89]:
# we reset the environment

initial_state = agent.env.reset()

initial_state

array([[0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       ...,
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00],
       [8.7338604e+03, 8.7966797e+03, 8.7072803e+03, ..., 2.3184912e+00,
        0.0000000e+00, 0.0000000e+00]], dtype=float32)

In [90]:
# predict our next action

agent.get_action(state=initial_state)

15

In [91]:
env.action_space

Discrete(21)