diff --git a/proxy/proxy.go b/proxy/proxy.go index 96054e43e..4a7c43cef 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -93,12 +93,12 @@ func (p *Proxy) accept(conn net.Conn) { clientPacket := clientConn.(wrappers.PacketReadWriteCloser) serverPacket := serverConn.(wrappers.PacketReadWriteCloser) go p.middlePipe(clientPacket, serverPacket, wait, &opts.ReadHacks) - go p.middlePipe(serverPacket, clientPacket, wait, &opts.WriteHacks) + p.middlePipe(serverPacket, clientPacket, wait, &opts.WriteHacks) } else { clientStream := clientConn.(wrappers.StreamReadWriteCloser) serverStream := serverConn.(wrappers.StreamReadWriteCloser) go p.directPipe(clientStream, serverStream, wait, p.conf.ReadBufferSize) - go p.directPipe(serverStream, clientStream, wait, p.conf.WriteBufferSize) + p.directPipe(serverStream, clientStream, wait, p.conf.WriteBufferSize) } wait.Wait() @@ -121,13 +121,8 @@ func (p *Proxy) getTelegramConn(ctx context.Context, cancel context.CancelFunc, return packetConn, nil } -func (p *Proxy) middlePipe(src wrappers.PacketReadCloser, dst io.WriteCloser, - wait *sync.WaitGroup, hacks *mtproto.Hacks) { - defer func() { - src.Close() // nolint: errcheck, gosec - dst.Close() // nolint: errcheck, gosec - wait.Done() - }() +func (p *Proxy) middlePipe(src wrappers.PacketReadCloser, dst io.Writer, wait *sync.WaitGroup, hacks *mtproto.Hacks) { + defer wait.Done() for { hacks.SimpleAck = false @@ -145,13 +140,8 @@ func (p *Proxy) middlePipe(src wrappers.PacketReadCloser, dst io.WriteCloser, } } -func (p *Proxy) directPipe(src wrappers.StreamReadCloser, dst io.WriteCloser, - wait *sync.WaitGroup, bufferSize int) { - defer func() { - src.Close() // nolint: errcheck, gosec - dst.Close() // nolint: errcheck, gosec - wait.Done() - }() +func (p *Proxy) directPipe(src wrappers.StreamReadCloser, dst io.Writer, wait *sync.WaitGroup, bufferSize int) { + defer wait.Done() buffer := make([]byte, bufferSize) if _, err := io.CopyBuffer(dst, src, buffer); err != nil { diff --git a/wrappers/conn.go b/wrappers/conn.go index 157828e92..239af5361 100644 --- a/wrappers/conn.go +++ b/wrappers/conn.go @@ -38,13 +38,6 @@ const ( connTimeoutWrite = 2 * time.Minute ) -type ioResult struct { - n int - err error -} - -type ioFunc func([]byte) (int, error) - // Conn is a basic wrapper for net.Conn providing the most low-level // logic and management as possible. type Conn struct { @@ -61,12 +54,20 @@ type Conn struct { func (c *Conn) Write(p []byte) (int, error) { select { case <-c.ctx.Done(): + c.Close() // nolint: gosec return 0, errors.Annotate(c.ctx.Err(), "Cannot write because context was closed") default: - n, err := c.doIO(c.conn.Write, p, connTimeoutWrite) + if err := c.conn.SetWriteDeadline(time.Now().Add(connTimeoutWrite)); err != nil { + c.Close() // nolint: gosec + return 0, errors.Annotate(err, "Cannot set write deadline to the socket") + } + n, err := c.conn.Write(p) c.logger.Debugw("Write to stream", "bytes", n, "error", err) stats.EgressTraffic(n) + if err != nil { + c.Close() // nolint: gosec + } return n, err } @@ -75,48 +76,30 @@ func (c *Conn) Write(p []byte) (int, error) { func (c *Conn) Read(p []byte) (int, error) { select { case <-c.ctx.Done(): + c.Close() // nolint: gosec return 0, errors.Annotate(c.ctx.Err(), "Cannot read because context was closed") default: - n, err := c.doIO(c.conn.Read, p, connTimeoutRead) + if err := c.conn.SetReadDeadline(time.Now().Add(connTimeoutRead)); err != nil { + c.Close() // nolint: gosec + return 0, errors.Annotate(err, "Cannot set read deadline to the socket") + } + n, err := c.conn.Read(p) c.logger.Debugw("Read from stream", "bytes", n, "error", err) stats.IngressTraffic(n) - - return n, err - } -} - -func (c *Conn) doIO(callback ioFunc, p []byte, timeout time.Duration) (int, error) { - resChan := make(chan ioResult, 1) - timer := time.NewTimer(timeout) - - go func() { - n, err := callback(p) - resChan <- ioResult{n: n, err: err} - }() - - select { - case res := <-resChan: - timer.Stop() - if res.err != nil { + if err != nil { c.Close() // nolint: gosec } - return res.n, res.err - case <-c.ctx.Done(): - timer.Stop() - c.Close() // nolint: gosec - return 0, errors.Annotate(c.ctx.Err(), "Cannot do IO because context is closed") - case <-timer.C: - c.Close() // nolint: gosec - return 0, errors.Annotate(c.ctx.Err(), "Timeout on IO operation") + + return n, err } } // Close closes underlying net.Conn instance. func (c *Conn) Close() error { - defer c.logger.Debugw("Close connection") - + c.logger.Debugw("Close connection") c.cancel() + return c.conn.Close() }