In [7]:
import logging

In [4]:
from src.data_loader import DataLoader

# N_WORKERS = 8
# BATCH_SIZE = 8
# EPOCHS = 3

# data_loader = DataLoader(
#     dataset_dir="../basalt_neurips_data/BuildWaterFall",
#     n_workers=N_WORKERS,
#     batch_size=BATCH_SIZE,
#     n_epochs=EPOCHS,
#     dataset_max_size=10
# )

In [62]:
from src.openai_vpt.preprocessing import MineRLAgentPP

from argparse import ArgumentParser
import pickle
import time
import logging

import gym
import minerl
import torch
import numpy as np
from tqdm import tqdm

from src.openai_vpt.agent import PI_HEAD_KWARGS, MineRLAgent
from src.data_loader import DataLoader
from src.openai_vpt.lib.tree_util import tree_map

logging.basicConfig(level=logging.DEBUG)

# Originally this code was designed for a small dataset of ~20 demonstrations per task.
# The settings might not be the best for the full BASALT dataset (thousands of demonstrations).
# Use this flag to switch between the two settings
USING_FULL_DATASET = False

EPOCHS = 1 if USING_FULL_DATASET else 1
# Needs to be <= number of videos
BATCH_SIZE = 64 if USING_FULL_DATASET else 8
# Ideally more than batch size to create
# variation in datasets (otherwise, you will
# get a bunch of consecutive samples)
# Decrease this (and batch_size) if you run out of memory
N_WORKERS = 100 if USING_FULL_DATASET else 16
DEVICE = "cuda"

LOSS_REPORT_RATE = 100

# Tuned with bit of trial and error
LEARNING_RATE = 0.000181
# OpenAI VPT BC weight decay
# WEIGHT_DECAY = 0.039428
WEIGHT_DECAY = 0.0
# KL loss to the original model was not used in OpenAI VPT
KL_LOSS_WEIGHT = 1.0
MAX_GRAD_NORM = 5.0

# MAX_BATCHES = 2000 if USING_FULL_DATASET else int(1e9)
# MAX_BATCHES = 10

MAX_EPISODES = 16

data_dir = "../basalt_neurips_data/BuildWaterFall/"


In [64]:

data_loader = DataLoader(
    dataset_dir=data_dir,
    n_workers=N_WORKERS,
    batch_size=BATCH_SIZE,
    n_epochs=EPOCHS,
    dataset_max_size=MAX_EPISODES,
    shuffle=False
)


# Keep track of the hidden state per episode/trajectory.
# DataLoader provides unique id for each episode, which will
# be different even for the same trajectory when it is loaded
# up again

episode_hidden_states = {}
dummy_first = torch.from_numpy(np.array((False,))).to(DEVICE)

agent = MineRLAgentPP(
    env="MineRLBasaltMakeWaterfall-v0",
    device=DEVICE,
)


def training_loop(pbar):

    for batch_i, (batch_images, batch_actions, batch_episode_id) in pbar:
        batch_loss = 0
        for image, action, episode_id in zip(
            batch_images, batch_actions, batch_episode_id
        ):

            """
            action={'ESC': 0, 'back': 0, 'drop': 0, 'forward': 0, 'hotbar.1': 0, 'hotbar.2': 0, 'hotbar.3': 0, 'hotbar.4': 0, 'hotbar.5': 0, 'hotbar.6': 0, 'hotbar.7': 0, 'hotbar.8': 0, 'hotbar.9': 0, 'inventory': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0, 'swapHands': 0, 'camera': array([ 0, -1]), 'attack': 0, 'use': 0, 'pickItem': 0}

            image.shape=(128, 128, 3)
            """

            if image is None and action is None:
                # A work-item was done. Remove hidden state
                if episode_id in episode_hidden_states:
                    removed_hidden_state = episode_hidden_states.pop(episode_id)
                    del removed_hidden_state
                continue

            agent_action = agent._env_action_to_agent(action, to_torch=True, check_if_null=True)
            if agent_action is None:
                continue

            agent_obs = agent._env_obs_to_agent({"pov": image})

            """
            agent_action={'buttons': tensor([[288]], device='cuda:0'), 'camera': tensor([[60]], device='cuda:0')}

            agent_obs.keys()=dict_keys(['img'])
            agent_obs['img'].shape=torch.Size([1, 128, 128, 3])
            """
            agent_action_input = torch.concat((agent_action["buttons"], agent_action["camera"]), dim=1)
            yield agent_obs["img"], agent_action_input, episode_id





