In [1]:
from __future__ import absolute_import, division, print_function

import sys
import os
import argparse
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Categorical

from knowledge_graph import KnowledgeGraph #用于preprocess 预处理
from kg_env import BatchKGEnvironment
from utils import *

logger = None

SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])




In [2]:
class ActorCritic(nn.Module):
    def __init__(self, state_dim, act_dim, gamma=0.99, hidden_sizes=[512, 256]):
        super(ActorCritic, self).__init__()
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.gamma = gamma

        self.l1 = nn.Linear(state_dim, hidden_sizes[0])
        self.l2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.actor = nn.Linear(hidden_sizes[1], act_dim)
        self.critic = nn.Linear(hidden_sizes[1], 1)

        self.saved_actions = []
        self.rewards = []
        self.entropy = []

    def forward(self, inputs):
        state, act_mask = inputs  # state: [bs, state_dim], act_mask: [bs, act_dim] //batch_size
        print('forward-state:',state.size())
        print('act_mask:',act_mask.size())


        '''
        forward-state: tensor([[ 0.0045,  0.0047, -0.0045,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0020, -0.0027,  0.0026,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0040, -0.0014, -0.0023,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.0024, -0.0019,  0.0045,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0018,  0.0029,  0.0010,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0026,  0.0008,  0.0005,  ...,  0.0000,  0.0000,  0.0000]])
        len: 32
        size: torch.Size([32, 400])
        '''
        x = self.l1(state)
        print('x = l1(state):',x.size())
        
        x = F.dropout(F.elu(x), p=0.5) #dropout是为了解决大网络但是小训练集的过拟合问题，但是可以一般性的避免过拟合。
        print('x = dropout:',x.size())
        

        out = self.l2(x)
        print('out=l2(x):',out.size())
        
        x = F.dropout(F.elu(out), p=0.5)
        '''
        print('x=drop(out):',x)
        print('size:',x.size())
        '''

        
        actor_logits = self.actor(x)
        print('actor_logits:',actor_logits.size())
        
        actor_logits[1 - act_mask] = -999999.0
        '''
        import torch
        a = torch.tensor([2,3,4,5])
        b = torch.tensor([0,-1,-1,0])
        a[1-b]->[3,4,4,3]
        
        '''
        print('actor_logits_masked:',actor_logits.size())
        act_probs = F.softmax(actor_logits, dim=-1)  # Tensor of [bs, act_dim]
        
        print('act_probs:',act_probs.size())
        #print('...............')
        

        
        
        state_values = self.critic(x)  # Tensor of [bs, 1]
        return act_probs, state_values
    
    def select_action(self, batch_state, batch_act_mask, device):
        state = torch.FloatTensor(batch_state).to(device)  # Tensor [bs, state_dim]
        act_mask = torch.ByteTensor(batch_act_mask).to(device)  # Tensor of [bs, act_dim]
        print('state:',len(state),len(state[0]))
        print('act_mask:',len(act_mask),len(act_mask[0]))
        probs, value = self((state, act_mask))  # act_probs: [bs, act_dim], state_value: [bs, 1] forward
        #prob = act_probs
        #value = state_values
        #self(())调用forward
        '''
        print('probs:',probs) #act_prob
        print('len:',len(probs))
        print('size:',probs.size())
        print('value:',value) #state_value
        print('len:',len(value))
        print('size:',value.size())
        '''

        
        m = Categorical(probs)#32*251
        #若probs = [0.2 0.2 0.3 0.3]
        #那么 m.sample 取 0 1 2 3 的概率就对应probs
        acts = m.sample()  # Tensor of [bs, ], requires_grad=False
        '''
        print('acts extracted from m:')
        print(acts)
        print('size:',acts.size())#[32]
        '''
        
        
        # [CAVEAT] If sampled action is out of action_space, choose the first action in action_space.
        valid_idx = act_mask.gather(1, acts.view(-1, 1)).view(-1)#torch.gather(input=act_mask,dim=1,acts.view(-1,1))
        #.view(-1行,1列)，-1 代表随便几行
        acts[valid_idx == 0] = 0
        print('select_act:',acts.cpu().numpy().tolist())
        #print(len(acts.cpu().numpy().tolist()[0]))
        
        #print('***********************************')
        self.saved_actions.append(SavedAction(m.log_prob(acts), value))
        
        self.entropy.append(m.entropy())
        return acts.cpu().numpy().tolist()

    def update(self, optimizer, device, ent_weight):
        if len(self.rewards) <= 0:
            del self.rewards[:]
            del self.saved_actions[:]
            del self.entropy[:]
            return 0.0, 0.0, 0.0

        batch_rewards = np.vstack(self.rewards).T  # numpy array of [bs, #steps]
        batch_rewards = torch.FloatTensor(batch_rewards).to(device)
        
        num_steps = batch_rewards.shape[1]
        for i in range(1, num_steps):
            batch_rewards[:, num_steps - i - 1] += self.gamma * batch_rewards[:, num_steps - i]

        actor_loss = 0
        critic_loss = 0
        entropy_loss = 0
        
        for i in range(0, num_steps):
            log_prob, value = self.saved_actions[i]  # log_prob: Tensor of [bs, ], value: Tensor of [bs, 1]
            advantage = batch_rewards[:, i] - value.squeeze(1)  # Tensor of [bs, ]
            actor_loss += -log_prob * advantage.detach()  # Tensor of [bs, ]
            critic_loss += advantage.pow(2)  # Tensor of [bs, ]
            entropy_loss += -self.entropy[i]  # Tensor of [bs, ]
        actor_loss = actor_loss.mean()
        critic_loss = critic_loss.mean()
        entropy_loss = entropy_loss.mean()
        loss = actor_loss + critic_loss + ent_weight * entropy_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del self.rewards[:]
        del self.saved_actions[:]
        del self.entropy[:]
        
        #loss, ploss, vloss, eloss
        return loss.item(), actor_loss.item(), critic_loss.item(), entropy_loss.item()




