@@ -2,21 +2,33 @@ package onedrive_sharelink
22
33import (
44 "context"
5+ "fmt"
6+ "io"
7+ "net/http"
58 "strings"
9+ "sync"
610 "time"
711
812 "github.com/OpenListTeam/OpenList/v4/internal/driver"
913 "github.com/OpenListTeam/OpenList/v4/internal/errs"
1014 "github.com/OpenListTeam/OpenList/v4/internal/model"
15+ "github.com/OpenListTeam/OpenList/v4/internal/net"
1116 "github.com/OpenListTeam/OpenList/v4/pkg/cron"
17+ "github.com/OpenListTeam/OpenList/v4/pkg/http_range"
18+ "github.com/OpenListTeam/OpenList/v4/pkg/singleflight"
1219 "github.com/OpenListTeam/OpenList/v4/pkg/utils"
1320 log "github.com/sirupsen/logrus"
1421)
1522
23+ const headerTTL = 25 * time .Minute
24+
1625type OnedriveSharelink struct {
1726 model.Storage
1827 cron * cron.Cron
1928 Addition
29+
30+ headerMu sync.RWMutex
31+ sg singleflight.Group [http.Header ]
2032}
2133
2234func (d * OnedriveSharelink ) Config () driver.Config {
@@ -38,17 +50,20 @@ func (d *OnedriveSharelink) Init(ctx context.Context) error {
3850 d .cron = cron .NewCron (time .Hour * 1 )
3951 d .cron .Do (func () {
4052 var err error
41- d . Headers , err = d .getHeaders (ctx )
53+ h , err : = d .getHeaders (ctx )
4254 if err != nil {
4355 log .Errorf ("%+v" , err )
56+ return
4457 }
58+ d .storeHeaders (h )
4559 })
4660
4761 // Get initial headers
48- d . Headers , err = d .getHeaders (ctx )
62+ h , err : = d .getHeaders (ctx )
4963 if err != nil {
5064 return err
5165 }
66+ d .storeHeaders (h )
5267
5368 return nil
5469}
@@ -76,21 +91,18 @@ func (d *OnedriveSharelink) Link(ctx context.Context, file model.Obj, args model
7691 // Cut the first char and the last char
7792 uniqueId = uniqueId [1 : len (uniqueId )- 1 ]
7893 url := d .downloadLinkPrefix + uniqueId
79- header := d .Headers
8094
81- // If the headers are older than 30 minutes, get new headers
82- if d .HeaderTime < time .Now ().Unix ()- 1800 {
83- var err error
84- log .Debug ("headers are older than 30 minutes, get new headers" )
85- header , err = d .getHeaders (ctx )
86- if err != nil {
87- return nil , err
88- }
95+ header , err := d .getValidHeaders (ctx )
96+ if err != nil {
97+ return nil , err
8998 }
9099
91100 return & model.Link {
92101 URL : url ,
93102 Header : header ,
103+ RangeReader : rangeReaderFunc (func (ctx context.Context , hr http_range.Range ) (io.ReadCloser , error ) {
104+ return d .rangeReadWithRefresh (ctx , url , hr )
105+ }),
94106 }, nil
95107}
96108
@@ -129,3 +141,102 @@ func (d *OnedriveSharelink) Put(ctx context.Context, dstDir model.Obj, stream mo
129141//}
130142
131143var _ driver.Driver = (* OnedriveSharelink )(nil )
144+
145+ // rangeReadWithRefresh tries once with current headers, and if the response
146+ // looks invalid (error status or html login page), it refreshes headers and retries.
147+ func (d * OnedriveSharelink ) rangeReadWithRefresh (ctx context.Context , url string , hr http_range.Range ) (io.ReadCloser , error ) {
148+ tryOnce := func (header http.Header ) (io.ReadCloser , error ) {
149+ h := cloneHeader (header )
150+ if h == nil {
151+ h = http.Header {}
152+ }
153+ h = http_range .ApplyRangeToHttpHeader (hr , h )
154+ resp , err := net .RequestHttp (ctx , http .MethodGet , h , url )
155+ if err != nil {
156+ return nil , err
157+ }
158+ ct := strings .ToLower (resp .Header .Get ("Content-Type" ))
159+ if strings .Contains (ct , "text/html" ) {
160+ _ = resp .Body .Close ()
161+ return nil , fmt .Errorf ("unexpected html response" )
162+ }
163+ return resp .Body , nil
164+ }
165+
166+ header , err := d .getValidHeaders (ctx )
167+ if err != nil {
168+ return nil , err
169+ }
170+ if body , err := tryOnce (header ); err == nil {
171+ return body , nil
172+ }
173+
174+ // refresh and retry once
175+ header , err = d .refreshHeaders (ctx )
176+ if err != nil {
177+ return nil , err
178+ }
179+ return tryOnce (header )
180+ }
181+
182+ type rangeReaderFunc func (ctx context.Context , hr http_range.Range ) (io.ReadCloser , error )
183+
184+ func (f rangeReaderFunc ) RangeRead (ctx context.Context , hr http_range.Range ) (io.ReadCloser , error ) {
185+ return f (ctx , hr )
186+ }
187+
188+ func cloneHeader (header http.Header ) http.Header {
189+ if header == nil {
190+ return nil
191+ }
192+ return header .Clone ()
193+ }
194+
195+ func (d * OnedriveSharelink ) headerSnapshot () http.Header {
196+ d .headerMu .RLock ()
197+ defer d .headerMu .RUnlock ()
198+ return cloneHeader (d .Headers )
199+ }
200+
201+ func (d * OnedriveSharelink ) storeHeaders (header http.Header ) {
202+ if header == nil {
203+ return
204+ }
205+ d .headerMu .Lock ()
206+ d .Headers = header
207+ d .HeaderTime = time .Now ().Unix ()
208+ d .headerMu .Unlock ()
209+ }
210+
211+ func (d * OnedriveSharelink ) headersExpired () bool {
212+ d .headerMu .RLock ()
213+ defer d .headerMu .RUnlock ()
214+ return time .Since (time .Unix (d .HeaderTime , 0 )) > headerTTL
215+ }
216+
217+ func (d * OnedriveSharelink ) refreshHeaders (ctx context.Context ) (http.Header , error ) {
218+ header , err , _ := d .sg .Do ("refresh" , func () (http.Header , error ) {
219+ h , e := d .getHeaders (ctx )
220+ if e != nil {
221+ return nil , e
222+ }
223+ d .storeHeaders (h )
224+ return h , nil
225+ })
226+ return header , err
227+ }
228+
229+ func (d * OnedriveSharelink ) getValidHeaders (ctx context.Context ) (http.Header , error ) {
230+ if h := d .headerSnapshot (); h != nil && ! d .headersExpired () {
231+ return h , nil
232+ }
233+ h , err := d .refreshHeaders (ctx )
234+ if err != nil {
235+ if h2 := d .headerSnapshot (); h2 != nil {
236+ log .Warnf ("onedrive_sharelink: use cached headers after refresh failure: %+v" , err )
237+ return h2 , nil
238+ }
239+ return nil , err
240+ }
241+ return h , nil
242+ }
0 commit comments