In [1]:
import argparse
import os
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from prae.plots import plot_rn, plot_actions, plot_vertices
%matplotlib notebook

In [2]:
def visualize(agg_states, rewards, actions, colors=None):
    """
    """
    fig, ax = plot_rn()
    actioned_states = [agg_states]
    for a, v_a in enumerate(actions):
        new_states = agg_states + v_a
        actioned_states.append(new_states)

    all_vertices = np.concatenate(actioned_states, axis=0)

    pca = PCA(n_components=3)

    all_vertices = pca.fit_transform(all_vertices)

    n_v = agg_states.shape[0]
    vertices = all_vertices[:n_v]

    ax = plot_actions(ax, vertices, all_vertices, n_v, actions)
    
    if colors is None:
        colors = rewards


    ax, fig = plot_vertices(vertices, fig=fig, ax=ax, colors=colors)
    return ax, fig


In [3]:
from prae.runner import Runner
from prae.helpers import set_seeds, make_env
from argparse import Namespace

args = Namespace()

args.log_dir = "f_trans"
args.data_folder = "f_trans"

args.n_episodes = 1000
args.model_train_epochs = 100
args.n_neg_samples = 5
args.batch_size = 1024
args.batch_size_sample = 1024
args.lr = 0.001
args.hinge = 1.0
args.env = "fashion_translations"
args.z_dim = 50
args.trans_tau = 1.0
args.prune_off = True

args.n_sweeps = 500
args.gamma = 0.9
args.objects = 1
args.test_goals = False
args.test_set = False
args.seed = 0
args.cpu = False

set_seeds(int(args.seed))                                                   
env = make_env(args) 

In [4]:
runner = Runner(env, args)                                                  
runner.loop() 

Sampling train trajectories 0 to 1000
Sampling valid trajectories 0 to 200
Number of samples=67534
Number of samples=13687
Train epoch 0
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 2.216708208575393 6.476146697998047 s
valid loss 0.9580818116664886█████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 1
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.7812604010105133 6.258250713348389 s
valid loss 0.593166879245213██████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 2
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.6247406945084081

valid loss 0.4144691697188786█████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 24
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.41366702860051935 6.344847917556763 s
valid loss 0.404595639024462██████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 25
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.40442109694986633 6.406969785690308 s
valid loss 0.40705567811216625████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 26
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.398624

valid loss 0.40125509032181333████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 48
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.3739172098311511 6.4260780811309814 s
valid loss 0.4091938180582864█████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 49
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.37625462223183026 6.498592376708984 s
valid loss 0.4084123394318989█████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 50
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.372229

valid loss 0.41316630371979307████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 72
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.3645683517961791 6.471348285675049 s
valid loss 0.4063133810247694█████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 73
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.3617570662137234 6.353683233261108 s
valid loss 0.4072771945169994█████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 74
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.36307725

valid loss 0.42181268334388733████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 96
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.3604183833707463 6.478039741516113 s
valid loss 0.41340359832559315████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 97
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.35860414351477765 6.404382944107056 s
valid loss 0.4051208049058914█████████████████████████████████████████████████████████████████████████--------| 92.9% complete
Train epoch 98
Training 822969 params
progress |██████████████████████████████████████████████████████████████████████████████████████████████████--| 98.5% complete
 train loss 0.3563266

In [5]:
from prae.evaluator import plan
from prae.helpers import load_abstract_mdp, get_model, load_network

epoch = 99

network = get_model(env.action_space.n, args)
network = load_network(network, f"runs/{args.log_dir}/{args.seed}", epoch)

In [6]:
log_path = f"runs/{args.log_dir}/{args.seed}/{epoch}"
mdp = load_abstract_mdp(os.path.join("runs", args.log_dir, str(args.seed)), epoch, network, args.trans_tau, args.gamma)
s_str = "a_states.npy"
r_str = "a_rewards.npy"
a_str = "a_actions.npy"

  torch_latents = torch.tensor(latent_states).float().to(device)
  torch_latents = torch.tensor(latent_states).float().to(device)


In [7]:
from prae.evaluator import dump_epoch_mdp
env.reset()
goal_state = env.get_goal_state()
plans = plan(network, mdp, goal_state, env.action_space.n, args)
values = plans[0].detach().cpu()
qvalues = plans[1].detach().cpu()

dump_epoch_mdp(log_path, mdp, values, qvalues)

states = np.load(os.path.join(log_path, s_str),
                 allow_pickle=True)
rewards = np.load(os.path.join(log_path, r_str),
                 allow_pickle=True)[:, 0]
actions = np.load(os.path.join(log_path, a_str),
                  allow_pickle=True)

In [9]:
# Angled plot to show the overall grid structure is recovered; but not perfectly
visualize(states, rewards, actions, colors=values)

<IPython.core.display.Javascript object>

(<matplotlib.axes._subplots.Axes3DSubplot at 0x7f604edfa1c0>,
 <Figure size 640x480 with 2 Axes>)

In [10]:
from prae.evaluator import Evaluator
args.start_epochs = 0
args.model_train_epochs = 100
args.eval_eps = 100
args.eta = 1e-20

In [11]:
evaluator = Evaluator(env, args)                                            
returns, lens = evaluator.loop()

  torch_latents = torch.tensor(latent_states).float().to(device)
  torch_latents = torch.tensor(latent_states).float().to(device)


Epoch 0, return: 0.5606870027910001, length: 22.02
Epoch 1, return: 0.6276220119700002, length: 13.37
Epoch 2, return: 0.6669916581710001, length: 12.84
Epoch 3, return: 0.6796226903900003, length: 7.47
Epoch 4, return: 0.6650735050610003, length: 6.87
Epoch 5, return: 0.6761721750900004, length: 7.51
Epoch 6, return: 0.6766818434519004, length: 6.67
Epoch 7, return: 0.6832116817910004, length: 4.89
Epoch 8, return: 0.6923038695020005, length: 5.66
Epoch 9, return: 0.6862855307910006, length: 4.85
Epoch 10, return: 0.6794407229830004, length: 5.81
Epoch 11, return: 0.7198993372710004, length: 6.14
Epoch 12, return: 0.6910866686, length: 6.47
Epoch 13, return: 0.6777288163710006, length: 5.85
Epoch 14, return: 0.6730714669800002, length: 5.84
Epoch 15, return: 0.7100890226400002, length: 4.55
Epoch 16, return: 0.6928746321800001, length: 5.63
Epoch 17, return: 0.6796002119900002, length: 8.3
Epoch 18, return: 0.6833365370700004, length: 4.88
Epoch 19, return: 0.6680964074950003, length:

In [12]:
fig, ax = plt.subplots(1)
ax.plot(lens, label="train")
ax.set_title("Performance on train goals")
ax.set_ylabel("Average episode length")
ax.set_xlabel("Number of training epochs")

<IPython.core.display.Javascript object>

Text(0.5, 0, 'Number of training epochs')

In [13]:
args.test_goals = True
env = make_env(args)

In [14]:
evaluator = Evaluator(env, args)                                            
returns_test, lens_test = evaluator.loop()

Epoch 0, return: 0.5276002350900001, length: 27.56
Epoch 1, return: 0.6345916653610003, length: 12.48
Epoch 2, return: 0.6873084920800006, length: 8.26
Epoch 3, return: 0.6627216872710002, length: 9.49
Epoch 4, return: 0.6733388769900006, length: 7.56
Epoch 5, return: 0.6745066178010002, length: 6.69
Epoch 6, return: 0.6753275535519002, length: 8.46
Epoch 7, return: 0.6833079943800004, length: 5.76
Epoch 8, return: 0.6862344989900002, length: 6.49
Epoch 9, return: 0.6915847786810003, length: 4.79
Epoch 10, return: 0.6731048983030005, length: 6.73
Epoch 11, return: 0.6725989720429005, length: 5.93
Epoch 12, return: 0.7055324596619001, length: 5.45
Epoch 13, return: 0.7084256324200005, length: 5.48
Epoch 14, return: 0.7276624264810003, length: 6.02
Epoch 15, return: 0.7265533027910003, length: 5.21
Epoch 16, return: 0.6960068927619004, length: 4.72
Epoch 17, return: 0.7229171663910005, length: 4.32
Epoch 18, return: 0.6705896342410003, length: 6.01
Epoch 19, return: 0.6980891863419003, l

