From 2806aa7b5915fd3cb64ec612f17a9883ea799215 Mon Sep 17 00:00:00 2001 From: wangjunwei Date: Fri, 22 May 2026 10:52:20 +0800 Subject: [PATCH] feat(handler): add configurable direct responses --- README.md | 6 ++ auth/auth.go | 39 +++++++++++- auth/rejectstatic.go | 4 ++ auth/response_test.go | 37 ++++++++++++ handler/config.go | 6 ++ handler/direct.go | 11 ++++ handler/direct_test.go | 134 +++++++++++++++++++++++++++++++++++++++++ handler/handler.go | 56 +++++++++++------ main.go | 36 +++++++++-- 9 files changed, 305 insertions(+), 24 deletions(-) create mode 100644 auth/response_test.go create mode 100644 handler/direct.go create mode 100644 handler/direct_test.go diff --git a/README.md b/README.md index e1646e3..e635f0b 100644 --- a/README.md +++ b/README.md @@ -308,6 +308,8 @@ Authentication parameters are passed as URI via `-auth` parameter. Scheme of URI * `else` - optional URL specifying the next auth provider to chain to, if authentication failed. * `lookup` - optional URL specifying another auth provider queried for session validity (typically `basicfile` or some Redis-backed password auth). Queries to this lookup provider ask for validity of session providing hexadecimal session ID as username and empty string as password. +`static` can also be used with `-direct-response` to respond to direct non-proxy HTTP requests. It accepts the same `code`, `body`, and `headers` parameters as `reject-static`. `reject-http`, `reject-https`, and `reject-static` can also be used with `-access-reject` to respond to requests denied by access filters. + ## Scripting With the dumbproxy, it is possible to modify request processing behaviour using simple scripts written in the JavaScript programming language. @@ -537,6 +539,8 @@ Configuration format is [RFC 4180](https://www.rfc-editor.org/rfc/rfc4180.html) ``` $ ~/go/bin/dumbproxy -h Usage of /home/user/go/bin/dumbproxy: + -access-reject string + reject response parameters for requests denied by access filters -auth string auth parameters (default "none://") -autocert @@ -589,6 +593,8 @@ Usage of /home/user/go/bin/dumbproxy: colon-separated list of enabled key exchange curves -deny-dst-addr value comma-separated list of CIDR prefixes of forbidden IP addresses (default 127.0.0.0/8, 0.0.0.0/32, 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, 169.254.0.0/16, ::1/128, ::/128, fe80::/10) + -direct-response string + response parameters for direct HTTP requests -disable-http2 disable HTTP2 -dns-cache-neg-ttl duration diff --git a/auth/auth.go b/auth/auth.go index cedceb9..60354d2 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -38,12 +38,47 @@ func NewAuth(paramstr string, logger *clog.CondLogger) (Auth, error) { case "none": return NoAuth{}, nil case "reject-http", "reject-https": - return NewRejectHTTPAuth(url, logger) + return newRejectAuthFromURL(url, logger) case "reject-static": - return NewStaticRejectAuth(url, logger) + return newRejectAuthFromURL(url, logger) case "tlscookie": return NewTLSCookieAuth(url, logger) default: return nil, errors.New("Unknown auth scheme") } } + +// NewRejectAuth constructs an auth provider which always responds and rejects. +func NewRejectAuth(paramstr string, logger *clog.CondLogger) (Auth, error) { + url, err := url.Parse(paramstr) + if err != nil { + return nil, err + } + + return newRejectAuthFromURL(url, logger) +} + +func newRejectAuthFromURL(url *url.URL, logger *clog.CondLogger) (Auth, error) { + switch strings.ToLower(url.Scheme) { + case "reject-http", "reject-https": + return NewRejectHTTPAuth(url, logger) + case "reject-static": + return NewStaticRejectAuth(url, logger) + default: + return nil, errors.New("Unknown reject scheme") + } +} + +func NewResponse(paramstr string, logger *clog.CondLogger) (Auth, error) { + url, err := url.Parse(paramstr) + if err != nil { + return nil, err + } + + switch strings.ToLower(url.Scheme) { + case "static": + return NewStaticResponse(url, logger) + default: + return nil, errors.New("Unknown response scheme") + } +} diff --git a/auth/rejectstatic.go b/auth/rejectstatic.go index d052bb5..bf81556 100644 --- a/auth/rejectstatic.go +++ b/auth/rejectstatic.go @@ -22,6 +22,10 @@ type StaticRejectAuth struct { } func NewStaticRejectAuth(u *url.URL, logger *clog.CondLogger) (*StaticRejectAuth, error) { + return NewStaticResponse(u, logger) +} + +func NewStaticResponse(u *url.URL, logger *clog.CondLogger) (*StaticRejectAuth, error) { values, err := url.ParseQuery(u.RawQuery) if err != nil { return nil, err diff --git a/auth/response_test.go b/auth/response_test.go new file mode 100644 index 0000000..59efb50 --- /dev/null +++ b/auth/response_test.go @@ -0,0 +1,37 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewResponseStatic(t *testing.T) { + response, err := NewResponse("static://?code=204", nil) + if err != nil { + t.Fatalf("NewResponse returned error: %v", err) + } + + rr := httptest.NewRecorder() + _, ok := response.Validate(context.Background(), rr, httptest.NewRequest(http.MethodGet, "/", nil)) + + if ok { + t.Fatalf("static response should not authorize requests") + } + if rr.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", rr.Code, http.StatusNoContent) + } +} + +func TestNewResponseRejectStaticIsInvalid(t *testing.T) { + if _, err := NewResponse("reject-static://?code=204", nil); err == nil { + t.Fatalf("NewResponse accepted reject-static scheme") + } +} + +func TestNewRejectAuthStaticIsInvalid(t *testing.T) { + if _, err := NewRejectAuth("static://?code=403", nil); err == nil { + t.Fatalf("NewRejectAuth accepted static scheme") + } +} diff --git a/handler/config.go b/handler/config.go index 10c7820..d21a912 100644 --- a/handler/config.go +++ b/handler/config.go @@ -12,6 +12,12 @@ type Config struct { // Auth optionally specifies request validator used to verify users // and return their username. Auth auth.Auth + // DirectResponse optionally specifies a response for direct HTTP requests + // that do not use proxy request form. + DirectResponse auth.Auth + // AccessReject optionally specifies a response for proxy requests + // denied by an access filter. + AccessReject auth.Auth // Logger specifies optional custom logger. Logger *clog.CondLogger // Forward optionally specifies custom connection pairing function diff --git a/handler/direct.go b/handler/direct.go new file mode 100644 index 0000000..1ab9aec --- /dev/null +++ b/handler/direct.go @@ -0,0 +1,11 @@ +package handler + +import "net/http" + +func isDirectRequest(req *http.Request) bool { + if req == nil || req.URL == nil || req.Method == http.MethodConnect || req.Method == "GETRANDOM" { + return false + } + + return req.URL.Scheme == "" && req.URL.Host == "" +} diff --git a/handler/direct_test.go b/handler/direct_test.go new file mode 100644 index 0000000..0d549b4 --- /dev/null +++ b/handler/direct_test.go @@ -0,0 +1,134 @@ +package handler + +import ( + "context" + "errors" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + derrors "github.com/SenseUnit/dumbproxy/dialer/errors" +) + +type staticReject struct { + status int + body string +} + +func (r staticReject) Validate(_ context.Context, wr http.ResponseWriter, _ *http.Request) (string, bool) { + wr.WriteHeader(r.status) + _, _ = wr.Write([]byte(r.body)) + return "", false +} + +func (staticReject) Close() error { + return nil +} + +type deniedDialer struct{} + +func (deniedDialer) DialContext(_ context.Context, _, _ string) (net.Conn, error) { + return nil, derrors.ErrAccessDenied{Err: errors.New("denied")} +} + +func TestIsDirectRequest(t *testing.T) { + tests := []struct { + name string + req *http.Request + want bool + }{ + { + name: "origin form get", + req: &http.Request{ + Method: http.MethodGet, + URL: &url.URL{Path: "/"}, + Host: "web.nacl.one", + }, + want: true, + }, + { + name: "absolute form get", + req: &http.Request{ + Method: http.MethodGet, + URL: &url.URL{Scheme: "http", Host: "openrouter.ai", Path: "/"}, + Host: "openrouter.ai", + }, + want: false, + }, + { + name: "connect", + req: &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: "openrouter.ai:443"}, + Host: "openrouter.ai:443", + }, + want: false, + }, + { + name: "trust tunnel random", + req: &http.Request{ + Method: "GETRANDOM", + URL: &url.URL{Path: "/32"}, + Host: "web.nacl.one", + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isDirectRequest(tt.req); got != tt.want { + t.Fatalf("isDirectRequest() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDirectResponse(t *testing.T) { + proxy := NewProxyHandler(&Config{ + DirectResponse: staticReject{status: http.StatusOK, body: "direct response"}, + }) + rr := httptest.NewRecorder() + req := &http.Request{ + Method: http.MethodGet, + URL: &url.URL{Path: "/"}, + Host: "web.nacl.one", + RemoteAddr: "198.51.100.7:1234", + } + + proxy.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rr.Code, http.StatusOK) + } + if rr.Body.String() != "direct response" { + t.Fatalf("body = %q, want direct response", rr.Body.String()) + } +} + +func TestAccessReject(t *testing.T) { + proxy := NewProxyHandler(&Config{ + Dialer: deniedDialer{}, + AccessReject: staticReject{status: http.StatusTeapot, body: "access response"}, + }) + rr := httptest.NewRecorder() + req := &http.Request{ + Method: http.MethodConnect, + URL: &url.URL{Host: "openrouter.ai:443"}, + RequestURI: "openrouter.ai:443", + Host: "openrouter.ai:443", + RemoteAddr: "198.51.100.7:1234", + ProtoMajor: 1, + } + + proxy.ServeHTTP(rr, req) + + if rr.Code != http.StatusTeapot { + t.Fatalf("status = %d, want %d", rr.Code, http.StatusTeapot) + } + if rr.Body.String() != "access response" { + t.Fatalf("body = %q, want access response", rr.Body.String()) + } +} diff --git a/handler/handler.go b/handler/handler.go index 40c164f..79d4b2f 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -32,14 +32,16 @@ type HandlerDialer interface { type ForwardFunc = func(ctx context.Context, username string, incoming, outgoing io.ReadWriteCloser, network, address string) error type ProxyHandler struct { - auth auth.Auth - logger *clog.CondLogger - dialer HandlerDialer - forward ForwardFunc - httptransport http.RoundTripper - outbound map[string]string - outboundMux sync.RWMutex - userIPHints bool + auth auth.Auth + directResponse auth.Auth + accessReject auth.Auth + logger *clog.CondLogger + dialer HandlerDialer + forward ForwardFunc + httptransport http.RoundTripper + outbound map[string]string + outboundMux sync.RWMutex + userIPHints bool } func NewProxyHandler(config *Config) *ProxyHandler { @@ -64,13 +66,15 @@ func NewProxyHandler(config *Config) *ProxyHandler { f = forward.PairConnections } return &ProxyHandler{ - auth: a, - logger: l, - dialer: d, - forward: f, - httptransport: httptransport, - outbound: make(map[string]string), - userIPHints: config.UserIPHints, + auth: a, + directResponse: config.DirectResponse, + accessReject: config.AccessReject, + logger: l, + dialer: d, + forward: f, + httptransport: httptransport, + outbound: make(map[string]string), + userIPHints: config.UserIPHints, } } @@ -80,7 +84,7 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request, u var accessErr derrors.ErrAccessDenied if errors.As(err, &accessErr) { s.logger.Warning("Access denied: %v", err) - http.Error(wr, "Access denied", http.StatusForbidden) + s.rejectAccess(wr, req) return } s.logger.Error("Can't satisfy CONNECT request: %v", err) @@ -179,7 +183,7 @@ func (s *ProxyHandler) HandleRequest(wr http.ResponseWriter, req *http.Request, var accessErr derrors.ErrAccessDenied if errors.As(err, &accessErr) { s.logger.Warning("Access denied: %v", err) - http.Error(wr, "Access denied", http.StatusForbidden) + s.rejectAccess(wr, req) return } s.logger.Error("HTTP fetch error: %v", err) @@ -229,6 +233,11 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { return } + if s.directResponse != nil && isDirectRequest(req) { + s.reject(s.directResponse, wr, req) + return + } + ctx := req.Context() username, ok := s.auth.Validate(ctx, wr, req) localAddr := getLocalAddr(req.Context()) @@ -266,6 +275,19 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { } } +func (s *ProxyHandler) rejectAccess(wr http.ResponseWriter, req *http.Request) { + if s.accessReject == nil { + http.Error(wr, "Access denied", http.StatusForbidden) + return + } + + s.reject(s.accessReject, wr, req) +} + +func (s *ProxyHandler) reject(reject auth.Auth, wr http.ResponseWriter, req *http.Request) { + _, _ = reject.Validate(req.Context(), wr, req) +} + func trimAddrPort(addrPort string) string { res, _, err := net.SplitHostPort(addrPort) if err != nil { diff --git a/main.go b/main.go index b746df9..fd8d99a 100644 --- a/main.go +++ b/main.go @@ -285,6 +285,8 @@ type CLIArgs struct { unixSockMode modeArg mode proxyModeArg auth string + directResponse string + accessReject string verbosity int cert, key, cafile string list_ciphers bool @@ -387,6 +389,8 @@ func parse_args() *CLIArgs { flag.Var(&args.unixSockMode, "unix-sock-mode", "set file mode for bound unix socket") flag.Var(&args.mode, "mode", "proxy operation mode (http/socks5/stdio/port-forward)") flag.StringVar(&args.auth, "auth", "none://", "auth parameters") + flag.StringVar(&args.directResponse, "direct-response", "", "response parameters for direct HTTP requests") + flag.StringVar(&args.accessReject, "access-reject", "", "reject response parameters for requests denied by access filters") flag.IntVar(&args.verbosity, "verbosity", 20, "logging verbosity "+ "(10 - debug, 20 - info, 30 - warning, 40 - error, 50 - critical)") flag.StringVar(&args.cert, "cert", "", "enable TLS and use certificate") @@ -641,6 +645,26 @@ func run() int { } defer authProvider.Close() + var directResponse auth.Auth + if args.directResponse != "" { + directResponse, err = auth.NewResponse(args.directResponse, authLogger) + if err != nil { + mainLogger.Critical("Failed to instantiate direct request response: %v", err) + return 3 + } + defer directResponse.Close() + } + + var accessReject auth.Auth + if args.accessReject != "" { + accessReject, err = auth.NewRejectAuth(args.accessReject, authLogger) + if err != nil { + mainLogger.Critical("Failed to instantiate access reject response: %v", err) + return 3 + } + defer accessReject.Close() + } + // setup access filters var filterRoot access.Filter = access.AlwaysAllow{} if args.jsAccessFilter != "" { @@ -899,11 +923,13 @@ func run() int { case proxyModeHTTP: server := http.Server{ Handler: handler.NewProxyHandler(&handler.Config{ - Dialer: dialerRoot, - Auth: authProvider, - Logger: proxyLogger, - UserIPHints: args.userIPHints, - Forward: forwarder, + Dialer: dialerRoot, + Auth: authProvider, + DirectResponse: directResponse, + AccessReject: accessReject, + Logger: proxyLogger, + UserIPHints: args.userIPHints, + Forward: forwarder, }), ErrorLog: log.New(logWriter, "HTTPSRV : ", log.LstdFlags|log.Lshortfile), ReadTimeout: 0,