diff --git a/README.md b/README.md index 1922198..65d8a0b 100644 --- a/README.md +++ b/README.md @@ -19,9 +19,15 @@ go install github.com/happyhackingspace/dit/cmd/dit@latest ```go import "github.com/happyhackingspace/dit" -// Load classifier (finds model.json automatically) +// Load classifier. On first call, if no model.json is found in the current +// directory (walked up to the nearest go.mod) or in ~/.dit/, the pretrained +// model is downloaded from Hugging Face to ~/.dit/model.json (~93MB, one-time) +// and reused on subsequent calls. c, _ := dit.New() +// Or load an explicit file (no network, no search). +c, _ := dit.Load("path/to/model.json") + // Classify page type page, _ := c.ExtractPageType(htmlString) fmt.Println(page.Type) // "login" diff --git a/dit.go b/dit.go index 05ad6a4..a9331d9 100644 --- a/dit.go +++ b/dit.go @@ -15,15 +15,26 @@ package dit import ( + "context" "fmt" + "io" + "log/slog" + "net/http" "os" "path/filepath" + "time" "github.com/happyhackingspace/dit/captcha" "github.com/happyhackingspace/dit/classifier" "github.com/happyhackingspace/dit/internal/htmlutil" ) +// downloadTimeout bounds the total time spent fetching the model. +const downloadTimeout = 1 * time.Minute + +// ModelURL is the canonical download location for the pretrained model. +const ModelURL = "https://huggingface.co/datasets/happyhackingspace/dit/resolve/main/model.json" + // Classifier wraps the form and field type classification models. type Classifier struct { fc *classifier.FormFieldClassifier @@ -59,12 +70,62 @@ type PageResultProba struct { // New loads the classifier from "model.json", searching the current directory // and parent directories up to the module root, then ~/.dit/model.json. +// If no model is found locally, it is downloaded from ModelURL to +// ~/.dit/model.json and loaded from there. The download is a one-time cost +// per machine; subsequent calls reuse the cached file. func New() (*Classifier, error) { - path, err := FindModel("model.json") - if err != nil { + if path, err := FindModel("model.json"); err == nil { + return Load(path) + } + + dest := filepath.Join(ModelDir(), "model.json") + slog.Info("Model not found, downloading", "url", ModelURL, "dest", dest) + if err := Download(dest); err != nil { return nil, fmt.Errorf("dit: %w", err) } - return Load(path) + return Load(dest) +} + +// Download fetches the pretrained model from ModelURL and writes it to dest, +// creating parent directories as needed. A partial file is removed on error. +func Download(dest string) error { + if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil { + return fmt.Errorf("create model dir: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), downloadTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ModelURL, nil) + if err != nil { + return fmt.Errorf("download model: %w", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("download model: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download model: HTTP %d", resp.StatusCode) + } + + f, err := os.Create(dest) + if err != nil { + return fmt.Errorf("create model file: %w", err) + } + + written, err := io.Copy(f, resp.Body) + if err != nil { + _ = f.Close() + _ = os.Remove(dest) + return fmt.Errorf("download model: %w", err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("close model file: %w", err) + } + + slog.Info("Model downloaded", "size", fmt.Sprintf("%.1fMB", float64(written)/1024/1024)) + return nil } // ModelDir returns the default model storage directory (~/.dit). diff --git a/internal/cli/data.go b/internal/cli/data.go index 4306653..acf6168 100644 --- a/internal/cli/data.go +++ b/internal/cli/data.go @@ -12,6 +12,7 @@ import ( "path/filepath" "strings" + "github.com/happyhackingspace/dit" "github.com/spf13/cobra" ) @@ -115,8 +116,8 @@ func dataDownload(dataFolder string) error { } slog.Info("Training data extracted", "files", count, "folder", dataFolder) - slog.Info("Downloading model", "url", modelURL) - modelResp, err := http.Get(modelURL) + slog.Info("Downloading model", "url", dit.ModelURL) + modelResp, err := http.Get(dit.ModelURL) if err != nil { return fmt.Errorf("download model: %w", err) } diff --git a/internal/cli/run.go b/internal/cli/run.go index 912cdd0..a2590c4 100644 --- a/internal/cli/run.go +++ b/internal/cli/run.go @@ -8,7 +8,6 @@ import ( "log/slog" "net/http" "os" - "path/filepath" "strings" "time" @@ -17,8 +16,6 @@ import ( "github.com/spf13/cobra" ) -const modelURL = "https://huggingface.co/datasets/happyhackingspace/dit/resolve/main/model.json" - func (c *CLI) newRunCommand() *cobra.Command { var modelPath string var threshold float64 @@ -93,7 +90,7 @@ func (c *CLI) newRunCommand() *cobra.Command { slog.Debug("HTML fetched", "target", target, "bytes", len(htmlContent)) start := time.Now() - cl, err := loadOrDownloadModel(modelPath) + cl, err := loadModel(modelPath) if err != nil { return err } @@ -159,49 +156,12 @@ func isStdinTerminal() bool { return fi.Mode()&os.ModeCharDevice != 0 } -func loadOrDownloadModel(modelPath string) (*dit.Classifier, error) { +func loadModel(modelPath string) (*dit.Classifier, error) { if modelPath != "" { slog.Debug("Loading custom model", "path", modelPath) return dit.Load(modelPath) } - - cl, err := dit.New() - if err == nil { - return cl, nil - } - - dest := filepath.Join(dit.ModelDir(), "model.json") - slog.Info("Model not found, downloading", "url", modelURL, "dest", dest) - - if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil { - return nil, fmt.Errorf("create model dir: %w", err) - } - - resp, err := http.Get(modelURL) - if err != nil { - return nil, fmt.Errorf("download model: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("download model: HTTP %d", resp.StatusCode) - } - - f, err := os.Create(dest) - if err != nil { - return nil, fmt.Errorf("create model file: %w", err) - } - - written, err := io.Copy(f, resp.Body) - if err != nil { - _ = f.Close() - _ = os.Remove(dest) - return nil, fmt.Errorf("download model: %w", err) - } - _ = f.Close() - - slog.Info("Model downloaded", "size", fmt.Sprintf("%.1fMB", float64(written)/1024/1024)) - return dit.Load(dest) + return dit.New() } type fetchOptions struct { diff --git a/internal/cli/up.go b/internal/cli/up.go index 6494f44..77e7252 100644 --- a/internal/cli/up.go +++ b/internal/cli/up.go @@ -65,7 +65,7 @@ func (c *CLI) selfUpdate() error { modelDest := filepath.Join(dit.ModelDir(), "model.json") if _, err := os.Stat(modelDest); err == nil { slog.Info("Updating cached model") - modelResp, err := http.Get(modelURL) + modelResp, err := http.Get(dit.ModelURL) if err == nil { defer func() { _ = modelResp.Body.Close() }() if modelResp.StatusCode == http.StatusOK {