In [None]:
!java -version
!sudo apt-get purge openjdk-*
!java -version
!sudo apt-get install openjdk-8-jdk

In [None]:
!pip3 install --upgrade minerl
!sudo apt-get install xvfb xserver-xephyr vnc4server
!sudo pip install pyvirtualdisplay

In [None]:
#@title Change python code { output-height: 10 , form-width: 10}
%%writefile /usr/local/lib/python3.6/dist-packages/minerl/herobraine/env_specs/navigate_specs.py
import sys
from typing import List

import minerl.herobraine
import minerl.herobraine.hero.handlers as handlers
from minerl.herobraine.env_specs.simple_env_spec import SimpleEnvSpec


class Navigate(SimpleEnvSpec):
    def __init__(self, dense, extreme):
        suffix = 'Extreme' if extreme else ''
        suffix += 'Dense' if dense else ''
        name = 'MineRLNavigate{}-v0'.format(suffix)
        xml = 'navigation{}.xml'.format(suffix)
        self.dense, self.extreme = dense, extreme
        super().__init__(name, xml)

    def is_from_folder(self, folder: str) -> bool:
        return folder == 'navigateextreme' if self.extreme else folder == 'navigate'

    def create_mission_handlers(self) -> List[minerl.herobraine.hero.AgentHandler]:
        mission_handlers = [
            handlers.RewardForTouchingBlock(
                {"diamond_block", 100.0}
            ),
            handlers.NavigateTargetReward(),
            handlers.NavigationDecorator(
                min_radius=64,
                max_radius=64,
                randomize_compass_target=True
            )
        ]
        if self.dense:
            mission_handlers.append(handlers.RewardForWalkingTwardsTarget(
                reward_per_block=1, reward_schedule="PER_TICK"
            ))
        return mission_handlers

    def determine_success_from_rewards(self, rewards: list) -> bool:
        reward_threshold = 100.0
        if self.dense:
            reward_threshold += 60
        return sum(rewards) >= reward_threshold

    def create_observables(self) -> List[minerl.herobraine.hero.AgentHandler]:
        return super().create_observables() + [
            handlers.CompassObservation(),
            handlers.DeathObservation(),
            handlers.FlatInventoryObservation(['dirt'])]

    def create_actionables(self) -> List[minerl.herobraine.hero.AgentHandler]:
        return super().create_actionables() + [handlers.PlaceBlock(['none', 'dirt'])]

    def get_docstring(self):
        return make_navigate_text(
            top="normal" if not self.extreme else "extreme",
            dense=self.dense)


def make_navigate_text(top, dense):
    navigate_text = """
.. image:: ../assets/navigate{}1.mp4.gif
    :scale: 100 %
    :alt: 
.. image:: ../assets/navigate{}2.mp4.gif
    :scale: 100 %
    :alt: 
.. image:: ../assets/navigate{}3.mp4.gif
    :scale: 100 %
    :alt: 
.. image:: ../assets/navigate{}4.mp4.gif
    :scale: 100 %
    :alt: 
In this task, the agent must move to a goal location denoted by a diamond block. This represents a basic primitive used in many tasks throughout Minecraft. In addition to standard observations, the agent has access to a “compass” observation, which points near the goal location, 64 meters from the start location. The goal has a small random horizontal offset from the compass location and may be slightly below surface level. On the goal location is a unique block, so the agent must find the final goal by searching based on local visual features.
The agent is given a sparse reward (+100 upon reaching the goal, at which point the episode terminates). """
    if dense:
        navigate_text += "**This variant of the environment is dense reward-shaped where the agent is given a reward every tick for how much closer (or negative reward for farther) the agent gets to the target.**\n"
    else:
        navigate_text += "**This variant of the environment is sparse.**\n"

    if top is "normal":
        navigate_text += "\nIn this environment, the agent spawns on a random survival map.\n"
        navigate_text = navigate_text.format(*["" for _ in range(4)])
    else:
        navigate_text += "\nIn this environment, the agent spawns in an extreme hills biome.\n"
        navigate_text = navigate_text.format(*["extreme" for _ in range(4)])
    return navigate_text

In [None]:
from pyvirtualdisplay import Display
display = Display(visible=0, size=(640, 480))
display.start()

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir runs

In [None]:
import torch
import cv2
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

import minerl
import gym

In [None]:
class ContinualAgent(nn.Module):
    def __init__(self, context_size, hidden_size):
        super().__init__()

        action_size = 64

        self.action_limit = 1.0499999523162842

        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv5 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2,2)

        self.bfc1 = nn.Bilinear(1024, context_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, action_size)
        self.fc3 = nn.Linear(hidden_size, context_size)

    def forward(self, x, c):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = self.pool(F.relu(self.conv5(x)))

        batch_size = x.size(0)
        x = x.view(batch_size, -1)

        x = F.relu(self.bfc1(x, c))
        action = torch.tanh(self.fc2(x))
        c = torch.tanh(self.fc3(x))

        action = action * self.action_limit

        return action, c

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, conv=False):
        super().__init__()

        self.conv = conv

        if self.conv:
            self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
            self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
            self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
            self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
            self.conv5 = nn.Conv2d(128, 256, 3, padding=1)
            self.pool = nn.MaxPool2d(2,2)

            input_size = 1024

        self.fc1 = nn.Linear(input_size, hidden_size)
        if not self.conv:
            self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        if self.conv:
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = self.pool(F.relu(self.conv3(x)))
            x = self.pool(F.relu(self.conv4(x)))
            x = self.pool(F.relu(self.conv5(x)))

        batch_size = x.size(0)
        x = x.view(batch_size, -1)

        x = F.relu(self.fc1(x))
        if not self.conv:
            x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

