Skip to content

Commit

Permalink
merge handler of servers
Browse files Browse the repository at this point in the history
Signed-off-by: Asutorufa <16442314+Asutorufa@users.noreply.github.com>
  • Loading branch information
Asutorufa committed May 28, 2023
1 parent b25dce8 commit 896f12c
Show file tree
Hide file tree
Showing 111 changed files with 864 additions and 849 deletions.
2 changes: 1 addition & 1 deletion cmd/android/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"syscall"

"github.com/Asutorufa/yuhaiin/pkg/log"
"github.com/Asutorufa/yuhaiin/pkg/net/interfaces/proxy"
proxy "github.com/Asutorufa/yuhaiin/pkg/net/interfaces"
"github.com/Asutorufa/yuhaiin/pkg/protos/config/listener"
"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
)
Expand Down
2 changes: 1 addition & 1 deletion cmd/yuhaiin/main_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"context"
"fmt"

"github.com/Asutorufa/yuhaiin/pkg/net/interfaces/proxy"
proxy "github.com/Asutorufa/yuhaiin/pkg/net/interfaces"
"github.com/Asutorufa/yuhaiin/pkg/net/netlink"
"github.com/Asutorufa/yuhaiin/pkg/utils/yerror"
)
Expand Down
2 changes: 1 addition & 1 deletion example/example.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"net"
"net/http"

"github.com/Asutorufa/yuhaiin/pkg/net/interfaces/proxy"
proxy "github.com/Asutorufa/yuhaiin/pkg/net/interfaces"
"github.com/Asutorufa/yuhaiin/pkg/node/register"
"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
Expand Down
5 changes: 2 additions & 3 deletions internal/http/config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package simplehttp

