Skip to content

Commit

Permalink
Add Callback Segment List
Browse files Browse the repository at this point in the history
  • Loading branch information
tung.tq committed Oct 11, 2023
1 parent 754347a commit 9c45710
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 11 deletions.
121 changes: 110 additions & 11 deletions session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package memproxy

import "time"
import (
"sync"
"time"
)

type sessionProviderImpl struct {
nowFn func() time.Time
Expand Down Expand Up @@ -61,6 +64,7 @@ func newSession(
lower: nil,
higher: higher,
}

if higher != nil {
higher.lower = s
s.isDirty = higher.isDirty
Expand All @@ -70,7 +74,7 @@ func newSession(

type sessionImpl struct {
provider *sessionProviderImpl
nextCalls []CallbackFunc
nextCalls callbackList
heap delayedCallHeap

isDirty bool // an optimization
Expand Down Expand Up @@ -99,10 +103,7 @@ func setDirtyRecursive(s *sessionImpl) {
// AddNextCall ...
func (s *sessionImpl) AddNextCall(fn CallbackFunc) {
setDirtyRecursive(s)
if s.nextCalls == nil {
s.nextCalls = make([]CallbackFunc, 0, 32)
}
s.nextCalls = append(s.nextCalls, fn)
s.nextCalls.append(fn)
}

// AddDelayedCall ...
Expand Down Expand Up @@ -145,11 +146,15 @@ func (s *sessionImpl) GetLower() Session {
}

func (s *sessionImpl) executeNextCalls() {
for len(s.nextCalls) > 0 {
nextCalls := s.nextCalls
s.nextCalls = nil
for _, call := range nextCalls {
call.Call()
for !s.nextCalls.isEmpty() {
it := s.nextCalls.getIterator()

for {
fn, ok := it.getNext()
if !ok {
break
}
fn.Call()
}
}
}
Expand All @@ -174,3 +179,97 @@ MainLoop:
}
}
}

// ===============================
// callback list
// ===============================

type callbackList struct {
head *callbackSegment
tail *callbackSegment
}

type callbackSegment struct {
next *callbackSegment // linked list of callback
size int
funcs [16]CallbackFunc
}

func (s *callbackList) append(fn CallbackFunc) {
if s.tail == nil {
s.head = getCallbackSegment()
s.tail = s.head
} else if s.tail.size >= len(s.tail.funcs) {
newTail := getCallbackSegment()
s.tail.next = newTail
s.tail = newTail
}

n := s.tail
n.funcs[n.size] = fn
n.size++
}

func (s *callbackList) isEmpty() bool {
return s.head == nil
}

type callbackListIterator struct {
current *callbackSegment
index int
}

// getIterator also clears the list
func (s *callbackList) getIterator() callbackListIterator {
it := callbackListIterator{
current: s.head,
index: 0,
}

s.head = nil
s.tail = nil

return it
}

func (it *callbackListIterator) getNext() (CallbackFunc, bool) {
if it.current == nil {
return CallbackFunc{}, false
}

if it.index >= it.current.size {
prev := it.current
it.current = it.current.next

putCallbackSegment(prev)

it.index = 0

if it.current == nil {
return CallbackFunc{}, false
}
}

fn := it.current.funcs[it.index]
it.index++
return fn, true
}

// ===============================
// Pool of Callback Segments
// ===============================

var callbackSegmentPool = sync.Pool{
New: func() any {
return &callbackSegment{}
},
}

func getCallbackSegment() *callbackSegment {
return callbackSegmentPool.Get().(*callbackSegment)
}

func putCallbackSegment(s *callbackSegment) {
*s = callbackSegment{}
callbackSegmentPool.Put(s)
}
121 changes: 121 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package memproxy

import (
"fmt"
"testing"
"time"
"unsafe"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -695,3 +697,122 @@ func TestEmpty(t *testing.T) {

assert.Equal(t, 1, calls)
}

func iterateCallbackSegment(l *callbackList) {
it := l.getIterator()
for {
fn, ok := it.getNext()
if !ok {
break
}
fn.Call()
}
}

func newCallbackListTest() *callbackList {
var l callbackList
return &l
}

func TestCallbackSegment(t *testing.T) {
t.Run("empty", func(t *testing.T) {
l := newCallbackListTest()
assert.Equal(t, true, l.isEmpty())
iterateCallbackSegment(l)
})

t.Run("single", func(t *testing.T) {
l := newCallbackListTest()

var values []string

fn1 := &callMock{
fn: func() {
values = append(values, "fn1")
},
}
l.append(fn1.get())

assert.Equal(t, false, l.isEmpty())
iterateCallbackSegment(l)
assert.Equal(t, true, l.isEmpty())

assert.Equal(t, 1, fn1.count)
assert.Equal(t, []string{"fn1"}, values)
})

t.Run("multiple", func(t *testing.T) {
l := newCallbackListTest()

var values []string

fn1 := &callMock{
fn: func() {
values = append(values, "fn1")
},
}
fn2 := &callMock{
fn: func() {
values = append(values, "fn2")
},
}
fn3 := &callMock{
fn: func() {
values = append(values, "fn3")
},
}
l.append(fn1.get())
l.append(fn2.get())
l.append(fn3.get())

iterateCallbackSegment(l)

assert.Equal(t, 1, fn1.count)
assert.Equal(t, 1, fn2.count)
assert.Equal(t, 1, fn3.count)

assert.Equal(t, []string{"fn1", "fn2", "fn3"}, values)
})

t.Run("multiples of 16", func(t *testing.T) {
l := newCallbackListTest()

var values []string

for i := 0; i < 16*3; i++ {
index := i
fn := &callMock{
fn: func() {
values = append(values, fmt.Sprintf("fn%02d", index))
},
}
l.append(fn.get())
}

iterateCallbackSegment(l)

expected := make([]string, 16*3)
for i := range expected {
expected[i] = fmt.Sprintf("fn%02d", i)
}

assert.Equal(t, 16*3, len(values))
assert.Equal(t, expected, values)
})
}

func TestCallbackSegmentPool(t *testing.T) {
t.Run("normal", func(t *testing.T) {
x := getCallbackSegment()
x.size = 14

oldPtr := unsafe.Pointer(x)

putCallbackSegment(x)

x = getCallbackSegment()
assert.Equal(t, 0, x.size)

fmt.Println("POINTERS EQUAL:", oldPtr == unsafe.Pointer(x))
})
}

0 comments on commit 9c45710

Please sign in to comment.