Skip to content

Commit

Permalink
feat: sniffer support
Browse files Browse the repository at this point in the history
sniffer:
  enable: true
  force: false # Overwrite domain
  sniffing:
    - tls
  • Loading branch information
Skyxim committed Apr 9, 2022
1 parent 07906c0 commit 544e0f1
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 58 deletions.
104 changes: 104 additions & 0 deletions component/sniffer/dispatcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package sniffer

import (
"errors"
"net"

CN "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
)

var (
ErrorUnsupportedSniffer = errors.New("unsupported sniffer")
)

var Dispatcher SnifferDispatcher

type SnifferDispatcher struct {
enable bool
force bool
sniffers []C.Sniffer
}

func (sd *SnifferDispatcher) Tcp(conn net.Conn, metadata *C.Metadata) {
bufConn, ok := conn.(*CN.BufferedConn)
if !ok {
return
}

if sd.force {
sd.cover(bufConn, metadata)
} else {
if metadata.Host != "" {
return
}

sd.cover(bufConn, metadata)
}
}

func (sd *SnifferDispatcher) Enable() bool {
return sd.enable
}

func (sd *SnifferDispatcher) cover(conn *CN.BufferedConn, metadata *C.Metadata) {
for _, sniffer := range sd.sniffers {
if sniffer.SupportNetwork() == C.TCP {
conn.Peek(1)
len := conn.Buffered()
bytes, err := conn.Peek(len)
if err != nil {
log.Warnln("the data lenght not enough")
continue
}

host, err := sniffer.SniffTCP(bytes)
if err != nil {
log.Warnln("Sniff data failed on Sniffer[%s]", sniffer.Protocol())
continue
}

metadata.Host = host
metadata.DstIP = nil
metadata.AddrType = C.AtypDomainName
if resolver.FakeIPEnabled() {
metadata.DNSMode = C.DNSFakeIP
} else {
metadata.DNSMode = C.DNSMapping
}

resolver.InsertHostByIP(metadata.DstIP, host)
break
}
}
}

func NewSnifferDispatcher(needSniffer []C.SnifferType, force bool) (SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{
enable: true,
force: force,
}

for _, snifferName := range needSniffer {
sniffer, err := NewSniffer(snifferName)
if err != nil {
log.Errorln("Sniffer name[%s] is error", snifferName)
return SnifferDispatcher{enable: false}, err
}

dispatcher.sniffers = append(dispatcher.sniffers, sniffer)
}

return dispatcher, nil
}

