Skip to content

Commit

Permalink
upstream: use mutex. imp logging
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jan 23, 2024
1 parent edb394b commit cea34d5
Showing 1 changed file with 36 additions and 25 deletions.
61 changes: 36 additions & 25 deletions upstream/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ type dnsOverQUIC struct {
quicConfigMu sync.Mutex

// connMu protects conn.
connMu sync.RWMutex
connMu sync.Mutex

// bytesPoolGuard protects bytesPool.
bytesPoolMu sync.Mutex
Expand Down Expand Up @@ -214,7 +214,7 @@ func (p *dnsOverQUIC) exchangeQUIC(req *dns.Msg) (resp *dns.Msg, err error) {
var stream quic.Stream
stream, err = p.openStream(conn)
if err != nil {
return nil, err
return nil, fmt.Errorf("opening stream: %w", err)
}

_, err = stream.Write(proxyutil.AddPrefix(buf))
Expand All @@ -226,7 +226,10 @@ func (p *dnsOverQUIC) exchangeQUIC(req *dns.Msg) (resp *dns.Msg, err error) {
// indicate through the STREAM FIN mechanism that no further data will
// be sent on that stream. Note, that stream.Close() closes the
// write-direction of the stream, but does not prevent reading from it.
_ = stream.Close()
err = stream.Close()
if err != nil {
log.Debug("closing quic stream: %s", err)
}

return p.readMsg(stream)
}
Expand Down Expand Up @@ -259,29 +262,30 @@ func (p *dnsOverQUIC) getBytesPool() (pool *sync.Pool) {
// argument controls whether we should try to use the existing cached
// connection. If it is false, we will forcibly create a new connection and
// close the existing one if needed.
func (p *dnsOverQUIC) getConnection(useCached bool) (quic.Connection, error) {
func (p *dnsOverQUIC) getConnection(useCached bool) (c quic.Connection, err error) {
var conn quic.Connection
p.connMu.RLock()
conn = p.conn
if conn != nil && useCached {
p.connMu.RUnlock()

return conn, nil
}
if conn != nil {
// we're recreating the connection, let's create a new one.
_ = conn.CloseWithError(QUICCodeNoError, "")
}
p.connMu.RUnlock()

p.connMu.Lock()
defer p.connMu.Unlock()

var err error
conn = p.conn
if conn != nil {
if useCached {
return conn, nil
}

// We're recreating the connection, let's create a new one.
err = conn.CloseWithError(QUICCodeNoError, "")
if err != nil {
log.Debug("closing stale connection: %s", err)
}
}

conn, err = p.openConnection()
if err != nil {
return nil, err
}

p.conn = conn

return conn, nil
Expand Down Expand Up @@ -320,7 +324,9 @@ func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) {
defer cancel()

stream, err := conn.OpenStreamSync(ctx)
if err == nil {
if err != nil {
log.Debug("opening quic stream: %s", err)
} else {
return stream, nil
}

Expand All @@ -330,30 +336,35 @@ func (p *dnsOverQUIC) openStream(conn quic.Connection) (quic.Stream, error) {
if err != nil {
return nil, err
}

// Open a new stream.
return newConn.OpenStreamSync(ctx)
}

// openConnection opens a new QUIC connection.
// openConnection dials a new QUIC connection.
func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {
dialContext, err := p.getDialer()
if err != nil {
return nil, fmt.Errorf("failed to bootstrap QUIC connection: %w", err)
return nil, fmt.Errorf("bootstrapping %s: %w", p.addr, err)
}

// we're using bootstrapped address instead of what's passed to the function
// it does not create an actual connection, but it helps us determine
// what IP is actually reachable (when there're v4/v6 addresses).
rawConn, err := dialContext(context.Background(), "udp", "")
if err != nil {
return nil, fmt.Errorf("failed to open a QUIC connection: %w", err)
return nil, fmt.Errorf("dialing raw connection to %s: %w", p.addr, err)
}

// It's never actually used.
err = rawConn.Close()
if err != nil {
log.Debug("closing raw connection for %s: %s", p.addr, err)
}
// It's never actually used
_ = rawConn.Close()

udpConn, ok := rawConn.(*net.UDPConn)
if !ok {
return nil, fmt.Errorf("failed to open connection to %s", p.addr)
return nil, fmt.Errorf("unexpected type %T of connection; should be %T", rawConn, udpConn)
}

addr := udpConn.RemoteAddr().String()
Expand All @@ -363,7 +374,7 @@ func (p *dnsOverQUIC) openConnection() (conn quic.Connection, err error) {

conn, err = quic.DialAddrEarly(ctx, addr, p.tlsConf.Clone(), p.getQUICConfig())
if err != nil {
return nil, fmt.Errorf("opening quic connection to %s: %w", p.addr, err)
return nil, fmt.Errorf("dialing quic connection to %s: %w", p.addr, err)
}

return conn, nil
Expand Down

0 comments on commit cea34d5

Please sign in to comment.