In [3]:
class ACDataLoader(object):
    def __init__(self, uids, batch_size):
        self.uids = np.array(uids)
        self.num_users = len(uids)
        self.batch_size = batch_size
        self.reset()

    def reset(self):
        self._rand_perm = np.random.permutation(self.num_users)
        self._start_idx = 0
        self._has_next = True

    def has_next(self):
        return self._has_next

    def get_batch(self):
        if not self._has_next:
            return None
        # Multiple users per batch
        end_idx = min(self._start_idx + self.batch_size, self.num_users)
        batch_idx = self._rand_perm[self._start_idx:end_idx]
        batch_uids = self.uids[batch_idx]
        self._has_next = self._has_next and end_idx < self.num_users
        self._start_idx = end_idx
        return batch_uids.tolist()




In [4]:
def train(args):
    env = BatchKGEnvironment(args.dataset, args.max_acts, max_path_len=args.max_path_len, state_history=args.state_history)
    uids = list(env.kg(USER).keys())
    dataloader = ACDataLoader(uids, args.batch_size)
    model = ActorCritic(env.state_dim, env.act_dim, gamma=args.gamma, hidden_sizes=args.hidden).to(args.device)
    logger.info('Parameters:' + str([i[0] for i in model.named_parameters()]))
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    total_losses, total_plosses, total_vlosses, total_entropy, total_rewards = [], [], [], [], []
    step = 0
    model.train()
    for epoch in range(1, args.epochs + 1):
        ### Start epoch ###
        dataloader.reset()
        while dataloader.has_next():
            batch_uids = dataloader.get_batch()
            print('batch_uids:',len(batch_uids))
            print(batch_uids)
            ### Start batch episodes ###
            batch_state = env.reset(batch_uids)  # numpy array of [bs, state_dim]
            print('batch_state:',batch_state.shape)
            done = False
            while not done:
                batch_act_mask = env.batch_action_mask(dropout=args.act_dropout)  # numpy array of size [bs, act_dim]
                '''select action'''
                batch_act_idx = model.select_action(batch_state, batch_act_mask, args.device)  # int
                batch_state, batch_reward, done = env.batch_step(batch_act_idx)
                model.rewards.append(batch_reward)
            ### End of episodes ###
            lr = args.lr * max(1e-4, 1.0 - float(step) / (args.epochs * len(uids) / args.batch_size))
            for pg in optimizer.param_groups:
                pg['lr'] = lr
            # Update policy
            total_rewards.append(np.sum(model.rewards))
            loss, ploss, vloss, eloss = model.update(optimizer, args.device, args.ent_weight)
            total_losses.append(loss)
            total_plosses.append(ploss)
            total_vlosses.append(vloss)
            total_entropy.append(eloss)
            step += 1
            # Report performance
            if step > 0 and step % 100 == 0:
                avg_reward = np.mean(total_rewards) / args.batch_size
                avg_loss = np.mean(total_losses)
                avg_ploss = np.mean(total_plosses)
                avg_vloss = np.mean(total_vlosses)
                avg_entropy = np.mean(total_entropy)
                total_losses, total_plosses, total_vlosses, total_entropy, total_rewards = [], [], [], [], []
                logger.info('epoch/step={:d}/{:d}'.format(epoch, step) + ' | loss={:.5f}'.format(avg_loss) + 
                            ' | ploss={:.5f}'.format(avg_ploss) + ' | vloss={:.5f}'.format(avg_vloss) + 
                            ' | entropy={:.5f}'.format(avg_entropy) +' | reward={:.5f}'.format(avg_reward))
        ### END of epoch ###

        policy_file = '{}/policy_model_epoch_{}.ckpt'.format(args.log_dir, epoch)
        logger.info("Save model to " + policy_file)
        torch.save(model.state_dict(), policy_file)