func NewSniffer(name C.SnifferType) (C.Sniffer, error) {
switch name {
case C.TLS:
return &TLSSniffer{}, nil
default:
return nil, ErrorUnsupportedSniffer
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package tls
package sniffer

import (
"testing"
Expand Down Expand Up @@ -142,7 +142,7 @@ func TestTLSHeaders(t *testing.T) {
}

for _, test := range cases {
header, err := SniffTLS(test.input)
domain, err := SniffTLS(test.input)
if test.err {
if err == nil {
t.Errorf("Exepct error but nil in test %v", test)
Expand All @@ -151,8 +151,8 @@ func TestTLSHeaders(t *testing.T) {
if err != nil {
t.Errorf("Expect no error but actually %s in test %v", err.Error(), test)
}
if header.Domain() != test.domain {
t.Error("expect domain ", test.domain, " but got ", header.Domain())
if *domain != test.domain {
t.Error("expect domain ", test.domain, " but got ", domain)
}
}
}
Expand Down
79 changes: 44 additions & 35 deletions common/snifer/tls/sniff.go → component/sniffer/tls_sniffer.go
Original file line number Diff line number Diff line change
@@ -1,129 +1,139 @@
package tls
package sniffer

import (
"encoding/binary"
"errors"
"strings"

C "github.com/Dreamacro/clash/constant"
)

var ErrNoClue = errors.New("not enough information for making a decision")
var (
errNotTLS = errors.New("not TLS header")
errNotClientHello = errors.New("not client hello")
ErrNoClue = errors.New("not enough information for making a decision")
)

type SniffHeader struct {
domain string
type TLSSniffer struct {
}

func (h *SniffHeader) Protocol() string {
func (tls *TLSSniffer) Protocol() string {
return "tls"
}

func (h *SniffHeader) Domain() string {
return h.domain
func (tls *TLSSniffer) SupportNetwork() C.NetWork {
return C.TCP
}

var (
errNotTLS = errors.New("not TLS header")
errNotClientHello = errors.New("not client hello")
)
func (tls *TLSSniffer) SniffTCP(bytes []byte) (string, error) {
domain, err := SniffTLS(bytes)
if err == nil {
return *domain, nil
} else {
return "", err
}
}

func IsValidTLSVersion(major, minor byte) bool {
return major == 3
}

// ReadClientHello returns server name (if any) from TLS client hello message.
// https://github.com/golang/go/blob/master/src/crypto/tls/handshake_messages.go#L300
func ReadClientHello(data []byte, h *SniffHeader) error {
func ReadClientHello(data []byte) (*string, error) {
if len(data) < 42 {
return ErrNoClue
return nil, ErrNoClue
}
sessionIDLen := int(data[38])
if sessionIDLen > 32 || len(data) < 39+sessionIDLen {
return ErrNoClue
return nil, ErrNoClue
}
data = data[39+sessionIDLen:]
if len(data) < 2 {
return ErrNoClue
return nil, ErrNoClue
}
// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
// they are uint16s, the number must be even.
cipherSuiteLen := int(data[0])<<8 | int(data[1])
if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
return errNotClientHello
return nil, errNotClientHello
}
data = data[2+cipherSuiteLen:]
if len(data) < 1 {
return ErrNoClue
return nil, ErrNoClue
}
compressionMethodsLen := int(data[0])
if len(data) < 1+compressionMethodsLen {
return ErrNoClue
return nil, ErrNoClue
}
data = data[1+compressionMethodsLen:]

if len(data) == 0 {
return errNotClientHello
return nil, errNotClientHello
}
if len(data) < 2 {
return errNotClientHello
return nil, errNotClientHello
}

extensionsLength := int(data[0])<<8 | int(data[1])
data = data[2:]
if extensionsLength != len(data) {
return errNotClientHello
return nil, errNotClientHello
}

for len(data) != 0 {
if len(data) < 4 {
return errNotClientHello
return nil, errNotClientHello
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return errNotClientHello
return nil, errNotClientHello
}

if extension == 0x00 { /* extensionServerName */
d := data[:length]
if len(d) < 2 {
return errNotClientHello
return nil, errNotClientHello
}
namesLen := int(d[0])<<8 | int(d[1])
d = d[2:]
if len(d) != namesLen {
return errNotClientHello
return nil, errNotClientHello
}
for len(d) > 0 {
if len(d) < 3 {
return errNotClientHello
return nil, errNotClientHello
}
nameType := d[0]
nameLen := int(d[1])<<8 | int(d[2])
d = d[3:]
if len(d) < nameLen {
return errNotClientHello
return nil, errNotClientHello
}
if nameType == 0 {
serverName := string(d[:nameLen])
// An SNI value may not include a
// trailing dot. See
// https://tools.ietf.org/html/rfc6066#section-3.
if strings.HasSuffix(serverName, ".") {
return errNotClientHello
return nil, errNotClientHello
}
h.domain = serverName
return nil

return &serverName, nil
}

d = d[nameLen:]
}
}
data = data[length:]
}

return errNotTLS
return nil, errNotTLS
}

func SniffTLS(b []byte) (*SniffHeader, error) {
func SniffTLS(b []byte) (*string, error) {
if len(b) < 5 {
return nil, ErrNoClue
}
Expand All @@ -139,10 +149,9 @@ func SniffTLS(b []byte) (*SniffHeader, error) {
return nil, ErrNoClue
}

h := &SniffHeader{}
err := ReadClientHello(b[5:5+headerLen], h)
domain, err := ReadClientHello(b[5 : 5+headerLen])
if err == nil {
return h, nil
return domain, nil
}
return nil, err
}
Loading

0 comments on commit 544e0f1

Please sign in to comment.