Skip to content

Commit

Permalink
fix: wrong usage of RLock
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Jul 22, 2024
1 parent fd5b537 commit 4eb13a7
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 20 deletions.
2 changes: 0 additions & 2 deletions adapter/outboundgroup/loadbalance.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ func strategyStickySessions(url string) strategyFn {
proxy := proxies[nowIdx]
if proxy.AliveForTestUrl(url) {
if nowIdx != idx {
lruCache.Delete(key)
lruCache.Set(key, nowIdx)
}

Expand All @@ -215,7 +214,6 @@ func strategyStickySessions(url string) strategyFn {
}
}

lruCache.Delete(key)
lruCache.Set(key, 0)
return proxies[0]
}
Expand Down
32 changes: 32 additions & 0 deletions common/lru/lrucache.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ func (c *LruCache[K, V]) Delete(key K) {
c.mu.Lock()
defer c.mu.Unlock()

c.delete(key)
}

func (c *LruCache[K, V]) delete(key K) {
if le, ok := c.cache[key]; ok {
c.deleteElement(le)
}
Expand Down Expand Up @@ -255,6 +259,34 @@ func (c *LruCache[K, V]) Clear() error {
return nil
}

// Compute either sets the computed new value for the key or deletes
// the value for the key. When the delete result of the valueFn function
// is set to true, the value will be deleted, if it exists. When delete
// is set to false, the value is updated to the newValue.
// The ok result indicates whether value was computed and stored, thus, is
// present in the map. The actual result contains the new value in cases where
// the value was computed and stored.
func (c *LruCache[K, V]) Compute(
key K,
valueFn func(oldValue V, loaded bool) (newValue V, delete bool),
) (actual V, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()

if el := c.get(key); el != nil {
actual, ok = el.value, true
}
if newValue, del := valueFn(actual, ok); del {
if ok { // data not in cache, so needn't delete
c.delete(key)
}
return lo.Empty[V](), false
} else {
c.set(key, newValue)
return newValue, true
}
}

type entry[K comparable, V any] struct {
key K
value V
Expand Down
4 changes: 2 additions & 2 deletions common/queue/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ func (q *Queue[T]) Copy() []T {

// Len returns the number of items in this queue.
func (q *Queue[T]) Len() int64 {
q.lock.Lock()
defer q.lock.Unlock()
q.lock.RLock()
defer q.lock.RUnlock()

return int64(len(q.items))
}
Expand Down
4 changes: 2 additions & 2 deletions common/utils/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ func NewCallback[T any]() *Callback[T] {
}

func (c *Callback[T]) Register(item func(T)) io.Closer {
c.mutex.RLock()
defer c.mutex.RUnlock()
c.mutex.Lock()
defer c.mutex.Unlock()
element := c.list.PushBack(item)
return &callbackCloser[T]{
element: element,
Expand Down
20 changes: 6 additions & 14 deletions component/sniffer/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/netip"
"sync"
"time"

"github.com/metacubex/mihomo/common/lru"
Expand All @@ -30,7 +29,6 @@ type SnifferDispatcher struct {
forceDomain *trie.DomainSet
skipSNI *trie.DomainSet
skipList *lru.LruCache[string, uint8]
rwMux sync.RWMutex
forceDnsMapping bool
parsePureIp bool
}
Expand Down Expand Up @@ -85,14 +83,11 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata
return false
}

sd.rwMux.RLock()
dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
if count, ok := sd.skipList.Get(dst); ok && count > 5 {
log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst)
defer sd.rwMux.RUnlock()
return false
}
sd.rwMux.RUnlock()

if host, err := sd.sniffDomain(conn, metadata); err != nil {
sd.cacheSniffFailed(metadata)
Expand All @@ -104,9 +99,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata
return false
}

sd.rwMux.RLock()
sd.skipList.Delete(dst)
sd.rwMux.RUnlock()

sd.replaceDomain(metadata, host, overrideDest)
return true
Expand Down Expand Up @@ -176,14 +169,13 @@ func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metad
}

func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
sd.rwMux.Lock()
dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
count, _ := sd.skipList.Get(dst)
if count <= 5 {
count++
}
sd.skipList.Set(dst, count)
sd.rwMux.Unlock()
sd.skipList.Compute(dst, func(oldValue uint8, loaded bool) (newValue uint8, delete bool) {
if oldValue <= 5 {
oldValue++
}
return oldValue, false
})
}

func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
Expand Down

0 comments on commit 4eb13a7

Please sign in to comment.