Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding highest performant solution to archive #2

Closed
vpariza opened this issue Mar 29, 2022 · 4 comments
Closed

Adding highest performant solution to archive #2

vpariza opened this issue Mar 29, 2022 · 4 comments

Comments

@vpariza
Copy link

vpariza commented Mar 29, 2022

Hello, I have been working with QDax and I think I encountered an issue regarding which solutions are stored in the archive
based on their evaluation scores. That is, it seems like there is a chance that not the highest performant solution will be stored in the archive.

leaf = jax.tree_leaves(repertoire.archive)[i].at[bd_insertion].set(weight)

More precisely, please execute the following code and trace which solution is stored in the archive at the end as generated in the line

new_archive = jax.tree_unflatten(jax.tree_structure(repertoire.archive), leaves)
.

The code:

import jax
from qdax.qd_utils import grid_archive
import jax.numpy as jnp

key = jax.random.PRNGKey(0)
params_size =  5
batch_size = 3
grid_shape = (30,30)
min_bd = 0
max_bd = 1
repertoire = grid_archive.Repertoire.create(jax.random.normal(key,shape=(params_size,)), min=min_bd, max=max_bd, grid_shape=grid_shape)

params = jnp.array([[-0.4518846,  -2.0728214,   0.02437184,  0.56900173, -2.0105903 ],
 [ 0.31103376, -0.29348192, -0.27793083, -1.2343968,   1.6130152 ],
 [ 1.997377,   -0.9525061,  -0.57822144,  0.8413021,  -2.02012   ]])

bds = jnp.array([[-0.16706778, -0.5440059 ],
 [-0.47653008, -1.6869655 ],
 [-0.9096347,  -0.07636569]])
objs = jnp.array([2, 10, 2 ]) 

dead = jnp.zeros(batch_size)

repertoire = repertoire.add_to_archive(repertoire = repertoire,
                                 pop_p = params,
                                 bds = bds,
                                 eval_scores = objs,
                                 dead = dead)

The result I get:

bds= [[-0.16706778 -0.5440059 ]
 [-0.47653008 -1.6869655 ]
 [-0.9096347  -0.07636569]]

pop_p= [[-0.4518846  -2.0728214   0.02437184  0.56900173 -2.0105903 ]
 [ 0.31103376 -0.29348192 -0.27793083 -1.2343968   1.6130152 ]
 [ 1.997377   -0.9525061  -0.57822144  0.8413021  -2.02012   ]]

eval_scores= [ 2 10  2]

bd_insertion= [0 0 0]

repertoire.archive= [[ 1.997377   -0.9525061  -0.57822144  0.8413021  -2.02012   ]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]
 .
 .
 .
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.        ]]

You can change the behavioural descriptors to any other values that would map to the same position and the same issue still exists. That is, this doesn't seem to happen on position 0 of the archive array only.

@maxiallard
Copy link
Collaborator

maxiallard commented Mar 29, 2022

@valentinosPariza Thanks a lot for opening the issue!

I did a few tests with your code and indeed there is a bug in the code that has the behaviour you described only if not other individuals have been added previously (i.e. the fitness is nan). If the cell was filled previously, the behaviour is as expected (you can test to add all the individuals twice for the bd 0).

I am not sure why you think lines 75 or 78 contain a bug since it's more related to the indexing.
I fixed this issue to correctly add the best individual if no other solution has been found yet and I will push it soon. Thanks a lot for pointing us towards that!

The fix you proposed is indeed what I did.

@vpariza
Copy link
Author

vpariza commented Mar 29, 2022

I am just including the solution I recommended before:
Replacing this line

bd_insertion = bd_indexes * mult_to_be_added
with the following code:

        mask = jnp.where(maximum_fitness.at[bd_indexes].get()==eval_scores,True,False)
        to_be_added = jnp.logical_and(better_fitness + current_fitness_nan, mask) 

Thanks for your response @maxiallard

In the beginning I thought that it was an issue with the indexing, but after further investigation, I found exactly what you mentioned and I was just ready to fix my issue identification explanation.

I think my point of indexing was valid since at the end the code was allowing solutions with non-maximum evaluation score to be mapped to the same index. That is, their index was not converted to the out-of bounds index.

But, indeed as you mentioned the issue was not exactly in lines 75 and 78, and that was the misconception I had at the beginning.

@maxiallard
Copy link
Collaborator

Excellent! Thanks a lot for the quick reply @valentinosPariza . I pushed the fix, maybe you can check whether it works for you now.
I added additional lines above to account for a special case with the 0 cell and changed the bd_insertion line to be:
mult_to_be_added = jnp.where(to_be_added,0,100000)
bd_insertion = bd_indexes + mult_to_be_added

Regarding the indexing, the sorting wouldn't have solved our issues so that's why we used the jax.ops.max_segment operation instead but then we missed this bit. So thanks again for pointing it out!

@vpariza
Copy link
Author

vpariza commented Mar 29, 2022

Great, thanks a lot for the update and for the fixes.
I will try them out.

Yes, you are right. Indexing is not the same for jax.Numpy as for NumPy.
In the beginning, I thought that jax.Numpy simultaneous assignment operation to same indices works like Numpy (applies the last update, rather than all of them) and thus sorting would have solved the issue. But, that's not the case.
I can only say that I am still learning the differences between NumPy and jax.NumPy.

Thanks again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants