Skip to content

Commit fefcfd9

Browse files
committed
feat: support IPv6 CIDR for network restrictions
1 parent a5a5241 commit fefcfd9

File tree

4 files changed

+146
-66
lines changed

4 files changed

+146
-66
lines changed

cmd/restrictions.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package cmd
22

33
import (
4-
"github.com/spf13/afero"
54
"github.com/spf13/cobra"
65
"github.com/supabase/cli/internal/restrictions/get"
76
"github.com/supabase/cli/internal/restrictions/update"
@@ -22,15 +21,15 @@ var (
2221
Use: "update",
2322
Short: "Update network restrictions",
2423
RunE: func(cmd *cobra.Command, args []string) error {
25-
return update.Run(cmd.Context(), flags.ProjectRef, dbCidrsToAllow, bypassCidrChecks, afero.NewOsFs())
24+
return update.Run(cmd.Context(), flags.ProjectRef, dbCidrsToAllow, bypassCidrChecks)
2625
},
2726
}
2827

2928
restrictionsGetCmd = &cobra.Command{
3029
Use: "get",
3130
Short: "Get the current network restrictions",
3231
RunE: func(cmd *cobra.Command, args []string) error {
33-
return get.Run(cmd.Context(), flags.ProjectRef, afero.NewOsFs())
32+
return get.Run(cmd.Context(), flags.ProjectRef)
3433
},
3534
}
3635
)
@@ -41,6 +40,5 @@ func init() {
4140
restrictionsUpdateCmd.Flags().BoolVar(&bypassCidrChecks, "bypass-cidr-checks", false, "Bypass some of the CIDR validation checks.")
4241
restrictionsCmd.AddCommand(restrictionsGetCmd)
4342
restrictionsCmd.AddCommand(restrictionsUpdateCmd)
44-
4543
rootCmd.AddCommand(restrictionsCmd)
4644
}

internal/restrictions/get/get.go

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,20 @@ import (
55
"fmt"
66

77
"github.com/go-errors/errors"
8-
"github.com/spf13/afero"
98
"github.com/supabase/cli/internal/utils"
109
)
1110

12-
func Run(ctx context.Context, projectRef string, fsys afero.Fs) error {
13-
// 1. Sanity checks.
14-
// 2. get network restrictions
15-
{
16-
resp, err := utils.GetSupabase().GetNetworkRestrictionsWithResponse(ctx, projectRef)
17-
if err != nil {
18-
return errors.Errorf("failed to retrieve network restrictions: %w", err)
19-
}
20-
if resp.JSON200 == nil {
21-
return errors.New("failed to retrieve network restrictions; received: " + string(resp.Body))
22-
}
23-
24-
fmt.Printf("DB Allowed CIDRs: %+v\n", resp.JSON200.Config.DbAllowedCidrs)
25-
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == "applied")
26-
return nil
11+
func Run(ctx context.Context, projectRef string) error {
12+
resp, err := utils.GetSupabase().GetNetworkRestrictionsWithResponse(ctx, projectRef)
13+
if err != nil {
14+
return errors.Errorf("failed to retrieve network restrictions: %w", err)
15+
}
16+
if resp.JSON200 == nil {
17+
return errors.New("failed to retrieve network restrictions; received: " + string(resp.Body))
2718
}
19+
20+
fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", resp.JSON200.Config.DbAllowedCidrs)
21+
fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", resp.JSON200.Config.DbAllowedCidrsV6)
22+
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == "applied")
23+
return nil
2824
}

internal/restrictions/update/update.go

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,49 +6,42 @@ import (
66
"net"
77

88
"github.com/go-errors/errors"
9-
"github.com/spf13/afero"
109
"github.com/supabase/cli/internal/utils"
1110
"github.com/supabase/cli/pkg/api"
1211
)
1312

