# Install MineRL v1.0

In [None]:
%%capture
!sudo add-apt-repository -y ppa:openjdk-r/ppa
!sudo apt-get purge openjdk-*
!sudo apt-get install openjdk-8-jdk
!sudo apt-get install xvfb
!sudo apt-get install xserver-xephyr
!sudo apt-get install vnc4server
!sudo apt-get install python-opengl
!sudo apt-get install ffmpeg

In [None]:
%%capture
!pip3 install pyvirtualdisplay
!pip3 install -U colabgymrender
!pip3 install imageio==2.4.1

In [None]:
!pip3 install git+https://github.com/minerllabs/minerl

# Install MineRL v4.0(datasets)

In [None]:
%%capture
!sudo add-apt-repository -y ppa:openjdk-r/ppa
!sudo apt-get purge openjdk-*
!sudo apt-get install openjdk-8-jdk
!sudo apt-get install xvfb
!sudo apt-get install xserver-xephyr
!sudo apt-get install vnc4server
!sudo apt-get install python-opengl
!sudo apt-get install ffmpeg

In [None]:
%%capture
!pip3 install pyvirtualdisplay
!pip3 install -U colabgymrender
!pip3 install imageio==2.4.1

In [None]:
%%capture
!python -m pip install --upgrade pip wheel==0.38.4 setuptools==65.6.1

In [None]:
!pip3 install minerl==0.4.4

# Import

In [None]:
import os
import numpy as np
import gym
import minerl
import matplotlib
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from colabgymrender.recorder import Recorder
from pyvirtualdisplay import Display
import logging
logging.disable(logging.ERROR)
import cv2
import glob
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
np.__version__

In [None]:
disp = Display(visible=0, backend="xvfb")
disp.start();

#Scenario

In [None]:
from minerl.herobraine.env_specs.human_controls import SimpleHumanEmbodimentEnvSpec
from minerl.herobraine.hero.mc import MS_PER_STEP, STEPS_PER_MS
from minerl.herobraine.hero.handler import Handler
from typing import List

import minerl.herobraine
import minerl.herobraine.hero.handlers as handlers
from minerl.herobraine.env_spec import EnvSpec
from minerl.herobraine.hero.mc import MS_PER_STEP, STEPS_PER_MS, ALL_ITEMS
from minerl.herobraine.hero import handlers as H, mc
from minerl.herobraine.hero.handlers.translation import TranslationHandler

MY_TREECHOP_DOC = """
In treechop, the agent must collect 64 `minercaft:log`. This replicates a common scenario in Minecraft, as logs are necessary to craft a large amount of items in the game, and are a key resource in Minecraft.

The agent begins in a forest biome (near many trees) with an iron axe for cutting trees. The agent is given +1 reward for obtaining each unit of wood, and the episode terminates once the agent obtains 64 units.
"""
TREECHOP_LENGTH = 8000
TREECHOP_WORLD_GENERATOR_OPTIONS = """{"coordinateScale":684.412,"heightScale":684.412,"lowerLimitScale":512.0,"upperLimitScale":512.0,"depthNoiseScaleX":200.0,"depthNoiseScaleZ":200.0,"depthNoiseScaleExponent":0.5,"mainNoiseScaleX":80.0,"mainNoiseScaleY":160.0,"mainNoiseScaleZ":80.0,"baseSize":8.5,"stretchY":12.0,"biomeDepthWeight":1.0,"biomeDepthOffset":0.0,"biomeScaleWeight":1.0,"biomeScaleOffset":0.0,"seaLevel":1,"useCaves":false,"useDungeons":false,"dungeonChance":8,"useStrongholds":false,"useVillages":false,"useMineShafts":false,"useTemples":false,"useMonuments":false,"useMansions":false,"useRavines":false,"useWaterLakes":false,"waterLakeChance":4,"useLavaLakes":false,"lavaLakeChance":80,"useLavaOceans":false,"fixedBiome":4,"biomeSize":4,"riverSize":1,"dirtSize":33,"dirtCount":10,"dirtMinHeight":0,"dirtMaxHeight":256,"gravelSize":33,"gravelCount":8,"gravelMinHeight":0,"gravelMaxHeight":256,"graniteSize":33,"graniteCount":10,"graniteMinHeight":0,"graniteMaxHeight":80,"dioriteSize":33,"dioriteCount":10,"dioriteMinHeight":0,"dioriteMaxHeight":80,"andesiteSize":33,"andesiteCount":10,"andesiteMinHeight":0,"andesiteMaxHeight":80,"coalSize":17,"coalCount":20,"coalMinHeight":0,"coalMaxHeight":128,"ironSize":9,"ironCount":20,"ironMinHeight":0,"ironMaxHeight":64,"goldSize":9,"goldCount":2,"goldMinHeight":0,"goldMaxHeight":32,"redstoneSize":8,"redstoneCount":8,"redstoneMinHeight":0,"redstoneMaxHeight":16,"diamondSize":8,"diamondCount":1,"diamondMinHeight":0,"diamondMaxHeight":16,"lapisSize":7,"lapisCount":1,"lapisCenterHeight":16,"lapisSpread":16}"""


