From 91a836ea89a81e4d3a2df24e904e4b818e62a4ca Mon Sep 17 00:00:00 2001 From: guonaihong Date: Mon, 13 May 2024 21:23:40 +0800 Subject: [PATCH] =?UTF-8?q?+=E6=B5=8B=E8=AF=95=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bytespool/bytes_pool.go | 3 + deflate/compression_decode.go | 2 +- deflate/compression_decode_test.go | 60 +++++++++++++++++ deflate/compression_encode.go | 8 +-- deflate/compression_encode_no_context.go | 4 +- deflate/compression_encode_no_context_test.go | 2 +- deflate/compression_encode_test.go | 64 +++++++++++++++++++ 7 files changed, 135 insertions(+), 8 deletions(-) create mode 100644 deflate/compression_decode_test.go create mode 100644 deflate/compression_encode_test.go diff --git a/bytespool/bytes_pool.go b/bytespool/bytes_pool.go index f1be934..2db4885 100644 --- a/bytespool/bytes_pool.go +++ b/bytespool/bytes_pool.go @@ -77,6 +77,9 @@ func GetBytes(n int) (rv *[]byte) { } func PutBytes(bytes *[]byte) { + if cap(*bytes) == 0 { + return + } if cap(*bytes) < enum.MaxFrameHeaderSize { panic("putBytes: bytes is too small") } diff --git a/deflate/compression_decode.go b/deflate/compression_decode.go index d9073ab..2c123f2 100644 --- a/deflate/compression_decode.go +++ b/deflate/compression_decode.go @@ -33,7 +33,7 @@ func decompressNoContextTakeoverInner(r io.Reader) io.ReadCloser { } // 无上下文-解压缩 -func decompressNoContextTakeover(payload []byte) (*[]byte, error) { +func DecompressNoContextTakeover(payload []byte) (*[]byte, error) { pr := bytes.NewReader(payload) r := decompressNoContextTakeoverInner(pr) diff --git a/deflate/compression_decode_test.go b/deflate/compression_decode_test.go new file mode 100644 index 0000000..9a9e459 --- /dev/null +++ b/deflate/compression_decode_test.go @@ -0,0 +1,60 @@ +// Copyright 2021-2024 antlabs. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package deflate + +import ( + "reflect" + "testing" +) + +func TestDecompressNoContextTakeover(t *testing.T) { + type args struct { + payload []byte + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + {name: "测试1", args: args{payload: []byte("hellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohello")}, want: []byte("hellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohello"), wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 压缩下一段数据 + gotPayload, err := CompressNoContextTakeover(tt.args.payload, 1) + if (err != nil) != tt.wantErr { + t.Errorf("CompressNoContextTakeover() error = %v, wantErr %v", err, tt.wantErr) + return + } + // 新建上下文解压缩 + de, err := NewDecompressContextTakeover(8) + if (err != nil) != tt.wantErr { + t.Errorf("NewDecompressContextTakeover() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // 解压 + gotPayload2, err := de.Decompress(*gotPayload, 0) + if (err != nil) != tt.wantErr { + t.Errorf("Decompress() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !reflect.DeepEqual(*gotPayload2, tt.want) { + t.Errorf("DecompressNoContextTakeover() = %v, want %v", gotPayload2, tt.want) + } + }) + } +} diff --git a/deflate/compression_encode.go b/deflate/compression_encode.go index 6edb965..8e6a752 100644 --- a/deflate/compression_encode.go +++ b/deflate/compression_encode.go @@ -23,7 +23,7 @@ import ( "github.com/klauspost/compress/flate" ) -type EnCompressContextTakeover struct { +type CompressContextTakeover struct { dict historyDict w *flate.Writer } @@ -38,18 +38,18 @@ func (w *wrapBuffer) Close() error { var enTail = []byte{0, 0, 0xff, 0xff} -func NewEncompressContextTakeover(bit uint8) (en *EnCompressContextTakeover, err error) { +func NewCompressContextTakeover(bit uint8) (en *CompressContextTakeover, err error) { size := 1 << bit w, err := flate.NewWriterWindow(nil, size) if err != nil { return nil, err } - en = &EnCompressContextTakeover{w: w} + en = &CompressContextTakeover{w: w} en.dict.InitHistoryDict(size) return en, nil } -func (e *EnCompressContextTakeover) Compress(payload []byte) (encodePayload *[]byte, err error) { +func (e *CompressContextTakeover) Compress(payload []byte) (encodePayload *[]byte, err error) { encodeBuf := bytespool.GetBytes(len(payload) + enum.MaxFrameHeaderSize) diff --git a/deflate/compression_encode_no_context.go b/deflate/compression_encode_no_context.go index 0748879..450d48f 100644 --- a/deflate/compression_encode_no_context.go +++ b/deflate/compression_encode_no_context.go @@ -24,7 +24,7 @@ var ( const ( minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 maxCompressionLevel = flate.BestCompression - defaultCompressionLevel = 1 + DefaultCompressionLevel = 1 ) var ( @@ -56,7 +56,7 @@ func CompressNoContextTakeover(payload []byte, level int) (encodeBuf *[]byte, er encodeBuf = bytespool.GetBytes(len(payload) + enum.MaxFrameHeaderSize) out := wrapBuffer{Buffer: bytes.NewBuffer((*encodeBuf)[:0])} - w := compressNoContextTakeoverInner(&out, defaultCompressionLevel) + w := compressNoContextTakeoverInner(&out, DefaultCompressionLevel) if _, err = io.Copy(w, bytes.NewReader(payload)); err != nil { return nil, err } diff --git a/deflate/compression_encode_no_context_test.go b/deflate/compression_encode_no_context_test.go index 5558359..c543619 100644 --- a/deflate/compression_encode_no_context_test.go +++ b/deflate/compression_encode_no_context_test.go @@ -30,7 +30,7 @@ func Test_compressNoContextTakeover(t *testing.T) { return } - gotDecode, err := decompressNoContextTakeover(*gotEncodeBuf) + gotDecode, err := DecompressNoContextTakeover(*gotEncodeBuf) if (err != nil) != tt.wantErr { t.Errorf("decompressNoContextTakeover() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/deflate/compression_encode_test.go b/deflate/compression_encode_test.go new file mode 100644 index 0000000..e460de9 --- /dev/null +++ b/deflate/compression_encode_test.go @@ -0,0 +1,64 @@ +// Copyright 2021-2024 antlabs. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package deflate + +import ( + "reflect" + "testing" + + "github.com/klauspost/compress/flate" +) + +func TestEnCompressContextTakeover_Compress(t *testing.T) { + type fields struct { + dict historyDict + w *flate.Writer + } + type args struct { + payload []byte + } + tests := []struct { + name string + fields fields + args args + wantEncodePayload *[]byte + wantErr bool + }{ + {name: "压缩测试1", args: args{payload: []byte("hello world 12345678910")}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e, err := NewCompressContextTakeover(8) + if (err != nil) != tt.wantErr { + t.Errorf("NewCompressContextTakeover() error = %v, wantErr %v", err, tt.wantErr) + return + } + gotEncodePayload, err := e.Compress(tt.args.payload) + if (err != nil) != tt.wantErr { + t.Errorf("CompressContextTakeover.Compress() error = %v, wantErr %v", err, tt.wantErr) + return + } + + decodePayload, err := DecompressNoContextTakeover(*gotEncodePayload) + if (err != nil) != tt.wantErr { + t.Errorf("CompressContextTakeover.Compress() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !reflect.DeepEqual(tt.args.payload, *decodePayload) { + t.Errorf("CompressContextTakeover.Compress() = %v, want %v", tt.args.payload, *decodePayload) + } + }) + } +}