In [1]:
from __future__ import division
import argparse
import os
import sys
import numpy as np
import torch

from agent import Agent
from minecraft import DummyMinecraft, Env, test_policy
from dataset import Dataset, Transition

import pickle
import time
from os.path import join as p_join
from os.path import exists as p_exists

from data_manager import StateManager, ActionManager

from get_dataset import put_data_into_dataset

import minerl
import gym



In [2]:
try:
    from torch.utils.tensorboard import SummaryWriter
except ModuleNotFoundError:
    from tensorboardX import SummaryWriter

In [3]:
def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

In [4]:
OUTPUT_DIR = '/home/ankitagarg/minerl/minerl_imitation_learning/output_2/'
DATASET_DIR = '/home/ankitagarg/minerl/data/'

enable_cudnn = True
train = True
c_action_magnitude = 22.5 #magnitude of discretized camera action
seed = 123
scale_rewards = True

learning_rate = 0.0000625
adam_eps = 1.5e-4

batch_size = 32

# parser.add_argument("--logdir", default=".", type=str, help="used for logging and to save network snapshots")
net = 'deep_resnet'
hidden_size = 1024
dataset_path = None
               
trainsteps = 3000000
augment_flip = True

dataset_only_successful = False
dataset_use_max_duration_steps = True
dataset_continuous_action_stacking = 3
dataset_max_reward = 256

save_dataset_path = '/home/ankitagarg/minerl/minerl_imitation_learning/data/saved_dataset'
quit_after_saving_dataset = False

dueling = True

add_treechop_data = True

stop_time = None
test = False

In [5]:
if p_exists(p_join(OUTPUT_DIR, 'model_last.pth')):
    print("Training already finished")
    train = False
if p_exists(p_join(OUTPUT_DIR, "tmp_time.p")):
    print("Detected tmp snapshot, will continue training from there")
    continue_from_tmp = True
else:
    continue_from_tmp = False

In [6]:
#Setup
np.random.seed(seed)
torch.manual_seed(np.random.randint(1, 10000))

assert torch.cuda.is_available()
torch.cuda.manual_seed(np.random.randint(1, 10000))
torch.backends.cudnn.enabled = enable_cudnn
device = torch.device('cuda')

print(f"Running on {device}")

state_manager = StateManager(device)
action_manager = ActionManager(device, c_action_magnitude)

Running on cuda


In [7]:
writer = SummaryWriter(OUTPUT_DIR)

with open(p_join(OUTPUT_DIR, "status.txt"), 'w') as status_file:
    status_file.write('running')

# extended error exception:
# def handle_exception(exc_type, exc_value, exc_traceback):

#     with open(p_join(OUTPUT_DIR, "status.txt"), 'w') as status_file_:
#         status_file_.write('error')

#     writer.flush()
#     writer.close()
#     env.close()
#     sys.__excepthook__(exc_type, exc_value, exc_traceback)

# sys.excepthook = handle_exception

In [8]:
#create the environment
env_ = DummyMinecraft()
env_.seed(seed)

env = Env(env_, state_manager, action_manager)

print("started env")

img, vec = env.reset()

print("env reset")

print("img, vec shapes: ", img.shape, vec.shape)

started env




env reset
img, vec shapes:  torch.Size([1, 3, 64, 64]) torch.Size([1, 216])


  img_torch = torch.tensor(img_list, dtype=torch.float32, device=self.device).div_(255).permute(0, 3, 1, 2)


In [9]:
num_actions = action_manager.num_action_ids_list[0]
image_channels = img.shape[1]

vec_size = vec.shape[1]
vec_shape = vec.shape[1:]

img_shape = list(img.shape[1:])
img_shape[0] = int(img_shape[0])

In [10]:
dataset = Dataset(device, 2000000, img_shape, vec_shape,
                  state_manager, action_manager,
                  scale_rewards=scale_rewards)
    
if dataset_path is not None:  # default None

    print(f"loading dataset {dataset_path}")
    dataset.load(dataset_path)
    print(f"loaded dataset")

else:  # creating dataset:

    assert DATASET_DIR is not None

    print("creating dataset")

    if dataset_use_max_duration_steps:  # default: True
        max_iron_pickaxe_duration = 6000
        max_diamond_duration = 18000
    else:
        max_iron_pickaxe_duration = None
        max_diamond_duration = None

    put_data_into_dataset(
        'MineRLObtainIronPickaxe-v0', action_manager, dataset, DATASET_DIR,
        dataset_continuous_action_stacking,
        dataset_only_successful,
        max_iron_pickaxe_duration,
        dataset_max_reward,
        test)

    put_data_into_dataset(
        'MineRLObtainDiamond-v0', action_manager, dataset, DATASET_DIR,
        dataset_continuous_action_stacking,
        dataset_only_successful,
        max_diamond_duration,
        dataset_max_reward,
        test)

    if add_treechop_data:
        put_data_into_dataset(
            'MineRLTreechop-v0', action_manager, dataset, DATASET_DIR,
            dataset_continuous_action_stacking,
            dataset_only_successful,
            None,
            dataset_max_reward,
            test)

    if save_dataset_path is not None:
        dataset.save(save_dataset_path)
        print(f"saved new dataset{save_dataset_path} with {dataset.transitions.index} transitions")
        
    else:
        print("continuing with new dataset without saving")


