Skip to content

Commit

Permalink
抽出readMessage函数,方便对该函数单独benchmark (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong committed Sep 10, 2023
1 parent 3842c68 commit 52687c9
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 154 deletions.
298 changes: 148 additions & 150 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ type Conn struct {
bp bytespool.BytesPool // 实验某些特性加的字段

delayWrite
readHeadArray [enum.MaxFrameHeaderSize]byte
fragmentFramePayload []byte // 存放分片帧的缓冲区
bufioPayload []byte
fragmentFrameHeader *frame.FrameHeader
}

func setNoDelay(c net.Conn, noDelay bool) error {
Expand Down Expand Up @@ -132,10 +136,36 @@ func decode(payload []byte) ([]byte, error) {
return o.Bytes(), nil
}

func (c *Conn) ReadLoop() error {
func (c *Conn) ReadLoop() (err error) {
c.OnOpen(c)

return c.readLoop()
defer func() {
// c.OnClose(c, err)
c.Close()
if c.fr.IsInit() {
defer func() {
c.fr.Release()
c.fr.BufPtr()
}()
}
}()

if c.br != nil {
newSize := int(1024 * c.bufioMultipleTimesPayloadSize)
if c.br.Size() != newSize {
// TODO sync.Pool管理
(*bufio2.Reader2)(unsafe.Pointer(c.br)).ResetBuf(make([]byte, newSize))
}
// bufio 模式才会使用payload
c.bufioPayload = *bytespool.GetBytes(1024 + enum.MaxFrameHeaderSize)
}

for {
err = c.readMessage()
if err != nil {
return err
}
}
}

func (c *Conn) StartReadLoop() {
Expand All @@ -144,7 +174,7 @@ func (c *Conn) StartReadLoop() {
}()
}

func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, payload *[]byte) (f frame.Frame, err error) {
func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, bufioPayload *[]byte) (f frame.Frame, err error) {
if c.readTimeout > 0 {
err = c.c.SetReadDeadline(time.Now().Add(c.readTimeout))
if err != nil {
Expand All @@ -156,7 +186,7 @@ func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, payload
if c.fr.IsInit() {
f, err = frame.ReadFrameFromWindows(&c.fr, headArray, c.windowsMultipleTimesPayloadSize)
} else {
f, err = frame.ReadFrameFromReader(c.br, headArray, payload)
f, err = frame.ReadFrameFromReader(c.br, headArray, bufioPayload)
}
if err != nil {
c.Callback.OnClose(c, err)
Expand All @@ -172,186 +202,154 @@ func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, payload
}

// 读取websocket frame.Frame的循环
func (c *Conn) readLoop() error {
var f frame.Frame
var fragmentFrameHeader *frame.FrameHeader

defer c.Close()

var err error
var op opcode.Opcode

if c.fr.IsInit() {
defer func() {
c.fr.Release()
c.fr.BufPtr()
}()
func (c *Conn) readMessage() (err error) {
// 从网络读取数据
f, err := c.readDataFromNet(&c.readHeadArray, &c.bufioPayload)
if err != nil {
return err
}

var fragmentFrameBuf []byte
var headArray [enum.MaxFrameHeaderSize]byte

var payload []byte
if c.br != nil {
newSize := int(1024 * c.bufioMultipleTimesPayloadSize)
if c.br.Size() != newSize {
// TODO sync.Pool管理
(*bufio2.Reader2)(unsafe.Pointer(c.br)).ResetBuf(make([]byte, newSize))
}
// bufio 模式才会使用payload
payload = *bytespool.GetBytes(1024 + enum.MaxFrameHeaderSize)
op := f.Opcode
if c.fragmentFrameHeader != nil {
op = c.fragmentFrameHeader.Opcode
}

for {
rsv1 := f.GetRsv1()
// 检查Rsv1 rsv2 Rfd, errsv3
if rsv1 && c.failRsv1(op) || f.GetRsv2() || f.GetRsv3() {
err = fmt.Errorf("%w:Rsv1(%t) Rsv2(%t) rsv2(%t) compression:%t", ErrRsv123, rsv1, f.GetRsv2(), f.GetRsv3(), c.compression)
return c.writeErrAndOnClose(ProtocolError, err)
}

// 从网络读取数据
f, err = c.readDataFromNet(&headArray, &payload)
if err != nil {
return err
}
fin := f.GetFin()
if c.fragmentFrameHeader != nil && !f.Opcode.IsControl() {
if f.Opcode == 0 {
c.fragmentFramePayload = append(c.fragmentFramePayload, f.Payload...)

op = f.Opcode
if fragmentFrameHeader != nil {
op = fragmentFrameHeader.Opcode
}
// 分段的在这返回
if fin {
// 解压缩
if c.fragmentFrameHeader.GetRsv1() && c.decompression {
tempBuf, err := decode(c.fragmentFramePayload)
if err != nil {
return err
}
c.fragmentFramePayload = tempBuf
}
// 这里的check按道理应该放到f.Fin前面, 会更符合rfc的标准, 前提是c.utf8Check修改成流式解析
// TODO c.utf8Check 修改成流式解析
if c.fragmentFrameHeader.Opcode == opcode.Text && !c.utf8Check(c.fragmentFramePayload) {
c.Callback.OnClose(c, ErrTextNotUTF8)
return ErrTextNotUTF8
}

rsv1 := f.GetRsv1()
// 检查Rsv1 rsv2 Rsv3
if rsv1 && c.failRsv1(op) || f.GetRsv2() || f.GetRsv3() {
err = fmt.Errorf("%w:Rsv1(%t) Rsv2(%t) rsv2(%t) compression:%t", ErrRsv123, rsv1, f.GetRsv2(), f.GetRsv3(), c.compression)
return c.writeErrAndOnClose(ProtocolError, err)
c.Callback.OnMessage(c, c.fragmentFrameHeader.Opcode, c.fragmentFramePayload)
c.fragmentFramePayload = c.fragmentFramePayload[0:0]
c.fragmentFrameHeader = nil
}
return nil
}

fin := f.GetFin()
if fragmentFrameHeader != nil && !f.Opcode.IsControl() {
if f.Opcode == 0 {
fragmentFrameBuf = append(fragmentFrameBuf, f.Payload...)

// 分段的在这返回
if fin {
// 解压缩
if fragmentFrameHeader.GetRsv1() && c.decompression {
tempBuf, err := decode(fragmentFrameBuf)
if err != nil {
return err
}
fragmentFrameBuf = tempBuf
}
// 这里的check按道理应该放到f.Fin前面, 会更符合rfc的标准, 前提是c.utf8Check修改成流式解析
// TODO c.utf8Check 修改成流式解析
if fragmentFrameHeader.Opcode == opcode.Text && !c.utf8Check(fragmentFrameBuf) {
c.Callback.OnClose(c, ErrTextNotUTF8)
return ErrTextNotUTF8
}
c.writeErrAndOnClose(ProtocolError, ErrFrameOpcode)
return ErrFrameOpcode
}

c.Callback.OnMessage(c, fragmentFrameHeader.Opcode, fragmentFrameBuf)
fragmentFrameBuf = fragmentFrameBuf[0:0]
fragmentFrameHeader = nil
}
continue
if f.Opcode == opcode.Text || f.Opcode == opcode.Binary {
if !fin {
prevFrame := f.FrameHeader
// 第一次分段
if len(c.fragmentFramePayload) == 0 {
c.fragmentFramePayload = append(c.fragmentFramePayload, f.Payload...)
f.Payload = nil
}

c.writeErrAndOnClose(ProtocolError, ErrFrameOpcode)
return ErrFrameOpcode
// 让fragmentFrame的Payload指向readBuf, readBuf 原引用直接丢弃
c.fragmentFrameHeader = &prevFrame
return
}

if f.Opcode == opcode.Text || f.Opcode == opcode.Binary {
if !fin {
prevFrame := f.FrameHeader
// 第一次分段
if len(fragmentFrameBuf) == 0 {
fragmentFrameBuf = append(fragmentFrameBuf, f.Payload...)
f.Payload = nil
}

// 让fragmentFrame的Payload指向readBuf, readBuf 原引用直接丢弃
fragmentFrameHeader = &prevFrame
continue
if rsv1 && c.decompression {
// 不分段的解压缩
f.Payload, err = decode(f.Payload)
if err != nil {
return err
}
}

if rsv1 && c.decompression {
// 不分段的解压缩
f.Payload, err = decode(f.Payload)
if err != nil {
return err
}
if f.Opcode == opcode.Text {
if !c.utf8Check(f.Payload) {
c.c.Close()
c.Callback.OnClose(c, ErrTextNotUTF8)
return ErrTextNotUTF8
}
}

if f.Opcode == opcode.Text {
if !c.utf8Check(f.Payload) {
c.c.Close()
c.Callback.OnClose(c, ErrTextNotUTF8)
return ErrTextNotUTF8
}
}
c.Callback.OnMessage(c, f.Opcode, f.Payload)
return
}

c.Callback.OnMessage(c, f.Opcode, f.Payload)
continue
if f.Opcode == Close || f.Opcode == Ping || f.Opcode == Pong {
// 对方发的控制消息太大
if f.PayloadLen > maxControlFrameSize {
c.writeErrAndOnClose(ProtocolError, ErrMaxControlFrameSize)
return ErrMaxControlFrameSize
}
// Close, Ping, Pong 不能分片
if !fin {
c.writeErrAndOnClose(ProtocolError, ErrNOTBeFragmented)
return ErrNOTBeFragmented
}

if f.Opcode == Close || f.Opcode == Ping || f.Opcode == Pong {
// 对方发的控制消息太大
if f.PayloadLen > maxControlFrameSize {
c.writeErrAndOnClose(ProtocolError, ErrMaxControlFrameSize)
return ErrMaxControlFrameSize
}
// Close, Ping, Pong 不能分片
if !fin {
c.writeErrAndOnClose(ProtocolError, ErrNOTBeFragmented)
return ErrNOTBeFragmented
if f.Opcode == Close {
if len(f.Payload) == 0 {
return c.writeErrAndOnClose(NormalClosure, ErrClosePayloadTooSmall)
}

if f.Opcode == Close {
if len(f.Payload) == 0 {
return c.writeErrAndOnClose(NormalClosure, ErrClosePayloadTooSmall)
}

if len(f.Payload) < 2 {
return c.writeErrAndOnClose(ProtocolError, ErrClosePayloadTooSmall)
}

if !c.utf8Check(f.Payload[2:]) {
return c.writeErrAndOnClose(ProtocolError, ErrTextNotUTF8)
}
if len(f.Payload) < 2 {
return c.writeErrAndOnClose(ProtocolError, ErrClosePayloadTooSmall)
}

code := binary.BigEndian.Uint16(f.Payload)
if !validCode(code) {
return c.writeErrAndOnClose(ProtocolError, ErrCloseValue)
}
if !c.utf8Check(f.Payload[2:]) {
return c.writeErrAndOnClose(ProtocolError, ErrTextNotUTF8)
}

// 回敬一个close包
if err := c.WriteTimeout(Close, f.Payload, 2*time.Second); err != nil {
return err
}
code := binary.BigEndian.Uint16(f.Payload)
if !validCode(code) {
return c.writeErrAndOnClose(ProtocolError, ErrCloseValue)
}

err = bytesToCloseErrMsg(f.Payload)
c.Callback.OnClose(c, err)
// 回敬一个close包
if err := c.WriteTimeout(Close, f.Payload, 2*time.Second); err != nil {
return err
}

if f.Opcode == Ping {
// 回一个pong包
if c.replyPing {
if err := c.WriteTimeout(Pong, f.Payload, 2*time.Second); err != nil {
c.Callback.OnClose(c, err)
return err
}
c.Callback.OnMessage(c, f.Opcode, f.Payload)
continue
}
}
err = bytesToCloseErrMsg(f.Payload)
c.Callback.OnClose(c, err)
return err
}

if f.Opcode == Pong && c.ignorePong {
continue
if f.Opcode == Ping {
// 回一个pong包
if c.replyPing {
if err := c.WriteTimeout(Pong, f.Payload, 2*time.Second); err != nil {
c.Callback.OnClose(c, err)
return err
}
c.Callback.OnMessage(c, f.Opcode, f.Payload)
return
}
}

c.Callback.OnMessage(c, f.Opcode, nil)
continue
if f.Opcode == Pong && c.ignorePong {
return
}
// 检查Opcode
c.writeErrAndOnClose(ProtocolError, ErrOpcode)
return ErrOpcode

c.Callback.OnMessage(c, f.Opcode, nil)
return
}
// 检查Opcode
c.writeErrAndOnClose(ProtocolError, ErrOpcode)
return ErrOpcode
}

type wrapBuffer struct {
Expand Down
4 changes: 0 additions & 4 deletions server_handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ var (
strWebSocketKey = "Sec-WebSocket-Key"
)

type ConnOption struct {
Config
}

func writeHeaderVal(w io.Writer, val []byte) (err error) {
if _, err = w.Write(val); err != nil {
return
Expand Down
4 changes: 4 additions & 0 deletions server_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ package quickws

type ServerOption func(*ConnOption)

type ConnOption struct {
Config
}

// 1.配置压缩和解压缩
func WithServerDecompressAndCompress() ServerOption {
return func(o *ConnOption) {
Expand Down

0 comments on commit 52687c9

Please sign in to comment.