Skip to content

Commit

Permalink
feat(config): only download when changed
Browse files Browse the repository at this point in the history
resolves #5176
  • Loading branch information
JanDeDobbeleer committed Jun 28, 2024
1 parent 12a732d commit 0449aa8
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 52 deletions.
4 changes: 2 additions & 2 deletions src/engine/migrate_glyphs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"strings"
"time"

"github.com/jandedobbeleer/oh-my-posh/src/platform"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
)

type codePoints map[uint64]uint64
Expand All @@ -24,7 +24,7 @@ func getGlyphCodePoints() (codePoints, error) {
return codePoints, &ConnectionError{reason: err.Error()}
}

response, err := platform.Client.Do(request)
response, err := net.HTTPClient.Do(request)
if err != nil {
return codePoints, err
}
Expand Down
4 changes: 2 additions & 2 deletions src/font/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"net/http"
"net/url"

"github.com/jandedobbeleer/oh-my-posh/src/platform"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
)

func Download(fontPath string) ([]byte, error) {
Expand Down Expand Up @@ -42,7 +42,7 @@ func getRemoteFile(location string) (data []byte, err error) {
if err != nil {
return nil, err
}
resp, err := platform.Client.Do(req)
resp, err := net.HTTPClient.Do(req)
if err != nil {
return
}
Expand Down
4 changes: 2 additions & 2 deletions src/font/fonts.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"strings"
"time"

"github.com/jandedobbeleer/oh-my-posh/src/platform"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
)

type release struct {
Expand Down Expand Up @@ -57,7 +57,7 @@ func fetchFontAssets(repo string) ([]*Asset, error) {
}

req.Header.Add("Accept", "application/vnd.github.v3+json")
response, err := platform.Client.Do(req)
response, err := net.HTTPClient.Do(req)
if err != nil || response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to get %s release", repo)
}
Expand Down
19 changes: 10 additions & 9 deletions src/log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,27 @@ func Plain() {
plain = true
}

func Info(message string) {
if !enabled {
return
}
log.WriteString(message)
}

func Trace(start time.Time, args ...string) {
if !enabled {
return
}

elapsed := time.Since(start)
fn, _ := funcSpec()
header := fmt.Sprintf("%s(%s) - %s", fn, strings.Join(args, " "), Text(elapsed.String()).Yellow().Plain())

printLn(trace, header)
}

func Debug(message string) {
func Debug(message ...string) {
if !enabled {
return
}

fn, line := funcSpec()
header := fmt.Sprintf("%s:%d", fn, line)
printLn(debug, header, message)

printLn(debug, header, strings.Join(message, " "))
}

func Error(err error) {
Expand All @@ -54,6 +51,7 @@ func Error(err error) {
}
fn, line := funcSpec()
header := fmt.Sprintf("%s:%d", fn, line)

printLn(bug, header, err.Error())
}

Expand All @@ -66,11 +64,14 @@ func funcSpec() (string, int) {
if !OK {
return "", 0
}

fn := runtime.FuncForPC(pc).Name()
fn = fn[strings.LastIndex(fn, ".")+1:]
file = filepath.Base(file)

if strings.HasPrefix(fn, "func") {
return file, line
}

return fmt.Sprintf("%s:%s", file, fn), line
}
118 changes: 118 additions & 0 deletions src/platform/config/download.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package config

import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"

"github.com/jandedobbeleer/oh-my-posh/src/log"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
)

func Download(cachePath, url string) (string, error) {
defer log.Trace(time.Now(), cachePath, url)

configPath, shouldUpdate := shouldUpdate(cachePath, url)
if !shouldUpdate {
return configPath, nil
}

log.Debug("downloading config from ", url, " to ", configPath)

ctx, cncl := context.WithTimeout(context.Background(), time.Second*time.Duration(5))
defer cncl()

request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
log.Error(err)
return "", err
}

response, err := net.HTTPClient.Do(request)
if err != nil {
log.Error(err)
return "", err
}

defer response.Body.Close()

if response.StatusCode != http.StatusOK {
err := fmt.Errorf("unexpected status code: %d", response.StatusCode)
log.Error(err)
return "", err
}

if len(configPath) == 0 {
configPath = formatConfigPath(url, response.Header.Get("Etag"), cachePath)
log.Debug("config path not set yet, using ", configPath)
}

out, err := os.Create(configPath)
if err != nil {
log.Error(err)
return "", err
}

defer out.Close()

_, err = io.Copy(out, response.Body)
if err != nil {
log.Error(err)
return "", err
}

log.Debug("config updated to ", configPath)

return configPath, nil
}

func shouldUpdate(cachePath, url string) (string, bool) {
defer log.Trace(time.Now(), cachePath, url)

ctx, cncl := context.WithTimeout(context.Background(), time.Second*time.Duration(5))
defer cncl()

request, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
log.Error(err)
return "", true
}

response, err := net.HTTPClient.Do(request)
if err != nil {
log.Error(err)
return "", true
}

defer response.Body.Close()

etag := response.Header.Get("Etag")
if len(etag) == 0 {
log.Debug("no etag found, updating config")
return "", true
}

configPath := formatConfigPath(url, etag, cachePath)