class MyTreechop(SimpleHumanEmbodimentEnvSpec):
    def __init__(self, *args, **kwargs):
        if 'name' not in kwargs:
            kwargs['name'] = 'MyMineRLTreechop-v0'

        super().__init__(*args,
                         max_episode_steps=TREECHOP_LENGTH, reward_threshold=64.0,resolution=[640, 360],
                         **kwargs)

    def create_observables(self) -> List[Handler]:
        return super().create_observables() + [
            handlers.EquippedItemObservation(
                items=ALL_ITEMS,
                mainhand=True,
                offhand=True,
                armor=True,
                _default="air",
                _other="air",
            ),
            handlers.ObservationFromLifeStats(),
            handlers.ObservationFromCurrentLocation(),
            handlers.ObserveFromFullStats("use_item"),
            handlers.ObserveFromFullStats("drop"),
            handlers.ObserveFromFullStats("pickup"),
            handlers.ObserveFromFullStats("break_item"),
            handlers.ObserveFromFullStats("craft_item"),
            handlers.ObserveFromFullStats("mine_block"),
            handlers.ObserveFromFullStats("damage_dealt"),
            handlers.ObserveFromFullStats("entity_killed_by"),
            handlers.ObserveFromFullStats("kill_entity"),
            handlers.FlatInventoryObservation(ALL_ITEMS),
            # handlers.ObserveFromFullStats(None),
        ]

    def create_actionables(self) -> List[TranslationHandler]:
        """
        Simple envs have some basic keyboard control functionality, but
        not all.
        """
        return [
           H.KeybasedCommandAction(v, v) for v in mc.KEYMAP.values()
        ] + [H.CameraAction()]

    def create_rewardables(self) -> List[Handler]:
        return [
            handlers.RewardForCollectingItems([
                dict(type="log", amount=1, reward=1.0),
            ])
        ]

    def create_agent_start(self) -> List[Handler]:
        return super().create_agent_start() + [
            handlers.SimpleInventoryAgentStart([
                dict(type="oak_log", quantity=4)
            ])
        ]

    def create_agent_handlers(self) -> List[Handler]:
        return [
            handlers.AgentQuitFromPossessingItem([
                dict(type="log", amount=64)]
            )
        ]

    def create_server_world_generators(self) -> List[Handler]:
        return [
            handlers.DefaultWorldGenerator(force_reset="true",
                                           generator_options=TREECHOP_WORLD_GENERATOR_OPTIONS
                                           )
        ]

    def create_server_quit_producers(self) -> List[Handler]:
        return [
            handlers.ServerQuitFromTimeUp(
                (TREECHOP_LENGTH * MS_PER_STEP)),
            handlers.ServerQuitWhenAnyAgentFinishes()
        ]

    def create_server_decorators(self) -> List[Handler]:
        return []

    def create_server_initial_conditions(self) -> List[Handler]:
        return [
            handlers.TimeInitialCondition(
                allow_passage_of_time=False
            ),
            handlers.SpawningInitialCondition(
                allow_spawning=True
            )
        ]

    def determine_success_from_rewards(self, rewards: list) -> bool:
        return sum(rewards) >= self.reward_threshold

    def is_from_folder(self, folder: str) -> bool:
        return folder == 'survivaltreechop'

    def get_docstring(self):
        return MY_TREECHOP_DOC

