In [1]:
from __future__ import division
import argparse
import os
import sys
import numpy as np
import torch

from agent import Agent
from minecraft import DummyMinecraft, Env, test_policy
from dataset import Dataset, Transition

import pickle
import time
from os.path import join as p_join
from os.path import exists as p_exists

from data_manager import StateManager, ActionManager

from get_dataset import put_data_into_dataset

import minerl
import gym



In [2]:
try:
    from torch.utils.tensorboard import SummaryWriter
except ModuleNotFoundError:
    from tensorboardX import SummaryWriter

In [3]:
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

In [4]:
OUTPUT_DIR = '/home/ankitagarg/minerl/minerl_imitation_learning/output/'
DATASET_DIR = '/home/ankitagarg/minerl/data/'

enable_cudnn = True
train = True
c_action_magnitude = 22.5 #magnitude of discretized camera action
seed = 123
scale_rewards = True

learning_rate = 0.0000625
adam_eps = 1.5e-4

batch_size = 32

# parser.add_argument("--logdir", default=".", type=str, help="used for logging and to save network snapshots")
net = 'deep_resnet'
hidden_size = 1024
dataset_path = None
               
trainsteps = 3000000
augment_flip = True

dataset_only_successful = False
dataset_use_max_duration_steps = True
dataset_continuous_action_stacking = 3
dataset_max_reward = 256

save_dataset_path = '/home/ankitagarg/minerl/minerl_imitation_learning/data/saved_dataset'
quit_after_saving_dataset = False

dueling = True

add_treechop_data = False

stop_time = 1
test = False

In [5]:
if p_exists(p_join(OUTPUT_DIR, 'model_last.pth')):
    print("Training already finished")
    train = False
if p_exists(p_join(OUTPUT_DIR, "tmp_time.p")):
    print("Detected tmp snapshot, will continue training from there")
    continue_from_tmp = True
else:
    continue_from_tmp = False

In [6]:
#Setup
np.random.seed(seed)
torch.manual_seed(np.random.randint(1, 10000))

assert torch.cuda.is_available()
torch.cuda.manual_seed(np.random.randint(1, 10000))
torch.backends.cudnn.enabled = enable_cudnn
device = torch.device('cuda')

print(f"Running on {device}")

state_manager = StateManager(device)
action_manager = ActionManager(device, c_action_magnitude)

Running on cuda


In [7]:
writer = SummaryWriter(OUTPUT_DIR)

with open(p_join(OUTPUT_DIR, "status.txt"), 'w') as status_file:
    status_file.write('running')

# extended error exception:
def handle_exception(exc_type, exc_value, exc_traceback):

    with open(p_join(OUTPUT_DIR, "status.txt"), 'w') as status_file_:
        status_file_.write('error')

    writer.flush()
    writer.close()
    env.close()
    sys.__excepthook__(exc_type, exc_value, exc_traceback)

sys.excepthook = handle_exception

In [8]:
#create the environment
env_ = DummyMinecraft()
env_.seed(seed)

env = Env(env_, state_manager, action_manager)

print("started env")

img, vec = env.reset()

print("env reset")

print("img, vec shapes: ", img.shape, vec.shape)



started env
env reset
img, vec shapes:  torch.Size([1, 3, 64, 64]) torch.Size([1, 216])


  img_torch = torch.tensor(img_list, dtype=torch.float32, device=self.device).div_(255).permute(0, 3, 1, 2)


In [9]:
num_actions = action_manager.num_action_ids_list[0]
image_channels = img.shape[1]

vec_size = vec.shape[1]
vec_shape = vec.shape[1:]

img_shape = list(img.shape[1:])
img_shape[0] = int(img_shape[0])

In [10]:
dataset = Dataset(device, 2000000, img_shape, vec_shape,
                  state_manager, action_manager,
                  scale_rewards=scale_rewards)
    
if dataset_path is not None:  # default None

    print(f"loading dataset {dataset_path}")
    dataset.load(dataset_path)
    print(f"loaded dataset")

else:  # creating dataset:

    assert DATASET_DIR is not None

    print("creating dataset")

    if dataset_use_max_duration_steps:  # default: True
        max_iron_pickaxe_duration = 6000
        max_diamond_duration = 18000
    else:
        max_iron_pickaxe_duration = None
        max_diamond_duration = None

    put_data_into_dataset(
        'MineRLObtainIronPickaxe-v0', action_manager, dataset, DATASET_DIR,
        dataset_continuous_action_stacking,
        dataset_only_successful,
        max_iron_pickaxe_duration,
        dataset_max_reward,
        test)

    put_data_into_dataset(
        'MineRLObtainDiamond-v0', action_manager, dataset, DATASET_DIR,
        dataset_continuous_action_stacking,
        dataset_only_successful,
        max_diamond_duration,
        dataset_max_reward,
        test)

    if add_treechop_data:
        put_data_into_dataset(
            'MineRLTreechop-v0', action_manager, dataset, DATASET_DIR,
            dataset_continuous_action_stacking,
            dataset_only_successful,
            None,
            dataset_max_reward,
            test)

    if save_dataset_path is not None:
        dataset.save(save_dataset_path)
        print(f"saved new dataset{save_dataset_path} with {dataset.transitions.index} transitions")
        
    else:
        print("continuing with new dataset without saving")


