In [41]:
# >> tensorboard --logdir=./tensorboard
'''
Do Surgery, test Surgery
and keep train. 
'''
import json
import gym
import datetime
import collections
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter


class Json_Parser:
    def __init__(self, file_name):
        with open(file_name) as json_file:
            self.json_data = json.load(json_file)

    def load_parser(self):
        return self.json_data


class Qnet(nn.Module):
    def __init__(self,):
        super().__init__()
        self.parser = Json_Parser("config.json")
        h = self.parser.load_parser()['agent']['hidden_unit']
        self.fc1 = nn.Linear(8, h)
        self.fc2 = nn.Linear(h, h)
        self.fc3 = nn.Linear(h, 2)

    def forward(self, x):
        acti = self.parser.load_parser()['agent']['activation']
        if acti == "relu":
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
        elif acti == "sigmoid":
            x = torch.sigmoid(self.fc1(x))
            x = torch.sigmoid(self.fc2(x))
            x = self.fc3(x)
        elif acti == "softmax":
            x = F.softmax(self.fc1(x), dim=0)
            x = F.softmax(self.fc2(x), dim=0)
            x = self.fc3(x)
        return x


class ReplayMemory:
    def __init__(self, memory_size, keys):
        self.memory = {}
        for key in keys:
            self.memory[key] = collections.deque(maxlen=memory_size)
        self.memory_size = memory_size

    def save(self, observations):
        for i, key in enumerate(self.memory.keys()):
            self.memory[key].append(observations[i])

    def __len__(self):
        return len(self.memory['x'])

    def sample(self, idx):
        sub_memory = {}
        for key in self.memory.keys():
            sub_memory[key] = [self.memory[key][i] for i in idx]

        ss, actions, rs, ss_next, dones = sub_memory.values()
        ss = torch.stack(ss)
        ss_next = torch.stack(ss_next)
        rs = np.array(rs)
        rs = torch.from_numpy(rs).float()

        return (ss, actions, rs, ss_next, dones)


