Skip to content

Commit 174eae8

Browse files
authored
perf(stream): optimize CacheFullAndWriter for better memory management (#1584)
* perf(stream): optimize CacheFullAndWriter for better memory management * fix(stream): ensure proper seek handling in CacheFullAndWriter for improved data integrity
1 parent b9f058f commit 174eae8

File tree

3 files changed

+108
-68
lines changed

3 files changed

+108
-68
lines changed

internal/model/obj.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ type FileStreamer interface {
4848
// for a non-seekable Stream, if Read is called, this function won't work.
4949
// caches the full Stream and writes it to writer (if provided, even if the stream is already cached).
5050
CacheFullAndWriter(up *UpdateProgress, writer io.Writer) (File, error)
51-
SetTmpFile(file File)
5251
// if the Stream is not a File and is not cached, returns nil.
5352
GetFile() File
5453
}

internal/stream/stream.go

Lines changed: 73 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ type FileStream struct {
2828
ForceStreamUpload bool
2929
Exist model.Obj //the file existed in the destination, we can reuse some info since we wil overwrite it
3030
utils.Closers
31-
32-
tmpFile model.File //if present, tmpFile has full content, it will be deleted at last
3331
peekBuff *buffer.Reader
3432
size int64
3533
oriReader io.Reader // the original reader, used for caching
@@ -39,12 +37,6 @@ func (f *FileStream) GetSize() int64 {
3937
if f.size > 0 {
4038
return f.size
4139
}
42-
if file, ok := f.tmpFile.(*os.File); ok {
43-
info, err := file.Stat()
44-
if err == nil {
45-
return info.Size()
46-
}
47-
}
4840
return f.Obj.GetSize()
4941
}
5042

@@ -71,14 +63,13 @@ func (f *FileStream) Close() error {
7163
if errors.Is(err1, os.ErrClosed) {
7264
err1 = nil
7365
}
74-
if file, ok := f.tmpFile.(*os.File); ok {
66+
if file, ok := f.Reader.(*os.File); ok {
7567
err2 = os.RemoveAll(file.Name())
7668
if err2 != nil {
7769
err2 = errs.NewErr(err2, "failed to remove tmpFile [%s]", file.Name())
78-
} else {
79-
f.tmpFile = nil
8070
}
8171
}
72+
f.Reader = nil
8273

8374
return errors.Join(err1, err2)
8475
}
@@ -94,50 +85,50 @@ func (f *FileStream) SetExist(obj model.Obj) {
9485
// It's not thread-safe!
9586
func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writer) (model.File, error) {
9687
if cache := f.GetFile(); cache != nil {
88+
_, err := cache.Seek(0, io.SeekStart)
89+
if err != nil {
90+
return nil, err
91+
}
9792
if writer == nil {
9893
return cache, nil
9994
}
100-
_, err := cache.Seek(0, io.SeekStart)
101-
if err == nil {
102-
reader := f.Reader
103-
if up != nil {
104-
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
105-
*up = model.UpdateProgressWithRange(*up, 50, 100)
106-
reader = &ReaderUpdatingProgress{
107-
Reader: &SimpleReaderWithSize{
108-
Reader: reader,
109-
Size: f.GetSize(),
110-
},
111-
UpdateProgress: cacheProgress,
112-
}
113-
}
114-
_, err = utils.CopyWithBuffer(writer, reader)
115-
if err == nil {
116-
_, err = cache.Seek(0, io.SeekStart)
95+
reader := f.Reader
96+
if up != nil {
97+
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
98+
*up = model.UpdateProgressWithRange(*up, 50, 100)
99+
reader = &ReaderUpdatingProgress{
100+
Reader: &SimpleReaderWithSize{
101+
Reader: reader,
102+
Size: f.GetSize(),
103+
},
104+
UpdateProgress: cacheProgress,
117105
}
118106
}
107+
_, err = utils.CopyWithBuffer(writer, reader)
108+
if err == nil {
109+
_, err = cache.Seek(0, io.SeekStart)
110+
}
119111
if err != nil {
120112
return nil, err
121113
}
122114
return cache, nil
123115
}
124116

125117
reader := f.Reader
126-
if up != nil {
127-
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
128-
*up = model.UpdateProgressWithRange(*up, 50, 100)
129-
reader = &ReaderUpdatingProgress{
130-
Reader: &SimpleReaderWithSize{
131-
Reader: reader,
132-
Size: f.GetSize(),
133-
},
134-
UpdateProgress: cacheProgress,
118+
if f.peekBuff != nil {
119+
f.peekBuff.Seek(0, io.SeekStart)
120+
if writer != nil {
121+
_, err := utils.CopyWithBuffer(writer, f.peekBuff)
122+
if err != nil {
123+
return nil, err
124+
}
125+
f.peekBuff.Seek(0, io.SeekStart)
135126
}
127+
reader = f.oriReader
136128
}
137129
if writer != nil {
138130
reader = io.TeeReader(reader, writer)
139131
}
140-
141132
if f.GetSize() < 0 {
142133
if f.peekBuff == nil {
143134
f.peekBuff = &buffer.Reader{}
@@ -174,7 +165,6 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ
174165
}
175166
}
176167
}
177-
178168
tmpF, err := utils.CreateTempFile(reader, 0)
179169
if err != nil {
180170
return nil, err
@@ -191,14 +181,33 @@ func (f *FileStream) CacheFullAndWriter(up *model.UpdateProgress, writer io.Writ
191181
return peekF, nil
192182
}
193183

194-
f.Reader = reader
184+
if up != nil {
185+
cacheProgress := model.UpdateProgressWithRange(*up, 0, 50)
186+
*up = model.UpdateProgressWithRange(*up, 50, 100)
187+
size := f.GetSize()
188+
if f.peekBuff != nil {
189+
peekSize := f.peekBuff.Size()
190+
cacheProgress(float64(peekSize) / float64(size) * 100)
191+
size -= peekSize
192+
}
193+
reader = &ReaderUpdatingProgress{
194+
Reader: &SimpleReaderWithSize{
195+
Reader: reader,
196+
Size: size,
197+
},
198+
UpdateProgress: cacheProgress,
199+
}
200+
}
201+
202+
if f.peekBuff != nil {
203+
f.oriReader = reader
204+
} else {
205+
f.Reader = reader
206+
}
195207
return f.cache(f.GetSize())
196208
}
197209

198210
func (f *FileStream) GetFile() model.File {
199-
if f.tmpFile != nil {
200-
return f.tmpFile
201-
}
202211
if file, ok := f.Reader.(model.File); ok {
203212
return file
204213
}
@@ -234,12 +243,29 @@ func (f *FileStream) RangeRead(httpRange http_range.Range) (io.Reader, error) {
234243

235244
func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
236245
if maxCacheSize > int64(conf.MaxBufferLimit) {
237-
tmpF, err := utils.CreateTempFile(f.Reader, f.GetSize())
246+
size := f.GetSize()
247+
reader := f.Reader
248+
if f.peekBuff != nil {
249+
size -= f.peekBuff.Size()
250+
reader = f.oriReader
251+
}
252+
tmpF, err := utils.CreateTempFile(reader, size)
238253
if err != nil {
239254
return nil, err
240255
}
256+
if f.peekBuff != nil {
257+
f.Add(utils.CloseFunc(func() error {
258+
return errors.Join(tmpF.Close(), os.RemoveAll(tmpF.Name()))
259+
}))
260+
peekF, err := buffer.NewPeekFile(f.peekBuff, tmpF)
261+
if err != nil {
262+
return nil, err
263+
}
264+
f.Reader = peekF
265+
return peekF, nil
266+
}
267+
241268
f.Add(tmpF)
242-
f.tmpFile = tmpF
243269
f.Reader = tmpF
244270
return tmpF, nil
245271
}
@@ -248,7 +274,7 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
248274
f.peekBuff = &buffer.Reader{}
249275
f.oriReader = f.Reader
250276
}
251-
bufSize := maxCacheSize - int64(f.peekBuff.Size())
277+
bufSize := maxCacheSize - f.peekBuff.Size()
252278
var buf []byte
253279
if conf.MmapThreshold > 0 && bufSize >= int64(conf.MmapThreshold) {
254280
m, err := mmap.Alloc(int(bufSize))
@@ -267,7 +293,7 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
267293
return nil, fmt.Errorf("failed to read all data: (expect =%d, actual =%d) %w", bufSize, n, err)
268294
}
269295
f.peekBuff.Append(buf)
270-
if int64(f.peekBuff.Size()) >= f.GetSize() {
296+
if f.peekBuff.Size() >= f.GetSize() {
271297
f.Reader = f.peekBuff
272298
f.oriReader = nil
273299
} else {
@@ -276,12 +302,6 @@ func (f *FileStream) cache(maxCacheSize int64) (model.File, error) {
276302
return f.peekBuff, nil
277303
}
278304

279-
func (f *FileStream) SetTmpFile(file model.File) {
280-
f.AddIfCloser(file)
281-
f.tmpFile = file
282-
f.Reader = file
283-
}
284-
285305
var _ model.FileStreamer = (*SeekableStream)(nil)
286306
var _ model.FileStreamer = (*FileStream)(nil)
287307

internal/stream/stream_test.go

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@ import (
77
"io"
88
"testing"
99

10-
"github.com/OpenListTeam/OpenList/v4/internal/conf"
1110
"github.com/OpenListTeam/OpenList/v4/internal/model"
1211
"github.com/OpenListTeam/OpenList/v4/pkg/http_range"
12+
"github.com/OpenListTeam/OpenList/v4/pkg/utils"
1313
)
1414

1515
func TestFileStream_RangeRead(t *testing.T) {
16-
conf.MaxBufferLimit = 16 * 1024 * 1024
1716
type args struct {
1817
httpRange http_range.Range
1918
}
@@ -73,16 +72,38 @@ func TestFileStream_RangeRead(t *testing.T) {
7372
}
7473
})
7574
}
76-
t.Run("after", func(t *testing.T) {
77-
if f.GetFile() == nil {
78-
t.Error("not cached")
79-
}
80-
buf2 := make([]byte, len(buf))
81-
if _, err := io.ReadFull(f, buf2); err != nil {
82-
t.Errorf("FileStream.Read() error = %v", err)
83-
}
84-
if !bytes.Equal(buf, buf2) {
85-
t.Errorf("FileStream.Read() = %s, want %s", buf2, buf)
86-
}
87-
})
75+
if f.GetFile() == nil {
76+
t.Error("not cached")
77+
}
78+
buf2 := make([]byte, len(buf))
79+
if _, err := io.ReadFull(f, buf2); err != nil {
80+
t.Errorf("FileStream.Read() error = %v", err)
81+
}
82+
if !bytes.Equal(buf, buf2) {
83+
t.Errorf("FileStream.Read() = %s, want %s", buf2, buf)
84+
}
85+
}
86+
87+
func TestFileStream_With_PreHash(t *testing.T) {
88+
buf := []byte("github.com/OpenListTeam/OpenList")
89+
f := &FileStream{
90+
Obj: &model.Object{
91+
Size: int64(len(buf)),
92+
},
93+
Reader: io.NopCloser(bytes.NewReader(buf)),
94+
}
95+
96+
const hashSize int64 = 20
97+
reader, _ := f.RangeRead(http_range.Range{Start: 0, Length: hashSize})
98+
preHash, _ := utils.HashReader(utils.SHA1, reader)
99+
if preHash == "" {
100+
t.Error("preHash is empty")
101+
}
102+
tmpF, fullHash, _ := CacheFullAndHash(f, nil, utils.SHA1)
103+
fmt.Println(fullHash)
104+
fileFullHash, _ := utils.HashFile(utils.SHA1, tmpF)
105+
fmt.Println(fileFullHash)
106+
if fullHash != fileFullHash {
107+
t.Errorf("fullHash and fileFullHash should match: fullHash=%s fileFullHash=%s", fullHash, fileFullHash)
108+
}
88109
}

0 commit comments

Comments
 (0)