-
Notifications
You must be signed in to change notification settings - Fork 1
/
helpers.go
440 lines (367 loc) · 12.3 KB
/
helpers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
package helpers
import (
"encoding/json"
"encoding/xml"
"errors"
"flag"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/golang-jwt/jwt"
"github.com/manifoldco/promptui"
"github.com/neicnordic/crypt4gh/keys"
log "github.com/sirupsen/logrus"
"github.com/vbauerster/mpb/v8"
"golang.org/x/exp/slices"
"gopkg.in/ini.v1"
)
//
// Helper functions used by more than one module
//
// FileExists checks if a file exists in the file system. Note that this
// function will not check if the file is readable, or if the file is a
// directory, only if it exists.
func FileExists(filename string) bool {
_, err := os.Stat(filename)
return err == nil
}
// FileIsReadable checks that a file exists, and is readable by the program.
func FileIsReadable(filename string) bool {
fileInfo, err := os.Stat(filename)
if err != nil || fileInfo.IsDir() {
return false
}
// Check readability by simply trying to open the file and read one byte
inFile, err := os.Open(filepath.Clean(filename))
if err != nil {
return false
}
defer func() {
if err := inFile.Close(); err != nil {
log.Errorf("Error closing file: %s\n", err)
}
}()
test := make([]byte, 1)
_, err = inFile.Read(test)
return err == nil
}
// FormatSubcommandUsage moves the lines in the standard usage strings around so
// that the usage string is indented under the help text instead of above it.
func FormatSubcommandUsage(usageString string) string {
// check that there's a formatting thing for os.Args[0]
if !strings.Contains(usageString, "%s") && !strings.Contains(usageString, "%v") {
return usageString
}
// format usage string with command name
usageString = fmt.Sprintf(usageString, os.Args[0])
// break string into lines
lines := strings.Split(strings.TrimSpace(usageString), "\n")
if len(lines) < 2 || !strings.HasPrefix(lines[0], "USAGE:") {
// if we don't have enough data, just return the usage string as is
return usageString
}
// reformat lines
usage := lines[0]
return fmt.Sprintf("\n%s\n\n %s\n\n", strings.Join(lines[2:], "\n"), usage)
}
// PromptPassword creates a user prompt for inputting passwords, where all
// characters are masked with "*"
func PromptPassword(message string) (password string, err error) {
prompt := promptui.Prompt{
Label: message,
Mask: '*',
}
return prompt.Run()
}
// ParseS3ErrorResponse checks if reader stream is xml encoded and if yes unmarshals
// the xml response and returns it.
func ParseS3ErrorResponse(respBody io.Reader) (string, error) {
respMsg, err := io.ReadAll(respBody)
if err != nil {
return "", fmt.Errorf("failed to read from response body, reason: %v", err)
}
if !strings.Contains(string(respMsg), `xml version`) {
return "", fmt.Errorf("cannot parse response body, reason: not xml")
}
xmlErrorResponse := XMLerrorResponse{}
err = xml.Unmarshal(respMsg, &xmlErrorResponse)
if err != nil {
return "", fmt.Errorf("failed to unmarshal xml response, reason: %v", err)
}
return fmt.Sprintf("%+v", xmlErrorResponse), nil
}
// Removes all positional arguments from args, and returns them.
// This function assumes that all flags have exactly one value.
func getPositional(args []string) ([]string, []string) {
argList := []string{"-r", "--r", "--force-overwrite", "-force-overwrite", "--force-unencrypted", "-force-unencrypted"}
i := 1
var positional []string
for i < len(args) {
switch {
case slices.Contains(argList, args[i]):
// if the current args is a boolean flag, skip it
i++
case args[i][0] == '-':
// if the current arg is a flag, skip the flag and its value
i += 2
default:
// if the current arg is positional, remove it and add it to
// `positional`
positional = append(positional, args[i])
args = append(args[:i], args[i+1:]...)
}
}
return positional, args
}
func ParseArgs(args []string, argFlags *flag.FlagSet) error {
var pos []string
pos, args = getPositional(args)
// append positional args back at the end of args
args = append(args, pos...)
err := argFlags.Parse(args[1:])
return err
}
//
// shared structs
//
// struct type to keep track of infiles and outfiles for encryption and
// decryption
type EncryptionFileSet struct {
Unencrypted string
Encrypted string
}
// struct type to unmarshall xml error response from s3 server
type XMLerrorResponse struct {
Code string `xml:"Code"`
Message string `xml:"Message"`
Resource string `xml:"Resource"`
}
// progress bar definitions
// Produces a progress bar with decorators that can produce different styles
// Check https://github.com/vbauerster/mpb for more info and how to use it
type CustomReader struct {
Fp *os.File
Size int64
Reads int64
Bar *mpb.Bar
SignMap map[int64]struct{}
Mux sync.Mutex
}
func (r *CustomReader) Read(p []byte) (int, error) {
return r.Fp.Read(p)
}
func (r *CustomReader) ReadAt(p []byte, off int64) (int, error) {
n, err := r.Fp.ReadAt(p, off)
if err != nil {
return n, err
}
r.Bar.SetTotal(r.Size, false)
r.Mux.Lock()
// Ignore the first signature call
if _, ok := r.SignMap[off]; ok {
r.Reads += int64(n)
r.Bar.SetCurrent(r.Reads)
} else {
r.SignMap[off] = struct{}{}
}
r.Mux.Unlock()
return n, err
}
func (r *CustomReader) Seek(offset int64, whence int) (int64, error) {
return r.Fp.Seek(offset, whence)
}
// Config struct for storing the s3cmd file values
type Config struct {
AccessKey string `ini:"access_key"`
SecretKey string `ini:"secret_key"`
AccessToken string `ini:"access_token"`
HostBucket string `ini:"host_bucket"`
HostBase string `ini:"host_base"`
MultipartChunkSizeMb int64 `ini:"multipart_chunk_size_mb"`
GuessMimeType bool `ini:"guess_mime_type"`
Encoding string `ini:"encoding"`
CheckSslCertificate bool `ini:"check_ssl_certificate"`
CheckSslHostname bool `ini:"check_ssl_hostname"`
UseHTTPS bool `ini:"use_https"`
SocketTimeout int `ini:"socket_timeout"`
HumanReadableSizes bool `ini:"human_readable_sizes"`
PublicKey string `ini:"public_key"`
}
// LoadConfigFile loads ini configuration file to the Config struct
func LoadConfigFile(path string) (*Config, error) {
config := &Config{}
cfg, err := ini.Load(path)
if err != nil {
return config, err
}
// ini sees a DEFAULT section by default
var iniSection string
if len(cfg.SectionStrings()) > 1 {
iniSection = cfg.SectionStrings()[1]
} else {
iniSection = cfg.SectionStrings()[0]
}
if err := cfg.Section(iniSection).MapTo(config); err != nil {
return nil, err
}
if config.AccessKey == "" || config.AccessToken == "" {
return nil, errors.New("failed to find credentials in configuration file")
}
if config.HostBase == "" {
return nil, errors.New("failed to find endpoint in configuration file")
}
if config.UseHTTPS {
config.HostBase = "https://" + config.HostBase
}
if config.Encoding == "" {
config.Encoding = "UTF-8"
}
// Where 15 is the default chunk size of the library
if config.MultipartChunkSizeMb <= 15 {
config.MultipartChunkSizeMb = 15
}
return config, nil
}
// GetAuth calls LoadConfig if we have a config file, otherwise try to load .sda-cli-session
func GetAuth(path string) (*Config, error) {
if path != "" {
return LoadConfigFile(path)
}
if FileExists(".sda-cli-session") {
return LoadConfigFile(".sda-cli-session")
}
return nil, errors.New("failed to read the configuration file")
}
// reads the .sda-cli-session file, creates the public key file and returns the name of the file
func GetPublicKeyFromSession() (string, error) {
// Check if the ".sda-cli-session" file exists
if !FileExists(".sda-cli-session") {
return "", errors.New("configuration file (.sda-cli-session) not found")
}
_, err := os.Open(".sda-cli-session")
if err != nil {
return "", err
}
// Load the configuration file
config, err := LoadConfigFile(".sda-cli-session")
if err != nil {
return "", fmt.Errorf("failed to load configuration file: %w", err)
}
// Check if the PublicKey field is present in the config
if config.PublicKey == "" {
return "", errors.New("public key not found in the configuration")
}
pubFile, err := CreatePubFile(config.PublicKey, "key-from-oidc.pub.pem")
if err != nil {
return "", err
}
return pubFile, nil
}
// Create public key file
func CreatePubFile(publicKey string, filename string) (string, error) {
// Create a fixed-size array to hold the public key data
var publicKeyData [32]byte
b := []byte(publicKey)
copy(publicKeyData[:], b)
// Open or create a file in write-only mode with file permissions 0600
pubFile, err := os.OpenFile(filepath.Clean(filename), os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
return "", fmt.Errorf("failed to open or create the public key file: %w", err)
}
defer func() {
// Close the file and log any error that may occur
if cerr := pubFile.Close(); cerr != nil {
log.Errorf("Error closing file: %s\n", cerr)
}
}()
// Write the publicKeyData array to the "key-from-oidc.pub.pem" file in Crypt4GHX25519 public key format
err = keys.WriteCrypt4GHX25519PublicKey(pubFile, publicKeyData)
if err != nil {
return "", fmt.Errorf("failed to write the public key data: %w", err)
}
// If everything is successful, return the name of the generated public key file
return filename, nil
}
// CheckTokenExpiration is used to determine whether the token is expiring in less than a day
func CheckTokenExpiration(accessToken string) error {
// Parse jwt token with unverifies, since we don't need to check the signatures here
token, _, err := new(jwt.Parser).ParseUnverified(accessToken, jwt.MapClaims{})
if err != nil {
return fmt.Errorf("could not parse token, reason: %s", err)
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return fmt.Errorf("broken token (claims are empty): %v\nerror: %s", claims, err)
}
// Check if the token has exp claim
if claims["exp"] == nil {
return fmt.Errorf("could not parse token, reason: no expiration date")
}
// Parse the expiration date from token, handle cases where
// the date format is nonstandard, e.g. test tokens are used
var expiration time.Time
switch iat := claims["exp"].(type) {
case float64:
expiration = time.Unix(int64(iat), 0)
case json.Number:
tmp, _ := iat.Int64()
expiration = time.Unix(tmp, 0)
case string:
i, err := strconv.ParseInt(iat, 10, 64)
if err != nil {
return fmt.Errorf("could not parse token, reason: %s", err)
}
expiration = time.Unix(int64(i), 0)
default:
return fmt.Errorf("could not parse token, reason: unknown expiration date format")
}
switch untilExp := time.Until(expiration); {
case untilExp < 0:
return fmt.Errorf("the provided access token has expired, please renew it")
case untilExp > 0 && untilExp < 24*time.Hour:
fmt.Fprintln(os.Stderr, "The provided access token expires in", time.Until(expiration).Truncate(time.Second))
fmt.Fprintln(os.Stderr, "Consider renewing the token.")
default:
fmt.Fprintln(os.Stderr, "The provided access token expires in", time.Until(expiration).Truncate(time.Second))
}
return nil
}
func ListFiles(config Config, prefix string) (result *s3.ListObjectsV2Output, err error) {
sess := session.Must(session.NewSession(&aws.Config{
// The region for the backend is always the specified one
// and not present in the configuration from auth - hardcoded
Region: aws.String("us-west-2"),
Credentials: credentials.NewStaticCredentials(config.AccessKey, config.SecretKey, config.AccessToken),
Endpoint: aws.String(config.HostBase),
DisableSSL: aws.Bool(!config.UseHTTPS),
S3ForcePathStyle: aws.Bool(true),
}))
svc := s3.New(sess)
result, err = svc.ListObjectsV2(&s3.ListObjectsV2Input{
Bucket: aws.String(config.AccessKey + "/"),
Prefix: aws.String(config.AccessKey + "/" + prefix),
})
if err != nil {
return nil, fmt.Errorf("failed to list objects, reason: %v", err)
}
return result, nil
}
// Check for invalid characters
func CheckValidChars(filename string) error {
re := regexp.MustCompile(`[\\:\*\?"<>\|\x00-\x1F\x7F]`)
dissallowedChars := re.FindAllString(filename, -1)
if dissallowedChars != nil {
return fmt.Errorf("filepath %v contains disallowed characters: %+v", filename, strings.Join(dissallowedChars, ", "))
}
return nil
}