-
Notifications
You must be signed in to change notification settings - Fork 0
/
command.go
165 lines (138 loc) · 4.67 KB
/
command.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
package update
import (
"context"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/google/go-github/v37/github"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
)
func defaultOptions() *optionCtx {
return &optionCtx{
logger: log.New(ioutil.Discard, "", 0),
debugLogger: log.New(ioutil.Discard, "DEBUG: ", log.Lmsgprefix),
errorLogger: log.New(ioutil.Discard, "ERROR: ", log.Lmsgprefix),
assetIsCompatibleFunc: func(asset *github.ReleaseAsset) bool {
return strings.Contains(asset.GetName(), runtime.GOOS)
},
githubTokenEnvironmentVariableName: "GITHUB_TOKEN",
}
}
func Command(owner, repo string, options ...Option) *cobra.Command {
oc := defaultOptions()
for _, o := range options {
o.apply(oc)
}
cmd := &cobra.Command{
Use: "update",
Short: "Download the latest release from GitHub",
Long: fmt.Sprintf(`Download the latest release from GitHub and install it in-place.
If the %[1]s environment variable is set, it will be used for any GitHub API requests.
%[1]s is required for private repositories.`, oc.githubTokenEnvironmentVariableName),
Run: func(cmd *cobra.Command, args []string) {
if err := update(cmd, owner, repo, options); err != nil {
fmt.Fprintf(cmd.OutOrStderr(), "Error: %v", err)
}
},
}
cmd.Flags().Bool("debug", false, "show debug logs")
cmd.Flags().Bool("force", false, "force a re-download even if already up-to-date")
return cmd
}
func update(cmd *cobra.Command, owner, repo string, options []Option) (updateErr error) {
ctx := context.Background()
oc := defaultOptions()
oc.logger.SetOutput(cmd.OutOrStdout())
oc.errorLogger.SetOutput(cmd.OutOrStderr())
for _, o := range options {
o.apply(oc)
}
if debug, _ := cmd.Flags().GetBool("debug"); debug {
oc.debugLogger.SetOutput(cmd.OutOrStdout())
}
force, _ := cmd.Flags().GetBool("force")
logger := oc.logger
debugLogger := oc.debugLogger
errorLogger := oc.errorLogger
isCompatible := oc.assetIsCompatibleFunc
var tc *http.Client
token := os.Getenv(oc.githubTokenEnvironmentVariableName)
if token != "" {
ts := oauth2.StaticTokenSource(
&oauth2.Token{AccessToken: token},
)
tc = oauth2.NewClient(ctx, ts)
}
client := github.NewClient(tc)
release, resp, err := client.Repositories.GetLatestRelease(ctx, owner, repo)
if err != nil {
if resp.StatusCode == 404 {
return fmt.Errorf("getting latest release (if this is a private repo, you need to set the %s environment variable): %v", oc.githubTokenEnvironmentVariableName, err)
}
return fmt.Errorf("getting latest release: %v", err)
}
currentVersion := cmd.Root().Version
debugLogger.Printf("Found release %s, current version is %s", release.GetName(), currentVersion)
if release.GetName() == currentVersion {
logger.Printf("Already up to date\n")
if !force {
return nil
}
}
logger.Printf("Updating to %s\n", release.GetName())
var assetID int64
for _, asset := range release.Assets {
debugLogger.Printf("Asset name: %s, id: %d, download at: %s", asset.GetName(), asset.GetID(), asset.GetBrowserDownloadURL())
if isCompatible(asset) {
debugLogger.Printf("Will download asset %d", asset.GetID())
assetID = asset.GetID()
break
}
}
if assetID == 0 {
return fmt.Errorf("could not find a suitable release to download, use --debug flag for details")
}
assetReader, _, err := client.Repositories.DownloadReleaseAsset(ctx, owner, repo, assetID, http.DefaultClient)
if err != nil {
return fmt.Errorf("download release asset: %v", err)
}
defer assetReader.Close()
outPath, err := os.Executable()
if err != nil {
return fmt.Errorf("get executable path: %v", err)
}
debugLogger.Printf("Got asset response, will write to %s", outPath)
tmpDir, err := os.MkdirTemp("", repo+"-bak-")
if err != nil {
return fmt.Errorf("make backup dir: %v", err)
}
defer os.RemoveAll(tmpDir)
bakFile := filepath.Join(tmpDir, filepath.Base(outPath))
debugLogger.Printf("Creating backup file: %s", bakFile)
if err := os.Rename(outPath, bakFile); err != nil {
return fmt.Errorf("rename old executable: %v", err)
}
defer func() {
if updateErr != nil {
if err := os.Rename(outPath, bakFile); err != nil {
errorLogger.Printf("failed to restore backup file after installation error: %v", err)
}
}
}()
f, err := os.OpenFile(outPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0766)
if err != nil {
return fmt.Errorf("create new executable: %v", err)
}
defer f.Close()
if _, err := io.Copy(f, assetReader); err != nil {
return fmt.Errorf("write new executable: %v", err)
}
return nil
}