In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [2]:
import jax
import jax.numpy as jnp
import numpy as onp
import chex
from icecream import ic
from sentence_transformers import SentenceTransformer
from functools import partial
from typing import Optional, Tuple, Union, Any

from craftax.craftax_classic.constants import *
from gymnax.environments import environment, spaces
from sklearn.manifold import TSNE

from crafter_constants import blocks_labels, mobs_labels, inventory_labels

def embedding_crafter(embedding_model, next_obs):
    batch_size = next_obs.shape[0]
    maps, metadata = jnp.split(next_obs, [7 * 9 * 21], axis=1)
    
    # For each block/mob type, count how many are in frame
    maps = jnp.reshape(maps, [-1, 7, 9, 21])
    maps = jnp.transpose(maps, [0, 3, 1, 2])
    maps = jnp.reshape(maps, [-1, 21, 7 * 9])
    maps = maps.sum(axis=2)
    maps = jnp.round(maps).astype(jnp.int32)
    blocks, mobs = jnp.split(maps, [17], axis=1)

    # Extract and format metadata
    inventory, intrinsics, direction, light_level, is_sleeping = jnp.split(metadata, [12, 16, 20, 21], axis=1)
    inventory = jnp.round(inventory * 10.0).astype(jnp.int32)
    intrinsics *= 10.0

    health, food, drink, energy = jnp.split(intrinsics, [1, 2, 3], axis=1)

    # Construct embedding
    sentences = []
    for i in range(batch_size):
        blocks_gt0 = jnp.argwhere(blocks[i] > 0).reshape(-1)
        mobs_gt0 = jnp.argwhere(mobs[i] > 0).reshape(-1)
        inventory_gt0 = jnp.argwhere(inventory[i] > 0).reshape(-1)

        blocks_str = 'You see {}.'.format(', '.join([blocks_labels[b] for b in blocks_gt0]))
        mobs_str = 'You see {}.'.format(', '.join([mobs_labels[m] for m in mobs_gt0]))
        inventory_str = 'You have in your inventory {}.'.format(', '.join([inventory_labels[i] for i in inventory_gt0]))

        status = []
        if food[i].item() < 10.0 - 1e-4:
            status.append("hungry")
        if drink[i].item() < 10.0 - 1e-4:
            status.append("thirsty")
        if energy[i].item() < 10.0 - 1e-4:
            status.append("tired")
        status_str = 'You feel {}.'.format(', '.join(status))

        if health[i].item() < 5.0 - 1e-4:
            health_str = 'You are at low health.'
        elif health[i].item() < 10.0 - 1e-4:
            health_str = 'You are at moderate health.'
        else:
            health_str = 'You are at full health.'

        desc = ['You are playing a Minecraft-like survival game.', status_str, health_str]
        if len(blocks_gt0) > 0:
            desc.append(blocks_str)
        if len(mobs_gt0) > 0:
            desc.append(mobs_str)
        if len(inventory_gt0) > 0:
            desc.append(inventory_str)

        desc = ' '.join(desc)
        sentences.append(desc)

    embeddings = embedding_model.encode(sentences)
    embeddings = jnp.array(embeddings)
    return embeddings, sentences

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from functools import partial
from jax import jit
from tqdm import tqdm

# @partial(jit, static_argnames=('batch_size',))
def obs_generator(key, batch_size=256):
    key, sb1, sb2, sb3, sb4, sb5, sb6 = jax.random.split(key, 7)
    direction = jax.random.randint(sb1, shape=(batch_size,), minval=0, maxval=4)
    direction = jax.nn.one_hot(direction, num_classes=4)
    block_mob_counts_1h = jax.random.bernoulli(sb2, 0.03, shape=(batch_size, 21 * 7 * 9))

    inventory = jnp.clip(jax.random.geometric(sb3, 0.5, shape=(batch_size, 12)) - 1, 0, 10)
    intrinsics = jax.random.randint(sb4, shape=(batch_size, 4), minval=0, maxval=11)
    light_level = jax.random.uniform(sb5, shape=(batch_size,), minval=0.0, maxval=10.0).reshape(batch_size, 1)
    is_sleeping = jax.random.randint(sb6, shape=(batch_size,), minval=0, maxval=2).reshape(batch_size, 1)
    
    obs = jnp.concatenate([block_mob_counts_1h, inventory, intrinsics, direction, light_level, is_sleeping], axis=1)
    return key, obs

In [4]:
embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')



In [5]:
from models import Captioner
from structures import ReplayBuffer
import datetime
import optax

dummy_obs = jnp.zeros((1, 1345))
key = jax.random.PRNGKey(int(datetime.datetime.now().timestamp()))
key, init_key = jax.random.split(key)

captioner = Captioner(
    hidden_size=768
)
captioner_params = captioner.init(init_key, dummy_obs)
captioner_opt = optax.adam(learning_rate=0.0001)
captioner_opt_state = captioner_opt.init(captioner_params)

buffer = ReplayBuffer.create({
    'obs': onp.zeros((1345,)),
    'embed': onp.zeros((768,)),
}, size=50000)

In [6]:
@jit
def update_captioner(captioner_params, captioner_opt_state, obs, emdeds):
    def loss_fn(captioner_params, obs, embeds):
        preds = captioner.apply(captioner_params, obs)
        loss = optax.losses.squared_error(preds, emdeds).mean()
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(captioner_params, obs, emdeds)
    captioner_updates, captioner_opt_state = captioner_opt.update(grads, captioner_opt_state)
    captioner_params = optax.apply_updates(captioner_params, captioner_updates)
    return captioner_params, captioner_opt_state, loss

In [16]:
TRAIN_STEPS = 10000
for i in tqdm(range(TRAIN_STEPS)):    
    sample = buffer.sample(1000)
    captioner_params, captioner_opt_state, loss = update_captioner(captioner_params, captioner_opt_state, sample['obs'], sample['embed'])
    if i % 100 == 0:
        print('Train step', i, 'Loss', loss)

  0%|          | 44/10000 [00:00<00:45, 217.62it/s]

Train step 0 Loss 2.01869e-05


  1%|▏         | 136/10000 [00:00<00:46, 211.45it/s]

Train step 100 Loss 2.0874786e-05


  2%|▏         | 219/10000 [00:01<01:03, 152.92it/s]

Train step 200 Loss 1.975254e-05


  3%|▎         | 327/10000 [00:01<00:51, 188.23it/s]

Train step 300 Loss 1.8757755e-05


  4%|▍         | 442/10000 [00:02<00:44, 216.45it/s]

Train step 400 Loss 1.9784184e-05


  5%|▌         | 515/10000 [00:02<00:44, 212.50it/s]

Train step 500 Loss 2.0249136e-05


  6%|▋         | 630/10000 [00:03<00:44, 209.82it/s]

Train step 600 Loss 1.9359934e-05


  7%|▋         | 726/10000 [00:03<00:40, 227.73it/s]

Train step 700 Loss 1.9713678e-05


  8%|▊         | 839/10000 [00:04<00:43, 210.38it/s]

Train step 800 Loss 1.8827712e-05


  9%|▉         | 925/10000 [00:04<00:49, 182.28it/s]

Train step 900 Loss 1.9382896e-05


 10%|█         | 1028/10000 [00:05<00:44, 199.45it/s]

Train step 1000 Loss 1.9507488e-05


 11%|█▏        | 1148/10000 [00:05<00:37, 234.96it/s]

Train step 1100 Loss 1.9851861e-05


 12%|█▏        | 1218/10000 [00:05<00:44, 198.87it/s]

Train step 1200 Loss 1.896777e-05


 13%|█▎        | 1335/10000 [00:06<00:39, 221.60it/s]

Train step 1300 Loss 1.9029485e-05


 14%|█▍        | 1429/10000 [00:06<00:38, 224.95it/s]