In [None]:
import gym

from minerl.herobraine.env_spec import EnvSpec

#MINERL_MY_TEST_V0 = MyTestlEnvSpec_2()
MINERL_MY_TREECHOP_V0 = MyTreechop()

# Register the envs.
ENVS = [env for env in locals().values() if isinstance(env, EnvSpec)]
for env in ENVS:
    if env.name not in gym.envs.registry.env_specs:
        env.register()

#Network

In [None]:
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Cnn(nn.Module):
    def __init__(self, input_shape=(3, 64, 64)):
        super().__init__()
        n_input_channels = input_shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
        )

        self.flat_size = 1024
        self.camera = nn.Linear(512, 5)
        self.buttons = nn.Linear(512, 6)

    def forward(self, observations):
        cnn_output = self.cnn(observations)

        camera_output = self.camera(cnn_output)
        buttons_ouput = self.buttons(cnn_output)

        return camera_output, buttons_ouput

    def initial_hidden_state(self):
        h0 = torch.zeros(2, 1, 512).to(device)
        c0 = torch.zeros(2, 1, 512).to(device)
        hidden = (h0, c0)
        return hidden

    def preprocess(self, img, img_size=(64, 64)):
        pov = cv2.resize(img, img_size).astype(np.float32)
        state = pov.transpose(2, 0, 1) / 255.0
        state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

        return state

    def get_result(self, camera_logits, button_logits, t=1.2):
        camera_probabilities = torch.softmax(camera_logits/t, dim=1)[0].detach().cpu().numpy()
        button_probabilities = torch.softmax(button_logits/t, dim=1)[0].detach().cpu().numpy()

        camera_action = np.random.choice(5, p=camera_probabilities)
        button_action = np.random.choice(6, p=camera_probabilities)

        return camera_action, button_action


In [None]:
class CNN(nn.Module):
    def __init__(self, input_shape, output_dim):
        super().__init__()
        n_input_channels = input_shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )

    def forward(self, observations):
        return self.cnn(observations)

# Action

In [None]:
def dataset_action_to_agent(dataset_actions, camera_margin=5):

    camera_actions = dataset_actions["camera"].squeeze()
    #attack_actions = dataset_actions["attack"].squeeze()
    forward_actions = dataset_actions["forward"].squeeze()
    left_actions = dataset_actions["left"].squeeze()
    right_actions = dataset_actions["right"].squeeze()
    jump_actions = dataset_actions["jump"].squeeze()
    sprint_actions = dataset_actions["sprint"].squeeze()

    batch_size = len(camera_actions)
    ca = np.zeros((batch_size,), dtype=int)
    ba = np.zeros((batch_size,), dtype=int)

    for i in range(len(camera_actions)):
        if camera_actions[i][0] < -camera_margin:
            ca[i] = 1
        elif camera_actions[i][0] > camera_margin:
            ca[i] = 2
        elif camera_actions[i][1] > camera_margin:
            ca[i] = 3
        elif camera_actions[i][1] < -camera_margin:
            ca[i] = 4
        else:
            ca[i] = 0

        if jump_actions[i] == 1:
            ba[i] = 3
        elif left_actions[i] == 1:
            ba[i] = 1
        elif right_actions[i] == 1:
            ba[i] = 2
        elif sprint_actions[i] == 1:
            ba[i] = 4
        else:
            ba[i] = 0

    return ca, ba

