diff --git a/src/handlers.go b/src/handlers.go index 42b10fc5..382578ba 100644 --- a/src/handlers.go +++ b/src/handlers.go @@ -422,12 +422,20 @@ func RemoveSave(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json;charset=UTF-8") vars := mux.Vars(r) - saveName := vars["save"] + name := vars["save"] + + save, err := findSave(name) + if err != nil { + resp.Data = fmt.Sprintf("Error removing save: %s", err) + if err := json.NewEncoder(w).Encode(resp); err != nil { + log.Printf("Error removing save %s", err) + } + } - err = rmSave(saveName) + err = save.remove() if err == nil { // save was removed - resp.Data = fmt.Sprintf("Removed save: %s", saveName) + resp.Data = fmt.Sprintf("Removed save: %s", save.Name) resp.Success = true if err := json.NewEncoder(w).Encode(resp); err != nil { log.Printf("Error removing save %s", err) diff --git a/src/saves.go b/src/saves.go index 9c0cded3..84e27f63 100644 --- a/src/saves.go +++ b/src/saves.go @@ -3,10 +3,8 @@ package main import ( "errors" "fmt" - "io/ioutil" - "log" "os" - "strings" + "path/filepath" "time" ) @@ -17,56 +15,41 @@ type Save struct { } func (s Save) String() string { - return fmt.Sprintf("%s", s.Name) + return s.Name } // Lists save files in factorio/saves -func listSaves(saveDir string) ([]Save, error) { - result := []Save{} - - files, err := ioutil.ReadDir(saveDir) - if err != nil { - log.Printf("Error listing save directory: %s", err) - return result, err - } - - for _, f := range files { - save := Save{f.Name(), f.ModTime(), f.Size()} - result = append(result, save) - } - - return result, nil +func listSaves(saveDir string) (saves []Save, err error) { + err = filepath.Walk(saveDir, func(path string, info os.FileInfo, err error) error { + saves = append(saves, Save{ + info.Name(), + info.ModTime(), + info.Size(), + }) + return nil + }) + return } -func rmSave(saveName string) error { - removed := false - if saveName == "" { - return errors.New("No save name provided") - } - +func findSave(name string) (*Save, error) { saves, err := listSaves(config.FactorioSavesDir) if err != nil { - log.Printf("Error in remove save: %s", err) - return err + return nil, fmt.Errorf("error listing saves: %v", err) } for _, save := range saves { - log.Printf("Checking if %s in %s", save, saveName) - if strings.Contains(save.Name, saveName) { - err := os.Remove(config.FactorioSavesDir + "/" + save.Name) - if err != nil { - log.Printf("Error removing save %s: %s", saveName, err) - return err - } - log.Printf("Deleted save: %s", save) - removed = true + if save.Name == name { + return &save, nil } } - if !removed { - log.Printf("Did not remove save: %s", saveName) - return errors.New(fmt.Sprintf("Did not remove save: %s", saveName)) + return nil, errors.New("save not found") +} + +func (s *Save) remove() error { + if s.Name == "" { + return errors.New("save name cannot be blank") } - return nil + return os.Remove(filepath.Join(config.FactorioSavesDir, s.Name)) }