In [65]:
from src.end2end_vid_segmentation_rcnn import VideoSegmentationModel

model = VideoSegmentationModel(128, 2, 3).to("cuda")

optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()

loss_sum = 0
pbar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Avg loss: {loss_sum / LOSS_REPORT_RATE:.4f}")

model.train()

for image_tensor, action_tensor, episode_id in training_loop(pbar):
    print(image_tensor.shape, action_tensor.shape, episode_id)

    # Initialize batch loss
    batch_loss = 0

    batch_size = 1  # FIXME
    # Initialize hidden state for the current batch
    hidden_state = model.init_hidden(batch_size)  # FIXME

    # Forward pass through the model
    logits, hidden_state = model(image_tensor, action_tensor, hidden_state)

    # Compute the loss
    loss = loss_fn(logits, action_tensor)
    batch_loss += loss.item()

    # Backpropagate the error and update the model parameters
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Update progress bar with average loss for the current batch
    pbar.set_description(f"Avg loss: {batch_loss / len(batch_images):.4f}")
    loss_sum += batch_loss

    # Stop training if maximum number of batches has been reached
    if batch_i > MAX_BATCHES:
        break

Avg loss: 0.0000:   0%|          | 11/3792 [00:00<00:51, 73.29it/s]

torch.Size([1, 128, 128, 3]) torch.Size([1, 2]) 0
torch.Size([1, 128, 128, 3]) torch.Size([1, 2]) 3
torch.Size([1, 128, 128, 3]) torch.Size([1, 2]) 7
torch.Size([1, 128, 128, 3]) torch.Size([1, 2]) 12
torch.Size([1, 128, 128, 3]) torch.Size([1, 2]) 14
torch.Size([1, 128, 128, 3]) torch.Size([1, 2]) 3
torch.Size([1, 128, 128, 3]) torch.Size([1, 2]) 14





In [None]:
pbar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Avg loss: {loss_sum / LOSS_REPORT_RATE:.4f}")

# Loop through data loader
for batch_i, (batch_images, batch_actions, batch_episode_id) in pbar:
    # Initialize batch loss
    batch_loss = 0

    # Initialize hidden state for the current batch
    hidden_state = model.init_hidden(batch_size)

    # Process each image and action in the batch
    for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id):
        if image is None and action is None:
            # A work-item was done. Remove hidden state
            if episode_id in episode_hidden_states:
                removed_hidden_state = episode_hidden_states.pop(episode_id)
                del removed_hidden_state
            continue

        # Convert image and action to tensors
        image_tensor = torch.from_numpy(image).to(DEVICE)
        action_tensor = torch.from_numpy(action).to(DEVICE)

        # Forward pass through the model
        logits, hidden_state = model(image_tensor, hidden_state)

        # Compute the loss
        loss = loss_fn(logits, action_tensor)
        batch_loss += loss.item()

        # Backpropagate the error and update the model parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Update progress bar with average loss for the current batch
    pbar.set_description(f"Avg loss: {batch_loss / len(batch_images):.4f}")
    loss_sum += batch_loss

    # Stop training if maximum number of batches has been reached
    if batch_i > MAX_BATCHES:
        break

In [16]:
import numpy as np

action = {'ESC': 1, 'back': 0, 'drop': 0, 'forward': 0, 'hotbar.1': 0, 'hotbar.2': 0, 'hotbar.3': 0, 'hotbar.4': 0, 'hotbar.5': 0, 'hotbar.6': 0, 'hotbar.7': 0, 'hotbar.8': 0, 'hotbar.9': 0, 'inventory': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0, 'swapHands': 0, 'camera': np.array([0, 0]), 'attack': 0, 'use': 0, 'pickItem': 0}

