In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import math
import numpy as np
import torch as T
import matplotlib.pyplot as plt

In [3]:
dataset_path = "datasets/MiniGrid-SimpleCrossingS11N5-v0-CrossingS11N5_A2C_Fullgrid-10.pickle"

In [4]:
with open(dataset_path, 'rb') as handle:
    data = pickle.load(handle)

In [5]:
data[9][3]['obs'].keys()

dict_keys(['image', 'direction', 'mission', 'partial_image', 'pixels', 'partial_pixels'])

In [6]:
data[0].keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34])

In [7]:
print([data[0][i]['action'] for i in range(len(data[0]))])

[2, 2, 2, 2, 2, 4, 5, 6, 2, 5, 5, 1, 2, 5, 4, 2, 3, 0, 2, 1, 2, 5, 2, 5, 2, 2, 0, 4, 2, 2, 1, 2, 3, 3, 2]


In [8]:
data[0][11]['direction']

0

In [9]:
data[0][12]['direction']

1

In [10]:
np.allclose(data[0][11]['obs']['partial_image'], data[0][12]['obs']['partial_image'])

False

In [11]:
data[0][11]['obs']['image'][:,:,0]

array([[ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
       [ 2,  1,  2,  1,  2,  1,  1,  1,  2,  1,  2],
       [ 2,  1,  2,  1,  2,  1,  1,  1,  2,  1,  2],
       [ 2,  1,  2,  1,  2,  1,  1,  1,  2,  1,  2],
       [ 2,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2],
       [ 2,  1,  2,  1,  2,  1,  1,  1,  2,  1,  2],
       [ 2,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2],
       [ 2, 10,  1,  1,  2,  1,  1,  1,  2,  1,  2],
       [ 2,  1,  2,  1,  1,  1,  1,  1,  2,  1,  2],
       [ 2,  1,  2,  1,  2,  1,  1,  1,  1,  8,  2],
       [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2]], dtype=uint8)

In [None]:
 data[0][12]['obs']['image'][:,:,0]

In [None]:
 data[0][12]['location']

In [None]:
T.Tensor([data[episode][step]['direction'] for episode in range(len(data)) for step in range(len(data[episode]))]).shape

In [None]:
support_trajectories = T.Tensor(np.array([data[episode][step]['obs']['partial_pixels'] for episode in range(len(data)) for step in range(len(data[episode]))]))
support_targets = T.tensor(np.array([episode for episode in range(len(data)) for step in range(len(data[episode]))]))

In [None]:
for target in support_targets.unique():
    trajectory=support_trajectories[support_targets==target]
    print(trajectory.shape)

In [None]:
from generative_contrastive_modelling.gcm_encoder import GCMEncoder
from generative_contrastive_modelling.gcm import GenerativeContrastiveModelling

GCM = GenerativeContrastiveModelling(support_trajectories.shape, 64, 64)
gcm_encoder = GCMEncoder(support_trajectories.shape, 64, 64)

In [None]:
gcm_encoder.forward(support_trajectories)

In [None]:
from generative_contrastive_modelling.proto_encoder import ProtoEncoder
from generative_contrastive_modelling.protonet import PrototypicalNetwork

proto_encoder = ProtoEncoder(support_trajectories.shape, 64, 64)
protonet = PrototypicalNetwork(support_trajectories.shape, 64, 64)

In [None]:
proto_encoder.forward(support_trajectories)

In [None]:
indices = T.randperm(support_trajectories.shape[0])[:5]
query_observations = support_trajectories[indices]
query_targets = support_targets[indices]

In [None]:
query_targets

In [None]:
observations=T.cat([support_trajectories, query_observations], dim=0)

In [None]:
GCM.compute_loss(support_trajectories=support_trajectories, support_targets=support_targets, query_observations=query_observations, query_targets=query_targets)

In [None]:
protonet.compute_loss(support_trajectories=support_trajectories, support_targets=support_targets, query_observations=query_observations, query_targets=query_targets)

In [None]:
support_targets.squeeze_(0).shape

In [None]:
support_targets.shape

In [None]:
data[0][11]['obs']['partial_pixels'].shape

In [None]:
plt.imshow(data[0][11]['obs']['partial_pixels'].transpose(1,2,0))

In [None]:
for i in range(len(data[0])):
    plt.imshow(data[0][i]['obs']['partial_pixels'].transpose(1,2,0))
    plt.show()

In [None]:
import torch.nn.functional as F
plt.imshow(F.interpolate(T.Tensor(data[0][11]['obs']['pixels'].transpose(2,1,0)), size=(11,11))/255.0)

In [None]:
data[0][11]['obs']['partial_pixels'].transpose(2,1,0).shape

In [None]:
from generate_trajectories import generate_data

In [None]:
import minigrid_rl_starter.utils as utils
# Load environment

env = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 0)
for _ in range(0):
    env.reset()