In [5]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default=BEAUTY, help='One of {clothing, cell, beauty, cd}')
parser.add_argument('--name', type=str, default='train_agent', help='directory name.')
parser.add_argument('--seed', type=int, default=123, help='random seed.')
parser.add_argument('--gpu', type=str, default='0', help='gpu device.')
parser.add_argument('--epochs', type=int, default=50, help='Max number of epochs.')
parser.add_argument('--batch_size', type=int, default=32, help='batch size.')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate.')
parser.add_argument('--max_acts', type=int, default=250, help='Max number of actions.')
#
parser.add_argument('--max_path_len', type=int, default=3, help='Max path length.')
parser.add_argument('--gamma', type=float, default=0.99, help='reward discount factor.')
parser.add_argument('--ent_weight', type=float, default=1e-3, help='weight factor for entropy loss')
parser.add_argument('--act_dropout', type=float, default=0.5, help='action dropout rate.')
parser.add_argument('--state_history', type=int, default=1, help='state history length')
parser.add_argument('--hidden', type=int, nargs='*', default=[512, 256], help='number of samples')
args = parser.parse_args(['--dataset',CELL])



In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
if not os.path.isdir(args.log_dir):
    os.makedirs(args.log_dir)

global logger
logger = get_logger(args.log_dir + '/train_log.txt')
logger.info(args)

set_random_seed(args.seed)

[INFO]  Namespace(act_dropout=0.5, batch_size=32, dataset='cell', device='cpu', ent_weight=0.001, epochs=50, gamma=0.99, gpu='0', hidden=[512, 256], log_dir='./tmp/Amazon_Cellphones/train_agent', lr=0.0001, max_acts=250, max_path_len=3, name='train_agent', seed=123, state_history=1)


In [7]:
train(args)

