Skip to content

Commit

Permalink
Merge pull request #181 from etiennecollin/patch-1
Browse files Browse the repository at this point in the history
Fixed typos in intro_gfn_continuous_line_simple.ipynb
  • Loading branch information
josephdviviano committed May 28, 2024
2 parents 4387e5b + e024fe1 commit 0764313
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions tutorials/notebooks/intro_gfn_continuous_line_simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In this tutorial, we will explore a simple use-case of continuous GFlowNets: sampling from a multinomial Gaussian. This is an exceedingly simple example which is not representative of the complexities inherent with applying this method in real applications, but will highlight some common challenges and tricks useful. But first, please run the cell below to make available some helper functions:"
"In this tutorial, we will explore a simple use case of continuous GFlowNets: sampling from a multinomial Gaussian. This is an exceedingly simple example which is not representative of the complexities inherent with applying this method in real applications, but will highlight some common challenges and tricks useful. But first, please run the cell below to make available some helper functions:"
]
},
{
Expand Down Expand Up @@ -143,24 +143,24 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Here, we are explore Continuous GFlowNets in an exceedingly simple case: from an initial starting point on the number line, sample a set of increments such that we learn to sample from some reward distribution. Here, that reward distribution will be some mixture of Gaussians. Each step will be sampled from Gaussians as well distribution.\n",
"Here, we are exploring Continuous GFlowNets in an exceedingly simple case: from an initial starting point on the number line, sample a set of increments such that we learn to sample from some reward distribution. Here, that reward distribution will be some mixture of Gaussians. Each step will be sampled from Gaussians as well distribution.\n",
"\n",
"The key difference with Continuous GFlowNets is that they sample some *delta* in a continuous space, instead of discrete actions. Typically, this means your GFlowNet uses a function approximator $f(\\cdot)$, which accepts the current state $s_{t}$, to predict the *paramaters of a distribution* $\\rho = \\{p_1, p_2, ..., p_n\\}$. Then your chosen distribution $D(\\rho)$ is used to sample a real-valued tensor $s_{\\Delta} \\sim D(\\rho)$ which is added to your current state to produce a the next step in the state space $s_{t+1} = s_{t} + s_{\\Delta}$ (note, we no longer consider a DAG here, but rather a topological space with distinguished source and sink states).\n",
"The key difference with Continuous GFlowNets is that they sample some *delta* in a continuous space, instead of discrete actions. Typically, this means your GFlowNet uses a function approximator $f(\\cdot)$, which accepts the current state $s_{t}$, to predict the *paramaters of a distribution* $\\rho = \\{p_1, p_2, ..., p_n\\}$. Then your chosen distribution $D(\\rho)$ is used to sample a real-valued tensor $s_{\\Delta} \\sim D(\\rho)$ which is added to your current state to produce the next step in the state space $s_{t+1} = s_{t} + s_{\\Delta}$ (note, we no longer consider a DAG here, but rather a topological space with distinguished source and sink states).\n",
"\n",
"As an aside, note that both $s_{\\delta} and the distribution $D(\\rho)$ can be as complex as you want, but this adds a lot of complexity and room for bugs. So to get you started, we're going to work with $s_{\\delta}$ being a single scalar, and $D(\\rho)$ being a simple Gaussian distribution. At the end, we will point to resources covering more complex settings which involve sampling from mixtures of distributions.\n",
"As an aside, note that both $s_{\\delta}$ and the distribution $D(\\rho)$ can be as complex as you want, but this adds a lot of complexity and room for bugs. So to get you started, we're going to work with $s_{\\delta}$ being a single scalar, and $D(\\rho)$ being a simple Gaussian distribution. At the end, we will point to resources covering more complex settings which involve sampling from mixtures of distributions.\n",
"\n",
"In our case, we want to increment along the number line in such a way that we learn to sample from some arbitrary multi-modal distribution. So we need a distribution from which to sample these steps. Recall the formula of a Gaussian:\n",
"In our case, we want to increment along the number line in such a way that we learn to sample from some arbitrary multimodal distribution. So we need a distribution from which to sample these steps. Recall the formula of a Gaussian:\n",
"\n",
"$$g(x) = \\frac{1}{\\sigma\\sqrt{2\\pi}} exp \\big(-\\frac{1}{2} \\frac{(x-\\mu)^2}{\\sigma^2} \\big)$$\n",
".\n",
"\n",
"To parameterize this, we will need a neural network to predict the parameters of the Gaussian: $\\mu$, the mean, and $\\sigma$, the standard deviation. We're also going to enforce that $ 0.1 <= \\sigma <= 2$ to help with convergence (see the hyperparameters above).\n",
"\n",
"In our setup, we will define a multimodal Gaussian distribution on the 1D line. We will also define a an arbitrary starting point $S_0$ on the number line where all trajectories will start. The GFlowNet must sample increments along the number line such that it samples final values along the number line proportionally to the mixture distribution.\n",
"In our setup, we will define a multimodal Gaussian distribution on the 1D line. We will also define an arbitrary starting point $S_0$ on the number line where all trajectories will start. The GFlowNet must sample increments along the number line such that it samples final values along the number line proportionally to the mixture distribution.\n",
"\n",
"We need to ensure there are no cycles in our state space to follow the theory of GFlowNets, but in this set up, a cycle would be easy to obtain. If we sampled first an increment of $+1$ and then an increment of $-1$, we could produce a cycle, and there are an infinite number of these on the real number line. To do so, let's simply include the count value, $t$, in the state $s_t$. In this setup, the state vector is `[x_position, n_steps]`, and the previous trajectory $[0, 0] \\rightarrow [1, 1] \\rightarrow [0, 2]$ would not be considerd a cycle. This step counter also can be used to know when to terminate this process, otherwise we never sample a final value. In this case, let's always terminate when $t=5$ (see hyperparameters above). There are more sophisitcated ways to handle termination, but they add complexity, and we want to focus this tutorial on only the core concepts.\n",
"We need to ensure there are no cycles in our state space to follow the theory of GFlowNets, but in this setup, a cycle would be easy to obtain. If we sampled first an increment of $+1$ and then an increment of $-1$, we could produce a cycle, and there are an infinite number of these on the real number line. To do so, let's simply include the count value, $t$, in the state $s_t$. In this setup, the state vector is `[x_position, n_steps]`, and the previous trajectory $[0, 0] \\rightarrow [1, 1] \\rightarrow [0, 2]$ would not be considered a cycle. This step counter can also be used to know when to terminate this process, otherwise we never sample a final value. In this case, let's always terminate when $t=5$ (see hyperparameters above). There are more sophisticated ways to handle termination, but they add complexity, and we want to focus this tutorial on only the core concepts.\n",
"\n",
"Since every state reachable by the backward policy must also be reachable by the forward policy, we also need to enforce that the the final transition of the backward policy goes exactly to $S_0$. We'll cover how this happens later."
"Since every state reachable by the backward policy must also be reachable by the forward policy, we also need to enforce that the final transition of the backward policy goes exactly to $S_0$. We'll cover how this happens later."
]
},
{
Expand Down Expand Up @@ -259,13 +259,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"To keep things simple, we'll enforce that all trajectories are exactly 5 steps. With probabilisitc exit actions, the logic becomes more tricky, though it is often useful in some applications. \n",
"To keep things simple, we'll enforce that all trajectories are exactly 5 steps. With probabilistic exit actions, the logic becomes more tricky, though it is often useful in some applications. \n",
"\n",
"For each forward action, we will add the action value to the current state, and increment the step counter. A backward action is simply the inverse: we will substract the action value from the current state, and decrement the step counter.\n",
"For each forward action, we will add the action value to the current state, and increment the step counter. A backward action is simply the inverse: we will subtract the action value from the current state, and decrement the step counter.\n",
"\n",
"Given this distribution we retrieve from `get_policy_dist()`, we sample an action $s_{\\Delta} \\sim D(\\rho)$. Recall that our state representation is `(x_position, n_steps)`. In this case, we are sampling $x_{\\Delta} \\sim \\mathcal{N}(\\mu, \\sigma^2)$, and our next state is `(x_position + x_delta, n_steps + 1)`.\n",
"\n",
"We'll also define a function that initalizes a state at $S_0$, which in our case has the `x_position` set to whatever we defined in our environment, and `n_steps` to 0."
"We'll also define a function that initializes a state at $S_0$, which in our case has the `x_position` set to whatever we defined in our environment, and `n_steps` to 0."
]
},
{
Expand Down Expand Up @@ -542,7 +542,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"What we're seeing here is mode collapse due to on policy training. We can fix this with off policy exploration.\n",
"What we're seeing here is a mode collapse due to on policy training. We can fix this with off policy exploration.\n",
"\n",
"We can go off policy in many ways, but one simple way would be to add some constant to the variance predicted by our forward policy for the normal distribution. We can also decay this constant linearly over training iterations too facilitate convergence.\n",
"\n",
Expand Down Expand Up @@ -575,9 +575,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In the below training loop, let's add the changes needed to allow for off policy exploration. To do so, we need to accomplish a few things:\n",
"In the training loop below, let's add the changes needed to allow for off policy exploration. To do so, we need to accomplish a few things:\n",
"\n",
"1) Define a value to increase the variance by, to encourage exploration. Ideally, this would be on a schedule, i.e,, the value we are adding to the variance of the predicted distribution will decrease over iterations. Let's use the `init_explortation_noise` variable for this.\n",
"1) Define a value to increase the variance by, to encourage exploration. Ideally, this would be on a schedule, i.e., the value we are adding to the variance of the predicted distribution will decrease over iterations. Let's use the `init_explortation_noise` variable for this.\n",
"2) Sample actions from the exploration distribution.\n",
"3) Calculate `logPF` using the log probabilities from the policy distribution."
]
Expand Down Expand Up @@ -762,7 +762,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Here, we have multiple challenges. Our starting point $S_0$ is now closer to Mode 1 than Modes 2 & 3, but the those combined modes have twice the probability mass that Mode 1 has. Furthermore, there is Mode 4, quite far from the initial starting point. To get started, let's modify some of out hyperparamaters to enable better sampling and exploration of this environment, and train for 10k iterations. We'll do this by allowing the policy to sample from Gaussian distributions with larger $\\sigma$ values, and increasing the `init_exploration_noise` value."
"Here, we have multiple challenges. Our starting point $S_0$ is now closer to Mode 1 than Modes 2 & 3, but those combined modes have twice the probability mass that Mode 1 has. Furthermore, there is Mode 4, quite far from the initial starting point. To get started, let's modify some of our hyperparamaters to enable better sampling and exploration of this environment, and train for 10k iterations. We'll do this by allowing the policy to sample from Gaussian distributions with larger $\\sigma$ values, and increasing the `init_exploration_noise` value."
]
},
{
Expand Down Expand Up @@ -823,13 +823,13 @@
"source": [
"Things aren't looking good yet. The model has learned to sample positive x positions due to majority of the reward mass being located in this region, but we are nowhere close to sampling from the correct probability distribution.\n",
"\n",
"We likely don't have time to train this during the tutorial, but by playing with the above hyperparameters long enough, we can eventually learn to sample from this reward distribution. Try playing with this notebook at home to build an intution as to how the different hyperparameters affect the results:\n",
"We likely don't have time to train this during the tutorial, but by playing with the above hyperparameters long enough, we can eventually learn to sample from this reward distribution. Try playing with this notebook at home to build an intuition as to how the different hyperparameters affect the results:\n",
"\n",
"+ `trajectory_length`\n",
"+ `init_exploration_noise`\n",
"+ `min/max_policy_std`\n",
"+ `n_iterations`\n",
"+ `learning_rate` (for the model and logZ estimate seperately!)\n",
"+ `learning_rate` (for the model and logZ estimate separately!)\n",
"+ `hid_dim`\n",
"\n",
"A key takeaway here is that the complexity of tuning the hyperparameters for training a continuous GFlowNet quickly grows with the complexity of the environment... even in a very simple case such as this one. "
Expand Down

0 comments on commit 0764313

Please sign in to comment.