env_copy = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 0)
for _ in range(0):
    env_copy.reset()
print("Environment loaded\n")

# Load agent

trained_model_dir = utils.get_model_dir("CrossingS11N5_A2C_Fullgrid", storage_dir="minigrid_rl_starter")
exploratory_model_dir = utils.get_model_dir("CrossingS11N5_A2C_state_bonus_fullgrid", storage_dir="minigrid_rl_starter")

trained_agent = utils.Agent(
    env.observation_space,
    env.action_space,
    trained_model_dir,
    argmax=True,
    use_memory=False,
    use_text=False,
)
print("Trained agent loaded\n")

exploratory_agent = utils.Agent(
    env.observation_space,
    env.action_space,
    exploratory_model_dir,
    argmax=True,
    use_memory=False,
    use_text=False,
)
print("Exploratory agent loaded\n")

In [None]:
train_dataset = generate_data(env=env, agent=trained_agent, episodes=5, render=False)

In [None]:
query_dataset = generate_data(env=env_copy, agent=exploratory_agent, episodes=5, render=False)

In [None]:
query_trajectories = T.Tensor(
            np.array(
                [
                    query_dataset[episode][step]["obs"]["partial_pixels"]
                    for episode in range(len(query_dataset))
                    for step in range(len(query_dataset[episode]))
                ]
            )
        )
query_targets = T.tensor(
            np.array(
                [
                    episode
                    for episode in range(len(query_dataset))
                    for step in range(len(query_dataset[episode]))
                ]
            )
        )
indices = T.randperm(query_trajectories.shape[0])[: 5]
query_observations = query_trajectories[indices]
query_targets = query_targets[indices]

In [None]:
for i in range(query_observations.shape[0]):
    plt.imshow(query_observations[i].permute(1,2,0)/255.0)
    plt.show()

In [None]:
query_targets

In [None]:
T.randperm(4)

In [None]:
query_environments = T.Tensor(
            np.array(
                [
                    query_dataset[episode][step]["obs"]["pixels"]
                    for episode in range(len(query_dataset))
                    for step in range(len(query_dataset[episode]))
                ]
            )
        )
query_environments = query_environments[indices]

In [None]:
for i in range(query_environments.shape[0]):
    plt.imshow(query_environments[i].permute(1,2,0)/255.0)
    plt.show()

In [None]:
train_directions=T.Tensor(np.array(
                [
                    train_dataset[episode][step]["direction"]
                    for episode in range(len(train_dataset))
                    for step in range(len(train_dataset[episode]))
                ]
 )).to(T.int64)

query_directions=T.Tensor(np.array(
                [
                    query_dataset[episode][step]["direction"]
                    for episode in range(len(query_dataset))
                    for step in range(len(query_dataset[episode]))
                ]
 )).to(T.int64)

In [None]:
query_environments_resampled=F.interpolate(query_environments, size=(56, 56))
query_environments_resampled=query_environments_resampled[:,:,2::5,2::5]
print(query_environments_resampled.shape)
plt.imshow(query_environments_resampled[0].permute(1,2,0)/255.0)

