Skip to content

Commit

Permalink
Add 30 second rate limit to CreateSnapshot
Browse files Browse the repository at this point in the history
See kubernetes-sigs#1608
See kubernetes-csi/external-snapshotter#778

This does not seek to be a comprehensive rate-limiting solution, but
rather to add a temporary workaround for the bug in the snapshotter
sidecar by refusing to call the CreateSnapshot for a specific volume
unless it has been 30 seconds since the last attempt.

Signed-off-by: Connor Catlett <conncatl@amazon.com>
  • Loading branch information
ConnorJC3 committed May 22, 2023
1 parent 771e745 commit 5e48e93
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
16 changes: 16 additions & 0 deletions pkg/cloud/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"fmt"
"os"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -167,6 +168,10 @@ var (

// VolumeNotBeingModified is returned if volume being described is not being modified
VolumeNotBeingModified = fmt.Errorf("volume is not being modified")

// ErrSnapshotRateLimit is returned when cloud.go refuses to send a CreateSnapshot AWS API request
// Used to work around https://github.com/kubernetes-csi/external-snapshotter/issues/778
ErrSnapshotRateLimit = fmt.Errorf("Refusing to send CreateSnapshot to AWS (see https://github.com/kubernetes-sigs/aws-ebs-csi-driver/issues/1608#issuecomment-1554748900)")
)

// Set during build time via -ldflags
Expand Down Expand Up @@ -776,7 +781,17 @@ func (c *cloud) IsExistInstance(ctx context.Context, nodeID string) bool {
return true
}

var lastCreateSnapshot sync.Map

func (c *cloud) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) {
if lastTimeUncast, ok := lastCreateSnapshot.Load(volumeID); ok {
if lastTime, ok := lastTimeUncast.(time.Time); ok {
if time.Since(lastTime) < time.Second*30 {
return nil, ErrSnapshotRateLimit
}
}
}

descriptions := "Created by AWS EBS CSI driver for volume " + volumeID

var tags []*ec2.Tag
Expand All @@ -796,6 +811,7 @@ func (c *cloud) CreateSnapshot(ctx context.Context, volumeID string, snapshotOpt
Description: aws.String(descriptions),
}

lastCreateSnapshot.Store(volumeID, time.Now())
res, err := c.ec2.CreateSnapshotWithContext(ctx, request)
if err != nil {
return nil, fmt.Errorf("error creating snapshot of volume %s: %w", volumeID, err)
Expand Down
38 changes: 38 additions & 0 deletions pkg/cloud/cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"sort"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
Expand Down Expand Up @@ -1081,6 +1082,43 @@ func TestCreateSnapshot(t *testing.T) {
}
}

func TestCreateSnapshotRateLimit(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockEC2 := NewMockEC2API(mockCtrl)
c := newCloud(mockEC2)

volumeId := "createsnapshot-ratelimit-test"

ec2snapshot := &ec2.Snapshot{
SnapshotId: aws.String("snapshot-id"),
VolumeId: aws.String(volumeId),
State: aws.String("completed"),
}

ctx := context.Background()
mockEC2.EXPECT().CreateSnapshotWithContext(gomock.Any(), gomock.Any()).Return(ec2snapshot, nil).Times(2)
mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{Snapshots: []*ec2.Snapshot{ec2snapshot}}, nil).AnyTimes()

_, err := c.CreateSnapshot(ctx, volumeId, &SnapshotOptions{Tags: map[string]string{}})
if err != nil {
t.Fatalf("CreateSnapshot() failed: expected no error, got: %v", err)
}
for i := 0; i < 10; i++ {
time.Sleep(2 * time.Second)
_, err := c.CreateSnapshot(ctx, volumeId, &SnapshotOptions{Tags: map[string]string{}})
if err != ErrSnapshotRateLimit {
t.Fatalf("CreateSnapshot() failed: expected ErrSnapshotRateLimit, got: %v", err)
}
}
time.Sleep(10 * time.Second)
_, err = c.CreateSnapshot(ctx, volumeId, &SnapshotOptions{Tags: map[string]string{}})
if err != nil {
t.Fatalf("CreateSnapshot() failed: expected no error, got: %v", err)
}

mockCtrl.Finish()
}

func TestEnableFastSnapshotRestores(t *testing.T) {
testCases := []struct {
name string
Expand Down

0 comments on commit 5e48e93

Please sign in to comment.