Skip to content

Commit

Permalink
mixpool: Add missing mutex acquire
Browse files Browse the repository at this point in the history
reconsiderOrphans was accessing pool state without the pool mutex held, which
results in data races.

To resolve this, some refactoring is necessary.  reconsiderOrphans calls
acceptKE, which did acquire the mutex, so that needs to be hoisted out.

In this commit, the pre-mutex-acquire sanity checks for PR and KE messages in
acceptPR and acceptKE are moved to separate checkAccept{PR,KE} functions for
the caller, who then becomes responsible for acquiring the mutex.  After
acceptPR/acceptKE return, reconsiderOrphans is then called with the mutex
still held.

This is a backport candidate for 2.0.1.
  • Loading branch information
jrick authored and davecgh committed May 18, 2024
1 parent dcd41c6 commit ac51ffa
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions mixing/mixpool/mixpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,13 @@ func (p *Pool) AcceptMessage(msg mixing.Message) (accepted []mixing.Message, err
var msgtype msgtype
switch msg := msg.(type) {
case *wire.MsgMixPairReq:
if err := p.checkAcceptPR(msg); err != nil {
return nil, err
}

p.mtx.Lock()
defer p.mtx.Unlock()

accepted, err := p.acceptPR(msg, &hash, id)
if err != nil {
return nil, err
Expand All @@ -883,6 +890,13 @@ func (p *Pool) AcceptMessage(msg mixing.Message) (accepted []mixing.Message, err
return allAccepted, nil

case *wire.MsgMixKeyExchange:
if err := p.checkAcceptKE(msg); err != nil {
return nil, err
}

p.mtx.Lock()
defer p.mtx.Unlock()

accepted, err := p.acceptKE(msg, &hash, id)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1008,21 +1022,21 @@ func (p *Pool) removePR(pr *wire.MsgMixPairReq, reason string) {
}
}

func (p *Pool) acceptPR(pr *wire.MsgMixPairReq, hash *chainhash.Hash, id *idPubKey) (accepted *wire.MsgMixPairReq, err error) {
func (p *Pool) checkAcceptPR(pr *wire.MsgMixPairReq) error {
switch {
case len(pr.UTXOs) == 0: // Require at least one utxo.
return nil, ruleError(ErrMissingUTXOs)
return ruleError(ErrMissingUTXOs)
case pr.MessageCount == 0: // Require at least one mixed message.
return nil, ruleError(ErrInvalidMessageCount)
return ruleError(ErrInvalidMessageCount)
case pr.InputValue < int64(pr.MessageCount)*pr.MixAmount:
return nil, ruleError(ErrInvalidTotalMixAmount)
return ruleError(ErrInvalidTotalMixAmount)
case pr.Change != nil:
if isDustAmount(pr.Change.Value, p2pkhv0PkScriptSize, feeRate) {
return nil, ruleError(ErrChangeDust)
return ruleError(ErrChangeDust)
}
if !stdscript.IsPubKeyHashScriptV0(pr.Change.PkScript) &&
!stdscript.IsScriptHashScriptV0(pr.Change.PkScript) {
return nil, ruleError(ErrInvalidScript)
return ruleError(ErrInvalidScript)
}
}

Expand All @@ -1032,36 +1046,36 @@ func (p *Pool) acceptPR(pr *wire.MsgMixPairReq, hash *chainhash.Hash, id *idPubK
maxExpiry := mixing.MaxExpiry(uint32(curHeight), p.params)
switch {
case uint32(curHeight) >= pr.Expiry:
return nil, fmt.Errorf("message has expired")
return fmt.Errorf("message has expired")
case pr.Expiry > maxExpiry:
return nil, fmt.Errorf("expiry is too far into future")
return fmt.Errorf("expiry is too far into future")
}

// Require known script classes.
switch mixing.ScriptClass(pr.ScriptClass) {
case mixing.ScriptClassP2PKHv0:
default:
return nil, fmt.Errorf("unsupported mixing script class")
return fmt.Errorf("unsupported mixing script class")
}

// Require enough fee contributed from this mixing participant.
// Size estimation assumes mixing.ScriptClassP2PKHv0 outputs and inputs.
err = checkFee(pr, p.feeRate)
if err != nil {
return nil, err
if err := checkFee(pr, p.feeRate); err != nil {
return err
}

// If able, sanity check UTXOs.
if p.utxoFetcher != nil {
err := p.checkUTXOs(pr, curHeight)
if err != nil {
return nil, err
return err
}
}

p.mtx.Lock()
defer p.mtx.Unlock()
return nil
}

func (p *Pool) acceptPR(pr *wire.MsgMixPairReq, hash *chainhash.Hash, id *idPubKey) (accepted *wire.MsgMixPairReq, err error) {
// Check if already accepted.
if _, ok := p.prs[*hash]; ok {
return nil, nil
Expand Down Expand Up @@ -1291,21 +1305,21 @@ func validateOwnerProofP2PKHv0(extractFunc func([]byte) []byte, pkscript, pubkey
return utxoproof.ValidateSecp256k1P2PKH(pubkey, sig, expires)
}

func (p *Pool) acceptKE(ke *wire.MsgMixKeyExchange, hash *chainhash.Hash, id *idPubKey) (accepted *wire.MsgMixKeyExchange, err error) {
func (p *Pool) checkAcceptKE(ke *wire.MsgMixKeyExchange) error {
// Validate PR order and session ID.
err = mixing.ValidateSession(ke)
if err != nil {
return nil, ruleError(err)
if err := mixing.ValidateSession(ke); err != nil {
return ruleError(err)
}

if ke.Pos >= uint32(len(ke.SeenPRs)) {
err := fmt.Errorf("peer position is an invalid seen PRs position")
return nil, ruleError(err)
return ruleError(err)
}

p.mtx.Lock()
defer p.mtx.Unlock()
return nil
}

func (p *Pool) acceptKE(ke *wire.MsgMixKeyExchange, hash *chainhash.Hash, id *idPubKey) (accepted *wire.MsgMixKeyExchange, err error) {
// Check if already accepted.
if _, ok := p.pool[*hash]; ok {
return nil, nil
Expand Down

0 comments on commit ac51ffa

Please sign in to comment.