In [1]:
from pyvirtualdisplay import Display
display = Display(visible=0, size=(400, 300))
display.start();

In [2]:
import numpy as np
import torch
import gym
import minerl
import random
import torch.nn.functional as F
import math
import random

from torch import nn
from sklearn.cluster import KMeans
from tqdm import tqdm
from minerl.data import BufferedBatchIter
from collections import deque
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter



In [3]:
from src import replay_memory
from src import dueling_network as network

In [4]:
def sample_from_buffer(online_memory_replay, expert_memory_replay, online_memory_batch_size, expert_memory_batch_size):
  
    expert_batch = expert_memory_replay.sample(expert_memory_batch_size)
    expert_batch_state = []
    expert_batch_action = [] 
    expert_batch_next_state = [] 
    expert_batch_reward = [] 
    expert_batch_done = []

    for x in expert_batch:
#         for x in t:
        expert_batch_state.append(x[0]) 
        expert_batch_action.append(x[1]) 
        expert_batch_next_state.append(x[3]) 
        expert_batch_reward.append(x[2]) 
        expert_batch_done.append(x[4])
            
    expert_batch_state = np.array([expert_batch_state[i]["pov"] for i in range(len(expert_batch_state))])
    expert_batch_state = torch.from_numpy(expert_batch_state.transpose(0, 3, 1, 2).astype(np.float32) / 255)

    expert_batch_next_state = np.array([expert_batch_next_state[i]["pov"] for i in range(len(expert_batch_next_state))])
    expert_batch_next_state = torch.from_numpy(expert_batch_next_state.transpose(0, 3, 1, 2).astype(np.float32) / 255)

    expert_batch_action = torch.Tensor([expert_batch_action[i] for i in range(len(expert_batch_action))])
    expert_batch_reward = torch.Tensor([expert_batch_reward[i] for i in range(len(expert_batch_reward))]).unsqueeze(1)
    expert_batch_done = torch.Tensor([expert_batch_done[i] for i in range(len(expert_batch_done))]).unsqueeze(1)

    if online_memory_replay.size() == 0:
        return expert_batch_state, expert_batch_action, expert_batch_reward, expert_batch_next_state, expert_batch_done

    
    online_batch = online_memory_replay.sample(online_memory_batch_size)
    online_batch_state = []
    online_batch_action = [] 
    online_batch_next_state = [] 
    online_batch_reward = [] 
    online_batch_done = []
    for x in online_batch:
#         for x in t:
        online_batch_state.append(x[0]) 
        online_batch_action.append(x[1]) 
        online_batch_next_state.append(x[3]) 
        online_batch_reward.append(x[2]) 
        online_batch_done.append(x[4])

    online_batch_state = np.array([online_batch_state[i]["pov"] for i in range(len(online_batch_state))])
    online_batch_state = torch.from_numpy(online_batch_state.transpose(0, 3, 1, 2).astype(np.float32) / 255)

    online_batch_next_state = np.array([online_batch_next_state[i]["pov"] for i in range(len(online_batch_next_state))])
    online_batch_next_state = torch.from_numpy(online_batch_next_state.transpose(0, 3, 1, 2).astype(np.float32) / 255)

    online_batch_action = torch.Tensor([online_batch_action[i] for i in range(len(online_batch_action))]).unsqueeze(1)
    online_batch_reward = torch.Tensor([online_batch_reward[i] for i in range(len(online_batch_reward))]).unsqueeze(1)
    online_batch_done = torch.Tensor([online_batch_done[i] for i in range(len(online_batch_done))]).unsqueeze(1)

    batch_state = torch.cat([online_batch_state, expert_batch_state], dim=0)
    batch_next_state = torch.cat([online_batch_next_state, expert_batch_next_state], dim=0)
    batch_action = torch.cat([online_batch_action, expert_batch_action], dim=0)
    batch_reward = torch.cat([online_batch_reward, expert_batch_reward], dim=0)
    batch_done = torch.cat([online_batch_done, expert_batch_done], dim=0)

    return batch_state, batch_action, batch_reward, batch_next_state, batch_done
    

In [5]:
action_centroids = np.load('./action_centroids.npy')

In [6]:
DATA_DIR = "data/" #directory containing mineral human demonstration data 
ENVIRONMENT = 'MineRLTreechopVectorObf-v0'
NUM_ACTION_CENTROIDS = 64
ONLINE_REPLAY_MEMORY = 20000
EXPERT_REPLAY_MEMORY = 40000

writer = SummaryWriter('logs/sqil')
data = minerl.data.make(ENVIRONMENT, data_dir=DATA_DIR)

In [7]:
class Policy(nn.Module):
    def __init__(self, input_shape, output_dim, alpha):
        super(Policy, self).__init__()
        self.alpha = alpha
        n_input_channels = input_shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(torch.zeros(1, *input_shape)).shape[1]

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )
        
    def forward(self, x):
        return self.linear(self.cnn(x))

    def getV(self, q_value):
        v = self.alpha * torch.log(torch.sum(torch.exp(q_value/self.alpha), dim=1, keepdim=True))
        return v
        
    def choose_action(self, state, epsilon):
        state = torch.FloatTensor(state)
        # print('state : ', state)
        with torch.no_grad():
            q = self.forward(state)
            v = self.getV(q).squeeze()
            dist = torch.exp((q-v)/self.alpha)
            dist = dist / torch.sum(dist)
            if epsilon < random.uniform(0, 1):
                a = torch.argmax(dist)
            else:
                c = Categorical(dist)
                a = c.sample()
        return a.item()

In [8]:
def test(network, action_centroids, num_episodes):
    
    num_actions = action_centroids.shape[0]
    action_list = np.arange(num_actions)

    episode_rewards = []
    for episode in range(num_episodes):
        obs = env.reset()
        done = False
        total_reward = 0
        steps = 0

        while not done:
            network_state = torch.from_numpy(obs['pov'].transpose(2, 0, 1)[None].astype(np.float32) / 255)
            selected_action = network.choose_action(network_state, 0)
            action = action_centroids[selected_action]
            minerl_action = {"vector": action}

            obs, reward, done, info = env.step(minerl_action)
            total_reward += reward
            steps += 1
            if steps > 18000: 
                break

        episode_rewards.append(total_reward)
    
    return episode_rewards

In [9]:
onlineQNetwork = network.SoftQNetwork(NUM_ACTION_CENTROIDS, 3)
targetQNetwork = network.SoftQNetwork(NUM_ACTION_CENTROIDS, 3)
targetQNetwork.load_state_dict(onlineQNetwork.state_dict())
env = gym.make(ENVIRONMENT)

In [10]:
expert_memory_replay = replay_memory.Memory(EXPERT_REPLAY_MEMORY)
expert_memory_replay.load('expert_memory_replay')
online_memory_replay = replay_memory.Memory(ONLINE_REPLAY_MEMORY)
online_memory_replay.load('online_memory_replay')

In [11]:
NUM_OF_EPOCHS = 100
NUM_STEPS = 200 
GAMMA = 0.99
REPLAY_START_SIZE = 5000 

epsilon = 1 #0.999
decay = 0.9
update_steps = 3000
batch_size = 32
min_epsilon = 0.1
learning_rate = 0.0001

In [None]:
learn_steps = 0
begin_learn = False
training_loss = []
training_qvalue = []
training_return = []
    
for epoch in range(NUM_OF_EPOCHS):
    state = env.reset()
    episode_reward = 0
    loss_values = []
    q_values = []
    for time_steps in range(NUM_STEPS):
        network_state = torch.from_numpy(state['pov'].transpose(2, 0, 1)[None].astype(np.float32) / 255)
        selected_action = onlineQNetwork.choose_action(network_state, epsilon)
        action = action_centroids[selected_action]
        minerl_action = {"vector": action}

        next_state, reward, done, _ = env.step(minerl_action)
        episode_reward += reward

        online_memory_replay.add((state, selected_action, 0, next_state, done))

        if online_memory_replay.size() > REPLAY_START_SIZE:
            if begin_learn is False:
                print('learn begin!')
                begin_learn = True
            learn_steps += 1
            if learn_steps % update_steps == 0:
                print('updating target network')
                targetQNetwork.load_state_dict(onlineQNetwork.state_dict())

            batch_state, batch_action, batch_reward, batch_next_state, batch_done = sample_from_buffer(online_memory_replay, expert_memory_replay, batch_size, batch_size)


            with torch.no_grad():
                next_q = targetQNetwork(batch_next_state)
                next_v = targetQNetwork.getV(next_q)
                y = batch_reward + (1 - batch_done) * GAMMA * next_v

            loss = F.huber_loss(onlineQNetwork(batch_state).gather(1, batch_action.long()), y)
            optimizer = torch.optim.Adam(onlineQNetwork.parameters(), lr=learning_rate)
            
            loss_values.append(loss.item())
            q_values.append(torch.mean(y).item())
            optimizer.zero_grad()
            loss.backward()
            for param in onlineQNetwork.parameters():
                param.grad.data.clamp_(-1, 1)

            optimizer.step()
            
            if (learn_steps % 100) == 0:
                print('learn_step: ', learn_steps,' loss: ', np.mean(np.array(loss_values)), ' q_value: ', np.mean(np.array(q_values)))
                training_loss.append(np.mean(np.array(loss_values)))
                training_qvalue.append(np.mean(np.array(q_values)))
            
            if (learn_steps % 10000) == 0:
                return_value = test(targetQNetwork, action_centroids, 5)
                training_return.append(np.mean(np.array(return_value)))
                print('training return', np.mean(np.array(return_value)))
                writer.add_scalar('training return', np.mean(np.array(return_value)), global_step=learn_steps)
            
            writer.add_scalar('training loss', loss.item(), global_step=learn_steps)
            writer.add_scalar('training q-value', torch.mean(y), global_step=learn_steps)

        if done:
            break

        state = next_state
    
#     epsilon = max(min_epsilon, epsilon*decay)
    torch.save(onlineQNetwork, 'sqil-policy-2.pth')
    online_memory_replay.save('online_memory_replay')

In [13]:
torch.save(targetQNetwork, 'target_model.pth')