Skip to content

Commit

Permalink
Add config
Browse files Browse the repository at this point in the history
  • Loading branch information
178inaba committed Aug 13, 2017
1 parent 01e3d22 commit 4969344
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 56 deletions.
115 changes: 75 additions & 40 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,15 @@ import (
"github.com/google/go-github/github"
tty "github.com/mattn/go-tty"
homedir "github.com/mitchellh/go-homedir"
toml "github.com/pelletier/go-toml"
uuid "github.com/satori/go.uuid"
)

const defaultTokenFilePath = "gistup/token"
const (
configDirName = "gistup"
configFileName = "config.toml"
tokenFileName = "token"
)

var (
isAnonymous = flag.Bool("a", false, "Create anonymous gist")
Expand All @@ -42,6 +47,12 @@ var (
}
)

type config struct {
APIRawurl string `toml:"url"`
APIURL *url.URL `toml:"-"`
IsInsecure bool `toml:"insecure"`
}

func main() {
log.SetFlags(0)
log.SetPrefix(fmt.Sprintf("%s: ", os.Args[0]))
Expand Down Expand Up @@ -74,14 +85,14 @@ func run() int {
cancel()
}()

tokenFilePath, err := getTokenFilePath()
confDirPath, err := getConfigDir()
if err != nil {
log.Print(err)
return 1
}

reAuth:
c, err := getClientWithToken(ctx, tokenFilePath)
c, err := getClientWithToken(ctx, confDirPath)
if err != nil {
log.Print(err)
return 1
Expand All @@ -93,7 +104,7 @@ reAuth:
if errResp, ok := err.(*github.ErrorResponse); ok &&
errResp.Response.StatusCode == http.StatusUnauthorized {
// Remove bad token file.
if err := os.Remove(tokenFilePath); err != nil {
if err := os.RemoveAll(confDirPath); err != nil {
log.Print(err)
return 1
}
Expand All @@ -113,69 +124,97 @@ reAuth:
return 0
}

func getTokenFilePath() (string, error) {
func getConfigDir() (string, error) {
if runtime.GOOS == "windows" {
return filepath.Join(os.Getenv("APPDATA"), defaultTokenFilePath), nil
return filepath.Join(os.Getenv("APPDATA"), configDirName), nil
}
home, err := homedir.Dir()
if err != nil {
return "", err
}
return filepath.Join(home, ".config", defaultTokenFilePath), nil
return filepath.Join(home, ".config", configDirName), nil
}

func getClientWithToken(ctx context.Context, tokenFilePath string) (*github.Client, error) {
var apiURL *url.URL
if *apiRawurl != "" {
var err error
apiURL, err = url.Parse(*apiRawurl)
if err != nil {
return nil, err
}
func getClientWithToken(ctx context.Context, confDirPath string) (*github.Client, error) {
// Read config.
conf, err := getConfig(confDirPath)
if err != nil {
return nil, err
}

if *isAnonymous {
c := github.NewClient(nil)
if apiURL != nil {
c.BaseURL = apiURL
if conf.APIURL != nil {
c.BaseURL = conf.APIURL
}
return c, nil
}

token, err := readFile(tokenFilePath)
tokenFilePath := filepath.Join(confDirPath, tokenFileName)
bs, err := readFile(tokenFilePath)
token := string(bs)
if err != nil {
token, err = getToken(ctx, apiURL, tokenFilePath)
token, err = getToken(ctx, conf, tokenFilePath)
if err != nil {
return nil, err
}
}

if *isInsecure {
if conf.IsInsecure {
tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
ctx = context.WithValue(ctx, oauth2.HTTPClient, &http.Client{Transport: tr})
}
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
c := github.NewClient(oauth2.NewClient(ctx, ts))
if apiURL != nil {
c.BaseURL = apiURL
if conf.APIURL != nil {
c.BaseURL = conf.APIURL
}
return c, nil
}

func getToken(ctx context.Context, apiURL *url.URL, tokenFilePath string) (string, error) {
func getConfig(confDirPath string) (*config, error) {
confFilePath := filepath.Join(confDirPath, configFileName)
var conf config
bs, err := readFile(confFilePath)
if err == nil {
if err := toml.Unmarshal(bs, &conf); err != nil {
return nil, err
}
}

conf.IsInsecure = *isInsecure
if *apiRawurl != "" {
conf.APIRawurl = *apiRawurl
conf.APIURL, err = url.Parse(conf.APIRawurl)
if err != nil {
return nil, err
}

bs, err := toml.Marshal(conf)
if err != nil {
return nil, err
}
if err := save(string(bs), confFilePath); err != nil {
return nil, err
}
}
return &conf, nil
}

func getToken(ctx context.Context, conf *config, tokenFilePath string) (string, error) {
username, password, err := prompt(ctx)
if err != nil {
return "", err
}

t := &github.BasicAuthTransport{Username: username, Password: password}
if *isInsecure {
if conf.IsInsecure {
t.Transport =
&http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
}
c := github.NewClient(t.Client())
if apiURL != nil {
c.BaseURL = apiURL
if conf.APIURL != nil {
c.BaseURL = conf.APIURL
}
a, _, err := c.Authorizations.Create(ctx, &github.AuthorizationRequest{
Scopes: []github.Scope{"gist"},
Expand All @@ -187,7 +226,7 @@ func getToken(ctx context.Context, apiURL *url.URL, tokenFilePath string) (strin
}

token := a.GetToken()
if err := saveToken(token, tokenFilePath); err != nil {
if err := save(token, tokenFilePath); err != nil {
return "", err
}
return token, nil
Expand Down Expand Up @@ -237,15 +276,13 @@ func readString(ctx context.Context, hint string, readFunc func(t *tty.TTY) (str
return s, nil
}

func saveToken(token, configFilePath string) error {
if err := os.MkdirAll(filepath.Dir(configFilePath), 0700); err != nil {
func save(s, saveFilePath string) error {
if err := os.MkdirAll(filepath.Dir(saveFilePath), 0700); err != nil {
return err
}

if err := ioutil.WriteFile(configFilePath, []byte(token), 0600); err != nil {
if err := ioutil.WriteFile(saveFilePath, []byte(s), 0600); err != nil {
return err
}

return nil
}

Expand All @@ -264,13 +301,13 @@ func createGist(ctx context.Context, fileNames []string, stdinContent string, gi
fp = filepath.Join(wd, fileName)
}

content, err := readFile(fp)
bs, err := readFile(fp)
if err != nil {
return nil, err
}

files[github.GistFilename(filepath.Base(fileName))] =
github.GistFile{Content: github.String(content)}
github.GistFile{Content: github.String(string(bs))}
}
} else {
files[github.GistFilename(*stdinFileName)] =
Expand All @@ -289,19 +326,17 @@ func createGist(ctx context.Context, fileNames []string, stdinContent string, gi
return g, nil
}

func readFile(fp string) (string, error) {
func readFile(fp string) ([]byte, error) {
f, err := os.Open(fp)
if err != nil {
return "", err
return nil, err
}
defer f.Close()

bs, err := ioutil.ReadAll(f)
if err != nil {
return "", err
return nil, err
}

return string(bs), nil
return bs, nil
}

func openURL(rawurl string) error {
Expand Down
33 changes: 17 additions & 16 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ import (
uuid "github.com/satori/go.uuid"
)

func TestGetTokenFilePath(t *testing.T) {
fp, err := getTokenFilePath()
func TestGetConfigDir(t *testing.T) {
dir, err := getConfigDir()
if err != nil {
t.Fatalf("should not be fail: %v", err)
}
if !strings.Contains(fp, defaultTokenFilePath) {
if !strings.Contains(dir, configDirName) {
t.Fatalf("%q should be contained in output of config file path: %v",
defaultTokenFilePath, fp)
configDirName, dir)
}
}

Expand All @@ -44,12 +44,12 @@ func TestGetClientWithToken(t *testing.T) {

*isAnonymous = true
*apiRawurl = ts.URL
if _, err := getClientWithToken(context.Background(), ""); err != nil {
fp := filepath.Join(os.TempDir(), uuid.NewV4().String())
if _, err := getClientWithToken(context.Background(), fp); err != nil {
t.Fatalf("should not be fail: %v", err)
}

*isAnonymous = false
fp := filepath.Join(os.TempDir(), uuid.NewV4().String())
readUsername = func(t *tty.TTY) (string, error) { return "", io.EOF }
readPassword = func(t *tty.TTY) (string, error) { return "", nil }
if _, err := getClientWithToken(context.Background(), fp); err == nil {
Expand Down Expand Up @@ -86,10 +86,11 @@ func TestGetToken(t *testing.T) {
if err != nil {
t.Fatalf("should not be fail: %v", err)
}
conf := &config{APIRawurl: ts.URL, APIURL: apiURL}

readUsername = func(t *tty.TTY) (string, error) { return "", nil }
readPassword = func(t *tty.TTY) (string, error) { return "", nil }
if _, err := getToken(context.Background(), apiURL, ""); err == nil {
if _, err := getToken(context.Background(), conf, ""); err == nil {
t.Fatalf("should be fail: %v", err)
}

Expand All @@ -102,12 +103,12 @@ func TestGetToken(t *testing.T) {
t.Fatalf("should not be fail: %v", err)
}
}()
if _, err := getToken(context.Background(), apiURL, filepath.Join(fp, "foo")); err == nil {
if _, err := getToken(context.Background(), conf, filepath.Join(fp, "foo")); err == nil {
t.Fatalf("should be fail: %v", err)
}

*isInsecure = true
token, err := getToken(context.Background(), apiURL, fp)
conf.IsInsecure = true
token, err := getToken(context.Background(), conf, fp)
if err != nil {
t.Fatalf("should not be fail: %v", err)
}
Expand Down Expand Up @@ -152,7 +153,7 @@ func TestPrompt(t *testing.T) {
func TestSaveToken(t *testing.T) {
token := "foobar"
fp := filepath.Join(os.TempDir(), uuid.NewV4().String())
if err := saveToken(token, fp); err != nil {
if err := save(token, fp); err != nil {
t.Fatalf("should not be fail: %v", err)
}
defer func() {
Expand Down Expand Up @@ -185,7 +186,7 @@ func TestSaveToken(t *testing.T) {
t.Fatalf("want %q but %q", token, string(bs))
}

if err := saveToken("", filepath.Join(fp, "foo")); err == nil {
if err := save("", filepath.Join(fp, "foo")); err == nil {
t.Fatalf("should be fail: %v", err)
}

Expand All @@ -198,7 +199,7 @@ func TestSaveToken(t *testing.T) {
t.Fatalf("should not be fail: %v", err)
}
}()
if err := saveToken("", errFP); err == nil {
if err := save("", errFP); err == nil {
t.Fatalf("should be fail: %v", err)
}
}
Expand Down Expand Up @@ -270,12 +271,12 @@ func TestReadFile(t *testing.T) {
t.Fatalf("should not be fail: %v", err)
}
}()
content, err := readFile(fp)
bs, err := readFile(fp)
if err != nil {
t.Fatalf("should not be fail: %v", err)
}
if content != tc {
t.Fatalf("want %q but %q", tc, content)
if string(bs) != tc {
t.Fatalf("want %q but %q", tc, bs)
}

if _, err := readFile(""); err == nil {
Expand Down

0 comments on commit 4969344

Please sign in to comment.