14-
func validateCidrs(cidrs []string, bypassChecks bool) error {
15-
for _, cidr := range cidrs {
13+
func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypassCidrChecks bool) error {
14+
// 1. separate CIDR to v4 and v6
15+
body := api.ApplyNetworkRestrictionsJSONRequestBody{
16+
DbAllowedCidrs: &[]string{},
17+
DbAllowedCidrsV6: &[]string{},
18+
}
19+
for _, cidr := range dbCidrsToAllow {
1620
ip, _, err := net.ParseCIDR(cidr)
1721
if err != nil {
1822
return errors.Errorf("failed to parse IP: %s", cidr)
1923
}
20-
if ip.IsPrivate() && !bypassChecks {
24+
if ip.IsPrivate() && !bypassCidrChecks {
2125
return errors.Errorf("private IP provided: %s", cidr)
2226
}
23-
if ip.To4() == nil {
24-
return errors.Errorf("only IPv4 supported at the moment: %s", cidr)
25-
}
26-
}
27-
return nil
28-
}
29-
30-
func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypassCidrChecks bool, fsys afero.Fs) error {
31-
// 1. sanity checks
32-
{
33-
err := validateCidrs(dbCidrsToAllow, bypassCidrChecks)
34-
if err != nil {
35-
return err
27+
if ip.To4() != nil {
28+
*body.DbAllowedCidrs = append(*body.DbAllowedCidrs, cidr)
29+
} else {
30+
*body.DbAllowedCidrsV6 = append(*body.DbAllowedCidrsV6, cidr)
3631
}
3732
}
3833

3934
// 2. update restrictions
40-
{
41-
resp, err := utils.GetSupabase().ApplyNetworkRestrictionsWithResponse(ctx, projectRef, api.ApplyNetworkRestrictionsJSONRequestBody{
42-
DbAllowedCidrs: dbCidrsToAllow,
43-
})
44-
if err != nil {
45-
return errors.Errorf("failed to apply network restrictions: %w", err)
46-
}
47-
if resp.JSON201 == nil {
48-
return errors.New("failed to update network restrictions: " + string(resp.Body))
49-
}
50-
fmt.Printf("DB Allowed CIDRs: %+v\n", resp.JSON201.Config.DbAllowedCidrs)
51-
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON201.Status == "applied")
52-
return nil
35+
resp, err := utils.GetSupabase().ApplyNetworkRestrictionsWithResponse(ctx, projectRef, body)
36+
if err != nil {
37+
return errors.Errorf("failed to apply network restrictions: %w", err)
5338
}
39+
if resp.JSON201 == nil {
40+
return errors.New("failed to apply network restrictions: " + string(resp.Body))
41+
}
42+
43+
fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", resp.JSON201.Config.DbAllowedCidrs)
44+
fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", resp.JSON201.Config.DbAllowedCidrsV6)
45+
fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON201.Status == "applied")
46+
return nil
5447
}
Lines changed: 109 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,121 @@
11
package update
22

33
import (
4+
"context"
5+
"net/http"
46
"testing"
57

8+
"github.com/go-errors/errors"
69
"github.com/stretchr/testify/assert"
10+
"github.com/supabase/cli/internal/testing/apitest"
11+
"github.com/supabase/cli/internal/utils"
12+
"github.com/supabase/cli/pkg/api"
13+
"gopkg.in/h2non/gock.v1"
714
)
815

9-
func TestPrivateSubnet(t *testing.T) {
10-
err := validateCidrs([]string{"12.3.4.5/32", "10.0.0.0/8", "1.2.3.1/24"}, false)
11-
assert.ErrorContains(t, err, "private IP provided: 10.0.0.0/8")
12-
err = validateCidrs([]string{"10.0.0.0/8"}, true)
13-
assert.Nil(t, err, "should bypass private subnet checks")
14-
}
16+
func TestUpdateRestrictionsCommand(t *testing.T) {
17+
projectRef := apitest.RandomProjectRef()
18+
// Setup valid access token
19+
token := apitest.RandomAccessToken(t)
20+
t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
21+
22+
t.Run("updates v4 and v6 CIDR", func(t *testing.T) {
23+
// Setup mock api
24+
defer gock.OffAll()
25+
gock.New(utils.DefaultApiHost).
26+
Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
27+
MatchType("json").
28+
JSON(api.NetworkRestrictionsRequest{
29+
DbAllowedCidrs: &[]string{"12.3.4.5/32", "1.2.3.1/24"},
30+
DbAllowedCidrsV6: &[]string{"2001:db8:abcd:0012::0/64"},
31+
}).
32+
Reply(http.StatusCreated).
33+
JSON(api.NetworkRestrictionsResponse{
34+
Status: api.NetworkRestrictionsResponseStatus("applied"),
35+
})
36+
// Run test
37+
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, false)
38+
// Check error
39+
assert.NoError(t, err)
40+
assert.Empty(t, apitest.ListUnmatchedRequests())
41+
})
42+
43+
t.Run("throws error on network failure", func(t *testing.T) {
44+
errNetwork := errors.New("network error")
45+
// Setup mock api
46+
defer gock.OffAll()
47+
gock.New(utils.DefaultApiHost).
48+
Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
49+
MatchType("json").
50+
JSON(api.NetworkRestrictionsRequest{
51+
DbAllowedCidrs: &[]string{},
52+
DbAllowedCidrsV6: &[]string{},
53+
}).
54+
ReplyError(errNetwork)
55+
// Run test
56+
err := Run(context.Background(), projectRef, []string{}, true)
57+
// Check error
58+
assert.ErrorIs(t, err, errNetwork)
59+
assert.Empty(t, apitest.ListUnmatchedRequests())
60+
})
1561

16-
func TestIpv4(t *testing.T) {
17-
err := validateCidrs([]string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, false)
18-
assert.ErrorContains(t, err, "only IPv4 supported at the moment: 2001:db8:abcd:0012::0/64")
19-
err = validateCidrs([]string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, true)
20-
assert.ErrorContains(t, err, "only IPv4 supported at the moment: 2001:db8:abcd:0012::0/64")
62+
t.Run("throws error on server unavailable", func(t *testing.T) {
63+
// Setup mock api
64+
defer gock.OffAll()
65+
gock.New(utils.DefaultApiHost).
66+
Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
67+
MatchType("json").
68+
JSON(api.NetworkRestrictionsRequest{
69+
DbAllowedCidrs: &[]string{},
70+
DbAllowedCidrsV6: &[]string{},
71+
}).
72+
Reply(http.StatusServiceUnavailable)
73+
// Run test
74+
err := Run(context.Background(), projectRef, []string{}, true)
75+
// Check error
76+
assert.ErrorContains(t, err, "failed to apply network restrictions:")
77+
assert.Empty(t, apitest.ListUnmatchedRequests())
78+
})
2179
}
2280

23-
func TestInvalidSubnets(t *testing.T) {
24-
err := validateCidrs([]string{"12.3.4.5", "10.0.0.0/8", "1.2.3.1/24"}, false)
25-
assert.ErrorContains(t, err, "failed to parse IP: 12.3.4.5")
26-
err = validateCidrs([]string{"100/36"}, true)
27-
assert.ErrorContains(t, err, "failed to parse IP: 100/36")
81+
func TestValidateCIDR(t *testing.T) {
82+
projectRef := apitest.RandomProjectRef()
83+
// Setup valid access token
84+
token := apitest.RandomAccessToken(t)
85+
t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
86+
87+
t.Run("bypasses private subnet checks", func(t *testing.T) {
88+
// Setup mock api
89+
defer gock.OffAll()
90+
gock.New(utils.DefaultApiHost).
91+
Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
92+
MatchType("json").
93+
JSON(api.NetworkRestrictionsRequest{
94+
DbAllowedCidrs: &[]string{"10.0.0.0/8"},
95+
DbAllowedCidrsV6: &[]string{},
96+
}).
97+
Reply(http.StatusCreated).
98+
JSON(api.NetworkRestrictionsResponse{
99+
Status: api.NetworkRestrictionsResponseStatus("applied"),
100+
})
101+
// Run test
102+
err := Run(context.Background(), projectRef, []string{"10.0.0.0/8"}, true)
103+
// Check error
104+
assert.NoError(t, err)
105+
assert.Empty(t, apitest.ListUnmatchedRequests())
106+
})
107+
108+
t.Run("throws error on private subnet", func(t *testing.T) {
109+
// Run test
110+
err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "10.0.0.0/8", "1.2.3.1/24"}, false)
111+
// Check error
112+
assert.ErrorContains(t, err, "private IP provided: 10.0.0.0/8")
113+
})
114+
115+
t.Run("throws error on invalid subnet", func(t *testing.T) {
116+
// Run test
117+
err := Run(context.Background(), projectRef, []string{"12.3.4.5", "10.0.0.0/8", "1.2.3.1/24"}, false)
118+
// Check error
119+
assert.ErrorContains(t, err, "failed to parse IP: 12.3.4.5")
120+
})
28121
}

0 commit comments

Comments
 (0)