Load embedding: ./tmp/Amazon_Cellphones/transe_embed.pkl
[INFO]  Parameters:['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'actor.weight', 'actor.bias', 'critic.weight', 'critic.bias']
batch_uids: 32
[8639, 14910, 10146, 23298, 8040, 18852, 11923, 8586, 6391, 848, 4884, 4475, 16206, 27283, 6009, 6830, 24060, 4372, 11941, 7039, 26419, 24441, 16744, 13233, 12395, 805, 7496, 11959, 6365, 9942, 17829, 22348]
batch_state: (32, 400)
state: 32 400
act_mask: 32 251
forward-state: torch.Size([32, 400])
act_mask: torch.Size([32, 251])
x = l1(state): torch.Size([32, 512])
x = dropout: torch.Size([32, 512])
out=l2(x): torch.Size([32, 256])
actor_logits: torch.Size([32, 251])
actor_logits_masked: torch.Size([32, 251])
act_probs: torch.Size([32, 251])
select_act: [2, 3, 6, 3, 3, 9, 9, 0, 10, 4, 3, 3, 1, 16, 1, 10, 10, 7, 1, 6, 10, 9, 0, 4, 6, 5, 5, 0, 5, 1, 5, 12]


TypeError: object of type 'int' has no len()

In [7]:
train(args)

Load embedding: ./tmp/Amazon_Cellphones/transe_embed.pkl
[INFO]  Parameters:['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias', 'actor.weight', 'actor.bias', 'critic.weight', 'critic.bias']
[INFO]  epoch/step=1/100 | loss=0.03575 | ploss=-0.00155 | vloss=0.04593 | entropy=-8.63200 | reward=0.03642
[INFO]  epoch/step=1/200 | loss=0.02726 | ploss=-0.00746 | vloss=0.04332 | entropy=-8.59969 | reward=0.03380
[INFO]  epoch/step=1/300 | loss=0.02667 | ploss=-0.00788 | vloss=0.04316 | entropy=-8.60557 | reward=0.03493
[INFO]  epoch/step=1/400 | loss=0.00226 | ploss=-0.03246 | vloss=0.04333 | entropy=-8.61099 | reward=0.03769
[INFO]  epoch/step=1/500 | loss=0.01120 | ploss=-0.02822 | vloss=0.04792 | entropy=-8.50753 | reward=0.04253
[INFO]  epoch/step=1/600 | loss=0.02836 | ploss=-0.02225 | vloss=0.05907 | entropy=-8.46007 | reward=0.04586
[INFO]  epoch/step=1/700 | loss=0.02852 | ploss=-0.01655 | vloss=0.05347 | entropy=-8.39805 | reward=0.04727
[INFO]  epoch/step=1/800 | loss=0.00033 | ploss=-0

[INFO]  epoch/step=16/13500 | loss=0.28591 | ploss=-0.01708 | vloss=0.30466 | entropy=-1.67439 | reward=0.21281
[INFO]  epoch/step=16/13600 | loss=0.27238 | ploss=-0.03011 | vloss=0.30418 | entropy=-1.68854 | reward=0.21205
[INFO]  epoch/step=16/13700 | loss=0.27563 | ploss=-0.02838 | vloss=0.30568 | entropy=-1.66973 | reward=0.21384
[INFO]  epoch/step=16/13800 | loss=0.28588 | ploss=-0.02396 | vloss=0.31146 | entropy=-1.62210 | reward=0.22485
[INFO]  epoch/step=16/13900 | loss=0.28189 | ploss=-0.01520 | vloss=0.29874 | entropy=-1.65199 | reward=0.20663
[INFO]  Save model to ./tmp/Amazon_Cellphones/train_agent/policy_model_epoch_16.ckpt
[INFO]  epoch/step=17/14000 | loss=0.27647 | ploss=-0.01484 | vloss=0.29296 | entropy=-1.64474 | reward=0.20312
[INFO]  epoch/step=17/14100 | loss=0.27154 | ploss=-0.02177 | vloss=0.29501 | entropy=-1.70294 | reward=0.20738
[INFO]  epoch/step=17/14200 | loss=0.31381 | ploss=-0.00949 | vloss=0.32499 | entropy=-1.68808 | reward=0.22595
[INFO]  epoch/step=

[INFO]  epoch/step=31/26900 | loss=0.27963 | ploss=-0.01720 | vloss=0.29806 | entropy=-1.22639 | reward=0.22043
[INFO]  epoch/step=31/27000 | loss=0.29160 | ploss=-0.01824 | vloss=0.31107 | entropy=-1.23183 | reward=0.23495
[INFO]  Save model to ./tmp/Amazon_Cellphones/train_agent/policy_model_epoch_31.ckpt
[INFO]  epoch/step=32/27100 | loss=0.25874 | ploss=-0.02428 | vloss=0.28426 | entropy=-1.23749 | reward=0.21720
[INFO]  epoch/step=32/27200 | loss=0.28120 | ploss=-0.02238 | vloss=0.30481 | entropy=-1.22463 | reward=0.23388
[INFO]  epoch/step=32/27300 | loss=0.28264 | ploss=-0.02383 | vloss=0.30766 | entropy=-1.19817 | reward=0.22716
[INFO]  epoch/step=32/27400 | loss=0.28087 | ploss=-0.01666 | vloss=0.29873 | entropy=-1.20174 | reward=0.22539
[INFO]  epoch/step=32/27500 | loss=0.29073 | ploss=-0.01796 | vloss=0.30987 | entropy=-1.18036 | reward=0.23129
[INFO]  epoch/step=32/27600 | loss=0.26725 | ploss=-0.01980 | vloss=0.28826 | entropy=-1.20706 | reward=0.21504
[INFO]  epoch/step=

[INFO]  epoch/step=47/40300 | loss=0.26640 | ploss=-0.03844 | vloss=0.30597 | entropy=-1.11851 | reward=0.23442
[INFO]  epoch/step=47/40400 | loss=0.29145 | ploss=-0.01039 | vloss=0.30295 | entropy=-1.10951 | reward=0.23197
[INFO]  epoch/step=47/40500 | loss=0.28833 | ploss=-0.02095 | vloss=0.31043 | entropy=-1.14456 | reward=0.23439
[INFO]  epoch/step=47/40600 | loss=0.28123 | ploss=-0.01990 | vloss=0.30224 | entropy=-1.11935 | reward=0.22772
[INFO]  epoch/step=47/40700 | loss=0.27604 | ploss=-0.02554 | vloss=0.30269 | entropy=-1.11438 | reward=0.23222
[INFO]  epoch/step=47/40800 | loss=0.27538 | ploss=-0.03276 | vloss=0.30926 | entropy=-1.12233 | reward=0.23290
[INFO]  epoch/step=47/40900 | loss=0.29862 | ploss=-0.01471 | vloss=0.31446 | entropy=-1.12279 | reward=0.23842
[INFO]  Save model to ./tmp/Amazon_Cellphones/train_agent/policy_model_epoch_47.ckpt
[INFO]  epoch/step=48/41000 | loss=0.29015 | ploss=-0.01885 | vloss=0.31012 | entropy=-1.12078 | reward=0.23509
[INFO]  epoch/step=

In [16]:
a = [1,2,3]
k = [b for b in a]
print(type(a))
print(type(k))
a.reshape(-1,1)

next_state_batch = next_state_batch.reshape(-1,400)

<class 'list'>
<class 'list'>


AttributeError: 'list' object has no attribute 'reshape'

In [None]:
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default=BEAUTY, help='One of {clothing, cell, beauty, cd}')
    parser.add_argument('--name', type=str, default='train_agent', help='directory name.')
    parser.add_argument('--seed', type=int, default=123, help='random seed.')
    parser.add_argument('--gpu', type=str, default='0', help='gpu device.')
    parser.add_argument('--epochs', type=int, default=50, help='Max number of epochs.')
    parser.add_argument('--batch_size', type=int, default=32, help='batch size.')
    parser.add_argument('--lr', type=float, default=1e-4, help='learning rate.')
    parser.add_argument('--max_acts', type=int, default=250, help='Max number of actions.')
    parser.add_argument('--max_path_len', type=int, default=3, help='Max path length.')
    parser.add_argument('--gamma', type=float, default=0.99, help='reward discount factor.')
    parser.add_argument('--ent_weight', type=float, default=1e-3, help='weight factor for entropy loss')
    parser.add_argument('--act_dropout', type=float, default=0.5, help='action dropout rate.')
    parser.add_argument('--state_history', type=int, default=1, help='state history length')
    parser.add_argument('--hidden', type=int, nargs='*', default=[512, 256], help='number of samples')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    args.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'

    args.log_dir = '{}/{}'.format(TMP_DIR[args.dataset], args.name)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)

    global logger
    logger = get_logger(args.log_dir + '/train_log.txt')
    logger.info(args)

    set_random_seed(args.seed)
    train(args)


