diff --git a/internal/leetcode/base.go b/internal/leetcode/base.go index 05886ed10..1f5e6d9bb 100644 --- a/internal/leetcode/base.go +++ b/internal/leetcode/base.go @@ -3,6 +3,8 @@ package leetcode import ( "bytes" "encoding/json" + "errors" + "fmt" "io/ioutil" "net/http" "net/http/cookiejar" @@ -12,18 +14,24 @@ import ( "path" ) -func checkErr(err error) { - if err != nil { - panic(err) +var err error + +func init() { + http.DefaultClient.Jar, err = cookiejar.New(nil) + checkErr(err) + http.DefaultClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + req.Header.Set("Referer", req.URL.String()) + fmt.Println(req.URL.String()) + if len(via) >= 3 { + return errors.New("stopped after 3 redirects") + } + return nil } } -func client() *http.Client { - jar, err := cookiejar.New(nil) - // jar.SetCookies() - checkErr(err) - return &http.Client{ - Jar: jar, +func checkErr(err error) { + if err != nil { + panic(err) } } @@ -36,29 +44,45 @@ func getCsrfToken(cookies []*http.Cookie) string { return "" } -func getPath(f string) string { - dir := os.TempDir() +func getCachePath(f string) string { + dir, err := os.UserCacheDir() + checkErr(err) u, err := user.Current() if err == nil && u.HomeDir != "" { dir = path.Join(u.HomeDir, ".leetcode") } - err = os.MkdirAll(dir, 0755) - checkErr(err) return path.Join(dir, f) } -func saveCookies(cookies []*http.Cookie) { - data, err := json.Marshal(cookies) +func getFilePath(filename string) string { + if dir := path.Dir(filename); dir != "" { + if err := os.MkdirAll(dir, 0755); err != nil { + checkErr(err) + } + } + return filename +} + +func filePutContents(filename string, data []byte) { + filename = getFilePath(filename) + err = ioutil.WriteFile(filename, data, 0644) + checkErr(err) +} + +func jsonEncode(v interface{}) []byte { + data, err := json.Marshal(v) checkErr(err) dst := bytes.Buffer{} err = json.Indent(&dst, data, "", "\t") checkErr(err) - err = ioutil.WriteFile(getPath(cookiesFile), dst.Bytes(), 0755) - checkErr(err) + return dst.Bytes() +} +func saveCookies(cookies []*http.Cookie) { + filePutContents(getCachePath(cookiesFile), jsonEncode(cookies)) } func getCookies() (cookies []*http.Cookie) { - b, err := ioutil.ReadFile(getPath(cookiesFile)) + b, err := ioutil.ReadFile(getCachePath(cookiesFile)) checkErr(err) err = json.Unmarshal(b, &cookies) checkErr(err) @@ -66,7 +90,13 @@ func getCookies() (cookies []*http.Cookie) { } func saveCredential(data url.Values) { - u := url.UserPassword(data.Get("login"), data.Get("password")) - err := ioutil.WriteFile(getPath(credentialsFile), []byte(u.String()), 0755) + filePutContents(getCachePath(credentialsFile), jsonEncode(data)) +} + +func getCredential() (data url.Values) { + b, err := ioutil.ReadFile(getCachePath(credentialsFile)) + checkErr(err) + err = json.Unmarshal(b, &data) checkErr(err) + return } diff --git a/internal/leetcode/config.go b/internal/leetcode/config.go index 0e20f0a91..69703c68d 100644 --- a/internal/leetcode/config.go +++ b/internal/leetcode/config.go @@ -11,6 +11,6 @@ const ( const cookiesFile = "cookies.json" -const credentialsFile = "credentials" +const credentialsFile = "credentials.json" const problemsAllFile = "problems_all.json" diff --git a/internal/leetcode/login.go b/internal/leetcode/login.go index f1286407e..451f40da9 100644 --- a/internal/leetcode/login.go +++ b/internal/leetcode/login.go @@ -1,36 +1,41 @@ package leetcode import ( + "fmt" "net/http" "net/url" - "strings" ) func AccountsLogin(username, password string) (*http.Response, error) { resp, err := http.Head(AccountsLoginUrl) checkErr(err) defer resp.Body.Close() - cookies := resp.Cookies() - saveCookies(cookies) - csrftoken := getCsrfToken(cookies) + saveCookies(resp.Cookies()) + csrftoken := getCsrfToken(resp.Cookies()) data := url.Values{ "login": {username}, "password": {password}, "csrfmiddlewaretoken": {csrftoken}, } - req, err := http.NewRequest("POST", AccountsLoginUrl, strings.NewReader(data.Encode())) - checkErr(err) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Referer", AccountsLoginUrl) - for _, cookie := range cookies { - req.AddCookie(cookie) - } - resp, err = http.DefaultClient.Do(req) + http.PostForm(AccountsLoginUrl, data) checkErr(err) defer resp.Body.Close() saveCookies(resp.Cookies()) if resp.StatusCode == 200 { saveCredential(data) + } else { + fmt.Println("login error: ", resp.Status) } return resp, err } + +func AutoLogin() (*http.Response, error) { + data := getCredential() + if data.Get("login") == "" { + fmt.Println("can't get username") + } + if data.Get("password") == "" { + fmt.Println("can't get password") + } + return AccountsLogin(data.Get("login"), data.Get("password")) +} diff --git a/internal/leetcode/problem.go b/internal/leetcode/problem.go index d83677d3e..4e87727df 100644 --- a/internal/leetcode/problem.go +++ b/internal/leetcode/problem.go @@ -70,8 +70,7 @@ func ProblemsAll() (pa ProblemsAllType) { dst := bytes.Buffer{} err = json.Indent(&dst, body, "", "\t") checkErr(err) - err = ioutil.WriteFile(getPath(problemsAllFile), dst.Bytes(), 0755) - checkErr(err) + filePutContents(getCachePath(problemsAllFile), dst.Bytes()) err = json.Unmarshal(body, &pa) checkErr(err) return