In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import torch.multiprocessing as mp
from torch.multiprocessing import Queue, Value, Lock, Process
import plotly.graph_objects as go
import plotly.express as px
import os
import time
import copy

import sys
sys.path.append('..')
from Env import Env
device = torch.device('cpu')
os.environ["OMP_NUM_THREADS"] = "1"

In [2]:
class Memory:
    def __init__(self):
        self.state_img = []
        self.state_stat = []
        self.action = []
        self.critic = []
        self.log_prob = []
        self.f = []
        self.is_done = []
    
    def clear(self):
        del self.state_img[:]
        del self.state_stat[:]
        del self.action[:]
        del self.critic[:]
        del self.log_prob[:]
        del self.f[:]
        del self.is_done[:]
        
        
class AC(nn.Module):
    
    def __init__(self, latent_num, cnn_chanel_num, stat_dim):
        super(AC, self).__init__()
        
        # Encode
        self.encode_img = nn.Sequential(
            nn.Conv2d(1, cnn_chanel_num, 4, stride=2), nn.ReLU(), nn.MaxPool2d(3, stride=2),
            nn.Conv2d(cnn_chanel_num, 2*cnn_chanel_num, 4, stride=2), nn.ReLU(), nn.MaxPool2d(3, stride=2),
            nn.Conv2d(2*cnn_chanel_num, cnn_chanel_num, 1, stride=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(192, latent_num), nn.ReLU())
        
        self.encode_stat = nn.Sequential(
            nn.Linear(stat_dim, latent_num), nn.ReLU(),
            nn.Linear(latent_num, latent_num), nn.ReLU())
        
        # Actor
        self.pi = nn.Sequential(
            nn.Linear(latent_num*2, latent_num), nn.ReLU(),
            nn.Linear(latent_num, latent_num//2), nn.ReLU())
        self.actor = nn.Sequential(
            nn.Linear(latent_num//2, 3), nn.Softmax(dim=-1))
        self.f = nn.Sequential(
            nn.Linear(latent_num//2, 1), nn.Sigmoid())
        
        # Critic
        self.V = nn.Sequential(
            nn.Linear(latent_num*2, latent_num), nn.ReLU(),
            nn.Linear(latent_num, latent_num//2), nn.ReLU(),
            nn.Linear(latent_num//2, 1))
    
    # Only used for visualization
    def forward(self, img, stat):
        encoded_img = self.encode_img(img)
        encoded_stat = self.encode_stat(stat)
        catted = torch.cat([encoded_img, encoded_stat], dim=1)
        hid = self.pi(catted)
        probs = self.actor(hid)
        f = self.f(hid)
        value = self.V(catted)
        return probs, f, value
    
    def act(self, img, stat, memory):
        img = torch.from_numpy(img).float().unsqueeze(0)
        stat = torch.from_numpy(stat).float()
        encoded_img = self.encode_img(img.unsqueeze(0).to(device))
        encoded_stat = self.encode_stat(stat.unsqueeze(0).to(device))
        catted = torch.cat([encoded_img, encoded_stat], dim=1)
        hid = self.pi(catted).squeeze()
        p = Categorical(self.actor(hid))
        f = self.f(hid)
        value = self.V(catted).squeeze()
        action = p.sample()
        memory.state_img.append(img)
        memory.state_stat.append(stat)
        memory.action.append(action)
        memory.critic.append(value)
        memory.f.append(f)
        memory.log_prob.append(p.log_prob(action))
        return action.item(), f.item()

    def act_max(self, img, stat):
        img = torch.from_numpy(img).float().unsqueeze(0)
        stat = torch.from_numpy(stat).float()
        encoded_img = self.encode_img(img.unsqueeze(0).to(device))
        encoded_stat = self.encode_stat(stat.unsqueeze(0).to(device))
        catted = torch.cat([encoded_img, encoded_stat], dim=1)
        hid = self.pi(catted).squeeze()
        p = self.actor(hid)
        f = self.f(hid)
        action = p.argmax()
        return action.item(), f.item()

<All keys matched successfully>

In [3]:
code = 'M8888.XDCE'
freq = '5m'
data = pd.read_csv(f"../data/{code}_{freq}_test.csv")
env = Env(data, total_step, img_shape, charge, init_money, threshold)

Env_started


In [4]:
money = []
steps = []
for step in range(10):
    img, stat = env.reset()
    done = False
    while not done:
        action = ac.act(img, stat)
        (img, stat), done = env.step(action, 1)
    steps.append(step)
    money.append(env.money/env.init_money)
    print(env.money/env.init_money)
px.line(money).show()
np.mean(money)

1.1673202312000024