In [8]:
# import torch
import joblib
import numpy as np
import os 
import os
import numpy as np
import argparse
# from sklearn.svm import OneClassSVM
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
# from sklearn.model_selection import train_test_split
# from sklearn.metrics import precision_recall_fscore_support, classification_report
# from sklearn.metrics import confusion_matrix
import joblib
from skill_helpers import * 
import sys 

from top_down_env import CraftaxTopDownEnv
import os
import gzip
import pickle
from pathlib import Path
from tqdm import tqdm
import numpy as np
import joblib
import argparse
import imageio

In [9]:
import torch
import torch.nn.functional as F
import numpy as np

# --- must match your training definitions ---
class ImageNormalizer:
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean, dtype=torch.float32).view(3,1,1)
        self.std  = torch.clamp(torch.tensor(std, dtype=torch.float32).view(3,1,1), min=1e-3)
    def __call__(self, x):  # x: [3,H,W] in [0,1]
        return (x - self.mean) / self.std

class ConvBlock(torch.nn.Module):
    def __init__(self, c_in, c_out, k=3, s=1, p=1):
        super().__init__()
        self.conv = torch.nn.Conv2d(c_in, c_out, kernel_size=k, stride=s, padding=p, bias=False)
        self.bn   = torch.nn.BatchNorm2d(c_out)  # or GroupNorm if you switched
        self.act  = torch.nn.GELU()
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

