Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve polling updater.stop() call when long polling #152

Merged
merged 2 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions ext/botmapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,15 @@ func (m *botMapping) getHandlerFunc(prefix string) func(writer http.ResponseWrit
w.WriteHeader(http.StatusNotFound)
return
}

b.updateWriterControl.Add(1)
defer b.updateWriterControl.Done()

if b.shouldStopUpdates() {
w.WriteHeader(http.StatusServiceUnavailable)
return
}

headerSecret := r.Header.Get("X-Telegram-Bot-Api-Secret-Token")
if b.webhookSecret != "" && b.webhookSecret != headerSecret {
// Drop any updates from invalid secret tokens.
Expand All @@ -184,10 +190,6 @@ func (m *botMapping) getHandlerFunc(prefix string) func(writer http.ResponseWrit
return
}

if b.isUpdateChannelStopped() {
return
}

b.updateChan <- bytes
}
}
Expand All @@ -213,7 +215,7 @@ func (b *botData) stop() {
close(b.updateChan)
}

func (b *botData) isUpdateChannelStopped() bool {
func (b *botData) shouldStopUpdates() bool {
select {
case <-b.stopUpdates:
// if anything comes in on the closing channel, we know the channel is closed.
Expand Down
4 changes: 2 additions & 2 deletions ext/botmapping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ func Test_botData_isUpdateChannelStopped(t *testing.T) {
t.Errorf("bot with token %s should not have failed to be added", b.Token)
return
}
if bData.isUpdateChannelStopped() {
if bData.shouldStopUpdates() {
t.Errorf("bot with token %s should not be stopped yet", b.Token)
return
}

bData.stop()
if !bData.isUpdateChannelStopped() {
if !bData.shouldStopUpdates() {
t.Errorf("bot with token %s should be stopped", b.Token)
return
}
Expand Down
12 changes: 8 additions & 4 deletions ext/updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[s
defer bData.updateWriterControl.Done()

for {
// Check if updater loop has been terminated.
if bData.shouldStopUpdates() {
return
}

// Manually craft the getUpdate calls to improve memory management, reduce json parsing overheads, and
// unnecessary reallocation of url.Values in the polling loop.
r, err := bData.bot.Request("getUpdates", v, nil, opts)
Expand Down Expand Up @@ -219,10 +224,6 @@ func (u *Updater) pollingLoop(bData *botData, opts *gotgbot.RequestOpts, v map[s

v["offset"] = strconv.FormatInt(lastUpdate.UpdateId+1, 10)

if bData.isUpdateChannelStopped() {
return
}

for _, updData := range rawUpdates {
temp := updData // use new mem address to avoid loop conflicts
bData.updateChan <- temp
Expand All @@ -240,6 +241,9 @@ func (u *Updater) Idle() {
}

// Stop stops the current updater and dispatcher instances.
//
// When using long polling, Stop() will wait for the getUpdates call to return, which may cause a delay due to the
// request timeout.
func (u *Updater) Stop() error {
// Stop any running servers.
if u.webhookServer != nil {
Expand Down
32 changes: 24 additions & 8 deletions ext/updater_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -98,8 +99,14 @@ func concurrentTest(t *testing.T) {
t.Parallel()

delay := time.Second
server := basicTestServer(t, map[string]testEndpoint{
"getUpdates": {delay: delay, reply: `{"ok": true, "result": [{"message": {"text": "stop"}}]}`},
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {
delay: delay,
replies: []string{
`{"ok": true, "result": [{"message": {"text": "stop"}}]}`,
},
reply: `{"ok": true, "result": []}`,
},
"deleteWebhook": {reply: `{"ok": true, "result": true}`},
})
defer server.Close()
Expand Down Expand Up @@ -290,7 +297,7 @@ func TestUpdater_GetHandlerFunc(t *testing.T) {
}

func TestUpdaterAllowsWebhookDeletion(t *testing.T) {
server := basicTestServer(t, map[string]testEndpoint{
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {reply: `{"ok": true}`},
"deleteWebhook": {reply: `{"ok": true, "result": true}`},
})
Expand Down Expand Up @@ -329,7 +336,7 @@ func TestUpdaterAllowsWebhookDeletion(t *testing.T) {
}

func TestUpdaterSupportsTwoPollingBots(t *testing.T) {
server := basicTestServer(t, map[string]testEndpoint{
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {reply: `{"ok": true, "result": []}`},
})
defer server.Close()
Expand Down Expand Up @@ -384,7 +391,7 @@ func TestUpdaterSupportsTwoPollingBots(t *testing.T) {
}

func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) {
server := basicTestServer(t, map[string]testEndpoint{
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {reply: `{"ok": true, "result": []}`},
})
defer server.Close()
Expand Down Expand Up @@ -432,7 +439,7 @@ func TestUpdaterThrowsErrorWhenSameLongPollAddedTwice(t *testing.T) {
}

func TestUpdaterSupportsLongPollReAdding(t *testing.T) {
server := basicTestServer(t, map[string]testEndpoint{
server := basicTestServer(t, map[string]*testEndpoint{
"getUpdates": {reply: `{"ok": true, "result": []}`},
})
defer server.Close()
Expand Down Expand Up @@ -484,10 +491,14 @@ func TestUpdaterSupportsLongPollReAdding(t *testing.T) {

type testEndpoint struct {
delay time.Duration
// Will reply these until we run out of replies, at which point we repeat "reply"
replies []string
idx atomic.Int32
// default reply
reply string
}

func basicTestServer(t *testing.T, methods map[string]testEndpoint) *httptest.Server {
func basicTestServer(t *testing.T, methods map[string]*testEndpoint) *httptest.Server {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pathItems := strings.Split(r.URL.Path, "/")
lastItem := pathItems[len(pathItems)-1]
Expand All @@ -498,7 +509,12 @@ func basicTestServer(t *testing.T, methods map[string]testEndpoint) *httptest.Se
if out.delay != 0 {
time.Sleep(out.delay)
}
fmt.Fprint(w, out.reply)
count := int(out.idx.Add(1) - 1)
if len(out.replies) != 0 && len(out.replies) > count {
fmt.Fprint(w, out.replies[count])
} else {
fmt.Fprint(w, out.reply)
}
return
}

Expand Down
17 changes: 12 additions & 5 deletions samples/echoMultiBot/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ func main() {

// If we get here, the updater.Idle() has ended.
// This means that updater.Stop() has been called, stopping all bots gracefully.
log.Println("Updater is no longer idling; all bots have been stopped gracefully.")
log.Println("Updater is no longer idling; all bots have been stopped gracefully. Exiting in 1s.")

// We sleep one last second to allow for the "stopall" goroutine to send the shutdown message.
time.Sleep(time.Second)
}

// startLongPollingBots demonstrates how to start multiple bots with long-polling.
Expand Down Expand Up @@ -159,11 +162,14 @@ func stop(b *gotgbot.Bot, ctx *ext.Context, updater *ext.Updater) error {
return fmt.Errorf("failed to echo message: %w", err)
}

if !updater.StopBot(b.Token) {
ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Unable to find bot %d; was it already stopped?", b.Id), nil)
return nil
}
go func() {
if !updater.StopBot(b.Token) {
ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Unable to find bot %d; was it already stopped?", b.Id), nil)
return
}

ctx.EffectiveMessage.Reply(b, "Stopped @"+b.Username, nil)
}()
return nil
}

Expand All @@ -181,6 +187,7 @@ func stopAll(b *gotgbot.Bot, ctx *ext.Context, updater *ext.Updater) error {
ctx.EffectiveMessage.Reply(b, fmt.Sprintf("Failed to stop updater: %s", err.Error()), nil)
return
}
ctx.EffectiveMessage.Reply(b, "All bots have been stopped.", nil)
}()

return nil
Expand Down
Loading