In [1]:
import jax
import jax.numpy as jnp
import numpy as onp
import chex
from icecream import ic
from sentence_transformers import SentenceTransformer
from functools import partial
from typing import Optional, Tuple, Union, Any

from craftax.craftax_classic.constants import *
from gymnax.environments import environment, spaces
from sklearn.manifold import TSNE

from crafter_constants import blocks_labels, mobs_labels, inventory_labels

def embedding_crafter(embedding_model, next_obs):
    batch_size = next_obs.shape[0]
    maps, metadata = jnp.split(next_obs, [7 * 9 * 21], axis=1)
    
    # For each block/mob type, count how many are in frame
    maps = jnp.reshape(maps, [-1, 7, 9, 21])
    maps = jnp.transpose(maps, [0, 3, 1, 2])
    maps = jnp.reshape(maps, [-1, 21, 7 * 9])
    maps = maps.sum(axis=2)
    maps = jnp.round(maps).astype(jnp.int32)
    blocks, mobs = jnp.split(maps, [17], axis=1)

    # Extract and format metadata
    inventory, intrinsics, direction, light_level, is_sleeping = jnp.split(metadata, [12, 16, 20, 21], axis=1)
    inventory = jnp.round(inventory * 10.0).astype(jnp.int32)
    intrinsics *= 10.0

    health, food, drink, energy = jnp.split(intrinsics, [1, 2, 3], axis=1)

    # Construct embedding
    sentences = []
    for i in range(batch_size):
        blocks_gt0 = jnp.argwhere(blocks[i] > 0).reshape(-1)
        mobs_gt0 = jnp.argwhere(mobs[i] > 0).reshape(-1)
        inventory_gt0 = jnp.argwhere(inventory[i] > 0).reshape(-1)

        blocks_str = 'You see {}.'.format(', '.join([blocks_labels[b] for b in blocks_gt0]))
        mobs_str = 'You see {}.'.format(', '.join([mobs_labels[m] for m in mobs_gt0]))
        inventory_str = 'You have in your inventory {}.'.format(', '.join([inventory_labels[i] for i in inventory_gt0]))

        status = []
        if food[i].item() < 10.0 - 1e-4:
            status.append("hungry")
        if drink[i].item() < 10.0 - 1e-4:
            status.append("thirsty")
        if energy[i].item() < 10.0 - 1e-4:
            status.append("tired")
        status_str = 'You feel {}.'.format(', '.join(status))

        if health[i].item() < 5.0 - 1e-4:
            health_str = 'You are at low health.'
        elif health[i].item() < 10.0 - 1e-4:
            health_str = 'You are at moderate health.'
        else:
            health_str = 'You are at full health.'

        desc = [status_str, health_str]
        if len(blocks_gt0) > 0:
            desc.append(blocks_str)
        if len(mobs_gt0) > 0:
            desc.append(mobs_str)
        if len(inventory_gt0) > 0:
            desc.append(inventory_str)

        desc = ' '.join(desc)
        sentences.append(desc)

    embeddings = embedding_model.encode(sentences)
    embeddings = jnp.array(embeddings)
    return embeddings, sentences

  from .autonotebook import tqdm as notebook_tqdm
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated CUDA installation found.
Version JAX was built against: 11080
Minimum supported: 12010
Installed version: 11080
The local installation version must be no lower than 12010.
--------------------------------------------------
Outdated cuBLAS installation found.
Version JAX was built against: 111103
Minimum supported: 120100
Installed version: 111103
The local installation version must be no lower than 120100.
--------------------------------------------------
Outdated cuSPARSE installation found.
Version JAX was built against: 11705
Minimum supported: 12100
Installed version: 11705
The local installation version must be no lower than 12100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

