In [None]:
%pip install gymnasium
%pip install numpy



In [1]:
from typing import Optional
import numpy as np
import gymnasium as gym
from gymnasium import spaces
import matplotlib.pyplot as plt

## BaseEnv

In [2]:
class STBaseEnv(gym.Env):
  def __init__(self, dataframe, ibalance=1e4,
               features=['Open', 'High', 'Low', 'Close', 'Volume'],
               slippage_rate=0.0005,
               min_trade_size=1,
               shares_held=0):
    super(STBaseEnv, self).__init__()
    self.df = dataframe
    self.ibalance = ibalance
    self.balance = ibalance
    self.net_worth = ibalance
    self.current_step = 0
    self.slippage_rate = slippage_rate
    self.shares_held = shares_held
    self.commission_rate = 0.001 # 0.1%q
    self.features = features # Store the list of features
    self.min_trade_size = min_trade_size

    # Actions of the format Buy x%, Sell x%, Hold, etc.
    self.action_space = spaces.Box(
        low=np.array([0, 0]), high=np.array([3, 1]), dtype=np.float16)

    # Shape is based on the number of selected features plus balance
    self.observation_space = spaces.Box(
          low=0, high=np.inf, shape=(len(self.features) + 1,), dtype=np.float32) # Changed high to np.inf and dtype to float32


  def reset(self):
    self.balance = self.ibalance
    self.current_step = 0

    return self._get_observation(), {} # Added info dictionary to reset return

  def step(self, action):
    self._act(action)
    terminated = False

    self.current_step += 1 # Corrected typo currentstep to current_step

    if self.current_step == self.df.shape[0]: # Corrected data to df and changed condition to check if current_step reaches the end of the dataframe
      terminated = True
      # self.current_step = 0 used for loop training

    obs = self._get_observation()
    reward = 0  # calculate reward

    # leave for implementation
    truncated = False
    info = self._get_info()

    return obs, reward, terminated, truncated, info

  def _act(self, action):
    current_price = self.df.loc[self.current_step, 'Close']

    action_type = int(action[0])
    percentage = np.clip(action[1], 0.0, 1.0)

    # calculate effective price with slippage(for simplicity, applied symmetrically for buy/sell)
    effective_buy_price = current_price * (1 + self.slippage_rate)
    effective_sell_price = current_price * (1 - self.slippage_rate)

    # transaction details for logging/info
    shares_traded = 0
    trade_value = 0
    commission_cost = 0

    if action_type == 0: # Buy
      max_shares_affordable = self.balance / effective_buy_price
      shares_to_buy = max(self.min_trade_size, ((self.balance * percentage) / effective_buy_price))

      if shares_to_buy > 0:
        shares_traded = min(max_shares_affordable, shares_to_buy)
        trade_value = shares_traded * current_price
        commission_cost = trade_value * self.commission_rate

        if self.balance > trade_value + commission_cost:
          self.balance -= (trade_value + commission_cost)

    elif action_type == 1: # Sell
      shares_to_sell = self.shares_held * percentage

      # ensure minimal
      shares_to_sell = max(self.min_trade_size, int(shares_to_sell))

      # ensure we don't sell more than we hold
      # possible error here
      shares_to_sell = min(shares_to_sell, self.shares_held)

      if shares_to_sell > 0:
        trade_value = shares_to_sell * effective_sell_price
        commission_cost = trade_value * self.commission_rate

        self.balance += (trade_value - commission_cost)
        self.shares_held -= shares_to_sell


      if shares_to_sell > 0:
        trade_value = shares_to_sell * current_price

    elif action_type == 2: # Hold
      pass

    self.net_worth = self.balance + (self.shares_held * current_price)

  def _get_observation(self): # Corrected indentation of _get_observation
    obs = [self.df.loc[self.current_step, feature] for feature in self.features]
    obs.append(self.balance)
    return np.array(obs, dtype=np.float32) # Ensure observation is a numpy array with specified dtype


  def _get_info(self):
      pass

  def render(self):
      pass

## Wrappers

