In [1]:
import torch
import gym
import numpy as np

from gym_atom_array.env import ArrayEnv, Config

## Specify exp name and model version

In [2]:
wandb_name = "run-20230213_191538-i6xmt1q1"
model_version = "final"

model_path = f"wandb/{wandb_name}/files/agent-{model_version}.pt"

## Configure environment and load model

In [3]:
from argparse import Namespace
from clean_agents.ppo import make_env
from clean_agents.networks import MaskedAgent as Agent

args = Namespace(Render=False, TargetSize=4, ArraySize=6, DefaultPenalty=-0.1, TargetPickUp=-5, TargetRelease=10, TimeLimit=200)
envs = gym.vector.SyncVectorEnv(
    [make_env(1, args),]
)
device = torch.device("cpu")

In [4]:
agent = Agent(envs, device)
state_dict = torch.load(model_path)
agent.load_state_dict(state_dict)

agent.eval()

MaskedAgent(
  (extractor): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
    (4): Flatten(start_dim=1, end_dim=-1)
    (5): Linear(in_features=512, out_features=64, bias=True)
    (6): ReLU()
  )
  (critic): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (3): ReLU()
      (4): Flatten(start_dim=1, end_dim=-1)
      (5): Linear(in_features=512, out_features=64, bias=True)
      (6): ReLU()
    )
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): Tanh()
    (3): Linear(in_features=64, out_features=32, bias=True)
    (4): Tanh()
    (5): Linear(in_features=32, out_features=1, bias=True)
  )
  (actor): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padd

# Comparison

In [13]:
def eval_agent(rl_agent, next_obs):
    next_obs = torch.Tensor(next_obs)

    num_moves, moves_length = 0, 0
    with torch.no_grad():
        holding = False
        
        i = 0
        done = False
        while not done and i < args.TimeLimit:
            action, logprob, _, _ = rl_agent.get_action_and_value(next_obs)
            action = action.cpu().numpy()

            next_obs, reward, done, info = envs.step(action)
            next_obs = torch.Tensor(next_obs)
            i += 1
            
            act = action[0]
            if holding and act == 5:
                holding = False
                num_moves += 1                
            
            if holding:
                moves_length += 1
    
            if not holding and act == 4:
                holding = True
    
    if not done:
        print("Timed out")
    return num_moves, moves_length

In [14]:
from classic_algos.LSAP import LSAPPlanner
from classic_algos.ASA import ASAPlanner

def get_total_len(moves):
    s = 0
    for start, end in moves:
        s += abs(start[0] - end[0]) + abs(start[1] - end[1])
    return s

In [16]:
next_obs = envs.reset()
envs.envs[0].render()

n_size = envs.single_observation_space.shape[1]
atom_grid, tar_grid, _ = next_obs[0]

atoms, targets = [], []
for r in range(n_size):
    for c in range(n_size):
        if atom_grid[r][c] == 1:
            atoms.append((r, c))
        if tar_grid[r][c] == 1:
            targets.append((r, c))

lsap_agent = LSAPPlanner(n_size, n_size, targets, "cityblock")
moves, num_moves = lsap_agent.get_moves(atoms)
moves_length = get_total_len(moves)
print(f"LSAP: {num_moves}, {moves_length}")

asa_agent = ASAPlanner(n_size, n_size, targets)
moves, num_moves = asa_agent.get_moves(atoms)
moves_length = get_total_len(moves)
print(f"ASA: {num_moves}, {moves_length}")

num_moves, moves_length = eval_agent(agent, next_obs)
print(f"RL: {num_moves}, {moves_length}")

------------------------------
  1    0    1    1    1    0  

  1   [0]  [1]  [1]  [1]   1  

  1   [0]  [0]  [1]  [1]   1  

  1   [0]  [1]  [0]  [0]   0  

  0   [0]  [0]  [0]  [1]   1  

  0    0    1    1    1    0  
  ↑                                
------------------------------
LSAP: 11, 16
ASA: 12, 16
RL: 9, 138


In [42]:
moves

[((3, 0), (3, 1)),
 ((3, 5), (3, 4)),
 ((5, 3), (4, 3)),
 ((5, 4), (4, 4)),
 ((1, 2), (2, 2)),
 ((0, 2), (1, 2)),
 ((0, 3), (1, 4))]

# Start visualization

In [16]:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

In [17]:
%matplotlib notebook