import (
"context"
"io"
"net/http"
"runtime"
Expand All @@ -21,7 +20,7 @@ type configHandler struct {
}

func (cc *configHandler) Get(w http.ResponseWriter, r *http.Request) error {
c, err := cc.cf.Load(context.TODO(), &emptypb.Empty{})
c, err := cc.cf.Load(r.Context(), &emptypb.Empty{})
if err != nil {
return err
}
Expand Down Expand Up @@ -49,7 +48,7 @@ func (c *configHandler) Post(w http.ResponseWriter, r *http.Request) error {
return err
}

_, err = c.cf.Save(context.TODO(), config)
_, err = c.cf.Save(r.Context(), config)
if err != nil {
return err
}
Expand Down
9 changes: 4 additions & 5 deletions internal/http/group.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package simplehttp

import (
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -24,7 +23,7 @@ type groupHandler struct {

func (g *groupHandler) Get(w http.ResponseWriter, r *http.Request) error {
group := r.URL.Query().Get("name")
ns, err := g.nm.Manager(context.TODO(), &wrapperspb.StringValue{})
ns, err := g.nm.Manager(r.Context(), &wrapperspb.StringValue{})
if err != nil {
return err
}
Expand Down Expand Up @@ -63,7 +62,7 @@ type tag struct {
}

func (t *tag) Get(w http.ResponseWriter, r *http.Request) error {
m, err := t.nm.Manager(context.TODO(), &wrapperspb.StringValue{})
m, err := t.nm.Manager(r.Context(), &wrapperspb.StringValue{})
if err != nil {
return err
}
Expand Down Expand Up @@ -119,7 +118,7 @@ func (t *tag) Post(w http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("unknown tag type: %v", z["type"])
}

_, err = t.ts.Save(context.TODO(), &snode.SaveTagReq{
_, err = t.ts.Save(r.Context(), &snode.SaveTagReq{
Tag: z["tag"],
Hash: z["hash"],
Type: pt.Type(tYPE),
Expand All @@ -130,7 +129,7 @@ func (t *tag) Post(w http.ResponseWriter, r *http.Request) error {
func (t *tag) Delete(w http.ResponseWriter, r *http.Request) error {
tag := r.URL.Query().Get("tag")

_, err := t.ts.Remove(context.TODO(), &wrapperspb.StringValue{
_, err := t.ts.Remove(r.Context(), &wrapperspb.StringValue{
Value: tag,
})
return err
Expand Down
3 changes: 1 addition & 2 deletions internal/http/latency.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package simplehttp

import (
"context"
"errors"
"net/http"

Expand Down Expand Up @@ -64,7 +63,7 @@ func (l *latencyHandler) Get(w http.ResponseWriter, r *http.Request) error {
req.Requests = append(req.Requests, l.udp(r))
}

lt, err := l.nm.Latency(context.TODO(), req)
lt, err := l.nm.Latency(r.Context(), req)
if err != nil {
return err
}
Expand Down
9 changes: 4 additions & 5 deletions internal/http/node.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package simplehttp

import (
"context"
"encoding/json"
"errors"
"html/template"
Expand Down Expand Up @@ -53,7 +52,7 @@ func (nn *nodeHandler) Get(w http.ResponseWriter, r *http.Request) error {

hash := r.URL.Query().Get("hash")

n, err := nn.nm.Get(context.TODO(), &wrapperspb.StringValue{Value: hash})
n, err := nn.nm.Get(r.Context(), &wrapperspb.StringValue{Value: hash})
if err != nil {
return err
}
Expand Down Expand Up @@ -122,7 +121,7 @@ func (n *nodeHandler) Delete(w http.ResponseWriter, r *http.Request) error {
return errors.New("hash is empty")
}

_, err := n.nm.Remove(context.TODO(), &wrapperspb.StringValue{Value: hash})
_, err := n.nm.Remove(r.Context(), &wrapperspb.StringValue{Value: hash})
if err != nil {
return err
}
Expand All @@ -143,7 +142,7 @@ func (n *nodeHandler) Post(w http.ResponseWriter, r *http.Request) error {
return err
}

_, err = n.nm.Save(context.TODO(), point)
_, err = n.nm.Save(r.Context(), point)
if err != nil {
return err
}
Expand All @@ -168,7 +167,7 @@ func (n *nodeHandler) Put(w http.ResponseWriter, r *http.Request) error {
req.Udp = true
}

_, err := n.nm.Use(context.TODO(), req)
_, err := n.nm.Use(r.Context(), req)
if err != nil {
return err
}
Expand Down
5 changes: 2 additions & 3 deletions internal/http/root.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package simplehttp

import (
"context"
"net/http"
"unsafe"

Expand All @@ -16,7 +15,7 @@ type rootHandler struct {
}

func (z *rootHandler) Get(w http.ResponseWriter, r *http.Request) error {
point, err := z.nm.Now(context.TODO(), &grpcnode.NowReq{Net: grpcnode.NowReq_tcp})
point, err := z.nm.Now(r.Context(), &grpcnode.NowReq{Net: grpcnode.NowReq_tcp})
if err != nil {
return err
}
Expand All @@ -25,7 +24,7 @@ func (z *rootHandler) Get(w http.ResponseWriter, r *http.Request) error {
return err
}

point, err = z.nm.Now(context.TODO(), &grpcnode.NowReq{Net: grpcnode.NowReq_udp})
point, err = z.nm.Now(r.Context(), &grpcnode.NowReq{Net: grpcnode.NowReq_udp})
if err != nil {
return err
}
Expand Down
16 changes: 8 additions & 8 deletions internal/http/statistic.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (c *conn) Delete(w http.ResponseWriter, r *http.Request) error {
return err
}

_, err = c.stt.CloseConn(context.TODO(), &gs.ConnectionsId{Ids: []uint64{i}})
_, err = c.stt.CloseConn(r.Context(), &gs.ConnectionsId{Ids: []uint64{i}})
if err != nil {
return err
}
Expand All @@ -45,7 +45,7 @@ func (cc *conn) Websocket(w http.ResponseWriter, r *http.Request) error {
return websocket.ServeHTTP(w, r, cc.handler)
}

func (cc *conn) handler(c *websocket.Conn) error {
func (cc *conn) handler(ctx context.Context, c *websocket.Conn) error {
defer c.Close()

var tickerStr string
Expand All @@ -59,10 +59,10 @@ func (cc *conn) handler(c *websocket.Conn) error {
return err
}

ctx, cancel := context.WithCancel(context.TODO())
cctx, cancel := context.WithCancel(ctx)
defer cancel()

go cc.stt.Notify(&emptypb.Empty{}, &connectionsNotifyServer{ctx, c})
go cc.stt.Notify(&emptypb.Empty{}, &connectionsNotifyServer{cctx, c})

ticker := time.NewTicker(time.Duration(t) * time.Millisecond)
defer ticker.Stop()
Expand All @@ -72,7 +72,7 @@ func (cc *conn) handler(c *websocket.Conn) error {
cancel()
}()

if err = cc.sendFlow(c); err != nil {
if err = cc.sendFlow(ctx, c); err != nil {
return err
}

Expand All @@ -82,15 +82,15 @@ func (cc *conn) handler(c *websocket.Conn) error {
return nil

case <-ticker.C:
if err = cc.sendFlow(c); err != nil {
if err = cc.sendFlow(ctx, c); err != nil {
return err
}
}
}
}

func (cc *conn) sendFlow(wsConn *websocket.Conn) error {
total, err := cc.stt.Total(context.TODO(), &emptypb.Empty{})
func (cc *conn) sendFlow(ctx context.Context, wsConn *websocket.Conn) error {
total, err := cc.stt.Total(ctx, &emptypb.Empty{})
if err != nil {
return err
}
Expand Down
9 changes: 4 additions & 5 deletions internal/http/sub.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package simplehttp

import (
"context"
"encoding/json"
"errors"
"net/http"
Expand All @@ -25,7 +24,7 @@ func (s *subHandler) Post(w http.ResponseWriter, r *http.Request) error {
return errors.New("name or link is empty")
}

_, err := s.nm.Save(context.TODO(), &grpcnode.SaveLinkReq{
_, err := s.nm.Save(r.Context(), &grpcnode.SaveLinkReq{
Links: []*subscribe.Link{
{
Name: name,
Expand All @@ -42,7 +41,7 @@ func (s *subHandler) Post(w http.ResponseWriter, r *http.Request) error {
}

func (s *subHandler) Get(w http.ResponseWriter, r *http.Request) error {
links, err := s.nm.Get(context.TODO(), &emptypb.Empty{})
links, err := s.nm.Get(r.Context(), &emptypb.Empty{})
if err != nil {
return err
}
Expand All @@ -69,7 +68,7 @@ func (s *subHandler) Delete(w http.ResponseWriter, r *http.Request) error {
return err
}

_, err := s.nm.Remove(context.TODO(), &grpcnode.LinkReq{Names: names})
_, err := s.nm.Remove(r.Context(), &grpcnode.LinkReq{Names: names})
if err != nil {
return err
}
Expand All @@ -90,7 +89,7 @@ func (s *subHandler) Patch(w http.ResponseWriter, r *http.Request) error {
return err
}

_, err := s.nm.Update(context.TODO(), &grpcnode.LinkReq{Names: names})
_, err := s.nm.Update(r.Context(), &grpcnode.LinkReq{Names: names})
if err != nil {
return err
}
Expand Down
34 changes: 20 additions & 14 deletions internal/shunt/shunt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ import (
"sync"

"github.com/Asutorufa/yuhaiin/pkg/log"
"github.com/Asutorufa/yuhaiin/pkg/net/interfaces/dns"
"github.com/Asutorufa/yuhaiin/pkg/net/interfaces/proxy"
proxy "github.com/Asutorufa/yuhaiin/pkg/net/interfaces"
"github.com/Asutorufa/yuhaiin/pkg/net/mapper"
"github.com/Asutorufa/yuhaiin/pkg/node"
pc "github.com/Asutorufa/yuhaiin/pkg/protos/config"
Expand Down Expand Up @@ -44,11 +43,11 @@ type Shunt struct {

type Opts struct {
DirectDialer proxy.Proxy
DirectResolver dns.DNS
DirectResolver proxy.Resolver
ProxyDialer proxy.Proxy
ProxyResolver dns.DNS
ProxyResolver proxy.Resolver
BlockDialer proxy.Proxy
BLockResolver dns.DNS
BLockResolver proxy.Resolver
DefaultMode bypass.Mode
}

Expand Down Expand Up @@ -138,7 +137,13 @@ func (s *Shunt) Dispatch(ctx context.Context, host proxy.Address) (proxy.Address

func (s *Shunt) dispatch(ctx context.Context, networkMode bypass.Mode, host proxy.Address) (bypass.Mode, proxy.Address) {
// get mode from upstream specified
mode := proxy.Value(host, ForceModeKey{}, bypass.Mode_bypass)

store := proxy.StoreFromContext(ctx)

mode, ok := proxy.Get[bypass.Mode](ctx, ForceModeKey{})
if !ok {
mode = bypass.Mode_bypass
}

if mode == bypass.Mode_bypass && networkMode != bypass.Mode_bypass {
// get mode from network(tcp/udp) rule
Expand All @@ -151,24 +156,25 @@ func (s *Shunt) dispatch(ctx context.Context, networkMode bypass.Mode, host prox

// get tag from bypass rule
if tag := fields.GetTag(); len(tag) != 0 {
host.WithValue(node.TagKey{}, tag)

store.Add(node.TagKey{}, tag)
}

if fields.GetResolveStrategy() == bypass.ResolveStrategy_prefer_ipv6 {
host.WithValue(proxy.PreferIPv6{}, true)
host.PreferIPv6(true)
}
}

host.WithValue(modeMarkKey{}, mode)
store.Add(modeMarkKey{}, mode)
host.WithResolver(s.resolver(mode), true)

if s.resolveProxy && host.Type() == proxy.DOMAIN && mode == bypass.Mode_proxy {
// resolve proxy domain if resolveRemoteDomain enabled
ip, err := host.IP(ctx)
if err == nil {
host.WithValue(DOMAIN_MARK_KEY{}, host.String())
store.Add(DOMAIN_MARK_KEY{}, host.String())
host = host.OverrideHostname(ip.String())
host.WithValue(IP_MARK_KEY{}, host.String())
store.Add(IP_MARK_KEY{}, host.String())
} else {
log.Warn("resolve remote domain failed", "err", err)
}
Expand All @@ -190,7 +196,7 @@ func (s *Shunt) dialer(m bypass.Mode) proxy.Proxy {
}
}

func (s *Shunt) resolver(m bypass.Mode) dns.DNS {
func (s *Shunt) resolver(m bypass.Mode) proxy.Resolver {
switch m {
case bypass.Mode_block:
return s.BLockResolver
Expand All @@ -203,9 +209,9 @@ func (s *Shunt) resolver(m bypass.Mode) dns.DNS {
}
}

var skipResolve = dns.ErrorDNS(func(domain string) error { return mapper.ErrSkipResolve })
var skipResolve = proxy.ErrorResolver(func(domain string) error { return mapper.ErrSkipResolve })

func (s *Shunt) Resolver(ctx context.Context, domain string) dns.DNS {
func (s *Shunt) Resolver(ctx context.Context, domain string) proxy.Resolver {
host := proxy.ParseAddressPort(0, domain, proxy.EmptyPort)
host.WithResolver(skipResolve, true)
return s.resolver(s.mapper.SearchWithDefault(ctx, host, s.DefaultMode).Mode())
Expand Down
Loading

0 comments on commit 896f12c

Please sign in to comment.