creating dataset

 Adding data from MineRLObtainIronPickaxe-v0 



 21%|██        | 664/3213 [00:00<00:00, 3509.24it/s]

{'success': False, 'duration_ms': 160650, 'duration_steps': 3213, 'total_reward': 547.0, 'stream_name': 'v3_juvenile_apple_angel-7_212895-216138', 'true_video_frame_count': 3244}


100%|██████████| 3213/3213 [00:00<00:00, 3478.86it/s]


1 / 2, added: 1


 11%|█▏        | 450/3965 [00:00<00:01, 2244.37it/s]

{'success': False, 'duration_ms': 198250, 'duration_steps': 3965, 'total_reward': 547.0, 'stream_name': 'v3_sticky_chick_pea_gnome-21_46603-50686', 'true_video_frame_count': 4085}


100%|██████████| 3965/3965 [00:01<00:00, 3147.54it/s]

2 / 2, added: 2

 Adding data from MineRLObtainDiamond-v0 




  1%|          | 566/69526 [00:00<00:12, 5656.90it/s]

{'success': False, 'duration_ms': 3476300, 'duration_steps': 69526, 'total_reward': 99.0, 'stream_name': 'v3_self_reliant_fig_doppelganger-1_37451-107047', 'true_video_frame_count': 69598}


100%|██████████| 69526/69526 [00:17<00:00, 3927.84it/s]


1 / 2, added: 1


  0%|          | 0/2018 [00:00<?, ?it/s]

{'success': False, 'duration_ms': 100900, 'duration_steps': 2018, 'total_reward': 35.0, 'stream_name': 'v3_key_nectarine_spirit-1_1619-3682', 'true_video_frame_count': 2065}


100%|██████████| 2018/2018 [00:00<00:00, 2896.52it/s]


2 / 2, added: 2
saved new dataset/home/ankitagarg/minerl/minerl_imitation_learning/data/saved_dataset with 17955 transitions


In [11]:
for j in range(dataset.transitions.index):
    dataset.transitions.data[j] = Transition(
        dataset.transitions.data[j].state.pin_memory(),
        dataset.transitions.data[j].vector.pin_memory(),
        dataset.transitions.data[j].action,
        dataset.transitions.data[j].reward,
        dataset.transitions.data[j].nonterminal
    )

In [12]:
agent = Agent(num_actions, image_channels, vec_size, writer,
              net, batch_size, augment_flip, hidden_size, dueling,
              learning_rate, adam_eps, device)

In [13]:
init_time = time.time()
if continue_from_tmp:
    start_int = pickle.load(open(p_join(OUTPUT_DIR, "tmp_time.p"), "rb"))
    print(f"continuing from {start_int} trainstep")
    agent.load(OUTPUT_DIR, "tmp")
else:
    start_int = 0

agent.train()

with open(p_join(OUTPUT_DIR, "status.txt"), 'w') as status_file:
    status_file.write('running training')

if test:
    trainsteps = 10

fps_t0 = time.time()