class ActionShaping():
    def __init__(self, env, camera_angle=10):
        self.env = env
        self.camera_angle = camera_angle
        self._button_actions = [
            dict(forward=1),
            # dict(back=1),
            dict(left=1),
            dict(right=1),
            dict(forward=1, jump=1),
            dict(forward=1, sprint=1),
        ]
        self._camera_actions = [
            dict(),
            dict(camera=[-self.camera_angle, 0]),
            dict(camera=[self.camera_angle, 0]),
            dict(camera=[0, self.camera_angle]),
            dict(camera=[0, -self.camera_angle]),
        ]

        self.min_distance = None
        self.last_distance = None


    def get_action(self, ca, ba, t=1.2):
        ca_probabilities = torch.softmax(ca/t, dim=1)[0].detach().cpu().numpy()
        ba_probabilities = torch.softmax(ba/t, dim=1)[0].detach().cpu().numpy()
        # Sample action according to the probabilities
        camera_action = np.random.choice(self._camera_actions, p=ca_probabilities)
        button_action = np.random.choice(self._button_actions, p=ca_probabilities)

        act = self.env.action_space.noop()
        act.update(camera_action)
        act.update(button_action)
        # act.update(dict(jump=1))

        ci = self._camera_actions.index(camera_action)
        bi = self._button_actions.index(button_action)

        return act, ci, bi

    def get_reward(self, obs, destination=[-205, 258]):
        location = np.array([obs['location_stats']['xpos'].item(), obs['location_stats']['zpos'].item()])
        destination = np.array(destination)
        distance = np.sqrt(np.sum((location - destination) ** 2))
        if self.min_distance == None:
            self.min_distance = distance
            self.last_distance = distance
            return 0

        if self.min_distance - distance > 0.1:
            self.min_distance = distance
            self.last_distance = distance
            return 5
        elif distance < 10:
            self.min_distance = None
            self.last_distance = None
            return 100
        elif distance - self.last_distance > 0.1:
            self.last_distance = distance
            return -0.5
        else:
            self.last_distance = distance
            return 0

    def reset(self, seed=10):
        self.env.seed(seed)
        self.min_distance = None
        self.now_distance = None
        obs = self.env.reset()
        return obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return obs, reward, done, info

    def action(self, ac, t=1.2):
        probabilities = torch.softmax(ac/t, dim=1)[0].detach().cpu().numpy()
        action = np.random.choice(self._actions, p=probabilities)
        act = self.env.action_space.noop()
        act.update(action)
        return act


In [None]:
class ActionShaping_2(gym.ActionWrapper):
    def __init__(self, env, camera_angle=10):
        super().__init__(env)
        self.camera_angle = camera_angle
        self._actions = [
            [('attack', 1)],
            [('forward', 1)],
            [('jump', 1)],
            [('camera', [-self.camera_angle, 0])],
            [('camera', [self.camera_angle, 0])],
            [('camera', [0, self.camera_angle])],
            [('camera', [0, -self.camera_angle])],
        ]
        self.actions = []
        for actions in self._actions:
            act = self.env.action_space.noop()
            for a, v in actions:
                act[a] = v
                act['attack'] = 1
            self.actions.append(act)
        self.action_space = gym.spaces.Discrete(len(self.actions))

    def action(self, action):
        return self.actions[action]

# Datasets

In [None]:
# Get data, find datasets in https://minerl.readthedocs.io/en/v0.4.4/environments/index.html.
minerl.data.download(directory='data', environment='MineRLNavigate-v0')
data = minerl.data.make("MineRLNavigate-v0", data_dir='data', num_workers=2)

# Behavior cloning

In [None]:
# Model
model = Cnn().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

In [None]:
# Training loop
step = 0
losses = []
for state, action, _, _, _ \
          in tqdm(data.batch_iter(num_epochs=6, batch_size=32, seq_len=1)):
    # Get pov observations
    obs = state['pov'].squeeze().astype(np.float32)
    # Transpose and normalize
    obs = obs.transpose(0, 3, 1, 2) / 255.0

    # Translate batch of actions for the ActionShaping wrapper
    ca, ba = dataset_action_to_agent(action)
    # Remove samples with no corresponding action
    mask = ca != -1
    obs = obs[mask]
    ca = ca[mask]
    ba = ba[mask]

    # Update weights with backprop
    camera_output, buttons_ouput = model(torch.from_numpy(obs).float().to(device))
    camera_loss = criterion(camera_output, torch.from_numpy(ca).long().to(device))
    buttons_loss = criterion(buttons_ouput, torch.from_numpy(ba).long().to(device))
    total_loss = camera_loss + buttons_loss
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    # Print loss
    step += 1
    losses.append(total_loss.item())
    if (step % 2000) == 0:
        mean_loss = sum(losses) / len(losses)
        tqdm.write(f'Step {step:>5} | Training loss = {mean_loss:.3f}')
        losses.clear()

    # break

torch.save(model.state_dict(), 'navigation_1.pth')
del data

# Expanded Knowledge Distillation

In [None]:
# VPT
%%capture
!pip3 install gym3
!git clone https://github.com/openai/video-pre-training
%cd video-pre-training

In [None]:
%%capture
weights_file = "https://openaipublic.blob.core.windows.net/minecraft-rl/models/rl-from-early-game-2x.weights"  #@param {type: "string", allow-input:true} ["https://openaipublic.blob.core.windows.net/minecraft-rl/models/foundation-model-1x.weights", "https://openaipublic.blob.core.windows.net/minecraft-rl/models/foundation-model-2x.weights", "https://openaipublic.blob.core.windows.net/minecraft-rl/models/foundation-model-1x.weights", "https://openaipublic.blob.core.windows.net/minecraft-rl/models/bc-house-3x.weights", "https://openaipublic.blob.core.windows.net/minecraft-rl/models/bc-early-game-2x.weights", "https://openaipublic.blob.core.windows.net/minecraft-rl/models/bc-early-game-3x.weights", "https://openaipublic.blob.core.windows.net/minecraft-rl/models/rl-from-foundation-2x.weights", "https://openaipublic.blob.core.windows.net/minecraft-rl/models/rl-from-house-2x.weights", "https://openaipublic.blob.core.windows.net/minecraft-rl/models/rl-from-early-game-2x.weights"]
multiplier = [x for x in ["1x", "2x", "3x"] if x in weights_file][0]
!wget {weights_file} -O model
!wget https://openaipublic.blob.core.windows.net/minecraft-rl/models/{multiplier}.model -O model
!wget {weights_file} -O weights

In [None]:
%cd video-pre-training

from agent import MineRLAgent
import pickle
import torch
from lib.policy import MinecraftAgentPolicy
from lib.tree_util import tree_map
from lib.action_mapping import CameraHierarchicalMapping
from gym3.types import DictType

class VPTAgent(MineRLAgent):
    def __init__(self, env, device=None, policy_kwargs=None, pi_head_kwargs=None) -> None:
        super().__init__(env, device=None, policy_kwargs=None, pi_head_kwargs=None)

    def get_action(self, minerl_obs):
        agent_input = self._env_obs_to_agent(minerl_obs)

        agent_action, self.hidden_state, result = self.policy.act(
            agent_input, self._dummy_first, self.hidden_state,
            stochastic=True, return_pd=True
        )
        minerl_action = self._agent_action_to_env(agent_action)

        # actions = {'buttons': result["pd"]['buttons'].squeeze(0).max(1)[1].view(1, 1), 'camera': result["pd"]['camera'].squeeze(0).max(1)[1].view(1, 1)}
        # minerl_action = super()._agent_action_to_env(actions)

        return minerl_action, agent_action, result

    def preprocess(self, obs, shape=(64, 64)):
        pov = obs['pov']
        pov = cv2.resize(pov, shape).astype(np.float32)
        state = pov.transpose(2, 0, 1) / 255.0
        state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

        return state

    def net_action(state, model):
        with torch.no_grad():
            # t.max(1) will return the largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            return model(state)[0].max(1)[1].view(1, 1), model(state)[1].max(1)[1].view(1, 1)

    def get_agent_action(self, obs, model, hidden, deterministic=False, preprocess_shape=(64, 64)):
        state = self.preprocess(obs, preprocess_shape)
        camera_output, buttons_ouput, hidden_new  = model(state, hidden)
        # buttons_ouput = buttons_ouput.max(1)[1].view(1, 1)
        camera_output = camera_output.max(1)[1].view(1, 1)
        buttons_ouput = self.sample(buttons_ouput.unsqueeze(0), deterministic)
        # camera_output = self.sample(camera_output.unsqueeze(0), deterministic)
        actions = {'buttons': buttons_ouput, 'camera': camera_output}
        # print(actions)
        minerl_action = super()._agent_action_to_env(actions)

        return minerl_action, hidden_new

    def get_agent_output(self, obs, model, hidden):
        state = self.preprocess(obs)
        camera_output, buttons_ouput, hidden_new  = model(state, hidden)
        actions = {'buttons': buttons_ouput, 'camera': camera_output}

        return actions, hidden_new

    def sample(self, logits: torch.Tensor, deterministic: bool = False):
        if deterministic:
            return torch.argmax(logits, dim=-1)
        else:
            # Gumbel-Softmax trick.
            u = torch.rand_like(logits)
            # In float16, if you have around 2^{float_mantissa_bits} logits, sometimes you'll sample 1.0
            # Then the log(-log(1.0)) will give -inf when it should give +inf
            # This is a silly hack to get around that.
            # This hack does not skew the probability distribution, because this event can't possibly win the argmax.
            u[u == 1.0] = 0.999

            return torch.argmax(logits - torch.log(-torch.log(u)), dim=-1)

    def gaussian_sample(self, pd_params: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        means = pd_params[..., 0]
        log_std = pd_params[..., 1]

        if deterministic:
            return means
        else:
            return torch.randn_like(means) * torch.exp(log_std) + means


[Errno 2] No such file or directory: 'video-pre-training'
/content/video-pre-training


In [None]:
agent_parameters = pickle.load(open("model", "rb"))
policy_kwargs = agent_parameters["model"]["args"]["net"]["args"]
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])

In [None]:
# Initialization environment, needs MineRL v1.0
env = gym.make('MyMineRLTreechop-v0')
env.seed(10)
obs = env.reset()
plt.imshow(obs['pov'])

In [None]:
from IPython.display import clear_output
from matplotlib import pyplot as plt
from tqdm.auto import trange

wooden_pickaxe_dataset = []

while len(wooden_pickaxe_dataset) < 12:
    print(len(wooden_pickaxe_dataset))
    env = gym.make('MyMineRLTreechop-v0')
    env.seed(10)
    obs = env.reset()
    agent = VPTAgent(env, policy_kwargs=policy_kwargs, pi_head_kwargs=pi_head_kwargs, device="cuda" if torch.cuda.is_available() else "cpu")
    agent.load_weights("weights")
    record = []
    actions = []
    play_steps = 1000
    live_display = False
    try:
        for _ in trange(play_steps):  # The t part will get erased anyway
            minerl_action, agent_action, result = agent.get_action(obs)
            obs, reward, done, info = env.step(minerl_action)

            #if np.mean(cv2.cvtColor(obs['pov'][0:99], cv2.COLOR_BGR2GRAY)) < 40:
            record.append(obs['pov'])
            actions.append(result['pd'])

            if obs['equipped_items']['mainhand']['type'] == 'wooden_pickaxe':
                print('Agent got the wooden_pickaxe!')
                wooden_pickaxe_dataset.append({'pov': record, 'pd': actions})
                done = True

            if done:
                break

            if live_display:
                clear_output(wait=True)
                plt.axis("off")
                plt.imshow(obs["pov"])
                plt.show()
    except KeyboardInterrupt:
        pass

In [None]:
# Save POVs and actions as .npz file
dataset_array = np.array(wooden_pickaxe_dataset)
np.savez('dataset.npz', dataset=dataset_array)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def load_dataset(path, index):
    data = np.load('dataset.npz')
    dataset_array = data['dataset']
    pov = dataset_array[index]['pov']
    actions = dataset_array[index]['pd']

    return pov, actions

In [None]:
student_model = Cnn().to(device)
from IPython.display import clear_output
from matplotlib import pyplot as plt
from tqdm.auto import trange
import torch.optim.lr_scheduler as lr_scheduler


optimizer = torch.optim.Adam(student_model.parameters(), lr=0.0001)
losses = []

record= []
play_steps = 1000
episodes = 10

