Skip to content

[WIP] Pull progress per layer #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions commands/compose.go
Original file line number Diff line number Diff line change
@@ -97,8 +97,8 @@ func downloadModelsOnlyIfNotFound(desktopClient *desktop.Client, models []string
}
return false
}) {
_, _, err = desktopClient.Pull(model, func(s string) {
_ = sendInfo(s)
_, _, err = desktopClient.Pull(model, func(msg *desktop.ProgressMessage) {
_ = sendInfo(msg.Message)
})
if err != nil {
_ = sendErrorf("Failed to pull model: %v", err)
253 changes: 253 additions & 0 deletions commands/progress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
package commands

import (
"fmt"
"os"
"sort"
"strings"
"sync"

"github.com/docker/model-cli/desktop"
"golang.org/x/term"
)

// LayerState represents the state of a layer download
type LayerState struct {
ID string
Status string
Size uint64
Current uint64
Complete bool
}

// ProgressTracker manages multiple layer progress displays
type ProgressTracker struct {
layers map[string]*LayerState
mutex sync.RWMutex
lastLines int
isActive bool
}

// NewProgressTracker creates a new progress tracker
func NewProgressTracker() *ProgressTracker {
return &ProgressTracker{
layers: make(map[string]*LayerState),
isActive: true,
}
}

// UpdateLayer updates the progress for a specific layer
func (pt *ProgressTracker) UpdateLayer(layerID string, size, current uint64, message *desktop.ProgressMessage) {
pt.mutex.Lock()
defer pt.mutex.Unlock()

if !pt.isActive {
return
}

// Determine status from message
status := "Downloading"
complete := false
if message.Type == "success" {
status = "Pull complete"
complete = true
current = size
} else if message.Type == "error" {
status = "Pull failed"
complete = true
current = size
}

// Shorten layer ID to first 12 characters like Docker
shortID := layerID
if len(layerID) > 12 {
if strings.HasPrefix(layerID, "sha256:") {
shortID = layerID[7:19] // Skip "sha256:" and take next 12 chars
} else {
shortID = layerID[:12]
}
}

pt.layers[layerID] = &LayerState{
ID: shortID,
Status: status,
Size: size,
Current: current,
Complete: complete,
}

pt.render()
}

// Stop stops the progress tracker and shows final completion state
func (pt *ProgressTracker) Stop() {
pt.mutex.Lock()
defer pt.mutex.Unlock()
pt.isActive = false

// If we have layers, show the final state
if len(pt.layers) > 0 {
pt.showFinalState()
}
}

// HasLayers returns true if the tracker has any layers
func (pt *ProgressTracker) HasLayers() bool {
pt.mutex.RLock()
defer pt.mutex.RUnlock()
return len(pt.layers) > 0
}

// showFinalState displays the final completion status for all layers
func (pt *ProgressTracker) showFinalState() {
if len(pt.layers) == 0 {
return
}

// Clear current progress display
pt.clearLines()

// Sort layers by ID for consistent display order
var layerIDs []string
for id := range pt.layers {
layerIDs = append(layerIDs, id)
}
sort.Strings(layerIDs)

// Show final status for each layer
for _, id := range layerIDs {
layer := pt.layers[id]
// Force all layers to show as "Pull complete" in final state
fmt.Printf("%s: Pull complete\n", layer.ID)
}
}

// clearLines clears the previously printed progress lines
func (pt *ProgressTracker) clearLines() {
if pt.lastLines > 0 {
// Move cursor up and clear lines
for i := 0; i < pt.lastLines; i++ {
fmt.Print("\033[A\033[K")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this will be portable to cmd.exe-based shells.

}
pt.lastLines = 0
}
}

// render displays the current progress for all layers
func (pt *ProgressTracker) render() {
if !pt.isActive {
return
}

// Clear previous output
pt.clearLines()

// Sort layers by ID for consistent display order
var layerIDs []string
for id := range pt.layers {
layerIDs = append(layerIDs, id)
}
sort.Strings(layerIDs)

lines := 0
for _, id := range layerIDs {
layer := pt.layers[id]
line := pt.formatLayerProgress(layer)
fmt.Println(line)
lines++
}

pt.lastLines = lines
}

// getTerminalWidth returns the terminal width, or 80 as default
func getTerminalWidth() int {
width, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
// Default to 80 columns if we can't detect terminal size
return 80
}
return width
}
Comment on lines +163 to +171
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also double check this function with cmd.exe. Even if it defaults to 80, sometimes cmd.exe will loop output at 1 character before the width. But I'm also not sure if that's true if the terminal is put into raw mode.


// formatLayerProgress formats a single layer's progress line
func (pt *ProgressTracker) formatLayerProgress(layer *LayerState) string {
if layer.Complete {
return fmt.Sprintf("%s: %s", layer.ID, layer.Status)
}

// Format sizes in MB or GB based on size
var currentStr, sizeStr string
const gbThreshold = 1024 * 1024 * 1024 // 1GB in bytes

if layer.Size >= gbThreshold {
currentGB := float64(layer.Current) / 1024 / 1024 / 1024
sizeGB := float64(layer.Size) / 1024 / 1024 / 1024
currentStr = fmt.Sprintf("%.2fGB", currentGB)
sizeStr = fmt.Sprintf("%.2fGB", sizeGB)
} else {
currentMB := float64(layer.Current) / 1024 / 1024
sizeMB := float64(layer.Size) / 1024 / 1024
currentStr = fmt.Sprintf("%.2fMB", currentMB)
sizeStr = fmt.Sprintf("%.2fMB", sizeMB)
}

// Check terminal width to decide format
termWidth := getTerminalWidth()
// Minimum width needed for progress bar format:
// "layerID: Status [===> ] current/size"
// Estimate: 12 + 2 + 12 + 3 + 50 + 3 + 20 = ~102 characters
minWidthForProgressBar := 100

if termWidth < minWidthForProgressBar {
// Use simple format when terminal is too narrow
return fmt.Sprintf("%s: %s %s/%s", layer.ID, layer.Status, currentStr, sizeStr)
}

// Calculate progress percentage
var percent float64
if layer.Size > 0 {
percent = float64(layer.Current) / float64(layer.Size) * 100
}

// Create progress bar (50 characters wide)
barWidth := 50
filled := int(percent / 100 * float64(barWidth))
if filled > barWidth {
filled = barWidth
}

bar := strings.Repeat("=", filled)
if filled < barWidth && filled > 0 {
bar += ">"
}
bar += strings.Repeat(" ", barWidth-len(bar))

return fmt.Sprintf("%s: %s [%s] %s/%s",
layer.ID,
layer.Status,
bar,
currentStr,
sizeStr,
)
}

// MultiLayerTUIProgress creates a progress function that handles multiple layers
func MultiLayerTUIProgress() (func(*desktop.ProgressMessage), *ProgressTracker) {
tracker := NewProgressTracker()

progressFunc := func(msg *desktop.ProgressMessage) {
if msg.Type == "progress" {
if msg.Layer.ID != "" && msg.Layer.Size > 0 {
// Use layer-specific information when available
tracker.UpdateLayer(msg.Layer.ID, msg.Layer.Size, msg.Layer.Current, msg)
} else {
// Fallback: use simple progress display for backward compatibility
// Clear the line and show the progress message
fmt.Print("\r\033[K", msg.Message)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question re: cmd.exe.

}
}
}

return progressFunc, tracker
}
128 changes: 128 additions & 0 deletions commands/progress_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package commands

import (
"strings"
"testing"

"github.com/docker/model-cli/desktop"
)

func TestFormatLayerProgress(t *testing.T) {
tracker := NewProgressTracker()

tests := []struct {
name string
layer *LayerState
expectBar bool
description string
}{
{
name: "complete layer",
layer: &LayerState{
ID: "1a12b4ea7c0c",
Status: "Pull complete",
Size: 100 * 1024 * 1024, // 100MB
Current: 100 * 1024 * 1024, // 100MB
Complete: true,
},
expectBar: false,
description: "completed layers should not show progress bars",
},
{
name: "downloading layer",
layer: &LayerState{
ID: "b58ee5cb7152",
Status: "Downloading",
Size: 200 * 1024 * 1024, // 200MB
Current: 50 * 1024 * 1024, // 50MB
Complete: false,
},
expectBar: true, // This depends on terminal width
description: "downloading layers format depends on terminal width",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tracker.formatLayerProgress(tt.layer)

// Check that completed layers don't have progress bars
if tt.layer.Complete {
if strings.Contains(result, "[") || strings.Contains(result, "]") {
t.Errorf("Complete layer should not have progress bar, got: %s", result)
}
expectedFormat := tt.layer.ID + ": " + tt.layer.Status
if result != expectedFormat {
t.Errorf("Expected %q, got %q", expectedFormat, result)
}
} else {
// For incomplete layers, check that we have the size information
if !strings.Contains(result, "MB") {
t.Errorf("Expected size information in MB, got: %s", result)
}

// Check that the layer ID and status are present
if !strings.Contains(result, tt.layer.ID) {
t.Errorf("Expected layer ID %s in result: %s", tt.layer.ID, result)
}
if !strings.Contains(result, tt.layer.Status) {
t.Errorf("Expected status %s in result: %s", tt.layer.Status, result)
}
}
})
}
}

func TestGetTerminalWidth(t *testing.T) {
// Test that getTerminalWidth returns a reasonable value
width := getTerminalWidth()

// Should return at least the default of 80
if width < 80 {
t.Errorf("Expected terminal width >= 80, got %d", width)
}

// Should return a reasonable maximum (most terminals are < 1000 chars wide)
if width > 1000 {
t.Errorf("Expected terminal width <= 1000, got %d", width)
}
}

func TestProgressTrackerBasicFunctionality(t *testing.T) {
tracker := NewProgressTracker()

// Test that tracker starts with no layers
if tracker.HasLayers() {
t.Error("New tracker should have no layers")
}

// Add a layer
tracker.UpdateLayer("sha256:1a12b4ea7c0c123456789", 100*1024*1024, 50*1024*1024, &desktop.ProgressMessage{
Type: "progress",
Message: "Downloading",
})

// Test that tracker now has layers
if !tracker.HasLayers() {
t.Error("Tracker should have layers after UpdateLayer")
}

// Test that layer ID is shortened correctly
tracker.mutex.RLock()
if len(tracker.layers) != 1 {
t.Errorf("Expected 1 layer, got %d", len(tracker.layers))
}

for _, layer := range tracker.layers {
if layer.ID != "1a12b4ea7c0c" {
t.Errorf("Expected shortened ID '1a12b4ea7c0c', got '%s'", layer.ID)
}
if layer.Status != "Downloading" {
t.Errorf("Expected status 'Downloading', got '%s'", layer.Status)
}
if layer.Complete {
t.Error("Layer should not be complete")
}
}
tracker.mutex.RUnlock()
}
Loading