In [40]:
def print_raw_action(action):
    print("[[[[[[[[[[[[[[[[")
    for e, f in action.items():
        if (isinstance(f, np.ndarray) and (f[0] or f[1])) or (not isinstance(f, np.ndarray) and f != 0):
            print(f"{e}: {' ' * (10 - len(e))}{f}")
    print("]]]]]]]]]]]]]]]]")


In [41]:

data_loader = DataLoader(
    dataset_dir=data_dir,
    n_workers=N_WORKERS,
    batch_size=BATCH_SIZE,
    n_epochs=EPOCHS,
    dataset_max_size=10,
    shuffle=False
)


# Keep track of the hidden state per episode/trajectory.
# DataLoader provides unique id for each episode, which will
# be different even for the same trajectory when it is loaded
# up again
episode_hidden_states = {}
dummy_first = torch.from_numpy(np.array((False,))).to(DEVICE)

agent = MineRLAgentPP(
    env="MineRLBasaltMakeWaterfall-v0",
    device=DEVICE,
)

loss_sum = 0
pbar = tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Avg loss: {loss_sum / LOSS_REPORT_RATE:.4f}")

for batch_i, (batch_images, batch_actions, batch_episode_id) in pbar:
    batch_loss = 0
    for image, action, episode_id in zip(
        batch_images, batch_actions, batch_episode_id
    ):

        """
        action={'ESC': 0, 'back': 0, 'drop': 0, 'forward': 0, 'hotbar.1': 0, 'hotbar.2': 0, 'hotbar.3': 0, 'hotbar.4': 0, 'hotbar.5': 0, 'hotbar.6': 0, 'hotbar.7': 0, 'hotbar.8': 0, 'hotbar.9': 0, 'inventory': 0, 'jump': 0, 'left': 0, 'right': 0, 'sneak': 0, 'sprint': 0, 'swapHands': 0, 'camera': array([ 0, -1]), 'attack': 0, 'use': 0, 'pickItem': 0}

        image.shape=(128, 128, 3)
        """

        if image is None and action is None:
            # A work-item was done. Remove hidden state
            if episode_id in episode_hidden_states:
                removed_hidden_state = episode_hidden_states.pop(episode_id)
                del removed_hidden_state
            continue


        agent_action = agent._env_action_to_agent(
            action, to_torch=True, check_if_null=True
        )

        if agent_action is None:
            continue

        agent_obs = agent._env_obs_to_agent({"pov": image})

        """
        agent_action={'buttons': tensor([[288]], device='cuda:0'), 'camera': tensor([[60]], device='cuda:0')}

        agent_obs.keys()=dict_keys(['img'])
        agent_obs['img'].shape=torch.Size([1, 128, 128, 3])
        """


    if batch_i > MAX_BATCHES:
        break



Avg loss: 0.0000:   0%|          | 3/7029 [00:00<03:55, 29.85it/s]

Action is null:
[[[[[[[[[[[[[[[[
]]]]]]]]]]]]]]]]
Action is null:
[[[[[[[[[[[[[[[[
]]]]]]]]]]]]]]]]
Action is null:
[[[[[[[[[[[[[[[[
]]]]]]]]]]]]]]]]
Action is null:
[[[[[[[[[[[[[[[[
]]]]]]]]]]]]]]]]
Action is NOT null:
[[[[[[[[[[[[[[[[
camera:     [-4 24]
]]]]]]]]]]]]]]]]
Action is null:
[[[[[[[[[[[[[[[[
]]]]]]]]]]]]]]]]
Action is null:
[[[[[[[[[[[[[[[[
]]]]]]]]]]]]]]]]
Action is NOT null:
[[[[[[[[[[[[[[[[
camera:     [ 2 -2]
]]]]]]]]]]]]]]]]
Action is null:
[[[[[[[[[[[[[[[[
]]]]]]]]]]]]]]]]
Action is null:
[[[[[[[[[[[[[[[[
]]]]]]]]]]]]]]]]
Action is null:
[[[[[[[[[[[[[[[[
]]]]]]]]]]]]]]]]
Action is null:
[[[[[[[[[[[[[[[[
]]]]]]]]]]]]]]]]



