Skip to content

Commit

Permalink
refactor oras download code
Browse files Browse the repository at this point in the history
Signed-off-by: jason yang <jasonyangshadow@gmail.com>
  • Loading branch information
JasonYangShadow committed May 9, 2024
1 parent 823ca7c commit 03704c6
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 47 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ The Singularity Project has been
and re-branded as Apptainer.
For older changes see the [archived Singularity change log](https://github.com/apptainer/singularity/blob/release-3.8/CHANGELOG.md).

## Changes for v1.3.x

- Fixed the issue that oras download progress bar gets stuck
when downloading large images.

## v1.3.1 - \[2024-04-24\]

- Make 'apptainer build' work with signed Docker containers.
Expand Down
19 changes: 14 additions & 5 deletions internal/pkg/client/oras/oras.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ import (
)

// DownloadImage downloads a SIF image specified by an oci reference to a file using the included credentials
func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, pb *client.DownloadProgressBar) error {
im, err := remoteImage(ref, ociAuth, noHTTPS, pb)
func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool) error {
rt := client.NewRoundTripper(ctx, nil)
im, err := remoteImage(ref, ociAuth, noHTTPS, rt)
if err != nil {
rt.ProgressShutdown()
return err
}

Expand All @@ -47,6 +49,7 @@ func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Dock
//
manifest, err := im.Manifest()
if err != nil {
rt.ProgressShutdown()
return err
}
if len(manifest.Layers) != 1 {
Expand All @@ -55,12 +58,14 @@ func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Dock
layer := manifest.Layers[0]
if layer.MediaType != SifLayerMediaTypeV1 &&
layer.MediaType != SifLayerMediaTypeProto {
rt.ProgressShutdown()
return fmt.Errorf("invalid layer mediatype: %s", layer.MediaType)
}

// Retrieve image to a temporary OCI layout
tmpDir, err := os.MkdirTemp("", "oras-tmp-")
if err != nil {
rt.ProgressShutdown()
return err
}
defer func() {
Expand All @@ -70,12 +75,17 @@ func DownloadImage(ctx context.Context, path, ref string, ociAuth *ocitypes.Dock
}()
tmpLayout, err := layout.Write(tmpDir, empty.Index)
if err != nil {
rt.ProgressShutdown()
return err
}
if err := tmpLayout.AppendImage(im); err != nil {
rt.ProgressShutdown()
return err
}

rt.ProgressComplete()
rt.ProgressWait()

// Copy SIF blob out from layout to final location
blob, err := tmpLayout.Blob(layer.Digest)
if err != nil {
Expand Down Expand Up @@ -235,7 +245,7 @@ func sha256sum(r io.Reader) (result string, nBytes int64, err error) {
}

// remoteImage returns a v1.Image for the provided remote ref.
func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, pb *client.DownloadProgressBar) (v1.Image, error) {
func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, rt *client.RoundTripper) (v1.Image, error) {
ref = strings.TrimPrefix(ref, "oras://")
ref = strings.TrimPrefix(ref, "//")

Expand All @@ -249,8 +259,7 @@ func remoteImage(ref string, ociAuth *ocitypes.DockerAuthConfig, noHTTPS bool, p
return nil, fmt.Errorf("invalid reference %q: %w", ref, err)
}
remoteOpts := []remote.Option{AuthOptn(ociAuth)}
if pb != nil {
rt := client.NewRoundTripper(nil, pb)
if rt != nil {
remoteOpts = append(remoteOpts, remote.WithTransport(rt))
}
im, err := remote.Image(ir, remoteOpts...)
Expand Down
10 changes: 2 additions & 8 deletions internal/pkg/client/oras/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@ import (
"os"

"github.com/apptainer/apptainer/internal/pkg/cache"
"github.com/apptainer/apptainer/internal/pkg/client"
"github.com/apptainer/apptainer/internal/pkg/util/fs"
"github.com/apptainer/apptainer/pkg/sylog"
ocitypes "github.com/containers/image/v5/types"
"golang.org/x/term"
)

// pull will pull an oras image into the cache if directTo="", or a specific file if directTo is set.
Expand All @@ -29,13 +27,9 @@ func pull(ctx context.Context, imgCache *cache.Handle, directTo, pullFrom string
return "", fmt.Errorf("failed to get checksum for %s: %s", pullFrom, err)
}

var pb *client.DownloadProgressBar
if term.IsTerminal(2) {
pb = &client.DownloadProgressBar{}
}
if directTo != "" {
sylog.Infof("Downloading oras image")
if err := DownloadImage(ctx, directTo, pullFrom, ociAuth, noHTTPS, pb); err != nil {
if err := DownloadImage(ctx, directTo, pullFrom, ociAuth, noHTTPS); err != nil {
return "", fmt.Errorf("unable to Download Image: %v", err)
}
imagePath = directTo
Expand All @@ -49,7 +43,7 @@ func pull(ctx context.Context, imgCache *cache.Handle, directTo, pullFrom string
if !cacheEntry.Exists {
sylog.Infof("Downloading oras image")

if err := DownloadImage(ctx, cacheEntry.TmpPath, pullFrom, ociAuth, noHTTPS, pb); err != nil {
if err := DownloadImage(ctx, cacheEntry.TmpPath, pullFrom, ociAuth, noHTTPS); err != nil {
return "", fmt.Errorf("unable to Download Image: %v", err)
}
if cacheFileHash, err := ImageHash(cacheEntry.TmpPath); err != nil {
Expand Down
11 changes: 11 additions & 0 deletions internal/pkg/client/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ import (
"github.com/vbauerster/mpb/v8/decor"
)

var defaultOption = []mpb.BarOption{
mpb.PrependDecorators(
decor.Counters(decor.SizeB1024(0), "%.1f / %.1f"),
),
mpb.AppendDecorators(
decor.Percentage(),
decor.AverageSpeed(decor.SizeB1024(0), " % .1f "),
decor.AverageETA(decor.ET_STYLE_GO),
),
}

func initProgressBar(totalSize int64) (*mpb.Progress, *mpb.Bar) {
p := mpb.New()

Expand Down
77 changes: 43 additions & 34 deletions internal/pkg/client/progress_roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,65 +11,74 @@
package client

import (
"io"
"context"
"net/http"

"github.com/apptainer/apptainer/pkg/sylog"
"github.com/vbauerster/mpb/v8"
"golang.org/x/term"
)

const contentSizeThreshold = 1024
const contentSizeThreshold = 64 * 1024

type RoundTripper struct {
inner http.RoundTripper
pb *DownloadProgressBar
p *mpb.Progress
bars []*mpb.Bar
sizes []int64
}

func NewRoundTripper(inner http.RoundTripper, pb *DownloadProgressBar) *RoundTripper {
func NewRoundTripper(ctx context.Context, inner http.RoundTripper) *RoundTripper {
if inner == nil {
inner = http.DefaultTransport
}

rt := RoundTripper{
inner: inner,
pb: pb,
}

if term.IsTerminal(2) && sylog.GetLevel() >= 0 {
rt.p = mpb.NewWithContext(ctx)
}

return &rt
}

type rtReadCloser struct {
inner io.ReadCloser
pb *DownloadProgressBar
}
func (t *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if t.p == nil || req.Method != http.MethodGet {
return t.inner.RoundTrip(req)
}

func (r *rtReadCloser) Read(p []byte) (int, error) {
return r.inner.Read(p)
resp, err := t.inner.RoundTrip(req)
if resp != nil && resp.Body != nil && resp.ContentLength >= contentSizeThreshold {
bar := t.p.AddBar(resp.ContentLength, defaultOption...)
t.bars = append(t.bars, bar)
t.sizes = append(t.sizes, resp.ContentLength)
resp.Body = bar.ProxyReader(resp.Body)
}
return resp, err
}

func (r *rtReadCloser) Close() error {
err := r.inner.Close()
if err == nil {
r.pb.Wait()
} else {
r.pb.Abort(false)
// ProgressComplete overrides all progress bars, setting them to 100% complete.
func (t *RoundTripper) ProgressComplete() {
if t.p != nil {
for i, bar := range t.bars {
bar.SetCurrent(t.sizes[i])
}
}

return err
}

func (t *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if t.pb != nil && req.Body != nil && req.ContentLength >= contentSizeThreshold {
t.pb.Init(req.ContentLength)
req.Body = &rtReadCloser{
inner: t.pb.bar.ProxyReader(req.Body),
pb: t.pb,
}
// ProgressWait shuts down the mpb Progress container by waiting for all bars to
// complete.
func (t *RoundTripper) ProgressWait() {
if t.p != nil {
t.p.Wait()
}
resp, err := t.inner.RoundTrip(req)
if t.pb != nil && resp != nil && resp.Body != nil && resp.ContentLength >= contentSizeThreshold {
t.pb.Init(resp.ContentLength)
resp.Body = &rtReadCloser{
inner: t.pb.bar.ProxyReader(resp.Body),
pb: t.pb,
}
}

// ProgressShutdown immediately shuts down the mpb Progress container.
func (t *RoundTripper) ProgressShutdown() {
if t.p != nil {
t.p.Shutdown()
}
return resp, err
}

0 comments on commit 03704c6

Please sign in to comment.