In [1]:
# import
import gym
import numpy
import torch.optim
import torch
import torch.nn
import torch.nn.functional
import matplotlib.pyplot

In [2]:
# variables
environment = gym.make('Pendulum-v0')
NUM_ACTIONS = environment.action_space.shape[0]
NUM_STATES = environment.observation_space.shape[0]

LEARNING_RATE_ACTOR = 0.0001
LEARNING_RATE_CRITIC = 0.001
TAU = 0.001
BATCH_SIZE = 64
MEMORY_CAPACITY = 10000
GAMMA = 0.99
TRAIN_EPISODE = 1000
TEST_EPISODE = 10
VAR = 3

In [13]:
# class
class Actor(torch.nn.Module):
    def __init__(self, state, action):
        super(Actor, self).__init__()
        self.linear1 = torch.nn.Linear(state, 400)
        self.linear2 = torch.nn.Linear(400, 300)
        self.linear3 = torch.nn.Linear(300, action)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.nn.functional.relu(x)
        x = self.linear2(x)
        x = torch.nn.functional.relu(x)
        x = self.linear3(x)
        x = torch.nn.functional.tanh(x)

        return x


class Critic(torch.nn.Module):
    def __init__(self, state, action):
        super(Critic, self).__init__()
        self.linear1 = torch.nn.Linear(state, 400)
        self.linear2 = torch.nn.Linear(400 + action, 300)
        self.linear3 = torch.nn.Linear(300, 1)
        self.linear4 = torch.nn.Linear(1, 1)

    def forward(self, x, act):
        # front
        x = self.linear1(x)
        x = torch.nn.functional.relu(x)

        x = torch.cat([x, act.type_as(x)], 1)

        # end
        x = self.linear2(x)
        x = torch.nn.functional.relu(x)
        x = self.linear3(x)
        x = torch.nn.functional.relu(x)
        x = self.linear4(x)
        x = torch.nn.functional.relu(x)
        return x


class Agent(object):
    def __init__(self, state, action, var):
        self.state = state
        self.action = action

        self.actor = Actor(self.state, self.action)
        self.critic = Critic(self.state, self.action)
        self.actorTarget = Actor(self.state, self.action)
        self.criticTarget = Critic(self.state, self.action)

        self.actorOptimizer = torch.optim.Adam(
            self.actor.parameters(), lr=LEARNING_RATE_ACTOR)
        self.criticOptimizer = torch.optim.Adam(
            self.critic.parameters(), lr=LEARNING_RATE_CRITIC)
        self.criterion = torch.nn.MSELoss()
        self.buffer = []
        self.var = var
        self.memoryCounter = 0
        self.memory = numpy.zeros((MEMORY_CAPACITY, NUM_STATES * 2 + 2))

    def chooseAction(self, state0):
        state0 = torch.tensor(state0, dtype=torch.float)
        action = self.actor(state0)

        action = action.cpu().data.numpy()
        action = numpy.clip(numpy.random.normal(action, self.var), -1, 1)
        return action

    def storeTransition(self, state, action, reward, state2):
        transition = numpy.hstack((state, [action, reward], state2))
        index = self.memoryCounter % MEMORY_CAPACITY
        self.memory[index, :] = transition
        self.memoryCounter += 1

    def softUpdate(self, target, source):
        for t, s in zip(target.parameters(), source.parameters()):
            t.data.copy_((1.0 - TAU) * t.data + TAU * s.data)

    def learn(self):
        if self.memoryCounter <= 5000:
            return

        if self.memoryCounter > MEMORY_CAPACITY:
            sample_index = numpy.random.choice(
                MEMORY_CAPACITY, size=BATCH_SIZE)
        else:
            sample_index = numpy.random.choice(
                self.memoryCounter, size=BATCH_SIZE)

        batchMemory = self.memory[sample_index, :]
        batchState = torch.FloatTensor(batchMemory[:, :NUM_STATES])
        batchAction = torch.LongTensor(
            batchMemory[:, NUM_STATES:NUM_STATES + NUM_ACTIONS])
        batchReward = torch.FloatTensor(
            batchMemory[:, -NUM_STATES - 1:-NUM_STATES])
        batchState2 = torch.FloatTensor(batchMemory[:, -NUM_STATES:])

        batchAction2 = self.actorTarget(batchState2)

        qTarget = batchReward + GAMMA * self.criticTarget(
            batchState2, batchAction2).detach()
        qPredict = self.critic(batchState, batchAction)

        # update critic
        self.criticOptimizer.zero_grad()
        criticLoss = self.criterion(qPredict, qTarget)
        criticLoss.backward()
        self.criticOptimizer.step()

        # update actor
        self.actorOptimizer.zero_grad()
        action = self.actor(batchState)  # action prediction
        actorLoss = -self.critic(batchState, action)  # max -> Q for prediction
        actorLoss = actorLoss.mean()
        actorLoss.backward()
        self.actorOptimizer.step()

        self.softUpdate(self.criticTarget, self.critic)
        self.softUpdate(self.actorTarget, self.actor)

        self.var *= 0.9995

In [15]:
# main
ddpg = Agent(NUM_STATES, NUM_ACTIONS, VAR)