creating dataset

 Adding data from MineRLObtainIronPickaxe-v0 



 15%|█▍        | 474/3213 [00:00<00:01, 2421.41it/s]

{'success': False, 'duration_ms': 160650, 'duration_steps': 3213, 'total_reward': 547.0, 'stream_name': 'v3_juvenile_apple_angel-7_212895-216138', 'true_video_frame_count': 3244}


100%|██████████| 3213/3213 [00:01<00:00, 3112.65it/s]


1 / 2, added: 1


 10%|█         | 406/3965 [00:00<00:01, 2024.13it/s]

{'success': False, 'duration_ms': 198250, 'duration_steps': 3965, 'total_reward': 547.0, 'stream_name': 'v3_sticky_chick_pea_gnome-21_46603-50686', 'true_video_frame_count': 4085}


100%|██████████| 3965/3965 [00:01<00:00, 2817.03it/s]

2 / 2, added: 2

 Adding data from MineRLObtainDiamond-v0 




  1%|          | 527/69526 [00:00<00:13, 5251.05it/s]

{'success': False, 'duration_ms': 3476300, 'duration_steps': 69526, 'total_reward': 99.0, 'stream_name': 'v3_self_reliant_fig_doppelganger-1_37451-107047', 'true_video_frame_count': 69598}


100%|██████████| 69526/69526 [00:19<00:00, 3505.20it/s]


1 / 2, added: 1


  0%|          | 0/2018 [00:00<?, ?it/s]

{'success': False, 'duration_ms': 100900, 'duration_steps': 2018, 'total_reward': 35.0, 'stream_name': 'v3_key_nectarine_spirit-1_1619-3682', 'true_video_frame_count': 2065}


100%|██████████| 2018/2018 [00:00<00:00, 2660.37it/s]

2 / 2, added: 2

 Adding data from MineRLTreechop-v0 




 47%|████▋     | 722/1528 [00:00<00:00, 3671.98it/s]

{'success': True, 'duration_ms': 76400, 'duration_steps': 1528, 'total_reward': 64.0, 'stream_name': 'v3_content_squash_angel-3_16074-17640', 'true_video_frame_count': 1567}


100%|██████████| 1528/1528 [00:00<00:00, 3790.33it/s]


1 / 2, added: 1


  0%|          | 0/1680 [00:00<?, ?it/s]

{'success': True, 'duration_ms': 84000, 'duration_steps': 1680, 'total_reward': 64.0, 'stream_name': 'v3_homely_string_bean_djinn-10_514-2235', 'true_video_frame_count': 1722}


100%|██████████| 1680/1680 [00:00<00:00, 4111.18it/s]


2 / 2, added: 2
saved new dataset/home/ankitagarg/minerl/minerl_imitation_learning/data/saved_dataset with 21101 transitions


In [11]:
for j in range(dataset.transitions.index):
    dataset.transitions.data[j] = Transition(
        dataset.transitions.data[j].state.pin_memory(),
        dataset.transitions.data[j].vector.pin_memory(),
        dataset.transitions.data[j].action,
        dataset.transitions.data[j].reward,
        dataset.transitions.data[j].nonterminal
    )

In [12]:
agent = Agent(num_actions, image_channels, vec_size, writer,
              net, batch_size, augment_flip, hidden_size, dueling,
              learning_rate, adam_eps, device)

In [None]:
init_time = time.time()
if continue_from_tmp:
    start_int = pickle.load(open(p_join(OUTPUT_DIR, "tmp_time.p"), "rb"))
    print(f"continuing from {start_int} trainstep")
    agent.load(OUTPUT_DIR, "tmp")
else:
    start_int = 0

agent.train()

with open(p_join(OUTPUT_DIR, "status.txt"), 'w') as status_file:
    status_file.write('running training')

if test:
    trainsteps = 10

fps_t0 = time.time()

for i in range(start_int, trainsteps):
# for i in range(start_int, 10000):

    agent.learn(i, dataset, write=(i % 1000 == 0))

    if i and i % 500000 == 0:
        agent.save(OUTPUT_DIR, i // 500000)
    
#     if i and i % 5000 == 0:
#         agent.save(OUTPUT_DIR, i // 5000)

    if stop_time is not None:
        if ((time.time() - init_time) / 60. / 60.) > stop_time:
            print(f"{(time.time() - init_time) / 60. / 60.} h passed, saving tmp snapshot", flush=True)
            agent.save(OUTPUT_DIR, "tmp")
            pickle.dump(int(i), open(p_join(OUTPUT_DIR, "tmp_time.p"), 'wb'))
            writer.close()
            print('saved')
            break

    if (i+1) % 5000 == 0:
        fps = float(i - start_int) / (time.time() - fps_t0)
        writer.add_scalar("fps", fps, i)

agent.save(OUTPUT_DIR, 'last')

In [None]:
with open(p_join(OUTPUT_DIR, "status.txt"), 'w') as status_file:
    status_file.write('finished')

env.close()

In [None]:
writer.close()