try:
    for i in range(episodes):
        env.seed(10)
        obs = env.reset()
        done = False
        teacher_model = VPTAgent(env, policy_kwargs=policy_kwargs, pi_head_kwargs=pi_head_kwargs, device="cuda" if torch.cuda.is_available() else "cpu")
        teacher_model.load_weights("weights")

        h0 = torch.zeros(2, 1, 512).to(device)
        c0 = torch.zeros(2, 1, 512).to(device)
        hidden = (h0, c0)

        print('Env reset done. Episode: ' + str(i))

        for _ in trange(play_steps):  # The t part will get erased anyway
            minerl_action, agent_action, result = teacher_model.get_action(obs)

            actions, hidden_new = teacher_model.get_agent_output(obs, student_model, hidden)
            hidden_new = (hidden_new[0].detach(), hidden_new[1].detach())
            hidden = hidden_new

            # Update weights with backprop

            loss_buttons = policy_distillation_loss(result['pd']['buttons'].squeeze(0), actions['buttons'], temperature=1.5)
            loss_camera = policy_distillation_loss(result['pd']['camera'].squeeze(0), actions['camera'], temperature=1.5)

            total_loss = loss_buttons + loss_camera
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            losses.append(total_loss.item())

            # Next step
            obs, reward, done, info = env.step(minerl_action)
            #record.append(obs)

            if obs['equipped_items']['mainhand']['type'] == 'wooden_pickaxe':
                print('Agent got the wooden_pickaxe!')
                done = True

            if done:
                break

        mean_loss = sum(losses) / len(losses)
        tqdm.write(f'Training loss = {mean_loss:.3f}')
        losses.clear()

except KeyboardInterrupt:
    pass

In [None]:
import torch.nn.functional as F

def action_step(action):
  ac = env.action_space.noop()
  ac.update(action)
  obs, reward, done, info = env.step(ac)
  plt.imshow(obs["pov"])
  plt.show()

def policy_distillation_loss(outputs_teacher, outputs_student, temperature):
    # KL loss
    KD_loss = F.kl_div(
        F.log_softmax(outputs_student/temperature, dim=1),
        F.softmax(outputs_teacher/temperature, dim=1),
        reduction='batchmean') * temperature * temperature

    return KD_loss

# Fine tune

In [None]:
policy_net = Cnn().to(device)

policy_net.load_state_dict(torch.load('/content/navigation_2.pth'))
policy_net.train()

target_net = Cnn().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.train()

In [None]:
from collections import namedtuple, deque
import random

