In [None]:
import torch

from customize_minigrid.custom_env import CustomEnv
import matplotlib.pyplot as plt
import numpy as np
from mdp_learner import OneHotEncodingMDPLearner
from customize_minigrid.wrappers import FullyObsSB3MLPWrapper
from mdp_graph.mdp_graph import MDPGraph, PolicyGraph, OptimalPolicyGraph

In [None]:
file_path = r"./maps/short_corridor.txt"
env = CustomEnv(
        txt_file_path=file_path,
        rand_gen_shape=None,
        display_size=None,
        display_mode="middle",
        agent_start_pos=(1, 1),
        agent_start_dir=0,
        random_rotate=False,
        random_flip=False,
        custom_mission="Find the key and open the door.",
        render_mode=None,
    )
env.reset()
img = env.get_frame(False, env.tile_size, env.agent_pov)
plt.imshow(img)
plt.show()

In [None]:
from customize_minigrid.wrappers import FullyObsImageWrapper

env = FullyObsImageWrapper(env, to_print=False)
learner = OneHotEncodingMDPLearner(env)
learner.learn()

In [None]:
optimal_graph = OptimalPolicyGraph()
optimal_graph.load_graph(learner.mdp_graph)
optimal_graph.uniform_prior_policy()
optimal_graph.visualize(highlight_states=[learner.start_state, *learner.done_states], use_grid_layout=False, display_state_name=False)

In [None]:
optimal_graph.optimal_value_iteration(0.999, threshold=1e-5)
optimal_graph.compute_optimal_policy(0.999, threshold=1e-5)
optimal_graph.control_info_iteration(1.0, threshold=1e-5)
optimal_graph.value_iteration(1.0, threshold=1e-5)

In [None]:
optimal_graph.visualize_policy_and_values(title="Policy and Values", value_type="value",
                                              highlight_states=[learner.start_state, *learner.done_states],
                                              use_grid_layout=False, display_state_name=False)

In [None]:
optimal_graph.visualize_policy_and_values(title="Policy and Control Info", 
                                          value_type="control_info",
                                          highlight_states=[learner.start_state, *learner.done_states],
                                          use_grid_layout=False, display_state_name=False)

In [None]:
# from stable_baselines3 import PPO
# 
# load_path = r'experiments/mazes-bin-16/run0/saved_models/saved_model_latest.zip'
# model = PPO.load(load_path, env=env)
# feature_model = model.policy.features_extractor
# feature_model.binary_output = True
# feature_model.to("cpu")

In [None]:
# import torch
# from binary_state_representation.binary2binaryautoencoder import Binary2BinaryFeatureNet
# from minigrid_abstract_encoding import EncodingMDPLearner
# 
# 
# device = "cuda" if torch.cuda.is_available() else "cpu"
# 
# NUM_ACTIONS = int(env.action_space.n)
# OBS_SPACE = int(env.total_features)
# LATENT_DIMS = 24
# print(NUM_ACTIONS, OBS_SPACE, LATENT_DIMS)
# 
# # train hyperparams
# WEIGHTS = {'inv': 1.0, 'dis': 1.0, 'neighbour': 0.0, 'dec': 0.0, 'rwd': 0.1, 'terminate': 1.0}
# BATCH_SIZE = 32
# LR = 1e-4
# 
# model = Binary2BinaryFeatureNet(NUM_ACTIONS, OBS_SPACE, n_latent_dims=LATENT_DIMS, lr=LR, weights=WEIGHTS, device=device, )
# 
# model.load(r'experiments/learn_feature_corridor_24/model_epoch_200.pth')
# model.use_bin = True
# 
# encoder = model.encoder.to(device)

In [None]:
# from minigrid_abstract_encoding import EncodingMDPLearner
# 
# learner = EncodingMDPLearner(env, feature_model, torch.device("cpu"), keep_dims=feature_model.features_dim)
# learner.learn()
# optimal_graph = OptimalPolicyGraph()
# optimal_graph.load_graph(learner.mdp_graph)
# optimal_graph.uniform_prior_policy()
# optimal_graph.visualize(highlight_states=[learner.encoded_start_state, *learner.encoded_done_states], use_grid_layout=False, display_state_name=True)
# 
# max_cols = 3
# for state in learner.encoded_state_set:
#     images = learner.encoded_state_to_unencoded_state_dict[state]
#     num_images = len(images)
#     num_rows = (num_images + max_cols - 1) // max_cols
#     num_cols = min(num_images, max_cols)
#     plt.figure(figsize=(num_cols * 5, num_rows * 5))
#     for i, unencoded in enumerate(images):
#         img = learner.unencoded_state_image_dict[unencoded]
#         plt.subplot(num_rows, num_cols, i + 1)
#         plt.imshow(img)
#         plt.axis('off')
#     plt.suptitle(state)
#     plt.tight_layout()
#     plt.show()