class DQNAgent:
    def __init__(self):
        super().__init__()

        date_time = datetime.datetime.now().strftime("%Y%m%d-%H_%M_%S")
        self.parser = Json_Parser("config.json")
        self.parm = self.parser.load_parser()
        self.method = self.parm['method']
        self.max_step = self.parm['max_step']
        self.discount_factor = self.parm['agent']['discount_factor']
        self.lr = self.parm['optimizer']['learning_rate']
        self.eps = self.parm['optimizer']['eps']
        self.eps_max = self.eps
        self.eps_min = self.parm['optimizer']['eps_min']
        self.eps_mid = self.parm['optimizer']['eps_mid']
        self.eps_anneal = self.parm['optimizer']['eps_anneal']
        self.episode_size = self.parm['episode_size']
        self.minibatch_size = self.parm['minibatch_size']
        self.net_update_period = self.parm['net_update_period']

        self.env = gym.make('{}'.format(self.parm['env_name']))
        self.net = Qnet()
        self.target_net = Qnet()
        self.target_net.load_state_dict(self.net.state_dict())
        self.target_net.eval()

        self.replay_memory = ReplayMemory(self.parm['memory_size'], keys=[
                                          'x', 'a', 'r', 'x_next', 'done'])
        self.optimizer = optim.Adam(
            self.net.parameters(), lr=self.lr, weight_decay=0)
        self.loss = nn.SmoothL1Loss()

        save_name = self.parm['env_name'] + '_' + self.method + '_' + self.parm['agent']['activation'] + '_' + \
            str(self.parm['agent']['hidden_unit']) + '_' +'surgery_' + date_time
        self.writer = SummaryWriter('./result/tensorboard/' + save_name)
        self.net_save_path = './result/model/model_{}.pth'.format(save_name)
        self.writer.add_text('config', json.dumps(self.parm))

    def get_action(self, x):
        if np.random.rand() < self.eps:
            action = np.random.randint(2)
        else:
            self.net.eval()
            q = self.net(x.view(1, -1))
            action = np.argmax(q.detach().numpy())
        return action

    def epsilon_decaying(self):
        if self.eps > self.eps_mid:
            self.eps -= (self.eps_max-self.eps_mid)/self.eps_anneal
        if self.eps < self.eps_mid and self.eps > self.eps_min:
            self.eps -= (self.eps_mid-self.eps_min)/self.eps_anneal

    def train(self, running_loss):
        self.epsilon_decaying()

        self.net.train()
        minibatch_idx = np.random.choice(self.replay_memory.__len__(), self.minibatch_size)
        ss, actions, rs, ss_next, dones = self.replay_memory.sample(minibatch_idx)
        final_state_idx = np.nonzero(dones)

        if self.method == "double":
            with torch.no_grad():
                self.net.eval()
                q_next = self.net(ss_next)
                q_next_ = self.target_net(ss_next)

            self.net.train()
            self.optimizer.zero_grad()
            q = self.net(ss)
            q_next_max, q_next_argmax = torch.max(q_next, 1)
            v_next = torch.gather(q_next_, 1, q_next_argmax.view(-1, 1)).squeeze()

        if self.method == "vanilla":
            with torch.no_grad():
                q_next = self.target_net(ss_next)

            self.optimizer.zero_grad()
            q = self.net(ss)
            q_next_max, q_next_argmax = torch.max(q_next, 1)
            v_next = q_next_max

        v_next[final_state_idx] = 0
        q_target = rs + self.discount_factor*v_next
        actions = torch.tensor(actions).view(-1, 1)
        q_relevant = torch.gather(q, 1, actions).squeeze()

        loss = self.loss(q_relevant, q_target)
        loss.backward()
        self.optimizer.step()

        running_loss = loss.item() if running_loss == 0 else 0.99 * \
            running_loss + 0.01*loss.item()

        return running_loss

    def surgery(self, old_link, new_link, old_para):
        '''
        do surgery for each layer, 
        simply use zero_init for new weight. 
        '''
        for name, g in self.net.named_parameters():
            if g.shape == old_para[name].shape:
                g.data = old_para[name].data
            else:
                new_g = torch.zeros_like(g)
                mid_channel, in_channel = old_para[name].shape
                for old_feature in old_link:
                    new_g[:mid_channel, new_link[old_feature]] = \
                        old_para[name][:, old_link[old_feature]]
                g.data = new_g
        self.target_net.load_state_dict(self.net.state_dict())
        self.eps_surgery = self.parm['optimizer']['eps_surgery']
        self.eps = self.eps_surgery
        # self.eps_anneal_surgery = self.parm['optimizer']['eps_anneal_surgery']
        
    def sample_with_surgery(self, sample_n=10000, show=False):
        '''
        Test performace after Surgery. 
        To check if performace is kept after Surgery.
        '''
        score = 0
        temp_eps = self.eps
        self.eps = 0
        for episode in range(sample_n):
            s_now = self.env.reset()
            s_prev = s_now
            for step in range(self.max_step):
                x = torch.from_numpy(np.concatenate(
                    (s_now, s_now-s_prev))).float()
                a = self.get_action(x)
                s_next, r, done, _ = self.env.step(a)
                score += 1
                x_next = torch.from_numpy(
                    np.concatenate((s_next, s_next-s_now))).float()
                self.replay_memory.save((x, a, r, x_next, done))
                if done:
                    break
                else:
                    s_prev = s_now
                    s_now = s_next
            if show: 
                print("sample episode: {} | score: {:3.1f}".format(episode, score))
            score = 0
        self.eps = temp_eps

    def run(self):
        backprops_total = 0
        running_loss = 0
        latest_scores = collections.deque(maxlen=100)
        pass_score = self.max_step - 4

        s_now = self.env.reset()
        s_prev = s_now
        score = 0
        terminal_flag = False

        for episode in range(self.episode_size):
            episode += 1
            for step in range(self.max_step):
                x = torch.from_numpy(np.concatenate(
                    (s_now, s_now-s_prev))).float()
                a = self.get_action(x)

                s_next, r, done, _ = self.env.step(a)
                score += 1

                x_next = torch.from_numpy(
                    np.concatenate((s_next, s_next-s_now))).float()
                self.replay_memory.save((x, a, r, x_next, done))

                if done:
                    latest_scores.append(score)
                    score = 0
                    s_now = self.env.reset()
                    s_prev = s_now
                else:
                    s_prev = s_now
                    s_now = s_next

                if self.replay_memory.__len__() > self.minibatch_size:
                    running_loss = self.train(running_loss)
                    backprops_total += 1

                self.writer.add_scalar('memory_size', self.replay_memory.__len__(), episode)
                self.writer.add_scalar('epsilon', self.eps, episode)
                self.writer.add_scalar('running_loss', running_loss, episode)
                self.writer.add_scalar('avg_score', np.mean(latest_scores), episode)

                if backprops_total % self.net_update_period == 0:
                    self.target_net.load_state_dict(self.net.state_dict())

                if done and episode % 100 == 0:
                    print("episode: {} | memory_size: {:5d} | eps: {:.3f} | running_loss: {:.3f} | last 100 avg score: {:3.1f}".
                          format(episode, self.replay_memory.__len__(), self.eps, running_loss, np.mean(latest_scores)))
                    torch.save(self.net.state_dict(), self.net_save_path)

                    if np.mean(latest_scores) > pass_score:
                        print('Latest 100 average score: {}, pass score: {}, test is passed'.format(
                            np.mean(latest_scores), pass_score))
                        terminal_flag = True

                if done:
                    break
            if terminal_flag:
                break
        self.env.close()


