Skip to content

Commit

Permalink
refactor: add context parameter to APNS client functions
Browse files Browse the repository at this point in the history
- Add context parameter to `InitAPNSClient` and `PushToIOS` functions
- Update calls to `InitAPNSClient` and `PushToIOS` to include context parameter
- Modify APNs client push method to use `PushWithContext`
- Add context import in `notification_apns.go`
- Update tests to include context parameter in `InitAPNSClient` calls

Signed-off-by: appleboy <appleboy.tw@gmail.com>
  • Loading branch information
appleboy committed Jun 15, 2024
1 parent 228ec1d commit 9702473
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 32 deletions.
6 changes: 3 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,11 @@ func main() {
return
}

if err := notify.InitAPNSClient(cfg); err != nil {
if err := notify.InitAPNSClient(g.ShutdownContext(), cfg); err != nil {
return
}

if _, err := notify.PushToIOS(req, cfg); err != nil {
if _, err := notify.PushToIOS(g.ShutdownContext(), req, cfg); err != nil {
return
}

Expand Down Expand Up @@ -359,7 +359,7 @@ func main() {
})

if cfg.Ios.Enabled {
if err = notify.InitAPNSClient(cfg); err != nil {
if err = notify.InitAPNSClient(g.ShutdownContext(), cfg); err != nil {
logx.LogError.Fatal(err)
}
}
Expand Down
2 changes: 1 addition & 1 deletion notify/notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func SendNotification(

switch v.Platform {
case core.PlatFormIos:
resp, err = PushToIOS(v, cfg)
resp, err = PushToIOS(ctx, v, cfg)
case core.PlatFormAndroid:
resp, err = PushToAndroid(ctx, v, cfg)
case core.PlatFormHuawei:
Expand Down
7 changes: 4 additions & 3 deletions notify/notification_apns.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package notify

import (
"context"
"crypto/ecdsa"
"crypto/tls"
"encoding/base64"
Expand Down Expand Up @@ -58,7 +59,7 @@ type Sound struct {
}

// InitAPNSClient use for initialize APNs Client.
func InitAPNSClient(cfg *config.ConfYaml) error {
func InitAPNSClient(ctx context.Context, cfg *config.ConfYaml) error {
if cfg.Ios.Enabled {
var err error
var authKey *ecdsa.PrivateKey
Expand Down Expand Up @@ -401,7 +402,7 @@ func getApnsClient(cfg *config.ConfYaml, req *PushNotification) (client *apns2.C
}

// PushToIOS provide send notification to APNs server.
func PushToIOS(req *PushNotification, cfg *config.ConfYaml) (resp *ResponsePush, err error) {
func PushToIOS(ctx context.Context, req *PushNotification, cfg *config.ConfYaml) (resp *ResponsePush, err error) {
logx.LogAccess.Debug("Start push notification for iOS")

var (
Expand Down Expand Up @@ -430,7 +431,7 @@ Retry:
notification.DeviceToken = token

// send ios notification
res, err := client.Push(&notification)
res, err := client.PushWithContext(ctx, &notification)
if err != nil || (res != nil && res.StatusCode != http.StatusOK) {
if err == nil {
// error message:
Expand Down
40 changes: 20 additions & 20 deletions notify/notification_apns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,15 +575,15 @@ func TestWrongIosCertificateExt(t *testing.T) {

cfg.Ios.Enabled = true
cfg.Ios.KeyPath = "test"
err := InitAPNSClient(cfg)
err := InitAPNSClient(context.Background(), cfg)

assert.Error(t, err)
assert.Equal(t, "wrong certificate key extension", err.Error())

cfg.Ios.KeyPath = ""
cfg.Ios.KeyBase64 = "abcd"
cfg.Ios.KeyType = "abcd"
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)

assert.Error(t, err)
assert.Equal(t, "wrong certificate key type", err.Error())
Expand All @@ -594,14 +594,14 @@ func TestAPNSClientDevHost(t *testing.T) {

cfg.Ios.Enabled = true
cfg.Ios.KeyPath = "../certificate/certificate-valid.p12"
err := InitAPNSClient(cfg)
err := InitAPNSClient(context.Background(), cfg)
assert.Nil(t, err)
assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host)

cfg.Ios.KeyPath = ""
cfg.Ios.KeyBase64 = certificateValidP12
cfg.Ios.KeyType = "p12"
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)
assert.Nil(t, err)
assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host)
}
Expand All @@ -612,14 +612,14 @@ func TestAPNSClientProdHost(t *testing.T) {
cfg.Ios.Enabled = true
cfg.Ios.Production = true
cfg.Ios.KeyPath = testKeyPath
err := InitAPNSClient(cfg)
err := InitAPNSClient(context.Background(), cfg)
assert.Nil(t, err)
assert.Equal(t, apns2.HostProduction, ApnsClient.Host)

cfg.Ios.KeyPath = ""
cfg.Ios.KeyBase64 = certificateValidPEM
cfg.Ios.KeyType = "pem"
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)
assert.Nil(t, err)
assert.Equal(t, apns2.HostProduction, ApnsClient.Host)
}
Expand All @@ -629,29 +629,29 @@ func TestAPNSClientInvaildToken(t *testing.T) {

cfg.Ios.Enabled = true
cfg.Ios.KeyPath = "../certificate/authkey-invalid.p8"
err := InitAPNSClient(cfg)
err := InitAPNSClient(context.Background(), cfg)
assert.Error(t, err)

cfg.Ios.KeyPath = ""
cfg.Ios.KeyBase64 = authkeyInvalidP8
cfg.Ios.KeyType = "p8"
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)
assert.Error(t, err)

// empty key-id or team-id
cfg.Ios.Enabled = true
cfg.Ios.KeyPath = testKeyPathP8
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)
assert.Error(t, err)

cfg.Ios.KeyID = "key-id"
cfg.Ios.TeamID = ""
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)
assert.Error(t, err)

cfg.Ios.KeyID = ""
cfg.Ios.TeamID = "team-id"
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)
assert.Error(t, err)
}

Expand All @@ -662,12 +662,12 @@ func TestAPNSClientVaildToken(t *testing.T) {
cfg.Ios.KeyPath = testKeyPathP8
cfg.Ios.KeyID = "key-id"
cfg.Ios.TeamID = "team-id"
err := InitAPNSClient(cfg)
err := InitAPNSClient(context.Background(), cfg)
assert.NoError(t, err)
assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host)

cfg.Ios.Production = true
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)
assert.NoError(t, err)
assert.Equal(t, apns2.HostProduction, ApnsClient.Host)

Expand All @@ -676,12 +676,12 @@ func TestAPNSClientVaildToken(t *testing.T) {
cfg.Ios.KeyPath = ""
cfg.Ios.KeyBase64 = authkeyValidP8
cfg.Ios.KeyType = "p8"
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)
assert.NoError(t, err)
assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host)

cfg.Ios.Production = true
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)
assert.NoError(t, err)
assert.Equal(t, apns2.HostProduction, ApnsClient.Host)
}
Expand All @@ -693,7 +693,7 @@ func TestAPNSClientUseProxy(t *testing.T) {
cfg.Ios.KeyPath = "../certificate/certificate-valid.p12"
cfg.Core.HTTPProxy = "http://127.0.0.1:8080"
_ = SetProxy(cfg.Core.HTTPProxy)
err := InitAPNSClient(cfg)
err := InitAPNSClient(context.Background(), cfg)
assert.Nil(t, err)
assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host)

Expand All @@ -707,7 +707,7 @@ func TestAPNSClientUseProxy(t *testing.T) {
cfg.Ios.KeyPath = testKeyPathP8
cfg.Ios.TeamID = "example.team"
cfg.Ios.KeyID = "example.key"
err = InitAPNSClient(cfg)
err = InitAPNSClient(context.Background(), cfg)
assert.Nil(t, err)
assert.Equal(t, apns2.HostDevelopment, ApnsClient.Host)
assert.NotNil(t, ApnsClient.Token)
Expand All @@ -728,7 +728,7 @@ func TestPushToIOS(t *testing.T) {

cfg.Ios.Enabled = true
cfg.Ios.KeyPath = testKeyPath
err := InitAPNSClient(cfg)
err := InitAPNSClient(context.Background(), cfg)
assert.Nil(t, err)
err = status.InitAppStatus(cfg)
assert.Nil(t, err)
Expand All @@ -741,7 +741,7 @@ func TestPushToIOS(t *testing.T) {
}

// send fail
resp, err := PushToIOS(req, cfg)
resp, err := PushToIOS(context.Background(), req, cfg)
assert.Nil(t, err)
assert.Len(t, resp.Logs, 2)
}
Expand All @@ -751,7 +751,7 @@ func TestApnsHostFromRequest(t *testing.T) {

cfg.Ios.Enabled = true
cfg.Ios.KeyPath = testKeyPath
err := InitAPNSClient(cfg)
err := InitAPNSClient(context.Background(), cfg)
assert.Nil(t, err)
err = status.InitAppStatus(cfg)
assert.Nil(t, err)
Expand Down
10 changes: 5 additions & 5 deletions router/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestMain(m *testing.M) {
cfg.Android.Enabled = true
cfg.Android.Credential = os.Getenv("FCM_CREDENTIAL")

if _, err := notify.InitFCMClient(cfg); err != nil {
if _, err := notify.InitFCMClient(context.Background(), cfg); err != nil {
log.Fatal(err)
}

Expand Down Expand Up @@ -474,7 +474,7 @@ func TestSenMultipleNotifications(t *testing.T) {

cfg.Ios.Enabled = true
cfg.Ios.KeyPath = testKeyPath
err := notify.InitAPNSClient(cfg)
err := notify.InitAPNSClient(ctx, cfg)
assert.Nil(t, err)

cfg.Android.Enabled = true
Expand Down Expand Up @@ -510,7 +510,7 @@ func TestDisabledAndroidNotifications(t *testing.T) {

cfg.Ios.Enabled = true
cfg.Ios.KeyPath = testKeyPath
err := notify.InitAPNSClient(cfg)
err := notify.InitAPNSClient(ctx, cfg)
assert.Nil(t, err)

cfg.Android.Enabled = false
Expand Down Expand Up @@ -546,7 +546,7 @@ func TestSyncModeForNotifications(t *testing.T) {

cfg.Ios.Enabled = true
cfg.Ios.KeyPath = testKeyPath
err := notify.InitAPNSClient(cfg)
err := notify.InitAPNSClient(ctx, cfg)
assert.Nil(t, err)

cfg.Android.Enabled = true
Expand Down Expand Up @@ -658,7 +658,7 @@ func TestDisabledIosNotifications(t *testing.T) {

cfg.Ios.Enabled = false
cfg.Ios.KeyPath = testKeyPath
err := notify.InitAPNSClient(cfg)
err := notify.InitAPNSClient(ctx, cfg)
assert.Nil(t, err)

cfg.Android.Enabled = true
Expand Down

0 comments on commit 9702473

Please sign in to comment.