In [None]:
import torchvision.transforms as transforms
resize_cropper = transforms.RandomResizedCrop(size=(352,352))
cropped_environments = resize_cropper(query_environments.unsqueeze(0))
plt.imshow(resize_cropper(cropped_environments[1]).permute(1,2,0)/255.0)
print(query_environments.shape)

In [None]:
directions = T.Tensor(
    np.array(
        [
            query_dataset[episode][step]["direction"]
            for episode in range(len(query_dataset))
            for step in range(len(query_dataset[episode]))
        ]
    )
).to(T.int64)

In [None]:
indices = T.randperm(directions.shape[0])[:5]
directions[directions!=directions[indices]]

In [None]:
from process_trajectories import data_to_tensors, sample_views

In [None]:
train_trajectory=data_to_tensors(train_dataset)

In [None]:
indices, query_views= sample_views(train_trajectory,5)

In [None]:
count=0
for view in query_views['observations']:
    for i in range(len(train_trajectory['observations'])):
        if T.allclose(train_trajectory['observations'][i],view):
            count+=1
print(count)
print(len(train_trajectory['observations']))

In [None]:
for i in range(500,569):
    plt.imshow(train_trajectory['observations'][i].transpose(2,0)/255.0)
    plt.show()

In [None]:
plt.imshow(query_views['observations'][0].transpose(2,0)/255.0)

In [None]:
T.allclose(query_views['observations'][0], train_trajectory['observations'][366])

In [None]:
support_obs = T.Tensor(np.array(
                [
                    train_dataset[episode][step]["obs"]["partial_pixels"] 
                    for episode in range(len(train_dataset)) 
                    for step in range(len(train_dataset[episode])) if step%5!=0
                ]
            )
        )

In [None]:
support_obs.shape

In [None]:
query_idx = np.array([i for i in range(len(support_obs)) if i%5==0])

In [None]:
support_idx = np.array([i for i in range(len(support_obs)) if i%5!=0])

In [None]:
len(support_idx)

In [None]:
len(query_idx)

In [None]:
indices = T.randperm(support_obs.shape[0])[:5]
indices

In [None]:
remaining_indices = T.tensor([i for i in range(support_obs.shape[0]) if i not in indices])

In [None]:
len(remaining_indices)

In [None]:
query_views, support_trajectories = sample_views(train_trajectory,100)

In [None]:
query_views['observations'].shape

In [None]:
for i in range(query_views['observations'].shape[0]):
    for j in range(support_trajectories['observations'].shape[0]):
        if T.allclose(query_views['observations'][i], support_trajectories['observations'][j]):
            print('ok')

In [None]:
import minigrid_rl_starter.utils as utils
# Load environment

env = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 0)
for _ in range(0):
    env.reset()
env_copy = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 0)
for _ in range(0):
    env_copy.reset()
print("Environment loaded\n")

# Load agent

trained_model_dir = utils.get_model_dir("CrossingS11N5_A2C_fullgrid_navigation", storage_dir="minigrid_rl_starter")
exploratory_model_dir = utils.get_model_dir("CrossingS11N5_A2C_fullgrid_navigation_state_bonus", storage_dir="minigrid_rl_starter")

trained_agent = utils.Agent(
    env.observation_space,
    env.action_space,
    trained_model_dir,
    argmax=True,
    use_memory=False,
    use_text=False,
)
print("Trained agent loaded\n")

exploratory_agent = utils.Agent(
    env.observation_space,
    env.action_space,
    exploratory_model_dir,
    argmax=True,
    use_memory=False,
    use_text=False,
)
print("Exploratory agent loaded\n")

In [None]:
train_dataset = generate_data(env=env, agent=trained_agent, episodes=5, render=False)

In [None]:
query_dataset = generate_data(env=env_copy, agent=exploratory_agent, episodes=5, render=False)

In [None]:
plt.imshow(query_dataset[0][14]['obs']['pixels'].transpose(1,2,0))

In [None]:
query_dataset[0][14]['direction']

In [None]:
import torch.nn.functional as F
F.one_hot(T.Tensor([2]).to(T.int64),4)