In [42]:
agent = DQNAgent()
old = torch.load('/home/workspace/util/surgery/CartPole_DQN/result/model/model_CartPole-v0_double_relu_256_in220220516-09_55_43.pth')

In [43]:
old_link = {"Cart Position":0, "Pole Angle":1, "Cart Position_d":2, "Pole Angle_d":3, }
new_link = {"Cart Position":0, "Cart Velocity":1, "Pole Angle":2, "Pole Angular Velocity":3, 
    "Cart Position_d":4, "Cart Velocity_d":5, "Pole Angle_d":6, "Pole Angular Velocity_d":7}

In [44]:
agent.surgery(old_link, new_link, old)
agent.sample_with_surgery(15, show=True)

sample episode: 0 | score: 199.0
sample episode: 1 | score: 199.0
sample episode: 2 | score: 199.0
sample episode: 3 | score: 199.0
sample episode: 4 | score: 199.0
sample episode: 5 | score: 199.0
sample episode: 6 | score: 199.0
sample episode: 7 | score: 199.0
sample episode: 8 | score: 199.0
sample episode: 9 | score: 199.0
sample episode: 10 | score: 199.0
sample episode: 11 | score: 199.0
sample episode: 12 | score: 199.0
sample episode: 13 | score: 199.0
sample episode: 14 | score: 199.0


In [46]:
agent.replay_memory.__len__(), agent.eps_surgery

(2985, 0.25)

In [47]:
agent.run()

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


episode: 100 | memory_size: 11979 | eps: 0.088 | running_loss: 0.607 | last 100 avg score: 107.1
episode: 200 | memory_size: 21979 | eps: 0.058 | running_loss: 0.524 | last 100 avg score: 166.7
episode: 300 | memory_size: 31979 | eps: 0.028 | running_loss: 0.530 | last 100 avg score: 200.0
Latest 100 average score: 200.0, pass score: 195, test is passed


In [48]:
for name ,g in agent.net.named_parameters():
    print(name, g[:10, ])
    break

fc1.weight tensor([[-7.3905e-01, -1.1381e-01,  1.7621e-02, -7.5021e-02,  4.4969e-02,
          2.4720e-02,  2.0213e-01, -3.1201e-02],
        [ 8.9196e-02,  0.0000e+00, -3.1708e-01,  0.0000e+00, -4.6932e-01,
          0.0000e+00,  2.5036e-01,  0.0000e+00],
        [ 4.6823e-01,  1.0666e-02, -3.4330e-01, -8.2394e-03,  1.0487e+00,
          1.1871e-02, -8.7368e-01, -7.6681e-03],
        [-7.6476e-01, -1.3465e-01, -2.9084e-01, -1.3447e-01, -7.6827e-01,
         -1.1589e-02, -8.3233e-01,  6.1097e-03],
        [-1.3532e-01, -1.2255e-01, -6.3762e-01, -8.3289e-02,  4.8578e-01,
          3.0598e-02, -2.1476e+00, -3.1516e-02],
        [-3.4078e-01, -5.0454e-02, -2.3279e-01,  1.3385e-02, -4.2977e-01,
         -2.3186e-02,  1.3525e+00,  1.9072e-02],
        [-7.4251e-02,  6.8741e-03,  5.4974e-01,  2.4864e-02,  3.8359e-01,
         -9.5824e-03,  1.4494e+00,  1.4727e-02],
        [ 3.2487e-01,  4.9868e-02,  2.8924e-01, -6.7122e-04,  4.6704e-02,
         -3.3893e-02, -3.8517e-01,  3.2998e-02],
     