In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym
import time
import mani_skill.env
import pandas as pd
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import style

PATH = 'data.csv'
ENV_NAME = 'OpenCabinetDrawer-v0'         # 游戏名称
SEED = 123                       # 随机数种子

class PreTrain(object):
    def __init__(self, a_dim, s_dim, a_bound,):
        
        data = pd.read_csv(PATH)
        self.memory = np.array(data)
        
        self.a_dim = a_dim
        self.s_dim = s_dim
        self.a_bound = a_bound
        self.pointer = 0                                                                         # exp buffer指针
        self.lr_a = 0.001                                                                        # learning rate for actor
        self.gamma = 0.9                                                                         # reward discount
        self.tau = 0.01                                                                          # 软更新比例
        self.memory_capacity = 406357
        self.batch_size = 32
#         self.memory = np.zeros((self.memory_capacity, s_dim * 2 + a_dim + 1), dtype=np.float32)

        self.Reward_record = []
        
        class ANet(nn.Module):                               # 定义动作网络
            def __init__(self, s_dim, a_dim, a_bound):
                super(ANet,self).__init__()
                self.a_bound = a_bound
                self.fc1 = nn.Linear(s_dim,256)
                self.fc1.weight.data.normal_(0,0.1)          # initialization
                
                self.fc2 = nn.Linear(256,512)
                self.fc2.weight.data.normal_(0,0.1)
                               
                self.fc3 = nn.Linear(512,256)
                self.fc3.weight.data.normal_(0,0.1)
                
                self.out = nn.Linear(256,a_dim)
                self.out.weight.data.normal_(0,0.1)          # initialization
            def forward(self,x):
                x = self.fc1(x)
                x = F.relu(x)
                
                x = self.fc2(x)
                x = F.relu(x)
                
                x = self.fc3(x)
                x = F.relu(x)
                
                x = self.out(x)
                x = F.tanh(x)
                actions_value = x * a_bound
                return actions_value



        self.Actor_eval = ANet(s_dim, a_dim, a_bound)        # 主网络
        self.atrain = torch.optim.Adam(self.Actor_eval.parameters(),lr = self.lr_a)  # actor的优化器
        self.loss = nn.MSELoss()                          # 损失函数采用均方误差
        
    def train(self):
            
        for i in range(len(self.memory)):
                
            bt = self.memory[i, :]                                              # 采样batch_size个sample
            bs = torch.FloatTensor(bt[:self.s_dim])                                # state
            ba = torch.FloatTensor(bt[self.s_dim: self.s_dim + self.a_dim])        # action
                
            a = self.Actor_eval(bs)
            loss_a = self.loss(ba, a)
            
            self.atrain.zero_grad()
            loss_a.backward()
            self.atrain.step()
            print(i)
        
    def save(self,folder_name):
        os.mkdir('./DPG model/' + folder_name)
        PATH1 = './DPG model/' + folder_name + '/Actor_eval.h5f'
        torch.save(self.Actor_eval.state_dict(), PATH1)
            

In [None]:
###############################  training  ####################################
env = gym.make(ENV_NAME)
env = env.unwrapped
env.seed(SEED)                                          # 设置Gym的随机数种子
torch.manual_seed(SEED)                                 # 设置pytorch的随机数种子

env.set_env_mode(obs_mode='state', reward_type='dense')

s_dim = env.observation_space.shape[0]                  # 状态空间
a_dim = env.action_space.shape[0]                       # 动作空间
a_bound = 1                         # 动作取值区间,对称区间，故只取上界
pre = PreTrain(a_dim, s_dim, a_bound)

pre.train()

pre.save('actor_pre')
print('Running time: ', time.time() - t1)