In [None]:
# from minigrid_abstract_encoding import EncodingMDPLearner
# from stable_baselines3 import PPO
# env = CustomEnv(
#         txt_file_path=f'./maps/7-1.txt',
#         rand_gen_shape=None,
#         display_size=7,
#         display_mode="middle",
#         agent_start_pos=(5, 5),
#         agent_start_dir=0,
#         random_rotate=False,
#         random_flip=False,
#         custom_mission="Find the key and open the door.",
#         render_mode=None,
#     )
# env = FullyObsImageWrapper(env, to_print=False)
# load_path = r'experiments/mazes-bin-16/run0/saved_models/saved_model_latest.zip'
# model = PPO.load(load_path, env=env)
# feature_model = model.policy.features_extractor
# feature_model.binary_output = True
# feature_model.to("cpu")
# encoding_learner = EncodingMDPLearner(env, feature_model, torch.device("cpu"), keep_dims=feature_model.features_dim)
# encoding_learner.learn()
# optimal_graph = OptimalPolicyGraph()
# optimal_graph.load_graph(encoding_learner.mdp_graph)
# optimal_graph.uniform_prior_policy()
# optimal_graph.visualize(highlight_states=[encoding_learner.encoded_start_state, *encoding_learner.encoded_done_states], use_grid_layout=False, display_state_name=True)

In [None]:
from minigrid_abstract_encoding import EncodingMDPLearner
from stable_baselines3 import PPO

file_paths = [
    f'./maps/7-{i}.txt' for i in range(1, 7)
]
env_list = []
for file_path in file_paths:
    env = CustomEnv(
        txt_file_path=file_path,
        rand_gen_shape=None,
        display_size=7,
        display_mode="middle",
        agent_start_pos=(5, 5),
        agent_start_dir=0,
        random_rotate=False,
        random_flip=False,
        custom_mission="Find the key and open the door.",
        render_mode=None,
    )
    env.reset()
    img = env.get_frame(False, env.tile_size, env.agent_pov)
    plt.imshow(img)
    plt.show()
    env = FullyObsImageWrapper(env, to_print=False)
    env_list.append(env)

load_path = r'experiments/mazes-bin-8/run0/saved_models/saved_model_latest.zip'
model = PPO.load(load_path, env=env)
feature_model = model.policy.features_extractor
feature_model.binary_output = True
feature_model.to("cpu")

encoded_state_set = set()
image_dict = dict()
for env in env_list:
    one_hot_learner = OneHotEncodingMDPLearner(env)
    one_hot_learner.learn()
    optimal_graph = OptimalPolicyGraph()
    optimal_graph.load_graph(one_hot_learner.mdp_graph)
    optimal_graph.uniform_prior_policy()
    optimal_graph.visualize(highlight_states=[one_hot_learner.start_state, *one_hot_learner.done_states], use_grid_layout=False, display_state_name=False)
    
    encoding_learner = EncodingMDPLearner(env, feature_model, torch.device("cpu"), keep_dims=feature_model.features_dim)
    encoding_learner.learn()
    optimal_graph = OptimalPolicyGraph()
    optimal_graph.load_graph(encoding_learner.mdp_graph)
    optimal_graph.uniform_prior_policy()
    optimal_graph.visualize(highlight_states=[encoding_learner.encoded_start_state, *encoding_learner.encoded_done_states], use_grid_layout=False, display_state_name=True)

    encoded_state_set = encoded_state_set | set(encoding_learner.encoded_state_set)
    for state in encoding_learner.encoded_state_set:
        for image_code in encoding_learner.encoded_state_to_unencoded_state_dict[state]:
            # img_ = encoding_learner.unencoded_state_image_dict[image_code]
            if state not in image_dict.keys():
                image_dict[state] = [encoding_learner.unencoded_state_image_dict[image_code]]
            else:
                image_dict[state].append(encoding_learner.unencoded_state_image_dict[image_code])

for state in image_dict.keys():
    imgs = image_dict[state]
    num_images = len(imgs)
    num_rows = (num_images + 4) // 5
    fig, axes = plt.subplots(num_rows, min(num_images, 5), figsize=(15, 3 * num_rows))
    fig.suptitle(f"State: {state}")

    if num_images <= 5:
        axes = np.array([axes])

    for i, ax in enumerate(axes.flat):
        if i < num_images:
            ax.imshow(imgs[i])
            ax.axis('off')
        else:
            ax.axis('off')

    plt.show()
