Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@ go install github.com/happyhackingspace/dit/cmd/dit@latest
```go
import "github.com/happyhackingspace/dit"

// Load classifier (finds model.json automatically)
// Load classifier. On first call, if no model.json is found in the current
// directory (walked up to the nearest go.mod) or in ~/.dit/, the pretrained
// model is downloaded from Hugging Face to ~/.dit/model.json (~93MB, one-time)
// and reused on subsequent calls.
c, _ := dit.New()

// Or load an explicit file (no network, no search).
c, _ := dit.Load("path/to/model.json")

// Classify page type
page, _ := c.ExtractPageType(htmlString)
fmt.Println(page.Type) // "login"
Expand Down
67 changes: 64 additions & 3 deletions dit.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,26 @@
package dit

import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"time"

"github.com/happyhackingspace/dit/captcha"
"github.com/happyhackingspace/dit/classifier"
"github.com/happyhackingspace/dit/internal/htmlutil"
)

// downloadTimeout bounds the total time spent fetching the model.
const downloadTimeout = 1 * time.Minute

// ModelURL is the canonical download location for the pretrained model.
const ModelURL = "https://huggingface.co/datasets/happyhackingspace/dit/resolve/main/model.json"

// Classifier wraps the form and field type classification models.
type Classifier struct {
fc *classifier.FormFieldClassifier
Expand Down Expand Up @@ -59,12 +70,62 @@ type PageResultProba struct {

// New loads the classifier from "model.json", searching the current directory
// and parent directories up to the module root, then ~/.dit/model.json.
// If no model is found locally, it is downloaded from ModelURL to
// ~/.dit/model.json and loaded from there. The download is a one-time cost
// per machine; subsequent calls reuse the cached file.
func New() (*Classifier, error) {
path, err := FindModel("model.json")
if err != nil {
if path, err := FindModel("model.json"); err == nil {
return Load(path)
}

dest := filepath.Join(ModelDir(), "model.json")
slog.Info("Model not found, downloading", "url", ModelURL, "dest", dest)
if err := Download(dest); err != nil {
return nil, fmt.Errorf("dit: %w", err)
}
return Load(path)
return Load(dest)
}

// Download fetches the pretrained model from ModelURL and writes it to dest,
// creating parent directories as needed. A partial file is removed on error.
func Download(dest string) error {
if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil {
return fmt.Errorf("create model dir: %w", err)
}

ctx, cancel := context.WithTimeout(context.Background(), downloadTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ModelURL, nil)
if err != nil {
return fmt.Errorf("download model: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("download model: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download model: HTTP %d", resp.StatusCode)
}

f, err := os.Create(dest)
if err != nil {
return fmt.Errorf("create model file: %w", err)
}

written, err := io.Copy(f, resp.Body)
if err != nil {
_ = f.Close()
_ = os.Remove(dest)
return fmt.Errorf("download model: %w", err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("close model file: %w", err)
}

slog.Info("Model downloaded", "size", fmt.Sprintf("%.1fMB", float64(written)/1024/1024))
return nil
}

// ModelDir returns the default model storage directory (~/.dit).
Expand Down
5 changes: 3 additions & 2 deletions internal/cli/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"path/filepath"
"strings"

"github.com/happyhackingspace/dit"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -115,8 +116,8 @@ func dataDownload(dataFolder string) error {
}
slog.Info("Training data extracted", "files", count, "folder", dataFolder)

slog.Info("Downloading model", "url", modelURL)
modelResp, err := http.Get(modelURL)
slog.Info("Downloading model", "url", dit.ModelURL)
modelResp, err := http.Get(dit.ModelURL)
if err != nil {
return fmt.Errorf("download model: %w", err)
}
Expand Down
46 changes: 3 additions & 43 deletions internal/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
"time"

Expand All @@ -17,8 +16,6 @@ import (
"github.com/spf13/cobra"
)

const modelURL = "https://huggingface.co/datasets/happyhackingspace/dit/resolve/main/model.json"

func (c *CLI) newRunCommand() *cobra.Command {
var modelPath string
var threshold float64
Expand Down Expand Up @@ -93,7 +90,7 @@ func (c *CLI) newRunCommand() *cobra.Command {
slog.Debug("HTML fetched", "target", target, "bytes", len(htmlContent))

start := time.Now()
cl, err := loadOrDownloadModel(modelPath)
cl, err := loadModel(modelPath)
if err != nil {
return err
}
Expand Down Expand Up @@ -159,49 +156,12 @@ func isStdinTerminal() bool {
return fi.Mode()&os.ModeCharDevice != 0
}

func loadOrDownloadModel(modelPath string) (*dit.Classifier, error) {
func loadModel(modelPath string) (*dit.Classifier, error) {
if modelPath != "" {
slog.Debug("Loading custom model", "path", modelPath)
return dit.Load(modelPath)
}

cl, err := dit.New()
if err == nil {
return cl, nil
}

dest := filepath.Join(dit.ModelDir(), "model.json")
slog.Info("Model not found, downloading", "url", modelURL, "dest", dest)

if err := os.MkdirAll(filepath.Dir(dest), 0755); err != nil {
return nil, fmt.Errorf("create model dir: %w", err)
}

resp, err := http.Get(modelURL)
if err != nil {
return nil, fmt.Errorf("download model: %w", err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("download model: HTTP %d", resp.StatusCode)
}

f, err := os.Create(dest)
if err != nil {
return nil, fmt.Errorf("create model file: %w", err)
}

written, err := io.Copy(f, resp.Body)
if err != nil {
_ = f.Close()
_ = os.Remove(dest)
return nil, fmt.Errorf("download model: %w", err)
}
_ = f.Close()

slog.Info("Model downloaded", "size", fmt.Sprintf("%.1fMB", float64(written)/1024/1024))
return dit.Load(dest)
return dit.New()
}

type fetchOptions struct {
Expand Down
2 changes: 1 addition & 1 deletion internal/cli/up.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (c *CLI) selfUpdate() error {
modelDest := filepath.Join(dit.ModelDir(), "model.json")
if _, err := os.Stat(modelDest); err == nil {
slog.Info("Updating cached model")
modelResp, err := http.Get(modelURL)
modelResp, err := http.Get(dit.ModelURL)
if err == nil {
defer func() { _ = modelResp.Body.Close() }()
if modelResp.StatusCode == http.StatusOK {
Expand Down