In [1]:
import torch
import joblib
import numpy as np
from top_down_env import CraftaxTopDownEnv

Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.


In [2]:
class Standardizer:
    def __init__(self):
        self.mean = None
        self.std = None

    def fit(self, X):
        self.mean = X.mean(axis=0, keepdims=True)
        self.std = X.std(axis=0, keepdims=True) + 1e-8
        return self

    def transform(self, X):
        return (X - self.mean) / self.std

    def fit_transform(self, X):
        return self.fit(X).transform(X)

class PolicyMLP(torch.nn.Module):
    def __init__(self, d_in, n_actions=16, hidden_sizes=(384, 252), p_drop=0.1):
        super().__init__()
        layers = [torch.nn.LayerNorm(d_in)]
        dim = d_in
        for h in hidden_sizes:
            layers += [torch.nn.Linear(dim, h), torch.nn.GELU(), torch.nn.Dropout(p_drop)]
            dim = h
        self.backbone = torch.nn.Sequential(*layers)
        self.head = torch.nn.Linear(dim, n_actions)

    def forward(self, x):
        return self.head(self.backbone(x))
    
    

In [3]:
scaler = joblib.load("Traces/Test/skill_models/scaler_wood.pkl")
d_in = scaler.mean.shape[1]  
model = PolicyMLP(d_in=d_in, n_actions=16, hidden_sizes=(384, 252))
model.load_state_dict(torch.load("Traces/Test/skill_models/model_wood.pt", map_location="cpu"))
model.eval()

PolicyMLP(
  (backbone): Sequential(
    (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=512, out_features=384, bias=True)
    (2): GELU(approximate='none')
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=384, out_features=252, bias=True)
    (5): GELU(approximate='none')
    (6): Dropout(p=0.1, inplace=False)
  )
  (head): Linear(in_features=252, out_features=16, bias=True)
)

In [4]:
artifacts = joblib.load("Traces/Test/pca_model.joblib")
pca_scaler = artifacts['scaler']   # StandardScaler(with_std=False)
pca = artifacts['pca']         # PCA(n_components=512)
n_features_expected = pca_scaler.mean_.shape[0]
print(f"Loaded model: PCA comps={pca.n_components_}, expected features={n_features_expected}")

Loaded model: PCA comps=512, expected features=225228


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [5]:
def predict_action(obs):

    img = np.asarray(obs, dtype=np.float32)  # shape (h, w, c)

    imgs = img[None, ...]
    n, h, w, c = imgs.shape
    X = imgs.reshape(n, -1)


    X_centered = pca_scaler.transform(X)
    X_feats = pca.transform(X_centered)


    state_std = scaler.transform(X_feats.reshape(1, -1))  # shape (1, 512)
    x = torch.from_numpy(state_std).float()
    with torch.no_grad():
        logits = model(x)                 # shape (1, n_actions)
        action = logits.argmax(dim=1).item()   # integer action
    return action

In [6]:
env = CraftaxTopDownEnv()
obs = env.reset() #(274, 274, 3) image
done = False
total_reward = 0

all_obs = [obs.copy()]

for _ in range(25):
    action = predict_action(obs)
    obs, reward, done, info = env.step(action)
    all_obs.append(obs.copy())
    total_reward += reward


INFO:2025-09-11 19:20:38,601:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/damionharvey/miniconda3/envs/hisd/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)
2025-09-11 19:20:38,601 - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/damionharvey/miniconda3/envs/hisd/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)


Loading Craftax-Classic textures from cache.
Textures successfully loaded from cache.


In [7]:
import imageio

# Save observations as a GIF
frames = [(np.clip(f, 0, 1) * 255).astype(np.uint8) for f in all_obs]
imageio.mimsave("craftax_run.gif", frames, fps=5)