-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate @lightvector neural network enhancements #25
Comments
All these enhancements are for the policy network, which is generally doing pretty well even without these tweaks. I would be more interesting in seeing if this improves the value head, in which you can observe the following issue consistently:
Following the logic posted in lightvector's repository one could reasonably expect Parametric ReLUs and Global Pooling Properties to have little effect on this problem. However Chain Pooling might greatly help with this issue. To help with the performance impact, one could limit the chain pooling to the final convolutional layer in the value head with input
This is only one layer of chain pooling so we expect it to have issues identifying problems with loosely connected groups or certain types of false eyes. But it would probably help in most situations. To implement this we would have to do a custom CUDA kernel, which could probably be done relatively easily using some flood filling technique: int global_index = threadIdx.x;
do {
if (global_index == 0) {
*is_dirty = 0;
}
__syncthreads(); // barrier
for (int i = 0; i < batch_size; ++i) {
for (int c = 0; c < num_channels; ++c) {
const float original = data[c, global_index];
if (chain[global_index] == chain[N[global_index]])
data[i, c, global_index] = max(data[i, c, global_index], data[i, c, N[global_index]]);
if (chain[global_index] == chain[E[global_index]])
data[i, c, global_index] = max(data[i, c, global_index], data[i, c, E[global_index]]);
if (chain[global_index] == chain[S[global_index]])
data[i, c, global_index] = max(data[i, c, global_index], data[i, c, S[global_index]]);
if (chain[global_index] == chain[W[global_index]])
data[i, c, global_index] = max(data[i, c, global_index], data[i, c, W[global_index]]);
if (original < data[c, global_index])
*is_dirty = 1;
}
}
__syncthreads(); // barrier
} while (*is_dirty > 0); |
Lightvector updated his blog with some more result, of especial interest is the fact that adding dilation has a similar effect as adding chain pooling [1] but since you only apply dilation over some of the channel local shape information is not lost. This is very promising since dilation is built into the convolutional operator in most frameworks, including cuDNN. The problem is that according to the cuDNN documentation, only the For some common filter sizes this gives (rounded down to the closest multiple of 8 for SIMD purposes). We want some balance between local and global shape anyway, so this might provide a good mixture between the two:
Lightvectors observations about history [2] is also very interesting, not because of the reason he mentions. But because his suggestion to zero out the history channels randomly can act as a training data augmentation, which should help with potential overfitting. I am particularly interested in this because I've observed the same behaviour he cites, where the neural network learns sequences of moves instead of judging each individual board position separately. It is understandable why it does so (humans does this too), but it is not a desirable property and I've been considering getting ride of the history features completely to avoid this. Unfortunately the history features are very important, so this might provide a reasonable in-between. One could even take this a step further, where if one were to use one-hot encodings of the history planes, then one could shuffle the history planes (with care, to avoid illegal board positions) in order to provide a sort of tewari-like effect. [1] https://github.com/lightvector/GoNN#dilated-convolutions-mar-2018 |
With a naive cuDNN implementation the performance hit is quite significant, even when computing the normal and dilated convolution in parallel the neural network with dilation is about 36% slower. However even when using a network that was only trained for a few hours it succeeds one of the dead dragon tests (and in the other two the neural network is less certain about who is winning than it has historically been). No other networks has managed this to date, so even without any performance improvements I might be able to squeeze out dilation is probably worth it:
For reference, this network trained for 5,901 steps (using a batch size of 512) and achieved a 40.27% policy accuracy and a 59.96% value accuracy after about 1 hour and 17 minutes of training. Based on previous experience these numbers should improve significantly with more training. Output from
|
Out of curiosity, is it faster if instead of concatenating you add after the next convolution? Using the following identity or similar: It's probably worse to do it this way because the next convolution ends up split up so you lose benefits of greater 'batching', but if concat is particularly expensive for some reason then there's an off-chance it's better. |
I will try it, since my current implementation does not really have good memory access patterns due to the concatenation forcing me to temporarily re-write them as CNHW. Your re-formulation would allow us to keep NCHW the entire way, at the expense of more kernel launches. The sum of convolutions is also really fast to calculate since cuDNN fuses that into the convolution kernel (by allowing you to blend the input and output arrays). I think I also screwed up my SIMD multiplier since if one looks at the runtime of a Winograd kernel over different number of output channels you can see some clear bumps on the graph where the number of outputs channels are a multiple of 32: If you are curious my dilation implementation at the moment is pretty much the following, notice how both
For the sake of transparency these are the benchmark number before I added dilation:
These are the current benchmark numbers using the algorithm described in the previous section:
|
I finished my mock-up implementation of the two ideas mentioned above and they are of mixed success. Changing the channel count was slightly better, while avoiding the concatenation does not seem to be worth it (probably due to the lack of a good fused kernel) and batching. Channels as a multiple of 32This gave a performance improvement of 6%, so nothing groundbreaking but a solid improvement:
For the sake of completion this is a trace of the CUDA calls performed during a single residual block (for a batch size of 256). You can clearly see the issue being that the convolution and dilation both takes the same amount of time, so adding dilation effectively increased the amount of work during each residual block from 2 to 3. This match up with the observations elsewhere as it would predict a 33% performance loss:
Alternative formulation of concatenationThis turned out to be a bit problematic to implement this as we needed to not do a rectified linear unit on the final result of the addition of two convolutions. This probably does not sound too hard but we are using the fused operator There is also, as observed by lightvector, less batching with this approach which is typically bad for performance. Interestingly enough this approach has a systematic advantage for a batch size of one:
The profiling output for this approach suggest the bottleneck are the two non-fused convolutional kernels (note the lack of a
|
Adding a single dilated convolution had some good effects on the global perspective of Dream Go. After training for about 2 days on human games it only recognized two of our test cases as valid. Our previous, non-dilated, version recognized none of the test cases so this is still an improvement:
More dilation?Since a single dilated convolution did not give a large enough effect I figured I could try adding two dilated convolutions (with dilation 2 and 3) to increase the peripheral vision of each residual block even further. With this enhancement each residual block effectively sees a 7x7 block, allowing information to travel from one side to another in only 3 residual block (in theory). With this change each residual blocks gets this architecture:
As you can observe I also increased the number of channels from 128 to 192 since we were afraid of the local shape information getting lost if we reduced the number of output channels to 64/32/32. This introduce additional variables to take into account when evaluating this change, but historically increasing the number of features has not helped much with the global perspective. This architecture does very well on our test cases, the neural network only fails one of the dead dragon tests. The test that is fails is a game that white should win by 7.5 points, because a black dragon has one, and a false eye, if the neural network misjudged the group as alive then black would win by 72.5 points:
As you can see the neural network judge the game as being pretty close, which suggests that it does not consider the dragon to be fully alive. But considering there is nothing else on the board that is undecided it is still a clear failure. At the time of writing this the neural network has ran for 148.5k out of 245.7k steps so it has not been fully trained and may therefore be subject to change. The performance of the neural network is as one would expect from the posts above, not great. It is 66% slower than the original neural network, which again correspond closely to the expected slowdown of
However if this is the price we have to pay for good predictions then that is an acceptable trade-off. But I still need to check so that this is not an artificial increase in strength (and the loss of quantity vs quality of rollouts is not worth it). |
I also trained a 128 channel version of the architecture described above with 32 channels in total devoted to dilations, so according to the previous diagram:
This is, as expected, in-between the 1-dilation network and the 2-3 dilation network in terms of performance and precision. Unfortunately it completely misjudge the two dead dragons that are marked as
Currently running a tournament between four different programs to determine which version of the programs is the best one. The settings are fast (but not blitz) games, with chinese scoring:
The following programs are part of the test, all of them were trained using the same hyper-parameters but a random seed:
I will update the following section with the results, but the expected results would be the following ranking based on the assumption that the network sanity tests have some correlation to reality.
This trial was cancelled after 37 games (for every match-up, so a total of 367 games) had been played since some match-ups could be eliminated due to a winner having already been determined. The most notable of which is all matches against leela, which performing very badly for some reason (pretty sure it should be stronger than this). The other candidate that could be eliminated is The remaining three candidates were put into another match-up that we can use to determine which of them were worth continuing with:
After these match-ups the following ELO could be estimated, there are not enough games to determine an accurate rank and I consider the top 3 to be essentially equal. It is unknown why
The same argument can be used to explain why The
These two reasons interact, since there are more large patterns than there are local pattern, so it has to look at the global scope but has fewer channels to do so. This will result in it having to generalize global shape into local shape, which may not always work out. A few other observations to keep in mind:
So the larger the fraction of channels that are devoted to global thinking, the harder it will be for the network to be able to recognize local shape (because of the regularization factor mentioned above). |
The problem presented above have two issues:
The second problem is easy to solve, we could just decrease the regularization coefficient or drop the second residual blocks from the regularization completely. We could also do some gated architecture as below, using, for example, the batch normalization scale parameter as
It is not obvious if we have to solve the first problem, or if solving the second is enough for the optimizer to reserve some channels for local properties on its own. The only solution to the first problem we can think of would be to run separate towers for the different dilation levels and then combine then at the final layer but this has several issues on its own. Some hybrid approaches where only some residual blocks use dilation might also work. |
What do you mean by "local" properties versus "global" properties? If
either way a property of the Go position is computed accurately ("this
stone belongs to a group that has only one eye within radius 6 of this
location") it does not matter if the computation of that property involved
convolutions with different dilation levels or not. Some properties may be
easier or harder to compute using different mixes of different dilations of
course, but I think there's no reason to try to avoid blending them,
because there's no such thing in the first place as an intrinsically
"dilation 1" feature or a "dilation 2" feature that can only usefully be
used by further convolutions with the exact same dilation factor.
I'm possibly misunderstanding something?
Also, I'm curious- what regularization are you referring to? Keep in mind
that if you're using an L2 penalty on your weights but you're also using
relus and batchnorm, then my understanding is that the L2 penalty does
*not* have a significant regularization effect to begin with, so it has no
relevance to whether any features are on equal footing with others or not.
But if you're using a different regularization method things might be
different.
It's cool to see these updates. I'd be interested to hear if you have
results from your blitz games yet - it's possible the reduced performance
is a bigger cost than the gain from better large-scale understanding, but
if not, that would be really neat. :)
…On Wed, Apr 11, 2018 at 1:32 PM, Karl Sundequist Blomdahl < ***@***.***> wrote:
The problem presented above have two issues:
- *Global properties* that has been blended into *local properties* in
previous residual blocks.
- The second convolutional layer in each residual block having to
consider both dilated and non-dilated features as equals due to the
regularization.
The second problem is easy to solve, we could just decrease the
regularization coefficient or drop the second residual blocks from the
regularization completely. We could also do some gated architecture as
below, using, for example, the batch normalization scale parameter as G₂
and G₃:
x
├───┬───╮
D₁ D₂ D₃
│ │ │
│ G₂ G₃
├───┴───╯
C
│
y
It is not obvious if we have to solve the first problem, or if solving the
second is enough for the optimizer to *reserve* some channels for *local
properties* on its own. The only solution to the first problem we can
think of would be to run separate towers for the different dilation levels
and then combine then at the final layer but this has several issues on its
own. Some hybrid approaches where only some residual blocks use dilation
might also work.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#25 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/ALY5-6E2_jHqesTngD0ILHBuA9GLHl-Bks5tnj4xgaJpZM4R7nmu>
.
|
To answer your question in a random order: 1. Blitz gamesI can add some blitz games, they should be fast to play. In fact I would not be surprised if this post has some in it since I started some just as I was typing this sentence and I'm planning on writing a fair bit more. These blitz games are what we internally refer to as policy play games, i.e. they play greedily according to what the neural network suggests with no search. So these results should indicate the quality of the neural network predictions:
Results mirror what my sanity tests suggests, that dilations results in higher quality predictions. The exception is
I claim the reason 2. Global and local features / propertiesI've been using global properties and local properties somewhat fuzzily intentionally since I cannot claim to understand exactly what the neural network computes in the first place, and I am not a professional baduk player. A better word for what I mean might be near periphical features and far periphical features. Where near periphical features contains information about stones close to the centre of each convolution, and far periphical features are about stones far from the centre of each convolution. The networks with dilated convolutions seems to favour strategies that involves far periphical features, so things like large scale captures, and influence. The networks with dilation also perform worse during Life & Death problems, that would involve mostly near periphical features. This is all from me skimming some of the games in the archive I linked above and it is possible I am wrong. The reason for this behaviour, I think, is that the way we have been adding dilation force the network to reserve some channels to far periphical features, whereas before it had some choice on this†. My current concerns about dilation is related to the forcing part, which is explored further in section 3. Since different features are easier to compute with input from different dilation levels. If we force the network to always consider the far periphical features then it will have a harder time computing some features, and vice versa for near periphical features. For example it may be hard to recognize an eye, if you have no choice but to look at 5x5 patterns since the stones marked with
This is of course a simplified example, the optimizer is not so stupid it would store all combinations of the 16 question marks. But something similar to this is going on since you can observe the style differences between the networks. † Recent research calls the claim that it could choose to do some into question. 3. RegularizationI am using L2 regularization, batch normalization, and gradient clipping during the training. The last one is unnecessary at this point, but was useful while I still had a bug during initialization where it would sometimes try to factorize singular matrices (resulting in huge weights that would collapse the entire network to zero without clipping). You are correct in that batch normalization and L2 regularization is somewhat redundant, but my understanding says that L2 regularization achieves two things:
It is the second effect that I am worried about since it says
4. Alternative InterpretationMy main reason for worrying about this is the fact that the network with dilation seems to play worse than than the network without dilation. The blitz game results suggests that this is mainly due to the lack of search due to time constraints, which would also give rise to the same problem where it would fail to notice vital moves during non-trivial situations as they require reading to spot. 5. ConclusionI will train another network with the following configuration, without L2 regularization to see if it has any significant effect. I do not believe we will have any problems with overfitting, and if the L2 regularization had no impact then results should be the same as before:
This mirrors the PS: My idea about separate towers for different dilations is probably stupid, and does not really make sense upon further inspection. |
1. (Blitz games) -Sorry for the confusion: I misread your earlier post, where you said you were running some fast (non-blitz) games. Did those pan out, or were the new neural nets worse once taking into account the shallower search due to the worse performance? For these games though - nice results! 2. (Near vs Far features)I think I understand you. Yes, if the dilated channels are there, then the neural net will use them, and therefore devote some proportion of the non-dilated channels to computing features that are useful for the dilated channels, so long as that improves the overall loss function more than not doing so. That will obviously make it worse at doing whatever the excess non-dilated channels were doing before. I don't think this is a problem though, it simply is a tradeoff of network capacity. To give a different example - imagine you originally did not provide any history planes as an input feature, and now you add some, but don't increase the number of channels in the rest of the neural net. Then obviously the neural net will get worse at some kinds of tactics, because it will be now devoting some of its channels to processing the new history information, instead of devoting them to whatever local shapes it was doing before. But doing so improves the overall prediction quality, because the new history information is strongly predictive of other things. I don't think there is any special about dilated/nondilated, it is exactly like adding any other new information or representation capacity. It will cause a tradeoff to use the new capacity, but in a way that (unless you're experiencing major underfitting or overfitting) should be overall better in predictive ability. Of course, better predictive ability does not always mean more strength, because predictive ability and playing strength are two different things. 3 (Regularization)I'm putting this in a separate reply because this gets pretty technical. 4. (Alternative Interpretation)Yes, of course. As in #2 above, I think there is no significant pathology or problem with "mixing" this kind of information, if there is a loss in strength there is a good chance it is due to something like:
|
3. (Regularization)You mentioned this:
But actually, to first order L2 loss does NOT encourage all weights to be the same relative size as each other and does NOT have a significant effect on effect on overfitting, in the presence of batch normalization, and if you are using gradient descent. I could be mistaken, but I'm reasonably sure about this. This is a surprising fact if you have not thought about it before! Why?
This means there is no regularization effect or avoidance of overfitting. For example, imagine we are at a local minimum in data loss where weight A and B serve extremely similar purposes but weight A is twice as large as B. Then after scaling due to the L2 loss step, weight A will still be twice as large as B, and there will still be no data loss gradient to change them, so even as both shrink, A will remain twice as large as B forever. This is different than if there was no batchnorm. In that case, there after scaling there would be a data loss gradient to re-increase both A and B since both are now too small. If A and B serve the same purpose, then A and B would experience the same gradient upward, but since B is only decreased by the L2 loss half as much but it is re-increased just as much, the net effect will over time to make A = B, as expected. So with batch norm, the only thing that is affected by L2 loss is the global scale of the weights in each layer. This still does have an effect on the gradient, but only on the scale. Multiplying all the weights in a layer by a constant factor C followed by a batchnorm causes all gradients in that layer to be multiplied by 1/C during the backward pass, which is an effective factor of 1/C^2 in the relative gradient (relative to the magnitude of the weight). If you are using momentum, then the picture is changed a bit, but broadly I think the same analysis holds. If you are using an entirely different kind of optimizer, such as ADAM, then the above does not hold, but I think L2 still have a very strange effect that is very different than the regularization it has without batchnorm. Conclusion:L2 loss causes no regularization effect or avoidance of overfitting when using batch normalization because it has no effect on the predictions or on the relative directions of gradients. Instead, to first order it only affects the scale of weights, so it is approximately the same as training without any L2 but performing a slight increase in the learning rate on every iteration. |
Blitz gamesThe result of the fast games are in this reply. I can understand the confusion since I tend to heavily edit my posts as most of the time they end up just being an experimental log that no one (?) except me reads. In summary the results of the fast games were mixed, adding dilated convolutions produced some better and some worse (!) engines. But there was no significant jump in strength, I suspect mostly because of the performance issues, resulting in fewer rollouts for the networks using dilation. Near and Far featuresI think we agree here, which features were added / removed / changed does not really matter, and some change in behaviour is to be expected. If this L2 regularization thingy is something (I'll get to your second post later), then it would be a problem for features outside of dilation too. My worry originally comes from the fact that RegularizationI am using SGD with momentum but I do not believe this is important for this discussion as the analysis should turn out the same. I think your analysis is correct, but the assumption that the If batch normalization and weight decay is performed independently then I believe the L2 regularization still encourage the weights to be roughly the same magnitude. This is because the SGD update formula with
If we want the
I might look further into this tomorrow, but it is getting a bit late at the moment. ConclusionRegardless of the reason, we both seem to agree that removing the L2 regularization I am using at the moment is a good idea, since:
|
Sounds good.
Regarding the sequential application of L2 and data loss gradient, I agree that doing them not in sequence is a small difference, but it is much smaller than the first-order effect that batch norm takes away (consider that except at the very start of training, a single gradient update usually changes each weight by a miniscule fraction of a percent of the root mean square of the weights, so the second order effects from sequence vs simultaneous are very small).
Keep in mind that all of the analysis I wrote above is about the relative magnitudes of the weights. It doesn't matter if the optimizer can counter the weight decay or not, because decayed weights behave the same as undecayed weights if the next layer is a batch norm, what matters is the relative magnitude. The weight A and weight B example I gave is a good example to go back to. L2 penalty causes both A and B to shrink proportionally, with A continuing to be twice as large. The regularizing behavior would be to make A/B ~= 1, and if there is no batch norm, that is what you get, because the data provides pressure to increase them again, such that the local optimum has A/B ~= 1. But when there is batch norm, there is no such pressure, the decayed weights are just as good. If A and B both shrink enough and then random walk due to noise enough so that randomly B could become larger sometimes, then of course you are free to call that "regularization", but it is exactly the same kind of "regularization" as if you removed L2 loss entirely and just turned up the learning rate enough that B could randomly sometimes overtake A. Sometimes the opposite would happen and it would become even smaller or negative compared to A. This is a very different kind of "regularization" than one that actually encourages and converges to precisely A/B ~= 1 as the unique local optimum.
Edit: Fixed some spacing. Also, looking forward to any further updates in the future, this thread has been great and following it has been very interesting! :)
…On Wed, Apr 11, 2018 at 10:08 PM, Karl Sundequist Blomdahl < ***@***.***> wrote:
Blitz games
The result of the fast games are in this reply
<#25 (comment)>.
I can understand the confusion since I tend to heavily edit my posts as
most of the time they end up just being an experimental log that no one (?)
except me reads.
In summary the results of the fast games were mixed, adding dilated
convolutions produced some better and some worse (!) engines. But there was
no significant jump in strength, I suspect mostly because of the
performance issues, resulting in fewer rollouts for the networks using
dilation.
Near and Far features
I think we agree here, which features were added / removed / changed does
not really matter, and some change in behaviour is to be expected. If this
L2 regularization thingy is something (I'll get to your second post later),
then it would be a problem for features outside of dilation too.
My worry originally comes from the fact that dg-d-192-2-3 performs worse
(even during blitz) than any other network using dilation despite being
larger. But this could be because it is larger, but was trained for the
same number of steps as the other networks so it received *"less training
per weight"* (I am not sure if this is a thing).
Regularization
I am using SGD with momentum
<https://github.com/Chicoryn/dream-go/blob/master/contrib/trainer/dream_tf/__main__.py#L658>
but I do not believe this is important for this discussion as the analysis
should turn out the same.
I think your analysis is correct, but the assumption that the weight_decay
(L2 regularization) and gradient descent will be done in sequence does not
hold in practice
<http://pytorch.org/docs/master/_modules/torch/optim/sgd.html#SGD> where
they are done independently.
If batch normalization and weight decay is performed independently then I
believe the L2 regularization still encourage the weights to be roughly the
same magnitude. This is because the SGD update formula with weight_decay
and a constant gradient norm puts a hard limit on the size of the
weights, and it will be hard for the optimizer to maintain the effort it
needs to counter the weight_decay considering how noisy the gradients are
in SGD (though less so when using momentum):
next_weights = weights - weight_decay · weights - gradients
If we want the next_weights to increase then gradients must be larger
than the weight_decay · weights so weights has an upper bound of -gradients
/ weight_decay:
next_weights > weights
⇒ weight_decay · weights > -gradients
⇒ weights > -gradients / weight_decay
I might look further into this tomorrow, but it is getting a bit late at
the moment.
Conclusion
Regardless of the reason, we both seem to agree that removing the L2
regularization I am using at the moment is a good idea, since:
- You claim it does nothing except waste GPU cycles during training.
- I claim it might hurt neural network performance.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#25 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/ALY5-0CmjFqeBbbxZpdDw5U009oICt1Zks5tnrcygaJpZM4R7nmu>
.
|
@lightvector I think I followed what you wrote, my understanding of how BN works is very weak. Are there any take aways from this that could apply to some of the other projects like LZGo, minigo, LZChess? Minigo and LZChess are having trouble recently. |
In the context of LZ, I think generally having some amount of L2 as is the case right now, is probably good. Not because of regularization, but to maintain the learning rate. Because for LZ, you train the net essentially forever, even new nets you often use net2net to bootstrap. If you don't have any L2 loss (i.e. weight decay), then over time the norm of the weights will drift larger (e.g. a high-dimensional brownian motion will very reliably move away from from the origin proportional to sqrt(time)). This means your effective learning rate will drop over time**, which is bad because because as you are always receiving new and better data, you want to maintain a fixed and high learning rate. But with some amount of L2, you will reach an equilibrium where the outward drift is balanced by the inward decay, and therefore maintain your effective learning rate. For fixed-data sets (e.g. one-shot training a policy net to convergence on a fixed set of pro games), I don't think there is any particular value in L2 with batchnorm, since you want to anneal your learning rate anyways. Except maybe it makes your learning rate easier to think about if you tune it so that the equilibrium is actually roughly equal to the scale of your weights that you initialize with, so that now the only factor that affects your effective learning rate is your literal learning rate, rather than also this subtle weight-growth phenomenon. I'm not sure if this amounts to any particular takeaway for LZ other than what it's already doing. If you are curious, you can check if you reached equilibrium by simply printing out the norm of the weights every few million training steps and seeing if the weight norm is no longer changing much. LZ almost certainly has. ** If you're having trouble following, it's pretty simple. Consider a Z = batchnorm(Y) where Y = W * X. Double the weights W. Then, Y is doubled. So batchnorm is now dividing by an extra factor of 2 to undo the doubling. So dZ/dY is cut in half. Therefore dZ/dW is cut in half. Also because of batchnorm, we know only relative changes to weights matter. If we perform one update W := W - learningrate*dZ/dW, then since W was doubled and dZ/dW is half the size, the relative step is one quarter as large now. So effectively the learning rate has been divided by 4. |
I think you are correct, I mentioned that I started training a new network a few posts age and the results (so far) match your predictions. The loss was virtually the same in the beginning but it fails to keep up with the other networks towards the end, your explanation about the effect of L2 regularization on the learning rate would explain this behaviour. A takeaway from this is also that I am not training long enough (sigh), since it still benefits from an increasing learning rate. I included a screenshot of the accuracy and loss of the different networks below, the network without L2 regularization is the pink line that is noticeably trailing behind towards the end: The pink line is hard to see in some of the charts because they all converge within a percentage or two of each other anyway, but the network without L2 regularization is performing worse in all of the metrics. This raises some interesting questions for me, at the moment I do not use batch normalization nor L2 regularization on the weights in the value head and policy head (I forget if there is a good reason for this). This might be a mistake if we are saying that the combination of batch normalization + L2 regularization is effectively just a dynamic learning rate boost, since then the weights in those heads are missing out on some training towards the end. PS: You idea about looking at the norm of the weights as an indication on whether the learning rate is too small or too large is an interesting one since my norms would suggests my learning rate schedule is not too great (the norms are monotonically decreasing) and it should be relatively trivial to implement a dynamic learning rate based on the previous (maybe a moving average) and the current norm of the weights. A screenshot of the per-variable norms are included below, you can ignore the norm of the offsets (aka biases or β is the batch normalization formula), which do not have L2 regularization: |
On an unrelated note, did you do any experiments with what percentage of games to "drop" the history inputs during? You mention 5-10% in your repository, but is there any experimental data on what percentage yielded the best results? I am wondering because looking at the monte-carlo trees one can clearly observe that the neural network has a strong tendency to memorize sequences of play, which is not a desirable property. The search can still overrule the suggestions of the neural networks of course but that is a waste of GPU cycles if we can just fix the problem in the network instead. Example a self-play game where you can see this behaviour is attached (see the fake ko fight in the corner). The game includes all variations considered by the search, so it is pretty big (even if I only did 1,600 rollouts): |
Thanks for asking! No, I didn't test this. I don't have any value net or MCTS component yet, so it's hard for me to do experiments regarding what would improve search. My first project goal needs only policy (neural-net-aided explorations of human biases by players of different ranks), so that's what I've started with. I chose around 10% because it was a simple conservative number large enough to get the behavior I wanted but small enough to be very unlikely to cause much worse prediction when history was still present. If I were to experiment in the future, it would probably be:
(Edit: One minor detail - with my input representation, the neural net can always determine ko legality regardless of history or no-history. I would make sure that this remains the case, because I don't want no-history to also blind the neural net as to what moves are legal in the first place in ko fights, the neural net should of course always be given enough info to determine legality of current move). |
Do you have a source on the linear interpolation behaviour of relu? I would be interested in reading more about that since I was skimming my list of features to make sure it could still determine whether a move / ko is legal or not without the history planes, when I noticed that according to previous experiments that I have done the "Am I black or white?" feature only affects the accuracy of the value head by about -2.4% (but it did not affect the policy head accuracy at all). These experiments are not very recent so a lot could have changed, but this suggests said feature is mainly used for komi. The number of testing games* that that are won because of komi are up to 13.2% (depending on how you score draws), which suggests the feature doesn't do its job too well† but if it did then you could implement a dynamic komi between -7.5 and 7.5 by setting the "Am I black or white?" feature plane to
We are still trying to fix the learning rate after dropping the L2 regularization, which is very time consuming. We figured we'll try one of the fancy automatic learning schedules just for the sake of it, but said learning schedule is proving unstable as my training input is a mixture between three different data-sets which can cause some bad mini-batches and as a consequence a very noisy loss:
The original reason for mixing different datasets was to expose the neural network to moves that may only appear during reading in higher ranked games, and especially with lower ranked players to get it to distrust the opponents moves. It is not obvious if it is necessary to do this anymore if we start messing around with the history features as it would achieve the same thing, and the first point doesn't actually make sense. So we could probably swap to a more monotonic dataset that would result in a less noisy loss, and therefore train faster. |
Mostly my intuitions about linearity come from a few papers like https://arxiv.org/pdf/1412.6572.pdf showing behavior in extrapolation that is linear-like (although later papers point to linearity of behavior as very much not being the sole factor in the existence of adversarial examples). Also posts like http://mlg.eng.cam.ac.uk/yarin/blog_3d801aa532c1ce.html that show that there it is possible to give a Bayesian interpretation to some of what neural nets do, and the fact that if along a particular dimension you only have two data points (the value at history = 0 and the value at history = 1), under many sensible priors your posterior will look like a linear interpolation. I have not tested it, so this could be wrong. Maybe in reality it varies in a wiggly or jumpy fashion as you go from 0 to 1, in a way that unpredictably varies between differently-initialized nets trained on the same data. I expect that even if it is approximately linear, it is not exactly so, and will randomly wiggle and curve depending on the position. I would NOT rely such an interpolation to handle a variety of komi when training a value net under an AlphaZero-like process. Instead I would just have my self-play games generated with a large variety of different komi, telling it what the komi was in each case. This is much more likely to be reliable, since rather than praying that the neural net interpolates correctly to 0.5 komi when only given -7.5 komi and 7.5 komi examples, you simply train it directly on 0.5 komi games, or 3.5 komi games, etc. Of course, you are unlikely to find such games from human data sets that don't suffer from bias (e.g. 0.5 komi games mean that White was a higher-ranked player), so this does require that you generate your data set yourself. The reason I had such a crazy hack in mind to try testing for the history feature that I would definitely not want to use for komi was because I don't know a reasonable way where you can train with only "halfway providing" the correct history. You can train with only providing half the komi, because that actually changes the nature of your training data (affecting whether some positions are winning and losing), but I don't know how you do that with history. Either you provide history, or you don't. You could try providing it with noise, for example making it 50% likely to be an incorrect history versus a correct one, but that seems very hard to do well because unless your noise distribution is highly plausible among histories that could have led to this position, the neural net will probably just learn to distinguish incorrect histories from correct ones, and then it will know mostly whether to pay attention to it or not. |
Re-reading Explaning and harnessing adversarial examples suggest to me that the interpolation is only locally linear around each known value, hence the ε. So for real valued features the interpolation will probably be fairly smooth but it is not obvious how it would behave for binary input features [1]. So using it for komi would almost certainly not work unless we provide it with actual examples as you suggest. Never the less, my survey of articles about adversarial examples has further convinced me that perturbing the history features is actually very important for real-life performance. I will train the following networks, when I've finished re-tuning the learning rate, using the d-128-2-3 network as a base:
You have to be a little bit careful when shuffling to avoid feeding the neural network rubbish, which would probably cause it to just ignore those features. So I suggest the following limitations:
You could go further and provide similar random positions from other games, but I am not sure there is any point since that will not occur in practice. [1] http://colinraffel.com/publications/iclr2018thermometer.pdf Alternative solution to the history problem could be done using adversarial networks as suggested by Stijn Tonk [2]. Where we would, as part of the training, train an additional network that given some policy (and some additional data like the current board position?) tries to predict the previous move, we then include the loss of this adversarial networks during training. Unclear how well this would fit to Go, but just noting it here in case I want to look into it later. [2] Stijn Tonk, https://blog.godatadriven.com/fairness-in-ml These are the experimental results of setting the occupied vertices in the history planes to different constants instead of
You can make some interesting observations from the example games (played with the
I also played a blitz tournament between the top 3 engines
I cancelled this blitz tournament after 20 games in each match-up since the results were pretty clear (and very weird), and instead re-started a tournament with 1,600 rollouts to check if this also affects the value head. I will post the results when they are done, said tournament should take a day or so to complete. |
Out of curiosity, I tested this now. I checked what happens if I feed in the history planes at 0.5 weight in my neural net. It uses one-hot indicators of the location of the previous several moves (each move in its own channel) instead of the AlphaGo representation, and as I mentioned before, was specifically trained to behave well at both 0 (history absent) and 1 (history present). Informally, I looked by hand at the colored prediction heatmaps of several dozen positions over several games, comparing between history plane at 0, 0.5, and 1.0 weight. It definitely does interpolate. In every case I looked at by hand, the resulting 0.5 heatmap looked reasonable and was roughly "in between" the 0 and 1 heatmaps. I did not find any example where it did anything crazy or significantly non-interpolation-like. However, it definitely was NOT a linear or consistent interpolation. While sometimes the 0.5 heatmap was pretty close to an average of the 0 and 1 heatmaps, sometimes also it was much closer to either the 0 or the 1 heatmaps alone, e.g. more like 0.1 * [0 heatmap] + 0.9 * [1 heatmap]. Also, it was not always uniform between moves on the board - sometimes move as you went 0 -> 0.5 -> 1, move A would light strongly from 0->0.5, and B only a little, and then from 0.5 -> 1, move A would only light up a bit more, while B would light strongly. This seemed more common when A and B were far apart on the board and/or differed in whether they were next to the last move or not. Still, often it did give something clearly in-between. So, pretty interesting! :) |
I completely forgot about the interpolation behavior as I was so surprised at the tournament results. No history in this case is Results from the tournament with playouts (1,600) looks as I expected the blitz games too look. The history features seems to contribute to an increased network strength. Since this is the opposite of the blitz results the obvious conclusion is that the value head collapse without the history features but the policy head can keep going. It is unclear if the value head should benefit from the history features, but since we only compute a single shared representation for both the policy and value head there is not much choice for the value head other then using the history features if they allow for a lower net loss in the policy head:
|
Since the L2 regularization has almost no effect when combined with batch normalization [1] [2], remove it from the Tensorflow model and replace it with an automatic learning rate schedule that is based on OLS in order to detect when the loss has plateaued [3]. I am still having issues re-producing the accuracy that was achived with L2 regularization, but this seems to get very close. Also adjust the training script to use the new features introduced in the previous commit [4]. [1] Twan van Laarhoven, "L2 Regularization versus Batch and Weight Normalization", https://arxiv.org/abs/1706.05350 [2] #25 (comment) [3] http://blog.dlib.net/2018/02/automatic-learning-rate-scheduling-that.html [4] https://github.com/chicoryn/dream-go/wiki#input-features
Just letting you know that I implemented your idea in #25 (comment) for Leela Zero at leela-zero/leela-zero#1599 together with dynamic komi. It was pointed out by @TFiFiE that the formula |
@alreadydone That is very cool. I've not read your entire thread but your implementation seems to be working much better than I expected when I coined the original concept. As for the monotonicity of different networks, my intuition says this has to do with overfitting of the network to minor correlations in the training data, which could be due to too little training data or too low of a learning rate (or other reasons, this is not a solved problem). If you have a lot of spare GPU cycles you might want to consider training a robust network [1] using something like PGD, which should avoid many of the local maximums that breaks the monotonicity as they are effectively adversarial examples to your network. [1] Towards Deep Learning Models Resistant to Adversarial Attack |
Add support for setting komi to values other than `7.5`. This is accomplished by using the colour planes as suggested in #25. There are some additional changes included in this changelist because setting the colour planes to non-binary values causes our feature representation, which is based on bit sets, to break-down. Hence I also re-wrote the feature to TFRecord's. \#25 #25 (comment)
Check out some of ideas mentioned here to enhance the neural network:
@lightvector https://github.com/lightvector/GoNN
Initial thoughts on the concepts without any further research to back it up:
Global Pooled Properties This should be easy to implement and makes a lot of sense. Unclear if we want to add them early, before each residual block, or only in the policy and value head. Need to experiment a bit with this to tell.
Parametric ReLUs Unfortunately there is no support for PReLU in cuDNN so we cannot do this. Otherwise I've had similar experiences a long time ago when I was playing around with neural network architectures for Go where PReLU is just better. I used to experiment with other activation functions, like selu, but they were not as good.
Chain Pooling This probably gives worse results since the pooling destroy the local shape information. Might be interesting to do a dense block approach where each residual blocks becomes a DenseNet [1]. This way each residual block would benefit from both the pooling and the local shape:
x
compute the the chain pooling of each channel and store the result inyₛ
.x
andyₛ
channel-wise intoyₓₛ
(soyₓₛ
has shape[256, 19, 19]
).y₁ <- relu(C(W₁, yₓₛ) + b₁)
(y₁
has shape[128, 19, 19]
)y₁
andyₛ
channel-wise intoy₁ₛ
(soy₁ₛ
has shape[256, 19, 19]
).y₂ <- relu(C(W₂, y₁) + x + b₂)
y₂
I suspect this is too expensive to do in practice (no good support for it in cuDNN), but it is a very interesting idea.
The text was updated successfully, but these errors were encountered: