Navigation Menu

Skip to content

Commit

Permalink
haved seq length, double batch size, now like in paper
Browse files Browse the repository at this point in the history
  • Loading branch information
pbecker93 committed Jun 24, 2019
1 parent 10ad6f0 commit 616493f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions rkn/pendulum_image_imputation.py
Expand Up @@ -63,8 +63,8 @@ def build_decoder_hidden(self):
seed=0,
pendulum_params=pend_params)

train_obs, train_obs_valid, train_targets = generate_imputation_data_set(data, 1000, 150, seed=42)
test_obs, test_obs_valid, test_targets = generate_imputation_data_set(data, 250, 150, seed=23541)
train_obs, train_obs_valid, train_targets = generate_imputation_data_set(data, 2000, 75, seed=42)
test_obs, test_obs_valid, test_targets = generate_imputation_data_set(data, 1000, 75, seed=23541)

# Build Model
rkn = PendulumImageImputationRKN(observation_shape=train_obs.shape[-3:], latent_observation_dim=15,
Expand All @@ -74,7 +74,7 @@ def build_decoder_hidden(self):

# Train Model
rkn.fit((train_obs, train_obs_valid),
train_targets, batch_size=25, epochs=1000,
train_targets, batch_size=50, epochs=500,
validation_data=((test_obs, test_obs_valid), test_targets))


Expand Down
6 changes: 3 additions & 3 deletions rkn/pendulum_state_estimation.py
Expand Up @@ -47,16 +47,16 @@ def build_var_decoder_hidden(self):
seed=0,
pendulum_params=pend_params)

train_obs, train_targets = generate_pendulum_filter_dataset(data, 1000, 150, 42)
test_obs, test_targets = generate_pendulum_filter_dataset(data, 250, 150, 12312)
train_obs, train_targets = generate_pendulum_filter_dataset(data, 2000, 75, np.random.randint(100000000))
test_obs, test_targets = generate_pendulum_filter_dataset(data, 1000, 75, np.random.randint(10000000))

# Build Model
rkn = PendulumStateEstemRKN(observation_shape=train_obs.shape[-3:], latent_observation_dim=15, output_dim=2, num_basis=15,
bandwidth=3, never_invalid=True)
rkn.compile(optimizer=k.optimizers.Adam(clipnorm=5.0), loss=rkn.gaussian_nll, metrics=[rkn.rmse])

# Train Model
rkn.fit(train_obs, train_targets, batch_size=25, epochs=1000, validation_data=(test_obs, test_targets))
rkn.fit(train_obs, train_targets, batch_size=50, epochs=500, validation_data=(test_obs, test_targets), verbose=2)



Expand Down

0 comments on commit 616493f

Please sign in to comment.