In [None]:
class MAIndicatorWrapper(gym.ObservationWrapper):
  def __init__(self, env, feature='Close', window=5):
    super(MAIndicatorWrapper, self).__init__(env)
    self.window = window
    self.feature = feature

    # Update observation space to include the new MA feature
    obs_shape = self.observation_space.shape[0]
    self.observation_space = spaces.Box(
        low=0, high=np.inf, shape=(obs_shape + 1,), dtype=np.float32
    )


  def observation(self, observation):
    # Calculate the moving average for the specified feature
    start_index = max(0, self.env.current_step - self.window + 1)
    end_index = self.env.current_step + 1
    ma = self.env.df.loc[start_index:end_index, self.feature].mean()

    # Append the moving average to the original observation
    return np.append(observation, ma)

## Finance Data

In [None]:
import yfinance as yf
import pandas as pd
import matplotlib.pyplot as plt

# 1. Define Stock Ticker and Date Range
ticker_symbol = "TSLA"
start_date = "2023-01-01"
end_date = "2024-01-01" # Data typically excludes the end date, so this fetches all of 2023

print(f"Downloading {ticker_symbol} data from {start_date} to {end_date}...")

# 2. Use yf.download() to fetch data
# yf.download(tickers, start=None, end=None, actions=False, threads=True,
#             group_by='column', auto_adjust=False, prepost=False,
#             proxy=None, rounding=False, tz=None, timeout=None, **kwargs)
try:
    tsla_data = yf.download(ticker_symbol, start=start_date, end=end_date)

    if tsla_data.empty:
        print(f"No historical data found for {ticker_symbol} within the specified date range. Please check the ticker symbol or date range.")
    else:
        print("\nData downloaded successfully!")
        print("First 5 rows of data:")
        print(tsla_data.head())

        print("\nData Information:")
        tsla_data.info()

        # 3. Basic Data Operations
        print(f"\n{ticker_symbol} 2023 Average Close Price: {tsla_data['Close'].mean():.2f}")
        print(f"{ticker_symbol} 2023 Highest Close Price: {tsla_data['Close'].max():.2f}")
        print(f"{ticker_symbol} 2023 Lowest Close Price: {tsla_data['Close'].min():.2f}")
        print(f"{ticker_symbol} 2023 Total Volume: {tsla_data['Volume'].sum()}")

        # 4. Visualize Close Price and Volume
        plt.figure(figsize=(14, 7))

        # Close Price Plot
        plt.subplot(2, 1, 1) # 2 rows, 1 column, 1st plot
        plt.plot(tsla_data.index, tsla_data['Close'], label='TSLA Close Price', color='blue')
        plt.title(f'{ticker_symbol} 2023 Close Price Trend')
        plt.xlabel('Date')
        plt.ylabel('Price (USD)')
        plt.grid(True)
        plt.legend()

        # Volume Plot
        plt.subplot(2, 1, 2) # 2 rows, 1 column, 2nd plot
        plt.bar(tsla_data.index, tsla_data['Volume'], label='TSLA Volume', color='gray')
        plt.title(f'{ticker_symbol} 2023 Trading Volume')
        plt.xlabel('Date')
        plt.ylabel('Volume')
        plt.grid(True)
        plt.legend()

        plt.tight_layout() # Adjust subplot parameters for a tight layout
        plt.show()

except Exception as e:
    print(f"An error occurred while downloading data: {e}")

  tsla_data = yf.download(ticker_symbol, start=start_date, end=end_date)
[*********************100%***********************]  1 of 1 completed

Downloading TSLA data from 2023-01-01 to 2024-01-01...

Data downloaded successfully!
First 5 rows of data:
Price            Close        High         Low        Open     Volume
Ticker            TSLA        TSLA        TSLA        TSLA       TSLA
Date                                                                 
2023-01-03  108.099998  118.800003  104.639999  118.470001  231402800
2023-01-04  113.639999  114.589996  107.519997  109.110001  180389000
2023-01-05  110.339996  111.750000  107.160004  110.510002  157986300
2023-01-06  113.059998  114.389999  101.809998  103.000000  220911100
2023-01-09  119.769997  123.519997  117.110001  118.959999  190284000

Data Information:
<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 250 entries, 2023-01-03 to 2023-12-29
Data columns (total 5 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   (Close, TSLA)   250 non-null    float64
 1   (High, TSLA)    250 non-null    float64
 2   (Low,