for times in range(1):

    # training
    print('----------TRAINING----------')
    scoreArray = []
    meanArray = []
    for i in range(TRAIN_EPISODE):
        state = environment.reset()
        totalReward = 0
        while True:
            action = ddpg.chooseAction(state)

            state2, reward, done, info = environment.step(action)

            reward = -1 if done else reward

            ddpg.storeTransition(state, action, reward, state2)

            totalReward += reward
            ddpg.learn()

            if done:
                print('[', i, ']', round(totalReward, 2))
                break

            state = state2

        scoreArray.append(totalReward)

        if i % 100 == 0:
            meanArray.append(sum(scoreArray[-100:]) / 100)

    # plot
    matplotlib.pyplot.figure(figsize=(10, 6))
    matplotlib.pyplot.plot(list(range(0, len(scoreArray))), scoreArray, color="#0000FF")
    matplotlib.pyplot.plot(list(range(0, len(scoreArray), len(meanArray))), meanArray)
    matplotlib.pyplot.title('Pendulum')
    matplotlib.pyplot.xlabel('EPISODE')
    matplotlib.pyplot.ylabel('REWARD')
    matplotlib.pyplot.show()

    # testing
    print('----------TESTING----------')
    scoreArray = []
    for i in range(TEST_EPISODE):
        state = environment.reset()
        totalReward = 0

        while True:
            action = ddpg.chooseAction(state)
            state2, reward, done, info = environment.step(action)

            totalReward += reward

            if done:
                print('[', i, ']', totalReward)
                break

            state = state2
        scoreArray.append(totalReward)

    # show result
    print('MEANS', ((sum(scoreArray) / len(scoreArray)) + 700) / 5)

----------TRAINING----------
[ 0 ] -1814.39
[ 1 ] -1068.55
[ 2 ] -777.6
[ 3 ] -1347.09
[ 4 ] -861.94
[ 5 ] -1372.56
[ 6 ] -1392.38
[ 7 ] -1057.3
[ 8 ] -1586.2
[ 9 ] -1598.21
[ 10 ] -1321.86
[ 11 ] -1805.87
[ 12 ] -1527.48
[ 13 ] -1570.61
[ 14 ] -1055.54
[ 15 ] -756.52
[ 16 ] -1115.64
[ 17 ] -1657.87
[ 18 ] -1469.42
[ 19 ] -890.92
[ 20 ] -1388.43
[ 21 ] -904.94
[ 22 ] -1732.98
[ 23 ] -769.29
[ 24 ] -750.5
[ 25 ] -1413.21
[ 26 ] -1583.21
[ 27 ] -1461.3
[ 28 ] -1404.06
[ 29 ] -1384.53
[ 30 ] -1095.44
[ 31 ] -1800.43
[ 32 ] -1177.29
[ 33 ] -1011.25
[ 34 ] -1410.8
[ 35 ] -1436.42
[ 36 ] -925.29
[ 37 ] -1065.12
[ 38 ] -1802.47
[ 39 ] -1049.98
[ 40 ] -1322.6
[ 41 ] -1599.98
[ 42 ] -1757.57
[ 43 ] -1302.53
[ 44 ] -1708.8
[ 45 ] -1625.7
[ 46 ] -835.74
[ 47 ] -1168.49
[ 48 ] -1276.41
[ 49 ] -1613.27
[ 50 ] -1434.52
[ 51 ] -1661.46
[ 52 ] -979.46
[ 53 ] -1305.37
[ 54 ] -1341.28
[ 55 ] -1294.93
[ 56 ] -1097.75
[ 57 ] -1209.34
[ 58 ] -1286.82
[ 59 ] -1321.44
[ 60 ] -1060.71
[ 61 ] -1347.26
[ 62 ] -

[ 493 ] -1504.84
[ 494 ] -1479.6
[ 495 ] -1350.21
[ 496 ] -1362.72
[ 497 ] -1827.02
[ 498 ] -1690.91
[ 499 ] -907.71
[ 500 ] -1066.78
[ 501 ] -1180.55
[ 502 ] -1501.76
[ 503 ] -1348.57
[ 504 ] -1022.72
[ 505 ] -1668.58
[ 506 ] -1363.9
[ 507 ] -1354.65
[ 508 ] -1350.04
[ 509 ] -1296.84
[ 510 ] -1355.3
[ 511 ] -1637.92
[ 512 ] -1361.1
[ 513 ] -1466.26
[ 514 ] -1345.92
[ 515 ] -1258.53
[ 516 ] -1278.91
[ 517 ] -1779.56
[ 518 ] -837.1
[ 519 ] -1193.82
[ 520 ] -1062.92
[ 521 ] -1813.0
[ 522 ] -1035.85
[ 523 ] -1193.84
[ 524 ] -1755.49
[ 525 ] -1347.22
[ 526 ] -1351.07
[ 527 ] -1067.49
[ 528 ] -1347.85
[ 529 ] -1298.61
[ 530 ] -888.87
[ 531 ] -1258.54
[ 532 ] -1290.94
[ 533 ] -1508.69
[ 534 ] -952.27
[ 535 ] -1123.2
[ 536 ] -1158.19
[ 537 ] -1090.13
[ 538 ] -1269.33
[ 539 ] -1459.77
[ 540 ] -1492.12
[ 541 ] -1195.86
[ 542 ] -1650.79
[ 543 ] -1334.9
[ 544 ] -1354.84
[ 545 ] -1357.47
[ 546 ] -1342.42
[ 547 ] -1421.6
[ 548 ] -1652.0
[ 549 ] -1343.0
[ 550 ] -1065.71
[ 551 ] -988.55
[ 552 ] -1735

[ 981 ] -1598.23
[ 982 ] -1798.97
[ 983 ] -1706.57
[ 984 ] -1183.52
[ 985 ] -1739.47
[ 986 ] -1249.61
[ 987 ] -1161.6
[ 988 ] -1443.85
[ 989 ] -1366.09
[ 990 ] -1669.3
[ 991 ] -837.97
[ 992 ] -1091.04
[ 993 ] -1301.47
[ 994 ] -1272.43
[ 995 ] -1780.82
[ 996 ] -1305.11
[ 997 ] -1059.94
[ 998 ] -1512.49
[ 999 ] -1356.16


NameError: name 'plt' is not defined