forked from hashicorp/vault
-
Notifications
You must be signed in to change notification settings - Fork 2
/
cors.go
127 lines (101 loc) · 2.84 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package vault
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
)
const (
CORSDisabled uint32 = iota
CORSEnabled
)
// CORSConfig stores the state of the CORS configuration.
type CORSConfig struct {
sync.RWMutex `json:"-"`
core *Core
Enabled uint32 `json:"enabled"`
AllowedOrigins []string `json:"allowed_origins,omitempty"`
}
func (c *Core) saveCORSConfig() error {
view := c.systemBarrierView.SubView("config/")
localConfig := &CORSConfig{
Enabled: atomic.LoadUint32(&c.corsConfig.Enabled),
}
c.corsConfig.RLock()
localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins
c.corsConfig.RUnlock()
entry, err := logical.StorageEntryJSON("cors", localConfig)
if err != nil {
return fmt.Errorf("failed to create CORS config entry: %v", err)
}
if err := view.Put(entry); err != nil {
return fmt.Errorf("failed to save CORS config: %v", err)
}
return nil
}
// This should only be called with the core state lock held for writing
func (c *Core) loadCORSConfig() error {
view := c.systemBarrierView.SubView("config/")
// Load the config in
out, err := view.Get("cors")
if err != nil {
return fmt.Errorf("failed to read CORS config: %v", err)
}
if out == nil {
return nil
}
newConfig := new(CORSConfig)
err = out.DecodeJSON(newConfig)
if err != nil {
return err
}
newConfig.core = c
c.corsConfig = newConfig
return nil
}
// Enable takes either a '*' or a comma-seprated list of URLs that can make
// cross-origin requests to Vault.
func (c *CORSConfig) Enable(urls []string) error {
if len(urls) == 0 {
return errors.New("the list of allowed origins cannot be empty")
}
if strutil.StrListContains(urls, "*") && len(urls) > 1 {
return errors.New("to allow all origins the '*' must be the only value for allowed_origins")
}
c.Lock()
c.AllowedOrigins = urls
c.Unlock()
atomic.StoreUint32(&c.Enabled, CORSEnabled)
return c.core.saveCORSConfig()
}
// IsEnabled returns the value of CORSConfig.isEnabled
func (c *CORSConfig) IsEnabled() bool {
return atomic.LoadUint32(&c.Enabled) == CORSEnabled
}
// Disable sets CORS to disabled and clears the allowed origins
func (c *CORSConfig) Disable() error {
atomic.StoreUint32(&c.Enabled, CORSDisabled)
c.Lock()
c.AllowedOrigins = []string(nil)
c.Unlock()
return c.core.saveCORSConfig()
}
// IsValidOrigin determines if the origin of the request is allowed to make
// cross-origin requests based on the CORSConfig.
func (c *CORSConfig) IsValidOrigin(origin string) bool {
// If we aren't enabling CORS then all origins are valid
if !c.IsEnabled() {
return true
}
c.RLock()
defer c.RUnlock()
if len(c.AllowedOrigins) == 0 {
return false
}
if len(c.AllowedOrigins) == 1 && (c.AllowedOrigins)[0] == "*" {
return true
}
return strutil.StrListContains(c.AllowedOrigins, origin)
}