Train step 1400 Loss 1.8856063e-05


 15%|█▌        | 1527/10000 [00:07<00:36, 232.71it/s]

Train step 1500 Loss 1.9207762e-05


 16%|█▋        | 1648/10000 [00:07<00:36, 228.07it/s]

Train step 1600 Loss 1.9256202e-05


 17%|█▋        | 1726/10000 [00:08<00:33, 246.02it/s]

Train step 1700 Loss 1.9150813e-05


 18%|█▊        | 1823/10000 [00:08<00:42, 193.74it/s]

Train step 1800 Loss 1.8969613e-05


 19%|█▉        | 1945/10000 [00:09<00:36, 221.34it/s]

Train step 1900 Loss 1.9622115e-05


 20%|██        | 2019/10000 [00:09<00:35, 227.13it/s]

Train step 2000 Loss 1.9182597e-05


 21%|██▏       | 2137/10000 [00:10<00:35, 224.10it/s]

Train step 2100 Loss 1.9399356e-05


 22%|██▏       | 2232/10000 [00:10<00:33, 228.66it/s]

Train step 2200 Loss 1.9443363e-05


 23%|██▎       | 2306/10000 [00:10<00:39, 193.96it/s]

Train step 2300 Loss 1.8958872e-05


 24%|██▍       | 2425/10000 [00:11<00:33, 227.60it/s]

Train step 2400 Loss 1.8432671e-05


 25%|██▌       | 2544/10000 [00:11<00:35, 210.35it/s]

Train step 2500 Loss 1.8608629e-05


 26%|██▋       | 2641/10000 [00:12<00:31, 232.83it/s]

Train step 2600 Loss 1.9020958e-05


 27%|██▋       | 2743/10000 [00:12<00:29, 245.19it/s]

Train step 2700 Loss 1.8920744e-05


 28%|██▊       | 2843/10000 [00:13<00:29, 240.27it/s]

Train step 2800 Loss 1.8157318e-05


 29%|██▉       | 2944/10000 [00:13<00:29, 239.19it/s]

Train step 2900 Loss 1.945733e-05


 30%|███       | 3046/10000 [00:14<00:29, 239.09it/s]

Train step 3000 Loss 1.8965893e-05


 31%|███▏      | 3142/10000 [00:14<00:30, 224.44it/s]

Train step 3100 Loss 1.888853e-05


 32%|███▏      | 3236/10000 [00:14<00:29, 232.56it/s]

Train step 3200 Loss 1.8435461e-05


 33%|███▎      | 3332/10000 [00:15<00:30, 220.44it/s]

Train step 3300 Loss 1.8081593e-05


 34%|███▍      | 3422/10000 [00:15<00:30, 213.92it/s]

Train step 3400 Loss 1.8576937e-05


 35%|███▌      | 3534/10000 [00:16<00:30, 215.44it/s]

Train step 3500 Loss 1.9046453e-05


 36%|███▋      | 3632/10000 [00:16<00:27, 235.83it/s]

Train step 3600 Loss 1.8475024e-05


 37%|███▋      | 3729/10000 [00:17<00:28, 218.03it/s]

Train step 3700 Loss 1.9185982e-05


 38%|███▊      | 3843/10000 [00:17<00:27, 220.15it/s]

Train step 3800 Loss 1.8562992e-05


 39%|███▉      | 3934/10000 [00:18<00:27, 219.65it/s]

Train step 3900 Loss 1.9420844e-05


 40%|████      | 4045/10000 [00:18<00:27, 215.04it/s]

Train step 4000 Loss 1.8946941e-05


 41%|████▏     | 4141/10000 [00:19<00:25, 228.89it/s]

Train step 4100 Loss 1.8459725e-05


 42%|████▏     | 4237/10000 [00:19<00:27, 210.49it/s]

Train step 4200 Loss 1.8125787e-05


 43%|████▎     | 4323/10000 [00:19<00:27, 205.20it/s]

Train step 4300 Loss 1.8334234e-05


 44%|████▍     | 4436/10000 [00:20<00:26, 208.72it/s]

