diff --git a/deflate/compression_decode.go b/deflate/compression_decode.go index 7a8b622..a97daf5 100644 --- a/deflate/compression_decode.go +++ b/deflate/compression_decode.go @@ -26,37 +26,6 @@ import ( var tailBytes = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff} -// 无上下文-解压缩 -func newDecompressNoContextTakeover(r io.Reader) io.ReadCloser { - fr, _ := flateReaderPool.Get().(io.ReadCloser) - fr.(flate.Resetter).Reset(io.MultiReader(r, bytes.NewReader(tailBytes)), nil) - return &flateReadWrapper{fr} -} - -// 无上下文-解压缩 -func DecompressNoContextTakeover(payload *[]byte) (*[]byte, error) { - pr := bytes.NewReader(*payload) - r := newDecompressNoContextTakeover(pr) - - // 从池里面拿buf, 这里的2是经验值,解压缩之后是2倍的大小 - decodeBuf := bytespool.GetBytes(len(*payload)*2 + enum.MaxFrameHeaderSize) - // 包装下 - out := bytes.NewBuffer((*decodeBuf)[:0]) - // 解压缩 - if _, err := io.Copy(out, r); err != nil { - return nil, err - } - // 拿到解压缩之后的buf - outBytes := out.Bytes() - // 如果解压缩之后的buf和从池里面拿的buf不一样,就把从池里面拿的buf放回去 - if unsafe.SliceData(*decodeBuf) != unsafe.SliceData(outBytes) { - bytespool.PutBytes(decodeBuf) - } - - r.Close() - return &outBytes, nil -} - // 上下文-解压缩 type DeCompressContextTakeover struct { dict historyDict @@ -71,9 +40,14 @@ func NewDecompressContextTakeover(bit uint8) (*DeCompressContextTakeover, error) } // 解压缩 +// d有值时,上下文接管的情况调用 +// d为nil时, 上下文不接管的情况下调用,利用了go,对象为空,调用函数不会panic的特性 func (d *DeCompressContextTakeover) Decompress(payload *[]byte, maxMessage int64) (outBytes2 *[]byte, err error) { // 获取dict - dict := d.dict.GetData() + var dict []byte + if d != nil { + dict = d.dict.GetData() + } // 拿到解码器 rc, _ := flateReaderPool.Get().(io.Reader) @@ -107,8 +81,11 @@ func (d *DeCompressContextTakeover) Decompress(payload *[]byte, maxMessage int64 if unsafe.SliceData(*decodeBuf) != unsafe.SliceData(outBytes) { bytespool.PutBytes(decodeBuf) } - // 写入dict - d.dict.Write(out.Bytes()) + + if d != nil { + // 写入dict + d.dict.Write(out.Bytes()) + } // 返回解压缩之后的buf return &outBytes, nil } diff --git a/deflate/compression_decode_no_context.go b/deflate/compression_decode_no_context.go deleted file mode 100644 index fc7faf0..0000000 --- a/deflate/compression_decode_no_context.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -package deflate - -import "io" - -type flateReadWrapper struct { - fr io.ReadCloser -} - -func (r *flateReadWrapper) Read(p []byte) (int, error) { - if r.fr == nil { - return 0, io.ErrClosedPipe - } - n, err := r.fr.Read(p) - if err == io.EOF { - // Preemptively place the reader back in the pool. This helps with - // scenarios where the application does not call NextReader() soon after - // this final read. - r.Close() - } - return n, err -} - -func (r *flateReadWrapper) Close() error { - if r.fr == nil { - return io.ErrClosedPipe - } - err := r.fr.Close() - flateReaderPool.Put(r.fr) - r.fr = nil - return err -} diff --git a/deflate/compression_decode_test.go b/deflate/compression_decode_test.go index 9f2d271..44cce43 100644 --- a/deflate/compression_decode_test.go +++ b/deflate/compression_decode_test.go @@ -25,6 +25,7 @@ func TestDecompressNoContextTakeover(t *testing.T) { type args struct { payload []byte } + tests := []struct { name string args args diff --git a/deflate/compression_encode_no_context_test.go b/deflate/compression_encode_no_context_test.go index 4d2203b..864991a 100644 --- a/deflate/compression_encode_no_context_test.go +++ b/deflate/compression_encode_no_context_test.go @@ -30,7 +30,8 @@ func Test_compressNoContextTakeover(t *testing.T) { return } - gotDecode, err := DecompressNoContextTakeover(gotEncodeBuf) + var decCtx *DeCompressContextTakeover + gotDecode, err := decCtx.Decompress(gotEncodeBuf, 0) if (err != nil) != tt.wantErr { t.Errorf("decompressNoContextTakeover() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/deflate/compression_encode_test.go b/deflate/compression_encode_test.go index 279a45a..643a21e 100644 --- a/deflate/compression_encode_test.go +++ b/deflate/compression_encode_test.go @@ -50,7 +50,8 @@ func TestEnCompressContextTakeover_Compress(t *testing.T) { return } - decodePayload, err := DecompressNoContextTakeover(gotEncodePayload) + var decCtx *DeCompressContextTakeover + decodePayload, err := decCtx.Decompress(gotEncodePayload, 0) if (err != nil) != tt.wantErr { t.Errorf("CompressContextTakeover.Compress() error = %v, wantErr %v", err, tt.wantErr) return