diff --git a/internal/homekit/api.go b/internal/homekit/api.go index 39fdaa43..306cc2a1 100644 --- a/internal/homekit/api.go +++ b/internal/homekit/api.go @@ -1,60 +1,63 @@ package homekit import ( - "encoding/json" "fmt" + "net/http" + "net/url" + "sync" + "github.com/AlexxIT/go2rtc/internal/app/store" "github.com/AlexxIT/go2rtc/internal/streams" "github.com/AlexxIT/go2rtc/pkg/hap" "github.com/AlexxIT/go2rtc/pkg/hap/mdns" - "net/http" - "net/url" + "github.com/gorilla/websocket" + "strings" + "time" ) +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +var activeConnections int +var activeConnectionsMutex sync.Mutex + func apiHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case "GET": - items := make([]any, 0) - for name, src := range store.GetDict("streams") { - if src := src.(string); strings.HasPrefix(src, "homekit") { - u, err := url.Parse(src) - if err != nil { - continue - } - device := Device{ - Name: name, - Addr: u.Host, - Paired: true, - } - items = append(items, device) - } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Error().Err(err).Caller().Send() + _, err = w.Write([]byte(err.Error())) + return } - - for info := range mdns.GetAll() { - if !strings.HasSuffix(info.Name, mdns.Suffix) { - continue - } - name := info.Name[:len(info.Name)-len(mdns.Suffix)] - device := Device{ - Name: strings.ReplaceAll(name, "\\", ""), - Addr: fmt.Sprintf("%s:%d", info.AddrV4, info.Port), + activeConnectionsMutex.Lock() + activeConnections++ + activeConnectionsMutex.Unlock() + + done := make(chan struct{}) + go hkDiscoverDevices(conn, done) + + for { + _, _, err := conn.ReadMessage() + if err != nil { + log.Debug().Err(err).Caller().Send() + _, err = w.Write([]byte(err.Error())) + break } - for _, field := range info.InfoFields { - switch field[:2] { - case "id": - device.ID = field[3:] - case "md": - device.Model = field[3:] - case "sf": - device.Paired = field[3] == '0' - } - } - items = append(items, device) } - _ = json.NewEncoder(w).Encode(items) + close(done) + activeConnectionsMutex.Lock() + activeConnections-- + activeConnectionsMutex.Unlock() + conn.Close() case "POST": // TODO: post params... @@ -76,6 +79,83 @@ func apiHandler(w http.ResponseWriter, r *http.Request) { } } +func hkDiscoverDevices(conn *websocket.Conn, done chan struct{}) { + queryCounter := 0 + for { + select { + case <-done: + return + default: + log.Trace().Int("active connections: ", activeConnections).Msg("[homekit] ") + activeConnectionsMutex.Lock() + if activeConnections <= 0 { + activeConnectionsMutex.Unlock() + return + } + activeConnectionsMutex.Unlock() + + queryCounter++ + timeout := time.Second + if queryCounter%10 == 0 { + timeout = 5 * time.Second + } + + entries := mdns.GetAll(timeout) + + for name, src := range store.GetDict("streams") { + if src := src.(string); strings.HasPrefix(src, "homekit") { + u, err := url.Parse(src) + if err != nil { + continue + } + device := Device{ + Name: name, + Addr: u.Host, + Paired: true, + } + err = conn.WriteJSON(device) + if err != nil { + log.Error().Err(err).Caller().Send() + + return + } + } + } + + for entry := range entries { + if !strings.HasSuffix(entry.Name, mdns.Suffix) { + continue + } + + name := entry.Name[:len(entry.Name)-len(mdns.Suffix)] + device := Device{ + Name: strings.ReplaceAll(name, "\\", ""), + Addr: fmt.Sprintf("%s:%d", entry.AddrV4, entry.Port), + } + for _, field := range entry.InfoFields { + switch field[:2] { + case "id": + device.ID = field[3:] + case "md": + device.Model = field[3:] + case "sf": + device.Paired = field[3] == '0' + } + } + + err := conn.WriteJSON(device) + if err != nil { + log.Debug().Err(err).Caller().Send() + + return + } + } + + time.Sleep(timeout) + } + } +} + func hkPair(deviceID, pin, name string) (err error) { var conn *hap.Conn diff --git a/pkg/hap/mdns/client.go b/pkg/hap/mdns/client.go index c5befa45..f0ae78a6 100644 --- a/pkg/hap/mdns/client.go +++ b/pkg/hap/mdns/client.go @@ -2,16 +2,22 @@ package mdns import ( "fmt" - "github.com/hashicorp/mdns" "strings" + "time" + + "github.com/hashicorp/mdns" ) const Suffix = "._hap._tcp.local." -func GetAll() chan *mdns.ServiceEntry { +func GetAll(timeout_opt ...time.Duration) chan *mdns.ServiceEntry { + timeout := time.Second + if len(timeout_opt) > 0 { + timeout = timeout_opt[0] + } entries := make(chan *mdns.ServiceEntry) params := &mdns.QueryParam{ - Service: "_hap._tcp", Entries: entries, DisableIPv6: true, + Service: "_hap._tcp", Entries: entries, DisableIPv6: true, Timeout: timeout, } go func() { diff --git a/www/add.html b/www/add.html index 0a99facb..b68f4d1c 100644 --- a/www/add.html +++ b/www/add.html @@ -125,29 +125,45 @@