Train step 4400 Loss 1.8250405e-05


 45%|████▌     | 4530/10000 [00:20<00:25, 216.03it/s]

Train step 4500 Loss 1.8502915e-05


 46%|████▌     | 4619/10000 [00:21<00:31, 170.08it/s]

Train step 4600 Loss 1.8525365e-05


 47%|████▋     | 4725/10000 [00:21<00:27, 191.61it/s]

Train step 4700 Loss 1.8302146e-05


 48%|████▊     | 4824/10000 [00:22<00:22, 233.63it/s]

Train step 4800 Loss 1.7881868e-05


 49%|████▉     | 4943/10000 [00:22<00:22, 227.27it/s]

Train step 4900 Loss 1.9076047e-05


 50%|█████     | 5041/10000 [00:23<00:21, 235.07it/s]

Train step 5000 Loss 1.8128707e-05


 51%|█████▏    | 5133/10000 [00:23<00:23, 205.86it/s]

Train step 5100 Loss 1.8070363e-05


 52%|█████▏    | 5226/10000 [00:24<00:21, 222.45it/s]

Train step 5200 Loss 1.887354e-05


 53%|█████▎    | 5321/10000 [00:24<00:20, 225.19it/s]

Train step 5300 Loss 1.8553541e-05


 54%|█████▍    | 5442/10000 [00:25<00:19, 228.80it/s]

Train step 5400 Loss 1.8838877e-05


 55%|█████▌    | 5529/10000 [00:25<00:22, 195.02it/s]

Train step 5500 Loss 1.8142733e-05


 56%|█████▋    | 5637/10000 [00:26<00:20, 207.93it/s]

Train step 5600 Loss 1.8524139e-05


 57%|█████▋    | 5726/10000 [00:26<00:19, 214.34it/s]

Train step 5700 Loss 1.8564462e-05


 58%|█████▊    | 5842/10000 [00:27<00:18, 227.68it/s]

Train step 5800 Loss 1.7764933e-05


 59%|█████▉    | 5938/10000 [00:27<00:19, 210.85it/s]

Train step 5900 Loss 1.7966693e-05


 60%|██████    | 6027/10000 [00:27<00:19, 208.96it/s]

Train step 6000 Loss 1.8356357e-05


 61%|██████▏   | 6138/10000 [00:28<00:18, 206.68it/s]

Train step 6100 Loss 1.7717863e-05


 62%|██████▏   | 6228/10000 [00:28<00:18, 204.34it/s]

Train step 6200 Loss 1.8203922e-05


 63%|██████▎   | 6346/10000 [00:29<00:16, 226.43it/s]

Train step 6300 Loss 1.8283097e-05


 64%|██████▍   | 6424/10000 [00:29<00:15, 223.74it/s]

Train step 6400 Loss 1.8116196e-05


 65%|██████▌   | 6539/10000 [00:30<00:16, 214.80it/s]

Train step 6500 Loss 1.7616805e-05


 66%|██████▋   | 6625/10000 [00:30<00:18, 183.37it/s]

Train step 6600 Loss 1.8399214e-05


 67%|██████▋   | 6736/10000 [00:31<00:16, 194.84it/s]

Train step 6700 Loss 1.873966e-05


 68%|██████▊   | 6826/10000 [00:31<00:17, 185.86it/s]

Train step 6800 Loss 1.8912964e-05


 69%|██████▉   | 6939/10000 [00:32<00:14, 212.67it/s]

Train step 6900 Loss 1.8238103e-05


 70%|███████   | 7028/10000 [00:32<00:13, 212.43it/s]

Train step 7000 Loss 1.8276749e-05


 71%|███████▏  | 7130/10000 [00:33<00:11, 243.28it/s]

Train step 7100 Loss 1.885458e-05


 72%|███████▏  | 7226/10000 [00:33<00:16, 169.44it/s]

