# Imports and Overview

In [1]:
# !pip install tensorflow tensorflow_probability tf_agents numpy gym highway-env tqdm wandb
# Neural Network
import tensorflow as tf
from tensorflow.keras.models import Model
# Layer
from tensorflow.keras.layers import Dense, Layer, Conv2DTranspose, Conv2D, GlobalAveragePooling2D, Reshape, BatchNormalization, GRUCell, MaxPooling2D, Flatten, RNN
from tensorflow.keras.losses import CategoricalCrossentropy, KLDivergence
import tensorflow_probability as tfp 



# Buffer 
from tf_agents.replay_buffers import tf_uniform_replay_buffer

# Further support
import numpy as np
from typing import NamedTuple
from tqdm import tqdm
import wandb
wandb.init(settings=wandb.Settings(_disable_stats=True))

# Environment
import gym
import highway_env
import random





[34m[1mwandb[0m: Currently logged in as: [33mpeterkeffer[0m ([33mcogsci[0m). Use [1m`wandb login --relogin`[0m to force relogin


# Experience Replay Buffer

In [2]:
# Has to save (Observation, action, reward, terminal state)
from numpy import float32


class Buffer:

    def __init__(
        self,
        batch_size=1,
        buffer_length=1000, 
        observation_size=(128,32,1),
        action_size=1
    ):
        """
        Create replay buffer

        Buffer size = batch_size * buffer_length

        """
        # Save batch size for other functions of buffer
        # NOT the usual batch size in Deep Learning
        # Batches in Uniform Replay Buffer describe size of input added to the buffer
        self.batch_size = batch_size

        # Tell buffer what data & which size to expect
        self.data_spec = (
            tf.TensorSpec(
                shape= observation_size,
                dtype=tf.dtypes.float32,
                name="Observation"
            ),
            tf.TensorSpec(
                shape=observation_size,
                dtype=tf.dtypes.float32,
                name="Next state"
            ),
            tf.TensorSpec(
                shape=[action_size],
                dtype=tf.dtypes.float32,
                name="Action"
            ),
            tf.TensorSpec(
                # Reward size
                shape=[1, ],
                dtype=tf.dtypes.float32,
                name="Reward"
            ),
            tf.TensorSpec(
                shape=[1, ],
                # Either 0 or 1 
                dtype=tf.dtypes.float32,
                name="Non-Terminal State"
            )
        )

        # Create the buffer 
        self.buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            self.data_spec, batch_size, buffer_length
        )

    def obtain_buffer_specs(self):
        return self.data_spec

    def add(self, items):
        """
        length of items must be equal to batch size

        items: list or tuple of batched data from (50, 5)


        """
        # Combine all values from "items" in tensor
        # Not sure wether we need tf.nest.map_structure
        batched_values = tf.nest.map_structure(
            lambda t: tf.stack([t] * self.batch_size),
            items
        )
        
        # Add to batch
        self.buffer.add_batch(batched_values)

    def sample(self, batch_size, prefetch_size):
        data = self.buffer.as_dataset(single_deterministic_pass=True)

                
        # data = data.map(lambda img, target: (tf.cast(img, tf.float32), target))
        # # normalize inputs from 0/255 to -1/1
        # data = data.map(lambda img, target: ((img/128.)-1, target))
        # # create one-hot vector for targets
        # data = data.map(lambda img, target: (img, tf.one_hot(target, depth=10)))

        # normalize inputs from 0/255 to -1/1
        data = data.map(lambda buffer_content, _: (((buffer_content[0]/128.)-1, (buffer_content[1]/128.)-1, buffer_content[2], buffer_content[3], buffer_content[4]), _))
        data = data.cache()
        data = data.batch(batch_size).prefetch(prefetch_size)
        #later we want these to be sequences (Do we though)
        return data




# Environment

In [3]:
class EnvironmentInteractor:

  def __init__(self, config, buffer, environment_name = "highway-fast-v0"):
    self.config = config

    self.env = gym.make(environment_name)    
    self.env.configure(config)

    self.buffer = buffer
    # Save sizes of the stupid tensors
    self.data_spec = self.buffer.obtain_buffer_specs()
  

  

  def create_trajectories(self, iterations):
    state = self.env.reset()
    
    for _ in range(iterations):
        action = self.env.action_space.sample()
        next_state, reward, done, _ = self.env.step(action)
        
        self.buffer.add((
          tf.cast(tf.constant(state, shape=self.data_spec[0].shape.as_list()), tf.float32),
          tf.cast(tf.constant(next_state, shape=self.data_spec[1].shape.as_list()), tf.float32),
          tf.cast(tf.constant(action, shape=self.data_spec[2].shape.as_list()), tf.float32),
          tf.cast(tf.constant(reward, shape=self.data_spec[3].shape.as_list()), tf.float32),
          tf.cast(tf.constant(1-done, shape=self.data_spec[4].shape.as_list()), tf.float32)
        ))

        state = next_state
        
        if done:
          state = self.env.reset()


  def __del__(self):
    self.env.close()

# Parameters

In [4]:
# Image size
image_shape = (128,32, 1)

# Long term memory of GRU
hidden_unit_size = 200

# Z in paper
stochastic_state_shape = (32,32)
stochastic_state_size = stochastic_state_shape[0] * stochastic_state_shape[1]

#
action_size = 1
horizon = 15
discount_factor = 0.995

#
mlp_hidden_layer_size = 100
batch_size = 50

# TODO different variable names for network inp/outp sizes




In [None]:
class RSSMState(NamedTuple):
    logits: tf.Tensor = tf.zeros(shape=(stochastic_state_size,))
    stochastic_state_z: tf.Tensor = tf.zeros(shape=(stochastic_state_size,))
    hidden_rnn_state: tf.Tensor = tf.zeros(shape=(hidden_unit_size,))

    @classmethod
    def from_list(cls, rssm_states):
        logits = tf.stack([rssm_state.logits for rssm_state in rssm_states])
        stochastic_state_z = tf.stack([rssm_state.stochastic_state_z for rssm_state in rssm_states])
        hidden_rnn_state = tf.stack([rssm_state.hidden_rnn_state for rssm_state in rssm_states])

        return cls(logits, stochastic_state_z, hidden_rnn_state)

    def get_hidden_state_h_and_stochastic_state_z(self):
        hidden_state_h_and_stochastic_state_z = tf.concat([self.stochastic_state_z, self.hidden_rnn_state], axis=-1)

        return hidden_state_h_and_stochastic_state_z

    @classmethod
    def convert_sequences_to_batches(cls, rssm_state):
        logits = cls.convert_sequence_to_batch(rssm_state.logits)
        stochastic_state_z = cls.convert_sequence_to_batch(rssm_state.stochastic_state_z)
        hidden_rnn_state = cls.convert_sequence_to_batch(rssm_state.hidden_rnn_state)

        return cls(logits, stochastic_state_z, hidden_rnn_state)

    @classmethod
    def convert_sequence_to_batch(cls, sequence):
        batch = tf.reshape(sequence, (sequence.shape[0] * sequence.shape[1], *sequence.shape[2:]))
        return batch


# World model

In [1]:
class WorldModel:


    def __init__(self) -> None:
        super().__init__()

        self.encoder = self.create_encoder()
        self.decoder = self.create_decoder()
        self.reward_model = self.create_reward_predictor()
        self.discount_model = self.create_discount_predictor()
        self.actor = self.create_actor()
        self.critic = self.create_critic()
        self.target_critic = tf.keras.models.clone_model(self.critic)



    def create_encoder(self, input_size=image_shape, output_size=hidden_unit_size):
        # Third dimension might be obsolete
        encoder_input = tf.keras.Input(shape=input_size)
        x = Conv2D(16, (3, 3), activation="elu", padding="same")(encoder_input) # 16 layers of filtered 192x48 features
        x = MaxPooling2D((2, 2), padding="same")(x) # 64 / 96x24
        x = Conv2D(32, (3, 3), activation="elu", padding="same")(x) # 64 / 96x24
        x = MaxPooling2D((2, 2), padding="same")(x) # 64 / 96x24
        x = Conv2D(64, (3, 3), activation="elu", padding="same")(x) # 64 / 48x12
        x = MaxPooling2D((2, 2), padding="same")(x) # 64 / 48x12
        x = GlobalAveragePooling2D()(x) # 64
        encoder_output = Dense(output_size, activation = "elu")(x)

        encoder = tf.keras.Model(encoder_input, encoder_output, name="Encoder")

        return encoder


    # Input size = 1024(z:32x32) + 200(size of hidden state)
    # Output size = game frame
    def create_decoder(
        self, 
        input_size=stochastic_state_size + hidden_unit_size, 
        output_size=image_shape
    ):
        # Third dimension might be obsolete
        decoder_input = tf.keras.Input(shape=input_size)
        # TODO WIE SCHLIMM IST EIN MLP HIER?
        x = Dense(256, activation= "elu")(decoder_input)
        x = Reshape((32, 8, 1))(x) 
        # TODO Check whether correct reshape happens
        #tf.debugging.assert_equal(x)
        x = Conv2DTranspose(16, (3, 3), strides=2, activation="elu", padding="same")(x)
        x = BatchNormalization()(x)
        x = Conv2DTranspose(1, (3, 3), strides=2, activation="linear", padding="same")(x)
       # x = Conv2DTranspose(1, (3, 3), strides=2, activation="elu", padding="same")(x)
        x = Flatten()(x)
        # Might needs shape as Tensor  #event_shape=output_size

        # decoder_output = tfp.layers.IndependentNormal(event_shape=output_size)(x)


        decoder = tf.keras.Model(
            decoder_input,
            x,
            name="Decoder"
        )

        return decoder
    

        # Input: concatination of h and z
    # Output: float predicting the obtained reward
    def create_reward_predictor(
        self, 
        input_size=hidden_unit_size+stochastic_state_size,
        output_size=1
    ):
        reward_predictor_input = tf.keras.Input(shape=input_size)
        x = Dense(mlp_hidden_layer_size, activation="elu")(reward_predictor_input)
        x = Dense(mlp_hidden_layer_size, activation="elu")(x)
        x = Dense(mlp_hidden_layer_size)(x)
        # Creates indipendent normal distribution
        # Hope is that it learns to output variables over reward space [0,1]
        #reward_predictor_output = tfp.layers.IndependentNormal()(x)

        reward_predictor = tf.keras.Model(
            reward_predictor_input,
            x,
            name="create_reward_predictor"
        )

        return reward_predictor
    

        # Input: concatination of h and z
    # Output: float predicting the obtained reward
    def create_discount_predictor(
        self, 
        input_size=hidden_unit_size+stochastic_state_size,
        output_size=1
    ):
        discount_predictor_input = tf.keras.Input(shape=input_size)
        x = Dense(mlp_hidden_layer_size, activation="elu")(discount_predictor_input)
        x = Dense(mlp_hidden_layer_size, activation="elu")(x)
        # Create 1 output sampled from bernoulli distribution
        #discount_predictor_output = tfp.layers.IndependentBernoulli()(x)

        discount_predictor = tf.keras.Model(
            discount_predictor_input,
            x,
            name="create_discount_predictor"
        )

        return discount_predictor

    def compute_actor_critic_loss(self, posterior_rssm_state: RSSMState):

        # TODO At the moment we are using only batches and not batches of sequences
        batched_posterior_rssm_states = tf.stop_gradient(posterior_rssm_state)

        dreamed_rssm_states, dreamed_log_probabilities, dreamed_policy_entropies = self.rssm(horizon, self.actor, batched_posterior_rssm_states)

        dreamed_hidden_state_h_and_stochastic_state_z = dreamed_rssm_states.get_hidden_state_h_and_stochastic_state_z()

        # TODO models definieren self.world_list+self.value_list+[self.TargetValueModel]+[self.DiscountModel]
        self.set_trainable_models(models, False)

        reward_logits = world_model.reward_model(hidden_state_h_and_stochastic_state_z)
        reward_distribution = tfp.distributions.Independent(tfp.distributions.Normal(reward_logits, 1))
        dreamed_reward = reward_distribution.mean()

        discount_logits = world_model.discount_model(hidden_state_h_and_stochastic_state_z)
        discount_distribution = tfp.distributions.Independent(tfp.distributions.Bernoulli(logits=discount_logits))
        dreamed_discount = discount_factor * discount_distribution * tf.round(discount_distribution.prob(discount_distribution.mean()))

        target_value_logits = self.target_critic(hidden_state_h_and_stochastic_state_z)
        target_value_distribution = tfp.distributions.Independent(tfp.distributions.Normal(target_value_logits, 1))
        dreamed_value = target_value_distribution.mean()

        self.set_trainable_models(models, True)


        actor_loss, discount, lambda_returns = self.actor_loss(dreamed_reward, dreamed_value, dreamed_discount, dreamed_log_probabilities, dreamed_policy_entropies)
        critic_loss = self.critic_loss(dreamed_hidden_state_h_and_stochastic_state_z[:-1], discount, lambda_returns)

        return actor_loss, critic_loss


    def actor_loss(self, dreamed_reward, dreamed_value, dreamed_discount, dreamed_log_probabilities, dreamed_policy_entropies, actor_entropy_scale=0.001, lmbda = 0.95):
        lambda_returns = self.compute_return(dreamed_reward[:-1], dreamed_value[:-1], dreamed_discount[:-1], bootstrap=dreamed_value[-1], lmbda=lmbda)

        advantage = tf.stop_gradient(lambda_returns - dreamed_value[:-1])
        objective = dreamed_log_probabilities[1:] * advantage

        discounts = tf.concat([tf.ones_like(dreamed_discount[:1]), dreamed_discount[1:]])
        discount = tf.math.cumprod(discounts[:-1], 0)
        policy_entropy = dreamed_policy_entropies[1:]
        actor_loss = -tf.math.reduce_sum(tf.math.reduce_mean(discount * (objective + actor_entropy_scale * policy_entropy), dim=1))
        return actor_loss, discount, lambda_returns


    def critic_loss(self, dreamed_hidden_state_h_and_stochastic_state_z, discount, lambda_returns):

        critic_distribution = self.critic(tf.stop_gradient(dreamed_hidden_state_h_and_stochastic_state_z))
        critic_loss = -tf.reduce_mean(tf.stop_gradient(discount) * tf.stop_gradient(critic_distribution.log_prob(lambda_returns)))

        return critic_loss

    def compute_return(self, reward,
                    value,
                    discount,
                    bootstrap,
                    lmbda):

        next_values = tf.concat([value[1:], bootstrap[None]], 0)
        target = reward + discount + next_values * (1 - lmbda)
        timesteps = list(range(reward.shape[0] - 1, -1, -1))
        outputs = []
        accumulated_reward = bootstrap
        for timestep in timesteps:
            inp = target[timestep]
            discount_factor = discount[timestep]
            accumulated_reward = inp + discount_factor * lmbda * accumulated_reward
            outputs.append(accumulated_reward)
        returns = tf.reverse(tf.stack(outputs), [0])
        return returns



    def set_trainable_models(self, models, trainable: bool):
        for model in models:
            model.trainable = trainable

    def compute_log_loss(self, distribution, target):
        """
        Computes loss for:
        - Image log loss(Output decoder, frame timestep t)
        - Reward log loss(Output reward network, obtained reward timestep t)
        - Discount log loss(Output of discount network, terminal state timestep t)
        """
        # TODO check whether distribution.log_prob  (target) matches target size
        # histogram von wahrsch. distribution /
        return -tf.math.reduce_mean(distribution.log_prob(target))


    def compute_kl_loss(self, prior_rssm_states, posterior_rssm_states, alpha=0.8):
        """
        alpha: weigh between training the prior toward the representations & regularizing
         the representations towards the prior
        prior: Z
        posterior: Z^
        """
        prior_distribution = tfp.distributions.Independent(tfp.distributions.OneHotCategorical(logits=prior_rssm_states.logits), 1)
        posterior_distribution = tfp.distributions.Independent(tfp.distributions.OneHotCategorical(logits=posterior_rssm_states.logits), 1)

        prior_distribution_detached = tfp.distributions.Independent(tfp.distributions.OneHotCategorical(logits=tf.stop_gradient(prior_rssm_states.logits)), 1)
        posterior_distribution_detached = tfp.distributions.Independent(tfp.distributions.OneHotCategorical(logits=tf.stop_gradient(posterior_rssm_states.logits)), 1)

        # Loss with KL Balancing
        # TODO check reihenfolge, reduce_mean hat Gradients?!!?
        return alpha * tf.math.reduce_mean(tfp.distributions.kl_divergence(posterior_distribution_detached, prior_distribution)) + (1-alpha) * tf.math.reduce_mean(tfp.distributions.kl_divergence(posterior_distribution, prior_distribution_detached))




NameError: name 'image_shape' is not defined

In [5]:


class RSSM:

    def __init__(self) -> None:
        super().__init__()

        self.state_action_embedder = self.create_stochastic_state_action_embedder()
        self.rnn = self.create_rnn()
        self.prior_model = self.create_prior_stochastic_state_embedder()
        self.posterior_model = self.create_posterior_stochastic_state_embedder()

    def create_stochastic_state_action_embedder(
        self,
        input_size=(stochastic_state_size + action_size,),
        output_size=hidden_unit_size
    ):
        state_action_input = tf.keras.Input(shape=input_size)
        state_action_output = Dense(output_size, activation = "elu")(state_action_input)

        stochastic_state_action_embedder = tf.keras.Model(
            state_action_input,
            state_action_output,
            name="stochastic_state_action_embedder"
        )

        return stochastic_state_action_embedder

    # Contains GRU cell
    def create_rnn(
        self,
        input_size=(hidden_unit_size, ),
        output_size=hidden_unit_size
    ):
        return RNN(GRUCell(output_size))

        rnn_input = tf.keras.Input(shape=input_size)
       # rnn_hidden_state_placeholder = tf.keras.Input(shape=(hidden_unit_size,))
        rnn_output = rnn = tf.keras.layers.RNN(tf.keras.layers.GRUCell(output_size))(rnn_input)


        rnn = tf.keras.Model(
            rnn_input,
            rnn_output,
            name="rnn"
        )

        return rnn

    # Z^ in paper
    def create_prior_stochastic_state_embedder(
        self,
        input_size=hidden_unit_size,
        output_size=stochastic_state_size
    ):
        state_embedder_input = tf.keras.Input(shape=input_size)
        x = Dense(mlp_hidden_layer_size, activation="elu")(state_embedder_input)
        # Activation function removed
        state_embedder_output = Dense(output_size)(x)

        create_prior_stochastic_state_embedder = tf.keras.Model(
            state_embedder_input,
            state_embedder_output,
            name="create_prior_stochastic_state_embedder"
        )

        return create_prior_stochastic_state_embedder

    # Z in paper
    # Input size = concatenated output of RNN with output of CNN
    def create_posterior_stochastic_state_embedder(
        self,
        input_size=hidden_unit_size+hidden_unit_size,
        output_size=stochastic_state_size
    ):
        state_embedder_input = tf.keras.Input(shape=input_size)
        x = Dense(mlp_hidden_layer_size, activation="elu")(state_embedder_input)
        # Activation function removed
        state_embedder_output = Dense(output_size)(x)

        create_posterior_stochastic_state_embedder = tf.keras.Model(
            state_embedder_input,
            state_embedder_output,
            name="create_posterior_stochastic_state_embedder"
        )

        return create_posterior_stochastic_state_embedder

    def sample_stochastic_state(self, logits):
        """
        Gets probabilities for each element of class in each category.
        Used to generate embeddings from logits.
        """

        # Logit Outputs from MLP
        logits = tf.reshape(logits, shape=(-1, *stochastic_state_shape))
        # OneHot distribution over logits
        logits_distribution = tfp.distributions.OneHotCategorical(logits)
        # Sample from OneHot distribution
        sample = tf.cast(logits_distribution.sample(), tf.float32)
        # TODO observe logits_distribution.prob(sample) after few iterations
        sample += logits_distribution.prob(sample) - tf.stop_gradient(logits_distribution.prob(sample))

        return sample

    def dream(self, previous_rssm_state: RSSMState, previous_action: tf.Tensor, non_terminal=True):
        """
        Creates Z^
        """
        # TODO invert terminal states (terminal state = 1 if episode ended, needs to be 0)
        # Embedding of concatenation prior z and action (t-1)
        state_action_embedding = self.state_action_embedder(tf.concat([previous_rssm_state.stochastic_state_z * non_terminal, previous_action], axis=1))
        # TODO Remove Squeeze
        # Create h from GRU with old h (t-1) and the embedding
        state_action_embedding = tf.reshape(state_action_embedding, shape=(-1, 200, 1))

        hidden_rnn_state = self.rnn(state_action_embedding, previous_rssm_state.hidden_rnn_state * non_terminal)

        # Logits created from h (with MLP) to create Z^
        prior_logits = self.prior_model(hidden_rnn_state)
        # Create Z^
        prior_stochastic_state_z = self.sample_stochastic_state(prior_logits)
        # Save logits for Z^, Z^ and h
        prior_rssm_state = RSSMState(prior_logits, tf.reshape(prior_stochastic_state_z, (-1, stochastic_state_size)), hidden_rnn_state)

        return prior_rssm_state

    def dreaming_rollout(self, horizon: int, actor: tf.keras.Model, previous_rssm_state: RSSMState):
        """
        Rollout only Z
        """
        rssm_state = previous_rssm_state

        next_rssm_states = []
        action_entropies = []
        image_log_probabilities = []

        for timestep in range(horizon):
            action, action_distribution = actor(tf.stop_gradient(rssm_state.get_hidden_state_h_and_stochastic_state_z()))
            rssm_state = self.dream(rssm_state, action)

            next_rssm_states.append(rssm_state)
            action_entropies.append(action_distribution.entropy(action))
            image_log_probabilities(action_distribution.log_prob(action))

        next_rssm_states = RSSMState.from_list(next_rssm_states)
        action_entropies = tf.stack(action_entropies, dim=0)
        image_log_probabilities = tf.stack(image_log_probabilities, dim=0)

        return next_rssm_states, image_log_probabilities, action_entropies


    def observe(self, encoded_state: tf.Tensor, previous_action: tf.Tensor, previous_non_terminal: tf.Tensor, previous_rssm_state: RSSMState):
        """
        Creates Z' and Z
        """
        # Obtain Z^
        prior_rssm_state = self.dream(previous_rssm_state, previous_action, previous_non_terminal)

        # concatenates h and the output of our CNN (encoded input frame X)
        encoded_state_and_hidden_state = tf.concat([prior_rssm_state.hidden_rnn_state, encoded_state], axis=1)

        # Logits created from concat of h and encoded frame X (with MLP) to create Z
        posterior_logits = self.posterior_model(encoded_state_and_hidden_state)
        # Create Z
        posterior_stochastic_state_z = self.sample_stochastic_state(posterior_logits)
        # Saves logits for Z, Z, and h
        posterior_rssm_state = RSSMState(posterior_logits, tf.reshape(posterior_stochastic_state_z, (-1, stochastic_state_size)), prior_rssm_state.hidden_rnn_state)

        return prior_rssm_state, posterior_rssm_state

    def observing_rollout(self, encoded_states: tf.Tensor, actions: tf.Tensor, non_terminals: tf.Tensor, previous_rssm_state: RSSMState):
        prior_rssm_states = []
        posterior_rssm_states = []

        for encoded_state, action, non_terminal in zip(encoded_states, actions, non_terminals):
            # TODO remove islandsolution
            encoded_state = tf.expand_dims(encoded_state, axis=0)
            action = tf.expand_dims(action, axis=0)
            non_terminal = tf.expand_dims(non_terminal, axis=0)
            #?? 0 if terminal state is reached
            previous_action = action * non_terminal
            # Z^, Z
            prior_rssm_state, posterior_rssm_state = self.observe(encoded_state, previous_action, non_terminal, previous_rssm_state)

            # Save Z^, Z
            prior_rssm_states.append(prior_rssm_state)
            posterior_rssm_states.append(posterior_rssm_state)

            # Z for next iteration
            previous_rssm_state = posterior_rssm_state
        prior_rssm_states = RSSMState.from_list(prior_rssm_states)
        posterior_rssm_states = RSSMState.from_list(posterior_rssm_states)

        return prior_rssm_states, posterior_rssm_states


# Test Everything

In [7]:
from importlib_metadata import distribution

# TODO move hyperparams
epochs = 32

optimizer = tf.keras.optimizers.Adam(0.0002)

buffer = Buffer(batch_size=1)
config = {
        "observation": {
            "type": "GrayscaleObservation",
            "observation_shape": (128, 32),
            "stack_size": 1,
            # weights for RGB conversion
            "weights": [0.01, 0.01, 0.98],  
            "scaling": 1.5,
        },
        # was at 2
        "policy_frequency": 1 
    }

environment_interactor = EnvironmentInteractor(config, buffer)

# Sample from buffer

world_model = WorldModel()
rssm = RSSM()

models = (
        world_model.encoder,
        world_model.decoder,
        world_model.reward_model,
        world_model.discount_model,
        rssm.state_action_embedder,
        rssm.rnn,
        rssm.prior_model,
        rssm.posterior_model)

for episode in range(50):
    environment_interactor.create_trajectories(1000)

    data = buffer.sample(batch_size=50, prefetch_size=5)

    for sequence in data:
        state, next_state, action, reward, non_terminal = sequence[0]

            # use tf.gradientTape to compute loss, then gradients and apply these to the model to modify the parameters
        combined_trainable_variables = models[0].trainable_variables
        for i in range(len(models)):
            if i+1 >= len(models):
                break
            combined_trainable_variables += models[i+1].trainable_variables

        with tf.GradientTape() as tape:
            encoded_state = world_model.encoder(state)
            initial_rssm_state = RSSMState()
            prior_rssm_states, posterior_rssm_states = rssm.observing_rollout(encoded_state, action, non_terminal, initial_rssm_state)
            hidden_state_h_and_stochastic_state_z = tf.concat([posterior_rssm_states.stochastic_state_z, posterior_rssm_states.hidden_rnn_state], axis=-1)

            # TODO ÄNDERN
            hidden_state_h_and_stochastic_state_z = tf.reshape(hidden_state_h_and_stochastic_state_z, (-1,stochastic_state_size + hidden_unit_size))

            decoder_logits = world_model.decoder(hidden_state_h_and_stochastic_state_z)

            # TODO ÄNDERN
            decoder_logits = tf.reshape(decoder_logits, (-1, image_shape[0], image_shape[1], image_shape[2]))

            decoder_distribution = tfp.distributions.Independent(tfp.distributions.Normal(decoder_logits, 1))
            reward_logits = world_model.reward_model(hidden_state_h_and_stochastic_state_z)
            reward_distribution = tfp.distributions.Independent(tfp.distributions.Normal(reward_logits, 1))
            discount_logits = world_model.discount_model(hidden_state_h_and_stochastic_state_z)
            discount_distribution = tfp.distributions.Independent(tfp.distributions.Bernoulli(logits=discount_logits))

            image_log_loss = compute_log_loss(decoder_distribution, state)
            reward_log_loss = compute_log_loss(reward_distribution, reward)
            discount_log_loss = compute_log_loss(discount_distribution, non_terminal)
            kl_loss = compute_kl_loss(prior_rssm_states, posterior_rssm_states)



            loss = image_log_loss + reward_log_loss + discount_log_loss + kl_loss
            print(f"Image Log Loss: {image_log_loss} Reward Log Loss: {reward_log_loss} Discount Log Loss {discount_log_loss} KL Loss {kl_loss}")
            wandb.log({"Image Log Loss": image_log_loss, "Reward Log Loss": reward_log_loss, "Discount Log Loss": discount_log_loss, "KL Loss": kl_loss, "Loss": loss})
            # TODO maybe in Gradienttape??
            gradients = tape.gradient(loss, combined_trainable_variables)

        optimizer.apply_gradients(zip(gradients, combined_trainable_variables))

  df = df.append(pd.DataFrame.from_records(


Instructions for updating:
Please pass an integer value for `reinterpreted_batch_ndims`. The current behavior corresponds to `reinterpreted_batch_ndims=tf.size(distribution.batch_shape_tensor()) - 1`.
Image Log Loss: 4098.0341796875 Reward Log Loss: 127.16785430908203 Discount Log Loss 69.48236846923828 KL Loss 0.00021589998505078256
Image Log Loss: 4099.5986328125 Reward Log Loss: 123.47802734375 Discount Log Loss 69.29757690429688 KL Loss 0.00022834003902971745
Image Log Loss: 4099.32666015625 Reward Log Loss: 122.35476684570312 Discount Log Loss 69.08842468261719 KL Loss 0.00024108559591695666
Image Log Loss: 4097.93701171875 Reward Log Loss: 124.6875991821289 Discount Log Loss 69.00785827636719 KL Loss 0.0002294034493388608
Image Log Loss: 4097.6494140625 Reward Log Loss: 123.93911743164062 Discount Log Loss 68.76273345947266 KL Loss 0.0002383828687015921
Image Log Loss: 4096.54833984375 Reward Log Loss: 123.92155456542969 Discount Log Loss 68.9259262084961 KL Loss 0.00023135272203

KeyboardInterrupt: 

# Training Loop

In [None]:
iterator = iter(dataset)
print("Iterator trajectories:")
trajectories = []
for _ in range(3):
  t, _ = next(iterator)
  trajectories.append(t)


    train_step((
        world_model.encoder,
        world_model.decoder,
        world_model.reward_model,
        world_model.discount_model,
        rssm.state_action_embedder,
        rssm.rnn,
        rssm.prior_model,
        rssm.posterior_model)
    )

#print(tf.nest.map_structure(lambda t: t.shape, trajectories))

# World ModelTraining Loop

In [None]:
g = tf.keras.layers.GRUCell(32)
r = tf.random.uniform((64,32), 0,3)
g(r,r)

# Actor Critic

# World model & agent training loops

# Hyperparam inits
Agent Data collection in environment + adding data to ERB (+ measure at which reward loop stops?) 
World model loop on data sampled from ERB
Agent training loop with world model feedback
 

# Function execution

In [None]:
# Instantiate environment and network objects
# Loop:
# Pass respective inputs to networks
# Collect outputs
# Compute individuall losses
# Add together to 1 big loss
# Propagate with gradient Tape through network


# compute the loss of an input for the model and optimize/tweak according the parameters
def train_step(model, input, target, loss_function, optimizer):
    # use tf.gradientTape to compute loss, then gradients and apply these to the model to modify the parameters
    with tf.GradientTape() as tape:
        prediction = model(input)
        loss = loss_function(target, prediction)
        gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


# TODO move hyperparams to the rest
epochs = 32

# define loss-function and optimizer
cross_entropy_loss = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adam()

for epoch in range(epochs): 


    for world_model_input in tqdm(data):
        train_loss = train_step()