In [15]:
fig, ax = plt.subplots(1)
ax.plot(lens_test, label="test")
ax.set_title("Performance on test goals")
ax.set_ylabel("Average episode length")
ax.set_xlabel("Number of training epochs")

<IPython.core.display.Javascript object>

Text(0.5, 0, 'Number of training epochs')

In [16]:
args.test_set = True
args.test_goals = False
env = make_env(args)

In [17]:
evaluator = Evaluator(env, args)                                            
returns_test_set, lens_test_set = evaluator.loop()

Epoch 0, return: 0.6219914504619003, length: 20.32
Epoch 1, return: 0.6912615913900004, length: 9.05
Epoch 2, return: 0.6759577869920004, length: 9.26
Epoch 3, return: 0.7144277772610004, length: 5.33
Epoch 4, return: 0.6555533514910004, length: 11.22
Epoch 5, return: 0.6731436132509003, length: 5.85
Epoch 6, return: 0.6971519609820004, length: 4.71
Epoch 7, return: 0.6911565397820003, length: 6.5
Epoch 8, return: 0.6783141949309004, length: 8.41
Epoch 9, return: 0.7025571033910005, length: 5.46
Epoch 10, return: 0.6811421936920002, length: 5.78
Epoch 11, return: 0.6969814149920005, length: 4.68
Epoch 12, return: 0.6463717616319004, length: 8.01
Epoch 13, return: 0.6624300191800002, length: 7.73
Epoch 14, return: 0.6974829419920003, length: 4.7
Epoch 15, return: 0.6676821167330005, length: 5.15
Epoch 16, return: 0.6931384598710004, length: 5.68
Epoch 17, return: 0.6946162655810006, length: 4.75
Epoch 18, return: 0.7013265101020003, length: 5.49
Epoch 19, return: 0.6939553534710003, len

In [18]:
fig, ax = plt.subplots(1)
ax.plot(lens_test_set, label="test")
ax.set_title("Performance on test set images")
ax.set_ylabel("Average episode length")
ax.set_xlabel("Number of training epochs")

<IPython.core.display.Javascript object>

Text(0.5, 0, 'Number of training epochs')

In [19]:
args.test_set = True
args.test_goals = True
env = make_env(args)

In [20]:
evaluator = Evaluator(env, args)                                            
returns_test_test, lens_test_test = evaluator.loop()

Epoch 0, return: 0.5014817526810004, length: 29.67
Epoch 1, return: 0.6365988144620003, length: 11.57
Epoch 2, return: 0.6254945737519005, length: 10.0
Epoch 3, return: 0.6816185332509004, length: 4.88
Epoch 4, return: 0.6906142126720003, length: 6.53
Epoch 5, return: 0.6688099507427002, length: 6.87
Epoch 6, return: 0.7132520924000003, length: 5.31
Epoch 7, return: 0.6819677090820002, length: 5.81
Epoch 8, return: 0.6895309437800008, length: 4.76
Epoch 9, return: 0.6685275793810003, length: 5.94
Epoch 10, return: 0.7108944515000004, length: 4.51
Epoch 11, return: 0.7077555591900005, length: 4.49
Epoch 12, return: 0.6981773786810004, length: 4.72
Epoch 13, return: 0.6876888731820001, length: 4.9
Epoch 14, return: 0.7231584523810005, length: 4.37
Epoch 15, return: 0.6739844357610008, length: 5.09
Epoch 16, return: 0.7021292531800003, length: 4.6
Epoch 17, return: 0.6627038819700006, length: 6.92
Epoch 18, return: 0.7168896665810006, length: 4.44
Epoch 19, return: 0.6990556676209002, len

In [21]:
fig, ax = plt.subplots(1)
ax.plot(lens_test_test, label="test")
ax.set_title("Performance on test set images, test goals")
ax.set_ylabel("Average episode length")
ax.set_xlabel("Number of training epochs")

<IPython.core.display.Javascript object>

Text(0.5, 0, 'Number of training epochs')