# Variables

In [None]:
batch_size = 16
frame_height = 160
frame_width = 256
original_fps = 30
desired_fps = 30
pool_type = "avg"
assert desired_fps <= original_fps, "desired_fps can't be higher than original_fps"
assert pool_type == "avg" or pool_type == "attn", "pool_type must be either avg or attn"

# Installing python packages

In [None]:
!pip install --upgrade importlib_resources==5.12.0 --quiet
!pip install --upgrade setuptools==65.4.1 --quiet
!pip install wheel==0.38.4 --quiet
!pip install av --quiet
!pip install git+https://github.com/MineDojo/MineCLIP --quiet

# Imports

In [None]:
import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from mineclip import MineCLIP
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive/')
%cd "drive/MyDrive/MineCLIP_prompt_comparison"

In [None]:
class VideoDataset(Dataset):
	def __init__(self, data):
		self.data = data

	def __getitem__(self, index):
		x = self.data[index]
		return x

	def __len__(self):
		return len(self.data)

# Initializing MineCLIP

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running with", device)
model = MineCLIP(
	arch="vit_base_p16_fz.v2.t2",
	resolution=[frame_height, frame_width],
	pool_type= "avg" if pool_type == "avg" else "attn.d2.nh8.glusw",
	image_feature_dim=512,
	mlp_adapter_spec="v0-2.t0", # v3-1.t2
	hidden_dim=512
).to(device)
model.load_ckpt(f"./{pool_type}.pth")

# Setting prompts

In [None]:
prompts = [
    np.array([
      "approach tree"
    ]), np.array([
      "chop tree",
      "farm wood",
      "farm logs"
    ]), np.array([
      "craft wooden planks",
      "craft planks",
      "make wooden planks",
      "make planks",
      "craft planks out of the logs"
    ]), np.array([
      "craft a crafting table",
      "craft a workbench",
      "make a crafting table",
      "make a workbench",
      "craft a crafting table with 4 wood"
    ]), np.array([
      "place crafting table",
      "place workbench"
    ]), np.array([
      "craft sticks",
      "make sticks",
      "craft sticks with 2 wood"
    ]), np.array([
      "craft a wooden pickaxe",
      "craft a pickaxe",
      "make a wooden pickaxe",
      "craft a pickaxe with 2 sticks and 3 wood"
    ]), np.array([
        "use an anvil",
        "defeat the enderdragon"
    ])
]
prompts_flattened = np.concatenate(np.array(prompts, dtype=object)).ravel().tolist()
prompt_feats = model.encode_text(prompts_flattened)

# Loading the video

In [None]:
video_object = torchvision.io.VideoReader("./MC_Clip.mp4", "video")
transform = T.Resize((frame_height, frame_width), antialias=False)
video_object.set_current_stream("video")
frames = []
for i, frame in enumerate(video_object):
	if i % int(original_fps/desired_fps) == 0:
		frames.append(transform(frame['data']))
loader = DataLoader(
	VideoDataset(torch.stack(frames, 0)),
	batch_size=batch_size,
	shuffle=False
)
print(f"Loaded video with {len(frames)} frames")

# Evaluation

In [None]:
rewards = []
max_scores = [-float("inf")] * len(prompts_flattened)
best_batches = [None] * len(prompts_flattened)
batch_count = 1
for data in loader:
  print(f"Evaluating batch {batch_count}/{round(len(frames)/batch_size+0.5)}")
  data = torch.unsqueeze(data, dim=0).to(device)
  with torch.no_grad():
    reward, _ = model(data, text_tokens=prompt_feats, is_video_features=False)
  reward = reward[0]
  for i in range(len(prompts_flattened)):
    if max_scores[i] < reward[i]:
      max_scores[i] = reward[i].cpu()
      best_batches[i] = data.cpu()[0]
  rewards.append(reward.cpu().numpy())
  batch_count += 1

# Plots

In [None]:
rewards_reshaped = np.array(rewards).T
prompt_count = 0
for i, prompts_specific in enumerate(prompts):
  rewards_specific = rewards_reshaped[prompt_count:prompt_count+len(prompts_specific)].T
  plt.rcParams["figure.figsize"] = [7.50, 3.50]
  plt.xlabel("Batches")
  plt.ylabel("Score")
  plt.plot(rewards_specific)
  plt.legend(prompts_specific)
  plt.savefig(f"plots/prompt_comparison_{i}.png", bbox_inches="tight")
  plt.show()
  prompt_count += len(prompts_specific)

In [None]:
for i in range(len(prompts_flattened)):
  fig, axs = plt.subplots(1, batch_size, figsize=(frame_width, frame_height/batch_size))
  fig.suptitle(f"batch with max score of {round(float(max_scores[i]), 2)} for prompt: \"{prompts_flattened[i]}\"", fontsize=100)
  for j in range(batch_size):
    if j >= len(best_batches[i]):
      break
    axs[j].imshow(best_batches[i][j].permute(1, 2, 0))
  plt.savefig(f"plots/best_batch_{i}.png", bbox_inches="tight")
  plt.show()

In [None]:
print(f"average max-score for all tasks: {np.average(max_scores)}")

# Results


*   Prompt "chop tree" performs a lot better than "farm wood" and "farm logs"
*   Phrase "craft" is better than "make"
*   Phrase "crafting table" is better than "workbench" (probably more used in training data because it is the official term)
*   Extra information on the recipe lead to higher scores
*   Scores for prompt "use an anvil" and "defeat the enderdragon" are only slightly lower than scores for other prompts, even though the agent could not complete the tasks
*   The prompt "make a wooden pickaxe" achieves highest score while looking at a tree, rather than actually crafting a wooden pickaxe (not everytime though)
*   Prompts for crafting an object sometimes achieve the highest score for looking at the finished object, rather than crafting it
* Prompts for crafting an object yield higher scores just when the inventory or the crafting table menu is open (and doesn't take into account what actually gets crafted)
*   Using lower FPS doesn't seem to have much of an impact on the scores (would probably be best to use FPS high enough so that there aren't too many different actions in one batch)
*   Pool types "avg" and "attn" seem to achieve similar relative results (but overall avg has lower scores). "use an anvil" and "defeat the enderdragon" prompt has lower scores with avg pool type though
*   All in all does the model a good job on detecting when the specific actions happen in the video

In [None]:
from PIL import Image

canvas = Image.new("RGB", (int(19896/3), int(25*855/3)))
for i in range(25):
  img = Image.open(f"plots/best_batch_{i}.png")
  img = img.resize((int(19896/3), int(855/3)))
  canvas.paste(img, (0, int(i*855/3)))
canvas.save("plots/best_batch.png")