In [18]:
def obs_to_plots(obs, atoms_plt, mt_plt):
    atom_grid, tar_grid, mt_grid = obs[0]
    
    dots = [[], []]
    mt_pos = (0, 0)
    for i in range(n):
        for j in range(n):
            if mt_grid[i, j] != 0:
                mt_pos = (i, j, mt_grid[i, j])
            if atom_grid[i, j] == 1:
                dots[0].append(i)
                dots[1].append(j)
    print(obs)
    atoms_plt.set_data(dots)
    mt_plt.set_data(mt_pos[0], mt_pos[1])
    mt_plt.set_marker('x' if mt_pos[2] == 1 else 'o')
    
    return atoms_plt, mt_plt

In [None]:
fig, ax = plt.subplots()

n = args.ArraySize
ax.set_xlim(-1, n)
ax.set_ylim(-1, n)

obs = envs.reset()
atom_grid, tar_grid, mt_grid = obs[0]
dots = [[], []]
for i in range(n):
    for j in range(n):
        if tar_grid[i, j] == 1:
            dots[0].append(i)
            dots[1].append(j)

targets, = ax.plot(dots[0], dots[1], 'gs', markersize=12, markerfacecolor=(1, 1, 0, 0.5))
atoms, = ax.plot([], [], 'bo')
mt, = ax.plot(0, 0, 'ro', markersize=10)

# obs_to_plots(obs, atoms, mt)

next_obs = torch.Tensor(obs)
done = False
def animate(frame_num):    
    if done[0]:
        ax.set_title("boobler")
        return None
    
    ax.set_title("boobler")
    with torch.no_grad():
        action, logprob, _, _ = agent.get_action_and_value(next_obs)
        
    next_obs, reward, done, info = envs.step(action.cpu().numpy())
    print(next_obs)
    atoms, mt = obs_to_plots(next_obs, atoms, mt)
    
    next_obs = torch.Tensor(next_obs)
    done = [done,]
    return (atoms, mt)

anim = FuncAnimation(fig, animate, frames=5, interval=1000, blit=True)

plt.show()

## Masking testing

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [21]:
obs1 = torch.tensor([[[0,1,1,1,1],[0,1,1,1,1],[1,0,1,0,1],[0,1,1,0,0],[1,0,1,1,0]],[[0,0,0,0,0],[0,1,1,1,0],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]],[[0,0,0,0,0],[0,0,0,0,1],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]]])
obs2 = torch.tensor([[[0,1,1,1,1],[0,1,1,1,1],[1,0,1,0,1],[0,1,1,0,0],[1,0,1,1,0]],[[0,0,0,0,0],[0,1,1,1,0],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]],[[0,0,0,0,0],[0,0,0,0,2],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]]])

batch = torch.tensor((obs1.numpy(), obs2.numpy()))

  batch = torch.tensor((obs1.numpy(), obs2.numpy()))


In [22]:
kernel = torch.tensor([[[[0, 1000, 0],[10, 0, 1],[0, 100, 0]],  [[0, 0, 0],[0, 0, 0],[0, 0, 0]],  [[0, 0, 0],[0, 10000, 0],[0, 0, 0]]]])


In [39]:
has_atoms = torch.amax(batch, (1, 2, 3))

padded = F.pad(batch, (1,)*4, value=2)
detect = F.conv2d(padded, kernel)
inter_masks = torch.amax(detect, (1, 2, 3))

inter_masks = inter_masks + inter_masks * (has_atoms == 2)
masks_ = torch.column_stack((
    inter_masks % 10000 < 2000,
    inter_masks % 1000 < 200,
    inter_masks % 100 < 20,
    inter_masks % 10 < 2,
    has_atoms == 1,
    has_atoms == 2
))

In [40]:
masks_

tensor([[ True,  True,  True, False,  True, False],
        [False, False, False, False, False,  True]])

In [38]:
masks_.T

tensor([[ True,  True,  True, False,  True, False],
        [False, False, False, False, False,  True]])

In [129]:
has_atoms

tensor([1, 2])

In [116]:
obs

tensor([[[0, 1, 1, 1, 1],
         [0, 1, 1, 1, 1],
         [1, 0, 1, 0, 1],
         [0, 1, 1, 0, 0],
         [1, 0, 1, 1, 0]],

        [[0, 0, 0, 0, 0],
         [0, 1, 1, 1, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]]])

In [60]:
torch.amax(obs, 2)

tensor([[1, 1, 1, 1, 1],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0]])

In [77]:
batch = torch.tensor((obs.numpy(), obs.numpy()*2))
has_atom = torch.amax(batch, (1, 2, 3))
has_atom == 1, has_atom == 2

(tensor([ True, False]), tensor([False,  True]))

In [111]:
batch.shape

torch.Size([2, 3, 5, 5])