In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import A2C

import gymnasium as gym
from environment import trading_env
plt.rc('figure',titleweight='bold',titlesize='large',figsize=(15,6))
plt.rc('axes',labelweight='bold',labelsize='large',titleweight='bold',titlesize='large',grid=True)

In [None]:
seed=np.linspace(0,2*np.pi,200)
y=2*np.sin(3*seed)+4
z=[]
for i in range(y.shape[0]-1):
    open=y[i]
    close=y[i+1]
    high=max(open,close)+0.2*np.random.randn(1)[0]
    low=min(open,close)-0.2*np.random.randn(1)[0]
    z.append([open,high,low,close])
z=pd.DataFrame(z,columns=['Open','High','Low','Close'])

In [None]:
class ql_agent():

    def __init__(self,env,qtable_height=40,qtable_width=3,learning_rate=0.1,discount_value=1,epochs=1000,epsilon=0.99,epsilon_decay=0.99,overflow=0.01):
        self.learning_rate=learning_rate
        self.discount_value=discount_value
        self.epochs=epochs
        self.epsilon=epsilon
        self.epsilon_decay=epsilon_decay
        self.env=env
        self.overflow=overflow
        self.qtable_height=qtable_height
        self.qtable_width=qtable_width
    
    def create_qtable(self):
        # self.qtable_size=(self.qtable_height,self.qtable_width)
        # self.qtable_segment_size=(self.env.observation_max*(1+self.overflow)-self.env.observation_min*(1-self.overflow))/np.array(self.qtable_size)
        # self.qtable=np.random.uniform(low=-2,high=-1,size=self.qtable_size+(3,))
        self.qtable_size=(self.qtable_height,)
        self.qtable_segment_size=(self.env.observation_max*(1+self.overflow)-self.env.observation_min*(1-self.overflow))/np.array(self.qtable_size)
        self.qtable=np.random.uniform(low=0,high=1,size=self.qtable_size+(3,))
        # self.qtable=np.zeros(self.qtable_size+(3,))

    def convert_state(self,current_state):
        return tuple(((current_state-self.env.observation_min*(1-self.overflow))/self.qtable_segment_size).astype(int))
    
    def update_qtable(self,reward,action):
        current_q_value=self.qtable[self.current_state+(action,)]

        if self.new_state is not None:
            self.new_state=self.new_state[0]
            self.new_state=self.convert_state(self.new_state)
            new_q_value=(1-self.learning_rate) * current_q_value+self.learning_rate*reward
        else:
            new_q_value=(1-self.learning_rate)*current_q_value+self.learning_rate*(reward+self.discount_value*np.max(self.qtable[self.new_state]))
        self.qtable[self.current_state + (action,)] = new_q_value

    def train(self):
        self.create_qtable()
        self.action_list=[]
        self.portforlio=None
        for epoch in range(self.epochs):

            print(f'Epoch :{epoch}')
            self.current_state,_=self.env.reset()
            self.current_state=self.current_state[0]
            self.current_state=self.convert_state(self.current_state)
            action_list=[]
            epoch_reward=0

            # while True:
            for _ in range(self.env.df.shape[0]-1):
                for action in range(3):
                    self.new_state,reward,terminate,truncate,_=self.env.step(action)
                    self.update_qtable(reward,action)

                self.current_state=self.new_state
                epoch_reward+=reward
            self.epsilon*=self.epsilon_decay
            print(self.env.coin,self.env.usd)
            # print(f'Epoch reward :{round(epoch_reward,3)},portforlio valuation :{round(self.portforlio,3)},number of actions :{len(action_list)}')

    def visualize(self):
        fg=plt.figure()
        ax=fg.add_subplot()
        self.env.df['Close'].plot(ax=ax)
        for i in range(len(self.action_list)):
            if self.action_list[i]==0:
                plt.text(i,self.env.df.iloc[i,0],'B',color='C2')
            elif self.action_list[i]==2:
                plt.text(i,self.env.df.iloc[i,0],'S',color='C3')

In [None]:
env=trading_env(df=z,window_size=1)
agent=ql_agent(env,epochs=1000,qtable_height=10)

In [None]:
agent.train()