diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index a4db858162..b94e9dcba3 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -24,6 +24,7 @@ import ( "fmt" "os" "strings" + "sync" "time" "github.com/aws/aws-sdk-go/aws" @@ -167,6 +168,10 @@ var ( // VolumeNotBeingModified is returned if volume being described is not being modified VolumeNotBeingModified = fmt.Errorf("volume is not being modified") + + // ErrSnapshotRateLimit is returned when cloud.go refuses to send a CreateSnapshot AWS API request + // Used to work around https://github.com/kubernetes-csi/external-snapshotter/issues/778 + ErrSnapshotRateLimit = fmt.Errorf("Refusing to send CreateSnapshot to AWS (see https://github.com/kubernetes-sigs/aws-ebs-csi-driver/issues/1608#issuecomment-1554748900)") ) // Set during build time via -ldflags @@ -776,7 +781,17 @@ func (c *cloud) IsExistInstance(ctx context.Context, nodeID string) bool { return true } +var lastCreateSnapshot sync.Map + func (c *cloud) CreateSnapshot(ctx context.Context, volumeID string, snapshotOptions *SnapshotOptions) (snapshot *Snapshot, err error) { + if lastTimeUncast, ok := lastCreateSnapshot.Load(volumeID); ok { + if lastTime, ok := lastTimeUncast.(time.Time); ok { + if time.Since(lastTime) < time.Second*30 { + return nil, ErrSnapshotRateLimit + } + } + } + descriptions := "Created by AWS EBS CSI driver for volume " + volumeID var tags []*ec2.Tag @@ -796,6 +811,7 @@ func (c *cloud) CreateSnapshot(ctx context.Context, volumeID string, snapshotOpt Description: aws.String(descriptions), } + lastCreateSnapshot.Store(volumeID, time.Now()) res, err := c.ec2.CreateSnapshotWithContext(ctx, request) if err != nil { return nil, fmt.Errorf("error creating snapshot of volume %s: %w", volumeID, err) diff --git a/pkg/cloud/cloud_test.go b/pkg/cloud/cloud_test.go index 00241cbc31..ad9c83bc3c 100644 --- a/pkg/cloud/cloud_test.go +++ b/pkg/cloud/cloud_test.go @@ -24,6 +24,7 @@ import ( "sort" "strings" "testing" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" @@ -1081,6 +1082,43 @@ func TestCreateSnapshot(t *testing.T) { } } +func TestCreateSnapshotRateLimit(t *testing.T) { + mockCtrl := gomock.NewController(t) + mockEC2 := NewMockEC2API(mockCtrl) + c := newCloud(mockEC2) + + volumeId := "createsnapshot-ratelimit-test" + + ec2snapshot := &ec2.Snapshot{ + SnapshotId: aws.String("snapshot-id"), + VolumeId: aws.String(volumeId), + State: aws.String("completed"), + } + + ctx := context.Background() + mockEC2.EXPECT().CreateSnapshotWithContext(gomock.Any(), gomock.Any()).Return(ec2snapshot, nil).Times(2) + mockEC2.EXPECT().DescribeSnapshotsWithContext(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSnapshotsOutput{Snapshots: []*ec2.Snapshot{ec2snapshot}}, nil).AnyTimes() + + _, err := c.CreateSnapshot(ctx, volumeId, &SnapshotOptions{Tags: map[string]string{}}) + if err != nil { + t.Fatalf("CreateSnapshot() failed: expected no error, got: %v", err) + } + for i := 0; i < 10; i++ { + time.Sleep(2 * time.Second) + _, err := c.CreateSnapshot(ctx, volumeId, &SnapshotOptions{Tags: map[string]string{}}) + if err != ErrSnapshotRateLimit { + t.Fatalf("CreateSnapshot() failed: expected ErrSnapshotRateLimit, got: %v", err) + } + } + time.Sleep(10 * time.Second) + _, err = c.CreateSnapshot(ctx, volumeId, &SnapshotOptions{Tags: map[string]string{}}) + if err != nil { + t.Fatalf("CreateSnapshot() failed: expected no error, got: %v", err) + } + + mockCtrl.Finish() +} + func TestEnableFastSnapshotRestores(t *testing.T) { testCases := []struct { name string