In [None]:
query_trajectories = data_to_tensors(query_dataset)

In [None]:
print(query_dataset[0][4]['direction'])
plt.imshow(query_dataset[0][4]['obs']['pixels'].transpose(1,2,0)/255.0)

In [None]:
plt.imshow(orientated_trajectories['observations'][4].permute(1,2,0)/255.0)
plt.show()
plt.imshow(T.rot90(query_trajectories['observations'][4], 2, [1,2]).permute(1,2,0)/255.0)

In [None]:
def orientate_observations(trajectories):
    for i in range(len(trajectories['directions'])):
        if trajectories['directions'][i][0]==1:
            trajectories['observations'][i] = T.rot90(trajectories['observations'][i], -1, [1,2])
        elif trajectories['directions'][i][1]==1:
            trajectories['observations'][i] = T.rot90(trajectories['observations'][i], 2, [1,2])
        elif trajectories['directions'][i][2]==1:
            trajectories['observations'][i] = T.rot90(trajectories['observations'][i], 1, [1,2])
        elif trajectories['directions'][i][3]==1:
            pass
    return trajectories

In [None]:
orientated_trajectories = orientate_observations(query_trajectories)

In [None]:
orientated_trajectories['observations'][3]

In [None]:
env = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 0)
for _ in range(0):
    env.reset()
env_copy = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 0)
for _ in range(0):
    env_copy.reset()
print("Environment loaded\n")

In [None]:
plt.imshow(env.render("rgb_array"))

In [None]:
plt.imshow(env_copy.render("rgb_array"))

In [None]:
env = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 0)
env.reset()

In [None]:
env = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 1)
env_copy = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 1)
env_copy.reset()
env.reset()
np.allclose(env.reset()['image'], env_copy.reset()['image'])

In [None]:
env.reset()

In [None]:
import wandb
api = wandb.Api()

run_exp = api.run("adamjelley/gen-con-rl/2on3lyxa")
run_sin= api.run("adamjelley/gen-con-rl/tkclcaor")

run_exp.config = run_sin.config
run_exp.config["exploratory_agent"] = "CrossingS11N5_A2C_fullgrid_navigation_state_bonus"
run_exp.update()

In [None]:
import torch as T
def euclidian_distances(prototypes, embeddings):
    distances = T.sum(
        (prototypes.unsqueeze(2) - embeddings.unsqueeze(1)) ** 2, dim=-1
    )
    return distances

In [None]:
proto=T.rand(1,10,128)
embeddings=T.rand(1,100,128)
euclidian_distances(proto, embeddings).shape

In [None]:
import wandb
import torch as T
import os
model_path = os.path.join()
checkpoint = T.load(f"/Users/ajelley/Projects/gen-con-rl/saved_models/GCM_Location_Direction_Exploratory_de09eko4/checkpoint.pt"
)

In [None]:
wandb.restore("checkpoint.pt", run_path='adamjelley/gen-con-rl/de09eko4').name
T.load('/Users/ajelley/Projects/gen-con-rl/checkpoint.pt')

In [None]:
api = wandb.Api()
run_path = 'adamjelley/gen-con-rl/de09eko4'
run_id = run_path.split('/')[-1]
run=api.run(run_path)
checkpoint=run.file("checkpoint.pt").download(root=f'./wandb/saved_checkpoints/{run_id}')

In [None]:
run_path = 'adamjelley/gen-con-rl/de09eko4'
run_id = run_path.split('/')[-1]
model_save_path = f'./wandb/saved_checkpoints/{run_id}/checkpoint.pt'
if os.path.exists(model_save_path):
    checkpoint = T.load(model_save_path)
else:
    api = wandb.Api()
    run=api.run(run_path)
    checkpoint=run.file("checkpoint.pt").download(root=f'./wandb/saved_checkpoints/{run_id}')
checkpoint['learner_state_dict']

In [None]:
shape = (3, 32, 32)
T.prod(shape)

In [None]:
plt.imshow(query_environments[i].permute(1,2,0)/255.0)