_, err = os.Stat(configPath)
if err != nil {
log.Debug("configfile ", configPath, " doest not exist, updating config")
return configPath, true
}

log.Debug("config found at", configPath, " skipping update")
return configPath, false
}

func formatConfigPath(url, etag, cachePath string) string {
ext := filepath.Ext(url)
etag = strings.TrimLeft(etag, `W/`)
etag = strings.Trim(etag, `"`)
filename := fmt.Sprintf("config.%s.omp%s", etag, ext)
return filepath.Join(cachePath, filename)
}
5 changes: 3 additions & 2 deletions src/platform/httpclient.go → src/platform/net/http.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package platform
package net

import (
"net"
Expand Down Expand Up @@ -31,5 +31,6 @@ var (
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
}
Client httpClient = &http.Client{Transport: defaultTransport}

HTTPClient httpClient = &http.Client{Transport: defaultTransport}
)
59 changes: 25 additions & 34 deletions src/platform/shell.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package platform

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand All @@ -22,6 +20,8 @@ import (
"github.com/jandedobbeleer/oh-my-posh/src/log"
"github.com/jandedobbeleer/oh-my-posh/src/platform/battery"
"github.com/jandedobbeleer/oh-my-posh/src/platform/cmd"
"github.com/jandedobbeleer/oh-my-posh/src/platform/config"
"github.com/jandedobbeleer/oh-my-posh/src/platform/net"
"github.com/jandedobbeleer/oh-my-posh/src/regex"

disk "github.com/shirou/gopsutil/v3/disk"
Expand Down Expand Up @@ -239,65 +239,46 @@ func (env *Shell) Init() {

func (env *Shell) resolveConfigPath() {
defer env.Trace(time.Now())

if len(env.CmdFlags.Config) == 0 {
env.CmdFlags.Config = env.Getenv("POSH_THEME")
}

if len(env.CmdFlags.Config) == 0 {
env.Debug("No config set, fallback to default config")
return
}

if strings.HasPrefix(env.CmdFlags.Config, "https://") {
if err := env.downloadConfig(env.CmdFlags.Config); err != nil {
// make it use default config when download fails
filePath, err := config.Download(env.CachePath(), env.CmdFlags.Config)
if err != nil {
env.Error(err)
env.CmdFlags.Config = ""
return
}

env.CmdFlags.Config = filePath
return
}

// Cygwin path always needs the full path as we're on Windows but not really.
// Doing filepath actions will convert it to a Windows path and break the init script.
if env.Platform() == WINDOWS && env.Shell() == "bash" {
env.Debug("Cygwin detected, using full path for config")
return
}

configFile := env.CmdFlags.Config
if strings.HasPrefix(configFile, "~") {
configFile = strings.TrimPrefix(configFile, "~")
configFile = filepath.Join(env.Home(), configFile)
}

if !filepath.IsAbs(configFile) {
configFile = filepath.Join(env.Pwd(), configFile)
}
env.CmdFlags.Config = filepath.Clean(configFile)
}

func (env *Shell) downloadConfig(location string) error {
defer env.Trace(time.Now(), location)
ext := filepath.Ext(location)
fileHash := base64.StdEncoding.EncodeToString([]byte(location))
filename := fmt.Sprintf("config.%s.omp%s", fileHash, ext)
configPath := filepath.Join(env.CachePath(), filename)
cfg, err := env.HTTPRequest(location, nil, 5000)
if err != nil {
if _, osErr := os.Stat(configPath); !os.IsNotExist(osErr) {
// use the already cached config
env.CmdFlags.Config = configPath
return nil
}

return err
}
out, err := os.Create(configPath)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, bytes.NewReader(cfg))
if err != nil {
return err
}
env.CmdFlags.Config = configPath
return nil
env.CmdFlags.Config = filepath.Clean(configFile)
}

func (env *Shell) Trace(start time.Time, args ...string) {
Expand Down Expand Up @@ -639,38 +620,48 @@ func (env *Shell) unWrapError(err error) error {

func (env *Shell) HTTPRequest(targetURL string, body io.Reader, timeout int, requestModifiers ...HTTPRequestModifier) ([]byte, error) {
defer env.Trace(time.Now(), targetURL)

ctx, cncl := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(timeout))
defer cncl()

request, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, body)
if err != nil {
return nil, err
}

for _, modifier := range requestModifiers {
modifier(request)
}

if env.CmdFlags.Debug {
dump, _ := httputil.DumpRequestOut(request, true)
env.Debug(string(dump))
}
response, err := Client.Do(request)

response, err := net.HTTPClient.Do(request)
if err != nil {
env.Error(err)
return nil, env.unWrapError(err)
}

// anything inside the range [200, 299] is considered a success
if response.StatusCode < 200 || response.StatusCode >= 300 {
message := "HTTP status code " + strconv.Itoa(response.StatusCode)
err := errors.New(message)
env.Error(err)
return nil, err
}

defer response.Body.Close()

responseBody, err := io.ReadAll(response.Body)
if err != nil {
env.Error(err)
return nil, err
}

env.Debug(string(responseBody))

return responseBody, nil
}

Expand Down
Loading

0 comments on commit 0449aa8

Please sign in to comment.