In [5]:
import os
import torch
import json5
import random
import datetime
import numpy as np
from env import Porfolio_Env
from data_preprocess import data_process2
from agent import DDPG_multitask
import shutil
from utils import (
    seed_everything,
    ReplayBuffer,
    stock_preview,
    EvaALL,
    result_plot,
    metric,
    copy_current_script_to_folder,
    mvddpg_alg,
    CustomSummaryWriter,
    plot_agent_statistics,
    ews_reward_df,
)
import net
import pandas as pd
from torch.utils.tensorboard import SummaryWriter

def load_config(config_path="config.jsonc"):
    with open(config_path, "r") as f:
        config = json5.load(f)
    return config
def create_multi_task_agent(config, in_features):
    in_channels = config["N_stock"]
    in_features = in_features
    num_actions = config["N_stock"] + 1

    actor = net.PolicyNet2(
        in_channels=in_channels,
        in_features=in_features,
        embed_dim=config["embed_dim"],
        num_actions=num_actions,
        hidden_size=config["hidden_size"],
        portfolio_size=config["portfolio_size"],
    ).to(config["device"])

    critic = net.Critic2(
        in_channels=in_channels,
        in_features=in_features,
        embed_dim=config["embed_dim"],
        num_actions=num_actions,
        hidden_size=config["hidden_size"],
    ).to(config["device"])
    if config["use_batch_lsre"] == 1:
        lsre = net.BatchLSRE(
            window_size=config["window_size"],
            in_features=in_features,
            embed_dim=config["embed_dim"],
            num_actions=num_actions,
        ).to(config["device"])
    else:
        lsre = net.LSRE(
            window_size=config["window_size"],
            in_features=in_features,
            embed_dim=config["embed_dim"],
            num_actions=num_actions,
        ).to(config["device"])

    agent = DDPG_multitask(
        actor=actor,
        critic=critic,
        lsre=lsre,
        writer=None,
        config=config,
    )
    return agent, None

In [6]:
config_path = "/home/psdz/Lin PJ/demo/config.jsonc"
config = load_config(config_path)

agent,_= create_multi_task_agent(config,in_features=18) 

In [9]:
model_path = "/home/psdz/Lin PJ/rlpm/ddpg_cnn/resultsave/2025_03_03_08_43_O/ddpg_multitask_experiment_model.pth"
checkpoint = torch.load(model_path)
checkpoint.keys()

dict_keys(['critic_state_dict', 'target_critic_state_dict', 'actor_state_dict', 'target_actor_state_dict'])

In [None]:
agent.critic.load_state_dict(checkpoint["critic_state_dict"])
agent.target_critic.load_state_dict(checkpoint["target_critic_state_dict"])
agent.actor.load_state_dict(checkpoint["actor_state_dict"])
agent.target_actor.load_state_dict(checkpoint["target_actor_state_dict"])
agent.lsre.load_state_dict(checkpoint["lsre_state_dict"])

In [1]:
import pandas as pd
data  = pd.read_feather("000300.ftr")
data.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,return_rate,low,open,volume,num_trades,high,total_turnover,size,non_linear_size,momentum,liquidity,book_to_price,leverage,growth,earnings_yield,beta,residual_volatility
order_book_id,date,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
000001.XSHE,2011-01-13,-0.003672,0.998168,1.008586,-0.448063,-0.547287,1.012868,-0.564419,-1.193669,1.493911,-1.454161,0.721873,1.599249,-0.963708,0.738409,0.191991,-0.602553,-0.736749
000001.XSHE,2011-01-14,-0.021411,0.999369,1.018766,-0.397247,-0.348588,1.018766,-0.534309,-1.260087,1.658144,-1.503734,0.746739,1.622667,-0.909416,0.681104,0.177222,-0.467178,-0.578394
000001.XSHE,2011-01-17,-0.038142,0.998687,1.027953,-0.147253,0.027248,1.035761,-0.353538,-1.271434,1.803348,-1.501349,0.804132,1.61883,-0.87599,0.668763,0.210132,-0.253452,-0.526066
000001.XSHE,2011-01-18,-0.002603,0.994781,1.003903,-0.602126,-0.583892,1.007171,-0.722224,-1.224987,1.72125,-1.430636,0.754798,1.635774,-0.778135,0.635342,0.193834,-0.25568,-0.69757
000001.XSHE,2011-01-19,0.009781,0.983213,0.990943,-0.547305,-0.571029,1.001933,-0.675754,-1.455619,1.940742,-1.392701,0.680336,1.679367,-0.813065,0.63379,0.208335,-0.329167,-0.643791
