Skip to content

Commit

Permalink
Fixing tuple, and 'UTC has no key' errors in examples/FinRL_PaperTrad…
Browse files Browse the repository at this point in the history
…ing_Demo.ipynb (#1186)

* fix: tuple and 'UTC has no key' errors

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
kurone02 and pre-commit-ci[bot] committed Mar 20, 2024
1 parent 9f47d57 commit 7b11491
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
29 changes: 16 additions & 13 deletions examples/FinRL_PaperTrading_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@
" action, logprob = [t.squeeze(0) for t in get_action(state.unsqueeze(0))[:2]]\n",
"\n",
" ary_action = convert(action).detach().cpu().numpy()\n",
" ary_state, reward, done, _ = env.step(ary_action)\n",
" ary_state, reward, done, _, _ = env.step(ary_action)\n",
" if done:\n",
" ary_state = env.reset()\n",
" ary_state, _ = env.reset()\n",
"\n",
" states[i] = state\n",
" actions[i] = action\n",
Expand Down Expand Up @@ -411,11 +411,12 @@
" self.if_discrete = False # discrete action or continuous action\n",
"\n",
" def reset(self) -> np.ndarray: # reset the agent in env\n",
" return self.env.reset()\n",
" resetted_env, _ = self.env.reset()\n",
" return resetted_env\n",
"\n",
" def step(self, action: np.ndarray) -> (np.ndarray, float, bool, dict): # agent interacts in env\n",
" # We suggest that adjust action space to (-1, +1) when designing a custom env.\n",
" state, reward, done, info_dict = self.env.step(action * 2)\n",
" state, reward, done, info_dict, _ = self.env.step(action * 2)\n",
" return state.reshape(self.state_dim), float(reward), done, info_dict\n",
"\n",
" \n",
Expand All @@ -424,7 +425,9 @@
"\n",
" env = build_env(args.env_class, args.env_args)\n",
" agent = args.agent_class(args.net_dims, args.state_dim, args.action_dim, gpu_id=args.gpu_id, args=args)\n",
" agent.states = env.reset()[np.newaxis, :]\n",
"\n",
" new_env, _ = env.reset()\n",
" agent.states = new_env[np.newaxis, :]\n",
"\n",
" evaluator = Evaluator(eval_env=build_env(args.env_class, args.env_args),\n",
" eval_per_step=args.eval_per_step,\n",
Expand Down Expand Up @@ -502,14 +505,14 @@
"def get_rewards_and_steps(env, actor, if_render: bool = False) -> (float, int): # cumulative_rewards and episode_steps\n",
" device = next(actor.parameters()).device # net.parameters() is a Python generator.\n",
"\n",
" state = env.reset()\n",
" state, _ = env.reset()\n",
" episode_steps = 0\n",
" cumulative_returns = 0.0 # sum of rewards in an episode\n",
" for episode_steps in range(12345):\n",
" tensor_state = torch.as_tensor(state, dtype=torch.float32, device=device).unsqueeze(0)\n",
" tensor_action = actor(tensor_state)\n",
" action = tensor_action.detach().cpu().numpy()[0] # not need detach(), because using torch.no_grad() outside\n",
" state, reward, done, _ = env.step(action)\n",
" state, reward, done, _, _ = env.step(action)\n",
" cumulative_returns += reward\n",
"\n",
" if if_render:\n",
Expand All @@ -525,7 +528,7 @@
"id": "9tzAw9k26nAC"
},
"source": [
"##DRL Agent Class"
"## DRL Agent Class"
]
},
{
Expand Down Expand Up @@ -634,7 +637,7 @@
"\n",
" # test on the testing env\n",
" _torch = torch\n",
" state = environment.reset()\n",
" state, _ = environment.reset()\n",
" episode_returns = [] # the cumulative_return / initial_account\n",
" episode_total_assets = [environment.initial_total_asset]\n",
" with _torch.no_grad():\n",
Expand All @@ -644,7 +647,7 @@
" action = (\n",
" a_tensor.detach().cpu().numpy()[0]\n",
" ) # not need detach(), because with torch.no_grad() outside\n",
" state, reward, done, _ = environment.step(action)\n",
" state, reward, done, _, _ = environment.step(action)\n",
"\n",
" total_asset = (\n",
" environment.amount\n",
Expand Down Expand Up @@ -1693,8 +1696,8 @@
"source": [
"def get_trading_days(start, end):\n",
" nyse = tc.get_calendar('NYSE')\n",
" df = nyse.sessions_in_range(pd.Timestamp(start,tz=pytz.UTC),\n",
" pd.Timestamp(end,tz=pytz.UTC))\n",
" df = nyse.sessions_in_range(pd.Timestamp(start),\n",
" pd.Timestamp(end))\n",
" trading_days = []\n",
" for day in df:\n",
" trading_days.append(str(day)[:10])\n",
Expand Down Expand Up @@ -1889,7 +1892,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.11"
},
"vscode": {
"interpreter": {
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ swig

tensorboardX
wheel>=0.33.6
wrds

# market data & paper trading API
yfinance
wrds

0 comments on commit 7b11491

Please sign in to comment.