Transition = namedtuple('Transition',
                        ('state', 'camera_action', 'button_action', 'next_state', 'reward'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
BATCH_SIZE = 64
GAMMA = 0.99
LR = 1e-4

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return 0

    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    #print(batch.state)

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    #print(state_batch.shape)
    camera_action_batch = torch.cat(batch.camera_action)
    button_action_batch = torch.cat(batch.button_action)
    camera_action_batch =  torch.unsqueeze(camera_action_batch, 0)
    button_action_batch =  torch.unsqueeze(button_action_batch, 0)
    #action_batch =  torch.unsqueeze(action_batch, 0)
    #print(action_batch.shape)
    reward_batch = torch.cat(batch.reward)

    # Obtain Q values for both camera and buttons from the policy net
    camera_state_action_values, buttons_state_action_values = policy_net(state_batch)
    camera_state_action_values = camera_state_action_values.gather(1, camera_action_batch)
    buttons_state_action_values = buttons_state_action_values.gather(1, button_action_batch)

    # Obtain next state Q values for both camera and buttons from the target net
    with torch.no_grad():
        camera_next_state_values = torch.zeros(BATCH_SIZE, device=device)
        buttons_next_state_values = torch.zeros(BATCH_SIZE, device=device)
        camera_next_state, buttons_next_state = target_net(non_final_next_states)
        camera_next_state_values[non_final_mask] = camera_next_state.max(1)[0]
        buttons_next_state_values[non_final_mask] = buttons_next_state.max(1)[0]

    # Compute the expected Q values
    expected_camera_state_action_values = (camera_next_state_values * GAMMA) + reward_batch
    expected_buttons_state_action_values = (buttons_next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    criterion = nn.SmoothL1Loss()
    camera_loss = criterion(camera_state_action_values, expected_camera_state_action_values.unsqueeze(1))
    buttons_loss = criterion(buttons_state_action_values, expected_buttons_state_action_values.unsqueeze(1))

    # Total loss is sum of individual losses
    loss = camera_loss + buttons_loss

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()

    # In-place gradient clipping
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

    return loss.item()

In [None]:
from itertools import count

TAU = 0.005

replay = deque(maxlen=10000)

if torch.cuda.is_available():
    num_episodes = 10
else:
    num_episodes = 2

steps = 0
total_rewards = [0]
play_steps = 6000
T = 100
losses = []

env = gym.make('MyMineRLTreechop-v0')
env.seed(10)
env_action = ActionShaping(env)

for i_episode in range(num_episodes):
    print('Episode: ' + str(i_episode), end='')
    # Initialize the environment and get it's state

    obs = env_action.reset()
    pov = obs["pov"]
    pov = cv2.resize(pov, (64, 64)).astype(np.float32)
    state = pov.transpose(2, 0, 1) / 255.0
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    best_distance= 0
    print(' reset')
    for _ in range(play_steps):

        camera_output, buttons_ouput = policy_net(state)
        minerl_action, ci, bi = env_action.get_action(camera_output, buttons_ouput, t=T)
        obs, reward, done, info = env_action.step(minerl_action)
        steps += 1

        if T > 1.2:
            T = max(T * 0.9995, 1.2)

        reward = env_action.get_reward(obs)
        # if reward > 0:
        #     print(reward)
        # print(env_action.distance)

        total_rewards.append(reward+total_rewards[-1])
        reward = torch.tensor([reward], device=device)
        if reward == 100:
            done = True

        if done:
            next_state = None
        else:
            pov = obs["pov"]
            replay.append(pov)
            pov = cv2.resize(pov, (64, 64)).astype(np.float32)
            next_state = pov.transpose(2, 0, 1) / 255.0
            next_state = torch.tensor(next_state, dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        camera_action = torch.tensor([ci], device=device)
        button_action = torch.tensor([bi], device=device)
        memory.push(state, camera_action, button_action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        loss = optimize_model()
        losses.append(loss)

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            break

        if steps % 1000 == 0:
            mean_loss = sum(losses) / len(losses)
            print('Step: ' + str(steps) + ' | Training loss = '+ str(mean_loss))
            losses.clear()

        # break

    print('Best distance: ' + str(env_action.min_distance))

print('Complete')

In [None]:
plt.plot(total_rewards)

In [None]:
torch.save(policy_net.state_dict(), 'navigation_3.pth')

#Evaluation

In [None]:
rewards = [0]
env_action = ActionShaping_2(env)
action_list = np.arange(env_action.action_space.n)
record = []
log_number = 4

for step in tqdm(range(3000)):
    # Get input in the correct format
    obs = torch.from_numpy(cv2.resize(obs['pov'], (64, 64)).transpose(2, 0, 1)[None].astype(np.float32) / 255).cuda()
    # Turn logits into probabilities
    logits = model(obs)

    t = 1.2
    probabilities = torch.softmax(logits/t, dim=1)[0].detach().cpu().numpy()
    # Sample action according to the probabilities
    action = np.random.choice(action_list, p=probabilities)

    obs, reward, _, _ = env_action.step(action)
    record.append(obs['pov'])

    if obs['inventory']['oak_log'].item() > log_number:
        log_number = obs['inventory']['oak_log'].item()
        print(str(obs['inventory']['oak_log'].item()) + ' oak logs')
        rewards.append(rewards[-1] + 5)
    else:
        rewards.append(rewards[-1] + 0)

In [None]:
import os
import cv2

# Save the task process as a video
file_path='saveVideo.mp4'
size=(320,180)
fps = 30

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
videoWriter = cv2.VideoWriter(file_path,fourcc,fps,size)

for item in record:
    b, g, r = cv2.split(item)
    img = cv2.merge([r, g, b])
    img = cv2.resize(img, (320, 180))
    # img = cv2.imread(item)
    videoWriter.write(img)

videoWriter.release()