Train step 7200 Loss 1.8413943e-05


 73%|███████▎  | 7348/10000 [00:34<00:11, 227.98it/s]

Train step 7300 Loss 1.8736857e-05


 74%|███████▍  | 7421/10000 [00:34<00:14, 178.92it/s]

Train step 7400 Loss 1.8658373e-05


 75%|███████▌  | 7512/10000 [00:35<00:12, 193.65it/s]

Train step 7500 Loss 1.8736937e-05


 77%|███████▋  | 7655/10000 [00:35<00:09, 238.15it/s]

Train step 7600 Loss 1.8006864e-05


 77%|███████▋  | 7731/10000 [00:36<00:09, 238.13it/s]

Train step 7700 Loss 1.8187095e-05


 78%|███████▊  | 7836/10000 [00:36<00:08, 253.80it/s]

Train step 7800 Loss 1.8164925e-05


 79%|███████▉  | 7935/10000 [00:37<00:10, 195.85it/s]

Train step 7900 Loss 1.845243e-05


 80%|████████  | 8029/10000 [00:37<00:09, 201.38it/s]

Train step 8000 Loss 1.9455363e-05


 81%|████████▏ | 8136/10000 [00:38<00:09, 198.51it/s]

Train step 8100 Loss 1.8083303e-05


 82%|████████▏ | 8236/10000 [00:38<00:09, 182.41it/s]

Train step 8200 Loss 1.8558281e-05


 83%|████████▎ | 8334/10000 [00:39<00:07, 215.86it/s]

Train step 8300 Loss 1.8723033e-05


 84%|████████▍ | 8427/10000 [00:39<00:07, 209.58it/s]

Train step 8400 Loss 1.8314362e-05


 85%|████████▌ | 8526/10000 [00:40<00:08, 180.45it/s]

Train step 8500 Loss 1.8099629e-05


 86%|████████▋ | 8643/10000 [00:40<00:06, 222.47it/s]

Train step 8600 Loss 1.8583849e-05


 87%|████████▋ | 8738/10000 [00:41<00:05, 217.86it/s]

Train step 8700 Loss 1.8469302e-05


 88%|████████▊ | 8830/10000 [00:41<00:05, 211.89it/s]

Train step 8800 Loss 1.8160452e-05


 89%|████████▉ | 8918/10000 [00:42<00:05, 180.74it/s]

Train step 8900 Loss 1.82797e-05


 90%|█████████ | 9034/10000 [00:42<00:04, 202.23it/s]

Train step 9000 Loss 1.8196148e-05


 91%|█████████▏| 9144/10000 [00:43<00:03, 217.97it/s]

Train step 9100 Loss 1.8676392e-05


 92%|█████████▏| 9232/10000 [00:43<00:03, 212.29it/s]

Train step 9200 Loss 1.7814125e-05


 93%|█████████▎| 9337/10000 [00:44<00:03, 189.39it/s]

Train step 9300 Loss 1.8325576e-05


 94%|█████████▍| 9446/10000 [00:44<00:02, 213.02it/s]

Train step 9400 Loss 1.8487115e-05


 95%|█████████▌| 9532/10000 [00:44<00:02, 201.59it/s]

Train step 9500 Loss 1.7818133e-05


 96%|█████████▌| 9613/10000 [00:45<00:02, 160.93it/s]

Train step 9600 Loss 1.8158336e-05


 97%|█████████▋| 9715/10000 [00:46<00:01, 188.47it/s]

Train step 9700 Loss 1.8392893e-05


 98%|█████████▊| 9817/10000 [00:46<00:00, 197.89it/s]

Train step 9800 Loss 1.8634928e-05


 99%|█████████▉| 9930/10000 [00:47<00:00, 174.17it/s]

Train step 9900 Loss 1.8179519e-05


100%|██████████| 10000/10000 [00:47<00:00, 209.70it/s]


In [17]:
jnp.save('crafter_captioner.npy', captioner_params)

In [50]:
lparams = jnp.load('crafter_captioner.npy', allow_pickle=True)