if __name__ == '__main__':
    main()



In [9]:
import torch
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
b = torch.tensor([[0,1],[1,0]])



print(b.view(-1,1))
print(a.gather(1, b))
c = a.gather(1, b.view(-1, 1)).view(-1)
print(c)

tensor([[0],
        [1],
        [1],
        [0]])


RuntimeError: Expected tensor [2, 2], src [3, 3] and index [2, 2] to have the same size apart from dimension 1

In [10]:
t = torch.Tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))

tensor([[1., 1.],
        [4., 3.]])

In [12]:
import numpy as np
t = np.array(t)

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, state_dim, act_dim, gamma=0.99, hidden_sizes=[512, 256]):
        super(ActorCritic, self).__init__()
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.gamma = gamma

        self.l1 = nn.Linear(state_dim, hidden_sizes[0])
        self.l2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.actor = nn.Linear(hidden_sizes[1], act_dim)
        self.critic = nn.Linear(hidden_sizes[1], 1)

        self.saved_actions = []
        self.rewards = []
        self.entropy = []

    def forward(self, inputs):
        state, act_mask = inputs  # state: [bs, state_dim], act_mask: [bs, act_dim] //batch_size

        x = self.l1(state)

        x = F.dropout(F.elu(x), p=0.5) #dropout是为了解决大网络但是小训练集的过拟合问题，但是可以一般性的避免过拟合。
        out = self.l2(x)

        x = F.dropout(F.elu(out), p=0.5)
        actor_logits = self.actor(x)
        actor_logits[1 - act_mask] = -999999.0
        act_probs = F.softmax(actor_logits, dim=-1)  # Tensor of [bs, act_dim]
        
        state_values = self.critic(x)  # Tensor of [bs, 1]
        return act_probs, state_values
    
    def select_action(self, batch_state, batch_act_mask, device):
        state = torch.FloatTensor(batch_state).to(device)  # Tensor [bs, state_dim]
        act_mask = torch.ByteTensor(batch_act_mask).to(device)  # Tensor of [bs, act_dim]
        probs, value = self((state, act_mask))  # act_probs: [bs, act_dim], state_value: [bs, 1] forward

        m = Categorical(probs)#32*251
        #若probs = [0.2 0.2 0.3 0.3]
        #那么 m.sample 取 0 1 2 3 的概率就对应probs
        acts = m.sample()  # Tensor of [bs, ], requires_grad=False
        
        # [CAVEAT] If sampled action is out of action_space, choose the first action in action_space.
        valid_idx = act_mask.gather(1, acts.view(-1, 1)).view(-1)#torch.gather(input=act_mask,dim=1,acts.view(-1,1))
        #.view(-1行,1列)，-1 代表随便几行
        acts[valid_idx == 0] = 0
        #print('***********************************')
        self.saved_actions.append(SavedAction(m.log_prob(acts), value))
        
        self.entropy.append(m.entropy())
        
        return acts.cpu().numpy().tolist()

    def update(self, optimizer, device, ent_weight):
        if len(self.rewards) <= 0:
            del self.rewards[:]
            del self.saved_actions[:]
            del self.entropy[:]
            return 0.0, 0.0, 0.0

        batch_rewards = np.vstack(self.rewards).T  # numpy array of [bs, #steps]
        batch_rewards = torch.FloatTensor(batch_rewards).to(device)
        
        num_steps = batch_rewards.shape[1]
        for i in range(1, num_steps):
            batch_rewards[:, num_steps - i - 1] += self.gamma * batch_rewards[:, num_steps - i]

        actor_loss = 0
        critic_loss = 0
        entropy_loss = 0
        
        for i in range(0, num_steps):
            log_prob, value = self.saved_actions[i]  # log_prob: Tensor of [bs, ], value: Tensor of [bs, 1]
            advantage = batch_rewards[:, i] - value.squeeze(1)  # Tensor of [bs, ]
            actor_loss += -log_prob * advantage.detach()  # Tensor of [bs, ]
            critic_loss += advantage.pow(2)  # Tensor of [bs, ]
            entropy_loss += -self.entropy[i]  # Tensor of [bs, ]
        actor_loss = actor_loss.mean()
        critic_loss = critic_loss.mean()
        entropy_loss = entropy_loss.mean()
        loss = actor_loss + critic_loss + ent_weight * entropy_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        del self.rewards[:]
        del self.saved_actions[:]
        del self.entropy[:]
        
        #loss, ploss, vloss, eloss
        return loss.item(), actor_loss.item(), critic_loss.item(), entropy_loss.item()