In [None]:
class Trainer():
    def __init__(self, context_size, hidden_size):
        super().__init__()

        self.policy = ContinualAgent(context_size, hidden_size)
        self.context_critic = MLP(context_size, hidden_size, 1)
        self.state_critic = MLP(None, hidden_size, 1, conv=True)
        self.past_predict = MLP(None, hidden_size, context_size, conv=True)
        self.future_predict = MLP(None, hidden_size, context_size, conv=True)

        self.policy.to("cuda:0")
        self.context_critic.to("cuda:0")
        self.state_critic.to("cuda:0")
        self.past_predict.to("cuda:0")
        self.future_predict.to("cuda:0")

        self.policy_optimizer = optim.SGD(self.policy.parameters(), lr=0.001)
        self.context_critic_optimizer = optim.SGD(self.context_critic.parameters(), lr=0.001)
        self.state_critic_optimizer = optim.SGD(self.state_critic.parameters(), lr=0.001)
        self.past_predict_optimizer = optim.SGD(self.policy.parameters(), lr=0.001)
        self.future_predict_optimizer = optim.SGD(self.policy.parameters(), lr=0.001)

        self.mem = []

    def train(self, c_action, cur_state, actor_action, c_next_state, next_state, done, total_timestep):

        c_state, prev_state = self.mem

        context_value = self.context_critic(c_state)

        state_value = self.state_critic(cur_state)

        next_context_value = self.context_critic(c_next_state)

        next_state_value = self.state_critic(next_state)

        pred_c_past = torch.tanh(self.past_predict(prev_state))
        pred_c_future = torch.tanh(self.future_predict(next_state))

        past_history_loss = F.mse_loss(pred_c_past, c_action.detach())
        future_history_loss = F.mse_loss(pred_c_future, c_action.detach())

        reward = torch.tanh(-torch.log(past_history_loss.detach()) + torch.log(future_history_loss.detach()))

        writer.add_scalar("Reward", reward.item(), total_timestep)
        writer.add_scalar("Prediction: Past Loss", past_history_loss.item(), total_timestep)
        writer.add_scalar("Prediction: Future Loss", future_history_loss.item(), total_timestep)

        actor_q_value = reward + 0.99*int(not done)*next_state_value

        actor_advantage = actor_q_value - state_value

        policy_loss = (-actor_action*actor_advantage.detach()).mean()
        state_critic_loss = F.mse_loss(state_value, actor_q_value.detach())

        writer.add_scalar("Policy: Loss", policy_loss.item(), total_timestep)
        writer.add_scalar("Critic: State Loss", state_critic_loss.item(), total_timestep)

        context_q_value = reward + 0.99*int(not done)*next_context_value

        context_advantage = context_q_value - context_value

        context_loss = (-c_action*context_advantage.detach()).mean()
        context_critic_loss = F.mse_loss(context_value, context_q_value.detach())

        writer.add_scalar("Context: Loss", context_loss.item(), total_timestep)
        writer.add_scalar("Critic: Context Loss", context_critic_loss.item(), total_timestep)

        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=True)
        context_loss.backward()
        self.policy_optimizer.step()

        self.state_critic_optimizer.zero_grad()
        self.context_critic_optimizer.zero_grad()
        state_critic_loss.backward()
        context_critic_loss.backward()
        self.context_critic_optimizer.step()
        self.state_critic_optimizer.step()

        self.past_predict_optimizer.zero_grad()
        past_history_loss.backward()
        self.past_predict_optimizer.step()

        self.future_predict_optimizer.zero_grad()
        future_history_loss.backward()
        self.future_predict_optimizer.step()

def preprocess(state):

    image_data = cv2.cvtColor(np.float32(state), cv2.COLOR_RGB2GRAY)
    image_data = np.reshape(image_data,(64, 64, 1))
    image_tensor = image_data.transpose(2, 0, 1)
    image_tensor = image_tensor.astype(np.float32)
    state = torch.from_numpy(image_tensor)

    return state

In [None]:
print("Hello")
env = gym.make('MineRLNavigateVectorObf-v0')
obs  = env.reset()
print("Hi")

trainer = Trainer(512, 512)

first_c = torch.zeros(1, 512).to("cuda:0")

timestep = 0
living = True
final_death = 15000
done = False

while living and timestep <= final_death:

    if done:
      print("I died")
      obs = env.reset()

    if timestep >= 1:
        c_state, prev_state = trainer.mem

        cur_state = (preprocess(obs["pov"]).unsqueeze(0)/255).to("cuda:0")

        _ , c_action = trainer.policy(prev_state, c_state)

        actor_action, c_next_state = trainer.policy(cur_state, c_action.detach())

        env_action = dict(vector=actor_action.detach().cpu().numpy())

        obs, _ , done, info = env.step(env_action)

        if info != {}:
            print(info)

        next_state = (preprocess(obs["pov"]).unsqueeze(0)/255).to("cuda:0")

        trainer.train(c_action, cur_state, actor_action, c_next_state, next_state, done, timestep)

        trainer.mem = [c_action.detach(), cur_state]
    
    else:
        cur_state = (preprocess(obs["pov"]).unsqueeze(0)/255).to("cuda:0")

        env_action = env.action_space.sample()

        print(env_action)

        obs, _ , done, info = env.step(env_action)

        trainer.mem = [first_c, cur_state]


    timestep += 1