forked from hashicorp/vault
-
Notifications
You must be signed in to change notification settings - Fork 2
/
cors.go
154 lines (122 loc) · 3.56 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package vault
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
)
const (
CORSDisabled uint32 = iota
CORSEnabled
)
var StdAllowedHeaders = []string{
"Content-Type",
"X-Requested-With",
"X-Vault-AWS-IAM-Server-ID",
"X-Vault-MFA",
"X-Vault-No-Request-Forwarding",
"X-Vault-Token",
"X-Vault-Wrap-Format",
"X-Vault-Wrap-TTL",
"X-Vault-Policy-Override",
}
// CORSConfig stores the state of the CORS configuration.
type CORSConfig struct {
sync.RWMutex `json:"-"`
core *Core
Enabled uint32 `json:"enabled"`
AllowedOrigins []string `json:"allowed_origins,omitempty"`
AllowedHeaders []string `json:"allowed_headers,omitempty"`
}
func (c *Core) saveCORSConfig() error {
view := c.systemBarrierView.SubView("config/")
localConfig := &CORSConfig{
Enabled: atomic.LoadUint32(&c.corsConfig.Enabled),
}
c.corsConfig.RLock()
localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins
localConfig.AllowedHeaders = c.corsConfig.AllowedHeaders
c.corsConfig.RUnlock()
entry, err := logical.StorageEntryJSON("cors", localConfig)
if err != nil {
return fmt.Errorf("failed to create CORS config entry: %v", err)
}
if err := view.Put(entry); err != nil {
return fmt.Errorf("failed to save CORS config: %v", err)
}
return nil
}
// This should only be called with the core state lock held for writing
func (c *Core) loadCORSConfig() error {
view := c.systemBarrierView.SubView("config/")
// Load the config in
out, err := view.Get("cors")
if err != nil {
return fmt.Errorf("failed to read CORS config: %v", err)
}
if out == nil {
return nil
}
newConfig := new(CORSConfig)
err = out.DecodeJSON(newConfig)
if err != nil {
return err
}
newConfig.core = c
c.corsConfig = newConfig
return nil
}
// Enable takes either a '*' or a comma-seprated list of URLs that can make
// cross-origin requests to Vault.
func (c *CORSConfig) Enable(urls []string, headers []string) error {
if len(urls) == 0 {
return errors.New("at least one origin or the wildcard must be provided.")
}
if strutil.StrListContains(urls, "*") && len(urls) > 1 {
return errors.New("to allow all origins the '*' must be the only value for allowed_origins")
}
c.Lock()
c.AllowedOrigins = urls
// Start with the standard headers to Vault accepts.
c.AllowedHeaders = append(c.AllowedHeaders, StdAllowedHeaders...)
// Allow the user to add additional headers to the list of
// headers allowed on cross-origin requests.
if len(headers) > 0 {
c.AllowedHeaders = append(c.AllowedHeaders, headers...)
}
c.Unlock()
atomic.StoreUint32(&c.Enabled, CORSEnabled)
return c.core.saveCORSConfig()
}
// IsEnabled returns the value of CORSConfig.isEnabled
func (c *CORSConfig) IsEnabled() bool {
return atomic.LoadUint32(&c.Enabled) == CORSEnabled
}
// Disable sets CORS to disabled and clears the allowed origins & headers.
func (c *CORSConfig) Disable() error {
atomic.StoreUint32(&c.Enabled, CORSDisabled)
c.Lock()
c.AllowedOrigins = nil
c.AllowedHeaders = nil
c.Unlock()
return c.core.saveCORSConfig()
}
// IsValidOrigin determines if the origin of the request is allowed to make
// cross-origin requests based on the CORSConfig.
func (c *CORSConfig) IsValidOrigin(origin string) bool {
// If we aren't enabling CORS then all origins are valid
if !c.IsEnabled() {
return true
}
c.RLock()
defer c.RUnlock()
if len(c.AllowedOrigins) == 0 {
return false
}
if len(c.AllowedOrigins) == 1 && (c.AllowedOrigins)[0] == "*" {
return true
}
return strutil.StrListContains(c.AllowedOrigins, origin)
}