Skip to content

Commit be4b190

Browse files
authored
ctxkey: add type-safe context keys (#123)
1 parent ae7d17e commit be4b190

10 files changed

Lines changed: 299 additions & 52 deletions

File tree

.devtools/config.txtar

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
-- copyright/exclusions.json --
22
[
3+
"ctxkey/ctxkey.go",
4+
"ctxkey/ctxkey_test.go",
35
"web/internal/hashfs/*.go",
46
"web/internal/unionfs/*.go",
57
"rr/*.go",

cli/cli.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ import (
8282
"strings"
8383
"time"
8484

85+
"go.astrophena.name/base/ctxkey"
8586
"go.astrophena.name/base/logger"
8687
"go.astrophena.name/base/syncx"
8788
"go.astrophena.name/base/version"
@@ -179,24 +180,19 @@ func (f AppFunc) Run(ctx context.Context) error {
179180
return f(ctx)
180181
}
181182

182-
type ctxKey int
183-
184-
var envKey ctxKey
183+
var envKey = ctxkey.New[*Env]("cli.Env", nil)
185184

186185
// GetEnv retrieves the application's environment from a context.
187186
// If the context has no environment, it returns one based on the current OS.
188187
func GetEnv(ctx context.Context) *Env {
189-
e, ok := ctx.Value(envKey).(*Env)
190-
if !ok {
191-
return OSEnv()
188+
if e, ok := envKey.ValueOk(ctx); ok {
189+
return e
192190
}
193-
return e
191+
return OSEnv()
194192
}
195193

196194
// WithEnv returns a new context that carries the provided [Env].
197-
func WithEnv(ctx context.Context, e *Env) context.Context {
198-
return context.WithValue(ctx, envKey, e)
199-
}
195+
func WithEnv(ctx context.Context, e *Env) context.Context { return envKey.WithValue(ctx, e) }
200196

201197
// Env encapsulates the application's environment, including arguments,
202198
// standard I/O streams, and environment variables.

ctxkey/LICENSE

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
BSD 3-Clause License
2+
3+
Copyright (c) 2020 Tailscale Inc & contributors.
4+
5+
Redistribution and use in source and binary forms, with or without
6+
modification, are permitted provided that the following conditions are met:
7+
8+
1. Redistributions of source code must retain the above copyright notice, this
9+
list of conditions and the following disclaimer.
10+
11+
2. Redistributions in binary form must reproduce the above copyright notice,
12+
this list of conditions and the following disclaimer in the documentation
13+
and/or other materials provided with the distribution.
14+
15+
3. Neither the name of the copyright holder nor the names of its
16+
contributors may be used to endorse or promote products derived from
17+
this software without specific prior written permission.
18+
19+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

ctxkey/ctxkey.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Copyright (c) Tailscale Inc & contributors
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
// ctxkey provides type-safe key-value pairs for use with [context.Context].
5+
//
6+
// Example usage:
7+
//
8+
// // Create a context key.
9+
// var TimeoutKey = ctxkey.New("mapreduce.Timeout", 5*time.Second)
10+
//
11+
// // Store a context value.
12+
// ctx = mapreduce.TimeoutKey.WithValue(ctx, 10*time.Second)
13+
//
14+
// // Load a context value.
15+
// timeout := mapreduce.TimeoutKey.Value(ctx)
16+
// ... // use timeout of type time.Duration
17+
//
18+
// This is inspired by https://go.dev/issue/49189.
19+
package ctxkey
20+
21+
import (
22+
"context"
23+
"fmt"
24+
"reflect"
25+
)
26+
27+
// Key is a generic key type associated with a specific value type.
28+
//
29+
// A zero Key is valid where the Value type itself is used as the context key.
30+
// This pattern should only be used with locally declared Go types,
31+
// otherwise different packages risk producing key conflicts.
32+
//
33+
// Example usage:
34+
//
35+
// type peerInfo struct { ... } // peerInfo is a locally declared type
36+
// var peerInfoKey ctxkey.Key[peerInfo]
37+
// ctx = peerInfoKey.WithValue(ctx, info) // store a context value
38+
// info = peerInfoKey.Value(ctx) // load a context value
39+
type Key[Value any] struct {
40+
name *stringer[string]
41+
defVal *Value
42+
}
43+
44+
// New constructs a new context key with an associated value type
45+
// where the default value for an unpopulated value is the provided value.
46+
//
47+
// The provided name is an arbitrary name only used for human debugging.
48+
// As a convention, it is recommended that the name be the dot-delimited
49+
// combination of the package name of the caller with the variable name.
50+
// If the name is not provided, then the name of the Value type is used.
51+
// Every key is unique, even if provided the same name.
52+
//
53+
// Example usage:
54+
//
55+
// package mapreduce
56+
// var NumWorkersKey = ctxkey.New("mapreduce.NumWorkers", runtime.NumCPU())
57+
func New[Value any](name string, defaultValue Value) Key[Value] {
58+
// Allocate a new stringer to ensure that every invocation of New
59+
// creates a universally unique context key even for the same name
60+
// since newly allocated pointers are globally unique within a process.
61+
key := Key[Value]{name: new(stringer[string])}
62+
if name == "" {
63+
name = reflect.TypeFor[Value]().String()
64+
}
65+
key.name.v = name
66+
if v := reflect.ValueOf(defaultValue); v.IsValid() && !v.IsZero() {
67+
key.defVal = &defaultValue
68+
}
69+
return key
70+
}
71+
72+
// contextKey returns the context key to use.
73+
func (key Key[Value]) contextKey() any {
74+
if key.name == nil {
75+
// Use the reflect.Type of the Value (implies key not created by New).
76+
return reflect.TypeFor[Value]()
77+
} else {
78+
// Use the name pointer directly (implies key created by New).
79+
return key.name
80+
}
81+
}
82+
83+
// WithValue returns a copy of parent in which the value associated with key is val.
84+
//
85+
// It is a type-safe equivalent of [context.WithValue].
86+
func (key Key[Value]) WithValue(parent context.Context, val Value) context.Context {
87+
return context.WithValue(parent, key.contextKey(), stringer[Value]{val})
88+
}
89+
90+
// ValueOk returns the value in the context associated with this key
91+
// and also reports whether it was present.
92+
// If the value is not present, it returns the default value.
93+
func (key Key[Value]) ValueOk(ctx context.Context) (v Value, ok bool) {
94+
vv, ok := ctx.Value(key.contextKey()).(stringer[Value])
95+
if !ok && key.defVal != nil {
96+
vv.v = *key.defVal
97+
}
98+
return vv.v, ok
99+
}
100+
101+
// Value returns the value in the context associated with this key.
102+
// If the value is not present, it returns the default value.
103+
func (key Key[Value]) Value(ctx context.Context) (v Value) {
104+
v, _ = key.ValueOk(ctx)
105+
return v
106+
}
107+
108+
// Has reports whether the context has a value for this key.
109+
func (key Key[Value]) Has(ctx context.Context) (ok bool) {
110+
_, ok = key.ValueOk(ctx)
111+
return ok
112+
}
113+
114+
// String returns the name of the key.
115+
func (key Key[Value]) String() string {
116+
if key.name == nil {
117+
return reflect.TypeFor[Value]().String()
118+
}
119+
return key.name.String()
120+
}
121+
122+
// stringer implements [fmt.Stringer] on a generic T.
123+
//
124+
// This assists in debugging such that printing a context prints key and value.
125+
// Note that the [context] package lacks a dependency on [reflect],
126+
// so it cannot print arbitrary values. By implementing [fmt.Stringer],
127+
// we functionally teach a context how to print itself.
128+
//
129+
// Wrapping values within a struct has an added bonus that interface kinds
130+
// are properly handled. Without wrapping, we would be unable to distinguish
131+
// between a nil value that was explicitly set or not.
132+
// However, the presence of a stringer indicates an explicit nil value.
133+
type stringer[T any] struct{ v T }
134+
135+
func (v stringer[T]) String() string { return fmt.Sprint(v.v) }

ctxkey/ctxkey_test.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (c) Tailscale Inc & contributors
2+
// SPDX-License-Identifier: BSD-3-Clause
3+
4+
package ctxkey
5+
6+
import (
7+
"fmt"
8+
"io"
9+
"regexp"
10+
"testing"
11+
"time"
12+
13+
"go.astrophena.name/base/testutil"
14+
)
15+
16+
func TestKey(t *testing.T) {
17+
ctx := t.Context()
18+
19+
// Test keys with the same name as being distinct.
20+
k1 := New("same.Name", "")
21+
testutil.AssertEqual(t, k1.String(), "same.Name")
22+
k2 := New("same.Name", "")
23+
testutil.AssertEqual(t, k2.String(), "same.Name")
24+
testutil.AssertEqual(t, k1 == k2, false)
25+
ctx = k1.WithValue(ctx, "hello")
26+
testutil.AssertEqual(t, k1.Has(ctx), true)
27+
testutil.AssertEqual(t, k1.Value(ctx), "hello")
28+
testutil.AssertEqual(t, k2.Has(ctx), false)
29+
testutil.AssertEqual(t, k2.Value(ctx), "")
30+
ctx = k2.WithValue(ctx, "goodbye")
31+
testutil.AssertEqual(t, k1.Has(ctx), true)
32+
testutil.AssertEqual(t, k1.Value(ctx), "hello")
33+
testutil.AssertEqual(t, k2.Has(ctx), true)
34+
testutil.AssertEqual(t, k2.Value(ctx), "goodbye")
35+
36+
// Test default value.
37+
k3 := New("mapreduce.Timeout", time.Hour)
38+
testutil.AssertEqual(t, k3.Has(ctx), false)
39+
testutil.AssertEqual(t, k3.Value(ctx), time.Hour)
40+
ctx = k3.WithValue(ctx, time.Minute)
41+
testutil.AssertEqual(t, k3.Has(ctx), true)
42+
testutil.AssertEqual(t, k3.Value(ctx), time.Minute)
43+
44+
// Test incomparable value.
45+
k4 := New("slice", []int(nil))
46+
testutil.AssertEqual(t, k4.Has(ctx), false)
47+
testutil.AssertEqual(t, k4.Value(ctx), []int(nil))
48+
ctx = k4.WithValue(ctx, []int{1, 2, 3})
49+
testutil.AssertEqual(t, k4.Has(ctx), true)
50+
testutil.AssertEqual(t, k4.Value(ctx), []int{1, 2, 3})
51+
52+
// Accessors should be allocation free.
53+
testutil.AssertEqual(t, testing.AllocsPerRun(100, func() {
54+
k1.Value(ctx)
55+
k1.Has(ctx)
56+
k1.ValueOk(ctx)
57+
}), 0.0)
58+
59+
// Test keys that are created without New.
60+
var k5 Key[string]
61+
testutil.AssertEqual(t, k5.String(), "string")
62+
testutil.AssertEqual(t, k1 == k5, false) // should be different from key created by New
63+
testutil.AssertEqual(t, k5.Has(ctx), false)
64+
ctx = k5.WithValue(ctx, "fizz")
65+
testutil.AssertEqual(t, k5.Value(ctx), "fizz")
66+
var k6 Key[string]
67+
testutil.AssertEqual(t, k6.String(), "string")
68+
testutil.AssertEqual(t, k5 == k6, true)
69+
testutil.AssertEqual(t, k6.Has(ctx), true)
70+
ctx = k6.WithValue(ctx, "fizz")
71+
72+
// Test interface value types.
73+
var k7 Key[any]
74+
testutil.AssertEqual(t, k7.Has(ctx), false)
75+
ctx = k7.WithValue(ctx, "whatever")
76+
testutil.AssertEqual(t, k7.Value(ctx), "whatever")
77+
ctx = k7.WithValue(ctx, []int{1, 2, 3})
78+
testutil.AssertEqual(t, k7.Value(ctx), []int{1, 2, 3})
79+
ctx = k7.WithValue(ctx, nil)
80+
testutil.AssertEqual(t, k7.Has(ctx), true)
81+
testutil.AssertEqual(t, k7.Value(ctx), nil)
82+
k8 := New[error]("error", io.EOF)
83+
testutil.AssertEqual(t, k8.Has(ctx), false)
84+
testutil.AssertEqual(t, k8.Value(ctx), io.EOF)
85+
ctx = k8.WithValue(ctx, nil)
86+
testutil.AssertEqual(t, k8.Value(ctx), nil)
87+
testutil.AssertEqual(t, k8.Has(ctx), true)
88+
err := fmt.Errorf("read error: %w", io.ErrUnexpectedEOF)
89+
ctx = k8.WithValue(ctx, err)
90+
testutil.AssertEqual(t, k8.Value(ctx), err)
91+
testutil.AssertEqual(t, k8.Has(ctx), true)
92+
}
93+
94+
func TestStringer(t *testing.T) {
95+
ctx := t.Context()
96+
assertMatches(t, fmt.Sprint(New("foo.Bar", "").WithValue(ctx, "baz")), regexp.MustCompile("foo.Bar.*baz"))
97+
assertMatches(t, fmt.Sprint(New("", []int{}).WithValue(ctx, []int{1, 2, 3})), regexp.MustCompile(fmt.Sprintf("%[1]T.*%[1]v", []int{1, 2, 3})))
98+
assertMatches(t, fmt.Sprint(New("", 0).WithValue(ctx, 5)), regexp.MustCompile("int.*5"))
99+
assertMatches(t, fmt.Sprint(Key[time.Duration]{}.WithValue(ctx, time.Hour)), regexp.MustCompile(fmt.Sprintf("%[1]T.*%[1]v", time.Hour)))
100+
}
101+
102+
func assertMatches(t *testing.T, got string, re *regexp.Regexp) {
103+
t.Helper()
104+
if !re.MatchString(got) {
105+
t.Fatalf("value does not match regexp:\ngot: %q\nregexp: %q", got, re)
106+
}
107+
}

logger/logger.go

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ import (
1010
"io"
1111
"log/slog"
1212
"sync"
13-
)
1413

15-
type ctxKey string
14+
"go.astrophena.name/base/ctxkey"
15+
)
1616

17-
const loggerKey ctxKey = "logger"
17+
var loggerKey = ctxkey.New("logger.Logger", defaultLogger)
1818

1919
// multiHandler fans out log records to multiple handlers.
2020
type multiHandler struct {
@@ -135,20 +135,13 @@ func newDefaultLogger() *Logger {
135135
}
136136

137137
// Put returns a new context with the provided [Logger].
138-
func Put(ctx context.Context, l *Logger) context.Context {
139-
return context.WithValue(ctx, loggerKey, l)
140-
}
138+
func Put(ctx context.Context, l *Logger) context.Context { return loggerKey.WithValue(ctx, l) }
141139

142140
// Get retrieves the [Logger] from the context.
143141
//
144142
// If the context has no [Logger], it returns a default [Logger] that discards all
145143
// messages.
146-
func Get(ctx context.Context) *Logger {
147-
if l, ok := ctx.Value(loggerKey).(*Logger); ok {
148-
return l
149-
}
150-
return defaultLogger
151-
}
144+
func Get(ctx context.Context) *Logger { return loggerKey.Value(ctx) }
152145

153146
// IsDefault returns true if l is the default [Logger].
154147
func IsDefault(l *Logger) bool { return l == defaultLogger }
@@ -157,12 +150,7 @@ func IsDefault(l *Logger) bool { return l == defaultLogger }
157150
//
158151
// If the context has no [Logger], it returns a [slog.LevelVar] for a default
159152
// [Logger].
160-
func LevelVar(ctx context.Context) *slog.LevelVar {
161-
if l, ok := ctx.Value(loggerKey).(*Logger); ok {
162-
return l.Level
163-
}
164-
return defaultLogger.Level
165-
}
153+
func LevelVar(ctx context.Context) *slog.LevelVar { return loggerKey.Value(ctx).Level }
166154

167155
// Debug logs a debug message.
168156
func Debug(ctx context.Context, msg string, attrs ...slog.Attr) {

web/debug.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func (s *Server) xffDebugHandler() http.Handler {
180180
parts = append(parts, item)
181181
}
182182

183-
connNetwork, _ := r.Context().Value(connNetworkContextKey).(string)
183+
connNetwork := connNetworkContextKey.Value(r.Context())
184184
resp := xffDebugResponse{
185185
RemoteAddr: r.RemoteAddr,
186186
RemoteHost: host,

0 commit comments

Comments
 (0)