# IMPORTS #

In [3]:
import numpy as np
import plotly
import pandas as pd
import gym
from gym import spaces

# DATA #

In [65]:
apple_stock_df = pd.read_csv("AAPL.csv")
apple_stock_df.name = "Apple"
shell_stock_df = pd.read_csv("SHEL.csv")
shell_stock_df.name = "Shell"

# apple_stock_df = apple_stock_df.set_index("Date")
# shell_stock_df = shell_stock_df.set_index("Date")

In [47]:
apple_stock_df.shape

(1257, 7)

# GYM ENV

In [49]:
class SDSTEnv(gym.Env):
    """
    Implement the Simple Double Stock Trading Environnement
    """

    def __init__(self, stock1: pd.DataFrame, stock2: pd.DataFrame) -> None:
        super(SDSTEnv, self).__init__()

        ##########
        # RMQ: Don't we add the balance and stocks possessed?
        ##########
        # Observation Space: s in {0,1} x {0,1} (i.e. stock trends, 0 for going down, 1 for going up)
        self.observation_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(2)))

        ##########
        # RMQ: Don't we add the action "hold"?
        # What means "buy" and "sell", what amount, under which conditions, where do we implement them?
        # Idea: Buy/sell 10 shares per action inn the limit of the balance left
        ##########
        # Action Space: a in {0,1} x {-1,1} (i.e. first is for stock choice, second is for buy or sell 10 shares)
        self.action_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(2)))

        # Data
        assert stock1["Date"].equals(
            stock2["Date"]
        ), "Stocks data should be timely synchronized, make sure the dates of stock prices are identical"
        self.stock1 = stock1
        self.stock2 = stock2

        # Date
        self.timeStep = 0

        # Portfolio
        self.totalPortfolio = 100
        self.stocksShare = [0, 0]
        self.balanceLeft = 100

    def __str__(self):
        stockValues = [
            self.stock1.loc[self.timeStep]["Open"],
            self.stock2.loc[self.timeStep]["Open"],
        ]
        info = """The environment is a Simple Double Stock Trading Problem.\n 
        It is using the stocks: {}, {}.\n 
        The episode is at the timestep {}\n
        The stock prices are {}\n
        Amount of shares held by the agent: {}\n
        Left balance: {}
        """.format(self.stock1.name, self.stock2.name, self.timeStep, stockValues, self.stocksShare, self.balanceLeft)

    def reset(self):
        # Reset the environment and return a random initial state
        return self.observation_space.sample()

    def step(self, action: tuple[int, int]):
        """
        Take a step in the environment

        Args:
            action: tuple[int, int] - The action the agent has taken

        Returns:
            observation: tuple[int, int] - The next observation state
            reward: float - The reward the agent is granted
            terminated: bool - Is the episode terminated
            truncated: bool - Is the episode truncated
            info: dict - Additional info
        """
        stock = action[0]
        trade = action[1]

        assert stock in [0, 1] and trade in [-1, 1], "Invalid action"

        stockValues = [
            self.stock1.loc[self.timeStep]["Open"],
            self.stock2.loc[self.timeStep]["Open"],
        ]

        ##########
        # RMQ: How does the agent do not sell shares it does not possess?
        ##########
        self.stocksShare[stock] += trade * 10
        self.balanceLeft -= trade * stockValues[stock]
        newtotalPortfolio = np.dot(self.stocksShare, stockValues) + self.balanceLeft

        # Update state
        self.timeStep += 1

        # Returned values
        terminated = self.timeStep == self.stock1.shape[0]
        if not terminated:
            obs = self.get_observation_from_data()
        else:
            obs = None
        reward = newtotalPortfolio - self.totalPortfolio
        self.totalPortfolio = newtotalPortfolio
        truncated = False
        info = {}

        return obs, reward, terminated, truncated, info

    def get_observation_from_data(self):
        if not self.timeStep:
            return self.observation_space.sample()

        ##########
        # RMQ: What value do we take for the stock price? Open, Close, High, Low, Mean? Whatever is fine IMO
        ##########
        trend1 = np.sign(
            self.stock1.loc[self.timeStep]["Open"]
            - self.stock1.loc[self.timeStep - 1]["Open"]
        )
        trend2 = np.sign(
            self.stock2.loc[self.timeStep]["Open"]
            - self.stock2.loc[self.timeStep - 1]["Open"]
        )

        return (trend1, trend2)

In [64]:
myEnv = SDSTEnv(apple_stock_df, shell_stock_df)

print(myEnv)

<SDSTEnv instance>
