In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch as T
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt

In [3]:
from generative_contrastive_modelling.gcm import GenerativeContrastiveModelling
from generate_trajectories import generate_data
from process_trajectories import data_to_tensors, complete_observation_data_to_tensors

ImportError: cannot import name 'complete_observation_data_to_tensors' from 'process_trajectories' (/disk/scratch1/adamjelley/gen-con-rl/process_trajectories.py)

In [None]:
learner = GenerativeContrastiveModelling(
            (3, 56, 56),
            128,
            128,
            True,
            True,
        )

In [None]:
learner.encoder.load_checkpoint('/Users/ajelley/Projects/gen-con-rl/wandb/run-20220705_164807-2on3lyxa/files')

In [None]:
import minigrid_rl_starter.utils as utils
# Load environment

env = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 0)
for _ in range(0):
    env.reset()
env_copy = utils.make_env("MiniGrid-SimpleCrossingS11N5-v0", 0)
for _ in range(0):
    env_copy.reset()
print("Environment loaded\n")

# Load agent

trained_model_dir = utils.get_model_dir("CrossingS11N5_A2C_fullgrid_navigation", storage_dir="minigrid_rl_starter")
exploratory_model_dir = utils.get_model_dir("CrossingS11N5_A2C_fullgrid_navigation_state_bonus", storage_dir="minigrid_rl_starter")

trained_agent = utils.Agent(
    env.observation_space,
    env.action_space,
    trained_model_dir,
    argmax=True,
    use_memory=False,
    use_text=False,
)
print("Trained agent loaded\n")

exploratory_agent = utils.Agent(
    env.observation_space,
    env.action_space,
    exploratory_model_dir,
    argmax=True,
    use_memory=False,
    use_text=False,
)
print("Exploratory agent loaded\n")

In [None]:
train_dataset = generate_data(env=env, agent=trained_agent, episodes=5, render=False)

In [None]:
train_trajectories = data_to_tensors(train_dataset)
environments = F.interpolate(T.Tensor(
            np.array([train_dataset[episode][0]["obs"]["pixels"] for episode in train_dataset])
        ), size=32)

In [None]:
means, precisions = learner.encoder.forward(train_trajectories['observations'], train_trajectories['locations'], train_trajectories['directions'])

In [None]:
means[:21].shape

In [None]:
plt.imshow(means[train_trajectories['targets']==3].detach().numpy())

In [None]:
len(train_trajectories['targets'][train_trajectories['targets']==1])

In [None]:
T.argmax(precisions[22])

In [None]:
plt.imshow(environments[2].permute(1,2,0)/255.0)

In [None]:
plt.imshow(environments[0].permute(1,2,0)/255.0)

In [None]:
from generative_contrastive_modelling.environment_decoder import EnvironmentDecoder

In [None]:
decoder = EnvironmentDecoder(128, 128, (3,32,32))
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001)

In [None]:
decoder.forward(means, precisions).shape

In [None]:
for episode in range(10000):
    train_dataset = generate_data(env=env, agent=trained_agent, episodes=5, render=False)
    train_trajectories = data_to_tensors(train_dataset)
    environments = F.interpolate(T.Tensor(
                np.array([train_dataset[episode][0]["obs"]["pixels"] for episode in train_dataset])
            ), size=32)/255.0

    means, precisions = learner.encoder.forward(train_trajectories['observations'], train_trajectories['locations'], train_trajectories['directions'])
    means=means.unsqueeze(0).detach()
    precisions=precisions.unsqueeze(0).detach()
    (env_proto_means,
    env_proto_precisions,
    log_env_proto_normalisation,
    ) = learner.inner_gaussian_product(
        means, precisions, train_trajectories['targets'].unsqueeze(0)
        )

    env_reconstructions = decoder.forward(env_proto_means.squeeze(), env_proto_precisions.squeeze())
    reconstruction_loss = F.mse_loss(env_reconstructions, environments)

    decoder_optimizer.zero_grad()
    reconstruction_loss.backward()
    decoder_optimizer.step()

    if episode%10==0:
        print(f'Episode: {episode}, \tLoss: {reconstruction_loss}')
        plt.imshow(torchvision.utils.make_grid(environments).permute(1,2,0))
        plt.show()
        plt.imshow(torchvision.utils.make_grid(env_reconstructions).permute(1,2,0))
        plt.show()

In [None]:
for episode in range(1000):
    train_dataset = generate_data(env=env, agent=trained_agent, episodes=5, render=False)
    train_trajectories = data_to_tensors(train_dataset)
    environments = F.interpolate(T.Tensor(
                np.array([train_dataset[episode][step]["obs"]["partial_pixels"] for episode in train_dataset for step in train_dataset[episode]])
            ), size=32)/255.0

    means, precisions = learner.encoder.forward(train_trajectories['observations'], train_trajectories['locations'], train_trajectories['directions'])
    means=means.unsqueeze(0).detach()
    precisions=precisions.unsqueeze(0).detach()
    # (env_proto_means,
    # env_proto_precisions,
    # log_env_proto_normalisation,
    # ) = learner.inner_gaussian_product(
    #     means, precisions, train_trajectories['targets'].unsqueeze(0)
    #     )

    env_reconstructions = decoder.forward(means.squeeze(), precisions.squeeze())
    reconstruction_loss = F.mse_loss(env_reconstructions, environments)

    decoder_optimizer.zero_grad()
    reconstruction_loss.backward()
    decoder_optimizer.step()

    if episode%10==0:
        print(f'Episode: {episode}, \tLoss: {reconstruction_loss}')
        plt.imshow(torchvision.utils.make_grid(environments[:5]).permute(1,2,0))
        plt.show()
        plt.imshow(torchvision.utils.make_grid(env_reconstructions[:5]).permute(1,2,0))
        plt.show()

In [None]:
import copy
initial_decoder = copy.deepcopy(decoder)

In [None]:
T.ones(4,100)[:,-6:].shape

In [5]:
from generative_contrastive_modelling.unsupervised_gcm import UnsupervisedGenerativeContrastiveModelling
learner = UnsupervisedGenerativeContrastiveModelling(
            input_shape=(3, 56, 56),
            hid_dim=128,
            z_dim=128
            prior_precision=0.01,
            use_location=F
            use_direction=config.use_direction,
            use_coordinates=False,
        )

In [None]:
precisions = np.logspace(-5, 5, 50)
log_likelihoods = []