for i in range(start_int, trainsteps):

    agent.learn(i, dataset, write=(i % 1000 == 0))

    if i and i % 100000 == 0:
        agent.save(OUTPUT_DIR, i // 100000)

    if stop_time is not None:
        if ((time.time() - init_time) / 60. / 60.) > stop_time:
            print(f"{(time.time() - init_time) / 60. / 60.} h passed, saving tmp snapshot", flush=True)
            agent.save(OUTPUT_DIR, "tmp")
            pickle.dump(int(i), open(p_join(OUTPUT_DIR, "tmp_time.p"), 'wb'))
            writer.close()
            print('saved')
            break

    if (i+1) % 5000 == 0:
        fps = float(i - start_int) / (time.time() - fps_t0)
        writer.add_scalar("fps", fps, i)

agent.save(OUTPUT_DIR, 'last')

1.0000023639202118 h passed, saving tmp snapshot
saved


In [None]:
# def put_data_into_dataset(env_name, action_manager, dataset, minecraft_human_data_dir,
#                           continuous_action_stacking_amount=3,
#                           only_successful=True, max_duration_steps=None, max_reward=256.,
#                           test=False):
#     """
#     :param env_name: Minecraft env name
#     :param action_manager: expects object of data_manager.ActionManager
#     :param dataset: expects object of dataset.Dataset
#     :param minecraft_human_data_dir: location of Minecraft human data
#     :param continuous_action_stacking_amount: number of consecutive states that are used to get the continuous action
#     (since humans move the camera slowly we add up the continuous actions of multiple consecutive states)
#     :param only_successful: skip trajectories that don't reach final reward when true
#     :param max_duration_steps: skip trajectories that take longer than max_duration_steps to reach the final reward
#     :param max_reward: remove trajectory part beyond the max_reward. Used to remove the "obtain diamond" part, since
#     the imitation policy never obtains diamonds anyway
#     :param test: if true a mini dataset is created for debugging

#     further all samples without rewards, and without terminal states, and with no_op action are removed
#     """

#     print(f"\n Adding data from {env_name} \n")

#     treechop_data = env_name == "MineRLTreechop-v0"

#     def is_success(sample):
#         if max_duration_steps is None:
#             return sample[-1]['success']
#         else:
#             return sample[-1]['success'] and sample[-1]['duration_steps'] < max_duration_steps

#     def is_no_op(sample):
#         action = sample[1]
#         a_id = action_manager.get_id(action)
#         assert type(a_id) == int
#         return a_id == 0  # no_op action has id of 0

#     def process_sample(sample, last_reward):
#         """adding single sample to dataset if all conditions are met, expects sample with already stacked
#         camera action"""

#         reward = sample[2]

#         if reward > last_reward:
#             last_reward = reward

#         gatherlog_sample = last_reward < 2.

#         if treechop_data:
#             # fill missing action and state parts with zeros:
#             for key, value in action_manager.zero_action.items():
#                 if key not in sample[1]:
#                     sample[1][key] = value

#             sample[0]['equipped_items'] = OrderedDict([(
#                 'mainhand',
#                 OrderedDict([('damage', 0), ('maxDamage', 0), ('type', 0)])
#             )])

#             sample[0]["inventory"] = OrderedDict([
#                 ('coal', 0),
#                 ('cobblestone', 0),
#                 ('crafting_table', 0),
#                 ('dirt', 0),
#                 ('furnace', 0),
#                 ('iron_axe', 0),
#                 ('iron_ingot', 0),
#                 ('iron_ore', 0),
#                 ('iron_pickaxe', 0),
#                 ('log', 0),
#                 ('planks', 0),
#                 ('stick', 0),
#                 ('stone', 0),
#                 ('stone_axe', 0),
#                 ('stone_pickaxe', 0),
#                 ('torch', 0),
#                 ('wooden_axe', 0),
#                 ('wooden_pickaxe', 0)
#             ])

#         if reward != 0.:
#             if reward > max_reward:
#                 # if a larger reward is encountered, the transition is deleted until previous reward:
#                 counter_change = - dataset.remove_new_data()
#             else:
#                 dataset.append_sample(sample, gatherlog_sample, treechop_data)
#                 dataset.update_last_reward_index()
#                 counter_change = 1
#         else:
#             if not is_no_op(sample) or sample[4]:  # remove no_op transitions, unless it is a terminal state
#                 dataset.append_sample(sample, gatherlog_sample, treechop_data)
#                 counter_change = 1
#             else:
#                 counter_change = 0

#         return counter_change, last_reward

#     data = minerl.data.make(env_name, data_dir=minecraft_human_data_dir)
#     trajs = data.get_trajectory_names()

#     # the ring buffer is used to stack the camera action of multiple consecutive states:
#     sample_que = deque(maxlen=continuous_action_stacking_amount)

#     total_trajs_counter = 0
#     added_sample_counter = 0

#     initial_sample_amount = dataset.transitions.current_size()

#     for n, traj in enumerate(trajs):
#         for j, sample in enumerate(data.load_data(traj, include_metadata=True)):

#             # checking if the trajectory will be used first:
#             if j == 0:
#                 print(sample[-1])

#                 if only_successful:
#                     if not is_success(sample):
#                         print("skipping trajectory")
#                         break

#                 total_trajs_counter += 1
#                 last_reward = 0.

#             sample_que.append(sample)

#             # Only continue when we have enough states to stack the camera actions:
#             if len(sample_que) == continuous_action_stacking_amount:

#                 # Stacking camera action for the oldest sample in the queue:
#                 for i in range(1, continuous_action_stacking_amount):
#                     sample_que[0][1]['camera'] += sample_que[i][1]['camera']

#                     if sample_que[i][2] != 0.:  # (if reward != 0)
#                         break  # no camera action stacking after a reward

#                 added_samples, last_reward = process_sample(sample_que[0], last_reward)

#                 added_sample_counter += added_samples

#         if len(sample_que) > 0:  # otherwise not successful traj
#             # for the last samples in the queue we don't stack the the camera actions
#             for i in range(1, continuous_action_stacking_amount):
#                 added_samples, last_reward = process_sample(sample_que[i], last_reward)
#                 added_sample_counter += added_samples

#             # a terminal state could be reached without exceeding max_reward:
#             added_sample_counter -= dataset.remove_new_data()

#             # making sure the last state from trajectory is terminal:
#             last_transition = deepcopy(dataset.transitions.data[dataset.transitions.index - 1])
#             dataset.transitions.data[dataset.transitions.index - 1] = \
#                 Transition(last_transition.state, last_transition.vector,
#                            last_transition.action, last_transition.reward, False)

#         sample_que.clear()

#         print(f"{n+1} / {len(trajs)}, added: {total_trajs_counter}")
#         assert dataset.transitions.current_size() - initial_sample_amount == added_sample_counter

#         if test:
#             if total_trajs_counter >= 2:
#                 assert total_trajs_counter == 2
#                 break