class PolicyCNN(torch.nn.Module):
    def __init__(self, n_actions=16):
        super().__init__()
        self.stem = torch.nn.Sequential(
            ConvBlock(3, 32, k=7, s=2, p=3),
            ConvBlock(32, 32),
            torch.nn.MaxPool2d(2),
        )
        self.stage2 = torch.nn.Sequential(
            ConvBlock(32, 64),
            ConvBlock(64, 64),
            torch.nn.MaxPool2d(2),
        )
        self.stage3 = torch.nn.Sequential(
            ConvBlock(64, 128),
            ConvBlock(128, 128),
            torch.nn.MaxPool2d(2),
        )
        self.stage4 = torch.nn.Sequential(
            ConvBlock(128, 256),
            ConvBlock(256, 256),
        )
        self.head = torch.nn.Linear(256, n_actions)

    def forward(self, x):
        x = self.stem(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = torch.flatten(x, 1)
        return self.head(x)

# ---- inference helpers ----

def load_policy(ckpt_path, device=None):
    """Load model + normalizer from a saved training checkpoint."""
    device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ckpt = torch.load(ckpt_path, map_location=device)
    n_actions = int(ckpt['n_actions'])
    model = PolicyCNN(n_actions=n_actions).to(device)
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    normalizer = ImageNormalizer(ckpt['mean'], ckpt['std'])
    return model, normalizer, device, n_actions

def preprocess_frame(frame_hw3, normalizer, target=256):
    """
    frame_hw3: numpy array [H,W,3], float32 in [0,1]
    returns torch tensor [1,3,target,target]
    """
    assert frame_hw3.ndim == 3 and frame_hw3.shape[2] == 3
    x = torch.from_numpy(np.transpose(frame_hw3, (2,0,1))).float()   # [3,H,W]
    x = F.interpolate(x.unsqueeze(0), size=(target, target), mode='bilinear', align_corners=False).squeeze(0)  # [3,T,T]
    x = normalizer(x)
    return x.unsqueeze(0)  # [1,3,T,T]

@torch.no_grad()
def act_greedy(model, normalizer, device, frame_hw3):
    """
    Returns (action_id, probs) where probs is a numpy array length n_actions.
    """
    x = preprocess_frame(frame_hw3, normalizer)            # [1,3,256,256]
    x = x.to(device)
    logits = model(x)                                      # [1,n_actions]
    probs = torch.softmax(logits, dim=-1).squeeze(0)       # [n_actions]
    action = int(torch.argmax(probs).item())
    return action, probs.cpu().numpy()

@torch.no_grad()
def act_sample(model, normalizer, device, frame_hw3, temperature=1.0):
    x = preprocess_frame(frame_hw3, normalizer).to(device)
    logits = model(x).squeeze(0)
    if temperature != 1.0:
        logits = logits / max(1e-6, float(temperature))
    probs = torch.softmax(logits, dim=-1)
    action = int(torch.multinomial(probs, num_samples=1).item())
    return action, probs.cpu().numpy()

In [10]:
artifacts = joblib.load('Traces/stone_pickaxe_easy/pca_models/pca_model_750.joblib')
scaler = artifacts['scaler']   # StandardScaler(with_std=False)
pca = artifacts['pca']         # PCA(n_components=512)
n_features_expected = scaler.mean_.shape[0]

def get_pca_feat(obs):
    arr = np.asarray(obs, dtype=np.float32)

    # flatten into shape (1, features)
    X = arr.reshape(1, -1)

    # --- Apply scaler + PCA ---
    X_centered = scaler.transform(X)
    X_feats = pca.transform(X_centered)

    return X_feats

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [11]:
skill_list = ['wood', 'stone', 'wood_pickaxe', 'stone_pickaxe', 'table']
bc_models = {}

for skill in skill_list:
    ckpt_path = os.path.join('Traces/stone_pickaxe_easy', 'bc_checkpoints', f'{skill}_policy_cnn.pt')
    bc_models[skill] = load_policy(ckpt_path)

In [12]:
pu_start_models = load_pu_models('Traces/stone_pickaxe_easy/pu_start_models')

In [13]:
# env = CraftaxTopDownEnv(seed=1)
# obs = env.reset()
# done = False
# total_reward = 0

# all_obs = [obs.copy()]
# all_actions = []

# end_models_path = 'Traces/stone_pickaxe_easy/pu_end_models'

# MAX_STEPS = 100
# MAX_STEPS_PER_SKILL = 200  # optional guard; tune or remove
# t = 0

# def pick_executable_skill(pu_start_models, pca_obs, end_models_path):
#     """
#     Return the first candidate skill whose end condition is NOT true on pca_obs.
#     If no candidate is executable, return None.
#     """
#     candidates = applicable_pu_start_models(
#         pu_start_models, pca_obs, return_details=False
#     )
#     if not candidates:
#         return None

#     # Iterate candidates in the provided (assumed ranked) order
#     for cand in candidates:
#         is_end = end_state_prob_pu(end_models_path, cand, pca_obs)['is_end']
#         if not is_end:
#             return cand

#     # All candidates already 'ended' for current state
#     return None

# while t < MAX_STEPS and not done:
#     # 1) Choose a skill that is actually executable from the current state
#     pca_obs = get_pca_feat(obs)
#     skill_name = pick_executable_skill(pu_start_models, pca_obs, end_models_path)

#     if skill_name is None:
#         print("No executable skill found; breaking to avoid spin.")
#         break

#     # 2) Execute the chosen skill until its end condition fires (or step caps/done)
#     skill_steps = 0
#     while t < MAX_STEPS and not done:
#         print("Executing skill:", skill_name)
#         print("Step:", t)

#         pca_obs = get_pca_feat(obs)
#         if end_state_prob_pu(end_models_path, skill_name, pca_obs)['is_end']:
#             print("Skill ended:", skill_name)
#             break

#         # Retrieve policy components for this skill
#         model, normalizer, device, n_actions = bc_models[skill_name]

#         # IMPORTANT: Ensure act_greedy expects 'obs' (raw image) vs transformed features.
#         # If it needs normalized PCA features, replace 'obs' with that here.
#         action, probs = act_greedy(model, normalizer, device, obs)

#         # Step the environment
#         obs, reward, done, info = env.step(action)
#         all_obs.append(obs.copy())
#         all_actions.append(action)
#         total_reward += reward

#         t += 1
#         skill_steps += 1

#         # Optional per-skill safety cap
#         if skill_steps >= MAX_STEPS_PER_SKILL:
#             print(f"Per-skill step cap hit for '{skill_name}'. Re-selecting skill.")
#             break

# print("Total steps:", t, "Total reward:", total_reward)
# # Save observations as a GIF
# frames = [(np.clip(f, 0, 1) * 255).astype(np.uint8) for f in all_obs]
# imageio.mimsave(f"craftax_run_auto_{1}_{total_reward}.gif", frames, fps=5)

In [15]:
# Assumes these exist:
# - CraftaxTopDownEnv, bc_models, act_greedy
# - (Optional) get_pca_feat if your policy wants features instead of raw obs

seeds = [6, 24, 3214, 32, 321, 4324, 165, 333]

for seed in seeds: 
    env = CraftaxTopDownEnv(seed=seed)
    obs = env.reset()
    done = False
    total_reward = 0

    all_obs = [obs.copy()]
    all_actions = []

    # Fixed skill schedule (8 steps each)
    skill_order = [
        "wood",
        "wood",
        "table",
        "wood",
        "wood_pickaxe",
        "stone",
        "wood",
        "stone_pickaxe",
    ]
    STEPS_PER_SKILL = 8
    MAX_STEPS = 120  # optional global cap

    t = 0

    for skill_name in skill_order:
        if done or t >= MAX_STEPS:
            break

        # Make sure we have a policy for this skill
        if skill_name not in bc_models:
            print(f"[WARN] No bc_model found for '{skill_name}'. Skipping.")
            continue

        model, normalizer, device, n_actions = bc_models[skill_name]

        print(f"\n=== Executing skill: {skill_name} ===")
        steps_this_skill = 0

        while (steps_this_skill < STEPS_PER_SKILL) and (t < MAX_STEPS) and (not done):
            print("Skill:", skill_name, "| Step:", t)

            # If your policy expects features, replace 'obs' with the appropriate transform:
            # feats = get_pca_feat(obs)              # if needed
            # action, probs = act_greedy(model, normalizer, device, feats)
            action, probs = act_greedy(model, normalizer, device, obs)

            obs, reward, done, info = env.step(action)
            all_obs.append(obs.copy())
            all_actions.append(action)
            total_reward += reward

            steps_this_skill += 1
            t += 1

    print("\nTotal steps:", t, "Total reward:", total_reward)



    # Save observations as a GIF
    frames = [(np.clip(f, 0, 1) * 255).astype(np.uint8) for f in all_obs]
    imageio.mimsave(f"craftax_run_test_seed_{seed}_{total_reward}.gif", frames, fps=5)


=== Executing skill: wood ===
Skill: wood | Step: 0
Skill: wood | Step: 1
Skill: wood | Step: 2
Skill: wood | Step: 3
Skill: wood | Step: 4
Skill: wood | Step: 5
Skill: wood | Step: 6
Skill: wood | Step: 7

=== Executing skill: wood ===
Skill: wood | Step: 8
Skill: wood | Step: 9
Skill: wood | Step: 10
Skill: wood | Step: 11
Skill: wood | Step: 12
Skill: wood | Step: 13
Skill: wood | Step: 14
Skill: wood | Step: 15

=== Executing skill: table ===
Skill: table | Step: 16
Skill: table | Step: 17
Skill: table | Step: 18
Skill: table | Step: 19
Skill: table | Step: 20
Skill: table | Step: 21
Skill: table | Step: 22
Skill: table | Step: 23

=== Executing skill: wood ===
Skill: wood | Step: 24
Skill: wood | Step: 25
Skill: wood | Step: 26
Skill: wood | Step: 27
Skill: wood | Step: 28
Skill: wood | Step: 29
Skill: wood | Step: 30
Skill: wood | Step: 31

=== Executing skill: wood_pickaxe ===
Skill: wood_pickaxe | Step: 32
Skill: wood_pickaxe | Step: 33
Skill: wood_pickaxe | Step: 34
Skill: wo