From 03704c6a4c5991a1a799af81b9226c9dcb6626eb Mon Sep 17 00:00:00 2001 From: jason yang Date: Wed, 8 May 2024 17:17:20 +0900 Subject: [PATCH] refactor oras download code Signed-off-by: jason yang --- CHANGELOG.md | 5 ++ internal/pkg/client/oras/oras.go | 19 ++++-- internal/pkg/client/oras/pull.go | 10 +-- internal/pkg/client/progress.go | 11 ++++ internal/pkg/client/progress_roundtrip.go | 77 +++++++++++++---------- 5 files changed, 75 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a777af2a6..bc0d800df5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ The Singularity Project has been and re-branded as Apptainer. For older changes see the [archived Singularity change log](https://github.com/apptainer/singularity/blob/release-3.8/CHANGELOG.md). +## Changes for v1.3.x + +- Fixed the issue that oras download progress bar gets stuck + when downloading large images. + ## v1.3.1 - \[2024-04-24\] - Make 'apptainer build' work with signed Docker containers. diff --git a/internal/pkg/client/oras/oras.go b/internal/pkg/client/oras/oras.go index d373f305c9..8ad40c2ed3 100644 --- a/internal/pkg/client/oras/oras.go +++ b/internal/pkg/client/oras/oras.go @@ -33,9 +33,11 @@ import ( ) // DownloadImage downloads a SIF image specified by an oci reference to a file using the included credentials -func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, pb *client.DownloadProgressBar) error { - im, err := remoteImage(ref, ociAuth, noHTTPS, pb) +func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) error { + rt := client.NewRoundTripper(ctx, nil) + im, err := remoteImage(ref, ociAuth, noHTTPS, rt) if err != nil { + rt.ProgressShutdown() return err } @@ -47,6 +49,7 @@ func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Dock // manifest, err := im.Manifest() if err != nil { + rt.ProgressShutdown() return err } if len(manifest.Layers) != 1 { @@ -55,12 +58,14 @@ func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Dock layer := manifest.Layers[0] if layer.MediaType != SifLayerMediaTypeV1 && layer.MediaType != SifLayerMediaTypeProto { + rt.ProgressShutdown() return fmt.Errorf("invalid layer mediatype: %s", layer.MediaType) } // Retrieve image to a temporary OCI layout tmpDir, err := os.MkdirTemp("", "oras-tmp-") if err != nil { + rt.ProgressShutdown() return err } defer func() { @@ -70,12 +75,17 @@ func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Dock }() tmpLayout, err := layout.Write(tmpDir, empty.Index) if err != nil { + rt.ProgressShutdown() return err } if err := tmpLayout.AppendImage(im); err != nil { + rt.ProgressShutdown() return err } + rt.ProgressComplete() + rt.ProgressWait() + // Copy SIF blob out from layout to final location blob, err := tmpLayout.Blob(layer.Digest) if err != nil { @@ -235,7 +245,7 @@ func sha256sum(r io.Reader) (result string, nBytes int64, err error) { } // remoteImage returns a v1.Image for the provided remote ref. -func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, pb *client.DownloadProgressBar) (v1.Image, error) { +func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, rt *client.RoundTripper) (v1.Image, error) { ref = strings.TrimPrefix(ref, "oras://") ref = strings.TrimPrefix(ref, "//") @@ -249,8 +259,7 @@ func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, p return nil, fmt.Errorf("invalid reference %q: %w", ref, err) } remoteOpts := []remote.Option{AuthOptn(ociAuth)} - if pb != nil { - rt := client.NewRoundTripper(nil, pb) + if rt != nil { remoteOpts = append(remoteOpts, remote.WithTransport(rt)) } im, err := remote.Image(ir, remoteOpts...) diff --git a/internal/pkg/client/oras/pull.go b/internal/pkg/client/oras/pull.go index 58edc0a154..1d19fc01f8 100644 --- a/internal/pkg/client/oras/pull.go +++ b/internal/pkg/client/oras/pull.go @@ -15,11 +15,9 @@ import ( "os" "github.com/apptainer/apptainer/internal/pkg/cache" - "github.com/apptainer/apptainer/internal/pkg/client" "github.com/apptainer/apptainer/internal/pkg/util/fs" "github.com/apptainer/apptainer/pkg/sylog" ocitypes "github.com/containers/image/v5/types" - "golang.org/x/term" ) // pull will pull an oras image into the cache if directTo="", or a specific file if directTo is set. @@ -29,13 +27,9 @@ func pull(ctx context.Context, imgCache *cache.Handle, directTo, pullFrom string return "", fmt.Errorf("failed to get checksum for %s: %s", pullFrom, err) } - var pb *client.DownloadProgressBar - if term.IsTerminal(2) { - pb = &client.DownloadProgressBar{} - } if directTo != "" { sylog.Infof("Downloading oras image") - if err := DownloadImage(ctx, directTo, pullFrom, ociAuth, noHTTPS, pb); err != nil { + if err := DownloadImage(ctx, directTo, pullFrom, ociAuth, noHTTPS); err != nil { return "", fmt.Errorf("unable to Download Image: %v", err) } imagePath = directTo @@ -49,7 +43,7 @@ func pull(ctx context.Context, imgCache *cache.Handle, directTo, pullFrom string if !cacheEntry.Exists { sylog.Infof("Downloading oras image") - if err := DownloadImage(ctx, cacheEntry.TmpPath, pullFrom, ociAuth, noHTTPS, pb); err != nil { + if err := DownloadImage(ctx, cacheEntry.TmpPath, pullFrom, ociAuth, noHTTPS); err != nil { return "", fmt.Errorf("unable to Download Image: %v", err) } if cacheFileHash, err := ImageHash(cacheEntry.TmpPath); err != nil { diff --git a/internal/pkg/client/progress.go b/internal/pkg/client/progress.go index 79cfc27479..38ff2ac7ea 100644 --- a/internal/pkg/client/progress.go +++ b/internal/pkg/client/progress.go @@ -18,6 +18,17 @@ import ( "github.com/vbauerster/mpb/v8/decor" ) +var defaultOption = []mpb.BarOption{ + mpb.PrependDecorators( + decor.Counters(decor.SizeB1024(0), "%.1f / %.1f"), + ), + mpb.AppendDecorators( + decor.Percentage(), + decor.AverageSpeed(decor.SizeB1024(0), " % .1f "), + decor.AverageETA(decor.ET_STYLE_GO), + ), +} + func initProgressBar(totalSize int64) (*mpb.Progress, *mpb.Bar) { p := mpb.New() diff --git a/internal/pkg/client/progress_roundtrip.go b/internal/pkg/client/progress_roundtrip.go index b00a128ab1..aa292d6c0c 100644 --- a/internal/pkg/client/progress_roundtrip.go +++ b/internal/pkg/client/progress_roundtrip.go @@ -11,65 +11,74 @@ package client import ( - "io" + "context" "net/http" + + "github.com/apptainer/apptainer/pkg/sylog" + "github.com/vbauerster/mpb/v8" + "golang.org/x/term" ) -const contentSizeThreshold = 1024 +const contentSizeThreshold = 64 * 1024 type RoundTripper struct { inner http.RoundTripper - pb *DownloadProgressBar + p *mpb.Progress + bars []*mpb.Bar + sizes []int64 } -func NewRoundTripper(inner http.RoundTripper, pb *DownloadProgressBar) *RoundTripper { +func NewRoundTripper(ctx context.Context, inner http.RoundTripper) *RoundTripper { if inner == nil { inner = http.DefaultTransport } rt := RoundTripper{ inner: inner, - pb: pb, + } + + if term.IsTerminal(2) && sylog.GetLevel() >= 0 { + rt.p = mpb.NewWithContext(ctx) } return &rt } -type rtReadCloser struct { - inner io.ReadCloser - pb *DownloadProgressBar -} +func (t *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if t.p == nil || req.Method != http.MethodGet { + return t.inner.RoundTrip(req) + } -func (r *rtReadCloser) Read(p []byte) (int, error) { - return r.inner.Read(p) + resp, err := t.inner.RoundTrip(req) + if resp != nil && resp.Body != nil && resp.ContentLength >= contentSizeThreshold { + bar := t.p.AddBar(resp.ContentLength, defaultOption...) + t.bars = append(t.bars, bar) + t.sizes = append(t.sizes, resp.ContentLength) + resp.Body = bar.ProxyReader(resp.Body) + } + return resp, err } -func (r *rtReadCloser) Close() error { - err := r.inner.Close() - if err == nil { - r.pb.Wait() - } else { - r.pb.Abort(false) +// ProgressComplete overrides all progress bars, setting them to 100% complete. +func (t *RoundTripper) ProgressComplete() { + if t.p != nil { + for i, bar := range t.bars { + bar.SetCurrent(t.sizes[i]) + } } - - return err } -func (t *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if t.pb != nil && req.Body != nil && req.ContentLength >= contentSizeThreshold { - t.pb.Init(req.ContentLength) - req.Body = &rtReadCloser{ - inner: t.pb.bar.ProxyReader(req.Body), - pb: t.pb, - } +// ProgressWait shuts down the mpb Progress container by waiting for all bars to +// complete. +func (t *RoundTripper) ProgressWait() { + if t.p != nil { + t.p.Wait() } - resp, err := t.inner.RoundTrip(req) - if t.pb != nil && resp != nil && resp.Body != nil && resp.ContentLength >= contentSizeThreshold { - t.pb.Init(resp.ContentLength) - resp.Body = &rtReadCloser{ - inner: t.pb.bar.ProxyReader(resp.Body), - pb: t.pb, - } +} + +// ProgressShutdown immediately shuts down the mpb Progress container. +func (t *RoundTripper) ProgressShutdown() { + if t.p != nil { + t.p.Shutdown() } - return resp, err }