Skip to content

Conversation

@mcognetta
Copy link

During inference, logits corresponding to illegal moves are supposed to be masked out so that they have probability 0 after softmax. Right now, this is done by multiplying the logit with the legal moves mask before the softmax layer. However, the legal move mask is a set of 0s and 1s and the logits are in unnormalized log space. This means that multiplication by 0 just converts the logit to be 1 in unnormalized real space, so the invalid moves have non-zero probability after softmax. The correct way is to add -inf to all invalid logits so that they have probability 0 in real space.

See the following example:

>>> x = torch.rand((3, 5)); x
tensor([[0.8924, 0.8796, 0.4099, 0.8877, 0.2908],
        [0.3853, 0.9005, 0.4000, 0.5488, 0.3874],
        [0.0757, 0.2595, 0.5019, 0.2247, 0.0422]])

>>> mask = torch.rand((3, 5)) > 0.6; mask
tensor([[False, False, False,  True, False],
        [ True,  True, False, False,  True],
        [False,  True, False,  True,  True]])

# without masking
>>> torch.softmax(x, dim = -1)
tensor([[0.2411, 0.2380, 0.1488, 0.2400, 0.1321],
        [0.1704, 0.2852, 0.1729, 0.2007, 0.1708],
        [0.1706, 0.2050, 0.2613, 0.1980, 0.1650]])

# current masking method; notice that everything is non-zero and invalid moves all have the same probability
>>> torch.softmax(x * mask, dim = -1)
tensor([[0.1555, 0.1555, 0.1555, 0.3779, 0.1555],
        [0.1986, 0.3324, 0.1351, 0.1351, 0.1990],
        [0.1789, 0.2318, 0.1789, 0.2239, 0.1866]])

# new masking method; masked out values are given 0 probability
>>> torch.softmax(x + mask.log(), dim = -1)
tensor([[0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
        [0.2720, 0.4554, 0.0000, 0.0000, 0.2726],
        [0.0000, 0.3610, 0.0000, 0.3486, 0.2904]])

@mcognetta
Copy link
Author

mcognetta commented Oct 8, 2025

This doesn't affect the results too much in most cases, but it can cause some big leakages. On the example testset in the README, the worst case is that 5.1% of the probability mass is leaked to invalid moves.

That position is 'rn1q1rk1/ppp2ppp/4bn2/3p3P/4p3/P3P3/1PPPBPPb/RNBQK3 w Q - 0 11' with elos (1500, 1498). The move probs are

{'a3a4': 0.1074, 'e1f1': 0.1011, 'h5h6': 0.0953, 'g2g3': 0.0947, 'b1c3': 0.0733, 'd2d3': 0.0564, 'e2f1': 0.0533, 'a1a2': 0.049, 'g2g4': 0.0477, 'b2b4': 0.04, 'd2d4': 0.0365, 'e2g4': 0.0322, 'e2b5': 0.0312, 'b2b3': 0.0279, 'f2f3': 0.0269, 'f2f4': 0.0216, 'c2c4': 0.0201, 'c2c3': 0.0118, 'e2f3': 0.0069, 'e2a6': 0.0055, 'e2c4': 0.0051, 'e2d3': 0.0051}

with sum(move_probs.values()) = .949.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant