From 5f8f514f03e92eeb33ee4b934d823c076ab8344b Mon Sep 17 00:00:00 2001 From: Pete Miller Date: Thu, 21 Mar 2024 14:46:32 -0700 Subject: [PATCH] AI Chat: sniff subresource content via throttle to detect new content metadata for same-page navigations (#22334) * AI Chat: sniff subresource content via throttle to detect new content metadata for same-page navigations * optimization: don't parse yt metadata (or fetch transcript) until an ai chat message is sent by the user --- browser/brave_content_browser_client.cc | 16 + .../chrome_content_renderer_client.cc | 4 +- .../content/browser/ai_chat_tab_helper.cc | 53 ++- .../content/browser/ai_chat_tab_helper.h | 21 ++ .../content/browser/page_content_fetcher.cc | 37 +- .../core/browser/conversation_driver.cc | 54 +-- .../core/browser/conversation_driver.h | 9 + .../common/mojom/page_content_extractor.mojom | 9 + components/ai_chat/renderer/BUILD.gn | 31 ++ .../ai_chat_resource_sniffer_throttle.cc | 60 ++++ .../ai_chat_resource_sniffer_throttle.h | 57 ++++ ..._chat_resource_sniffer_throttle_delegate.h | 32 ++ ...chat_resource_sniffer_throttle_unittest.cc | 318 ++++++++++++++++++ .../ai_chat_resource_sniffer_url_loader.cc | 100 ++++++ .../ai_chat_resource_sniffer_url_loader.h | 62 ++++ .../renderer/page_content_extractor.cc | 111 ++++-- .../ai_chat/renderer/page_content_extractor.h | 30 +- components/ai_chat/renderer/yt_util.cc | 101 ++++++ components/ai_chat/renderer/yt_util.h | 36 ++ .../ai_chat/renderer/yt_util_unittest.cc | 189 +++++++++++ renderer/DEPS | 1 + ...brave_url_loader_throttle_provider_impl.cc | 30 ++ test/BUILD.gn | 1 + 23 files changed, 1284 insertions(+), 78 deletions(-) create mode 100644 components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.cc create mode 100644 components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.h create mode 100644 components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_delegate.h create mode 100644 components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_unittest.cc create mode 100644 components/ai_chat/renderer/ai_chat_resource_sniffer_url_loader.cc create mode 100644 components/ai_chat/renderer/ai_chat_resource_sniffer_url_loader.h create mode 100644 components/ai_chat/renderer/yt_util.cc create mode 100644 components/ai_chat/renderer/yt_util.h create mode 100644 components/ai_chat/renderer/yt_util_unittest.cc diff --git a/browser/brave_content_browser_client.cc b/browser/brave_content_browser_client.cc index a3975c757f5b1..3fce3a4bc079a 100644 --- a/browser/brave_content_browser_client.cc +++ b/browser/brave_content_browser_client.cc @@ -153,10 +153,12 @@ using extensions::ChromeContentBrowserClientExtensionsPart; #if BUILDFLAG(ENABLE_AI_CHAT) #include "brave/browser/ui/webui/ai_chat/ai_chat_ui.h" +#include "brave/components/ai_chat/content/browser/ai_chat_tab_helper.h" #include "brave/components/ai_chat/content/browser/ai_chat_throttle.h" #include "brave/components/ai_chat/core/browser/utils.h" #include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h" #include "brave/components/ai_chat/core/common/mojom/settings_helper.mojom.h" #if BUILDFLAG(IS_ANDROID) #include "brave/components/ai_chat/core/browser/android/ai_chat_iap_subscription_android.h" @@ -602,6 +604,19 @@ void BraveContentBrowserClient:: &render_frame_host)); #endif +#if BUILDFLAG(ENABLE_AI_CHAT) + // AI Chat page content extraction renderer -> browser interface + associated_registry.AddInterface( + base::BindRepeating( + [](content::RenderFrameHost* render_frame_host, + mojo::PendingAssociatedReceiver< + ai_chat::mojom::PageContentExtractorHost> receiver) { + ai_chat::AIChatTabHelper::BindPageContentExtractorHost( + render_frame_host, std::move(receiver)); + }, + &render_frame_host)); +#endif + ChromeContentBrowserClient:: RegisterAssociatedInterfaceBindersForRenderFrameHost(render_frame_host, associated_registry); @@ -831,6 +846,7 @@ void BraveContentBrowserClient::RegisterBrowserInterfaceBindersForFrame( user_prefs::UserPrefs::Get(render_frame_host->GetBrowserContext()); if (ai_chat::IsAIChatEnabled(prefs) && brave::IsRegularProfile(render_frame_host->GetBrowserContext())) { + // WebUI -> Browser interface content::RegisterWebUIControllerInterfaceBinder(map); #if !BUILDFLAG(IS_ANDROID) diff --git a/chromium_src/chrome/renderer/chrome_content_renderer_client.cc b/chromium_src/chrome/renderer/chrome_content_renderer_client.cc index b29972fa89254..ebb545d73ee7b 100644 --- a/chromium_src/chrome/renderer/chrome_content_renderer_client.cc +++ b/chromium_src/chrome/renderer/chrome_content_renderer_client.cc @@ -6,6 +6,7 @@ #include "brave/components/ai_chat/core/common/buildflags/buildflags.h" #include "brave/components/content_settings/renderer/brave_content_settings_agent_impl.h" #include "chrome/common/chrome_isolated_world_ids.h" +#include "chrome/renderer/chrome_render_thread_observer.h" #include "components/dom_distiller/content/renderer/distillability_agent.h" #include "components/feed/content/renderer/rss_link_reader.h" #include "content/public/common/isolated_world_ids.h" @@ -22,7 +23,8 @@ void RenderFrameWithBinderRegistryCreated( service_manager::BinderRegistry* registry) { new feed::RssLinkReader(render_frame, registry); #if BUILDFLAG(ENABLE_AI_CHAT) - if (ai_chat::features::IsAIChatEnabled()) { + if (ai_chat::features::IsAIChatEnabled() && + !ChromeRenderThreadObserver::is_incognito_process()) { new ai_chat::PageContentExtractor(render_frame, registry, content::ISOLATED_WORLD_ID_GLOBAL, ISOLATED_WORLD_ID_BRAVE_INTERNAL); diff --git a/components/ai_chat/content/browser/ai_chat_tab_helper.cc b/components/ai_chat/content/browser/ai_chat_tab_helper.cc index 14e678cc7783a..8170e729be57d 100644 --- a/components/ai_chat/content/browser/ai_chat_tab_helper.cc +++ b/components/ai_chat/content/browser/ai_chat_tab_helper.cc @@ -5,6 +5,7 @@ #include "brave/components/ai_chat/content/browser/ai_chat_tab_helper.h" +#include #include #include #include @@ -19,6 +20,7 @@ #include "brave/components/ai_chat/content/browser/page_content_fetcher.h" #include "brave/components/ai_chat/core/browser/ai_chat_metrics.h" #include "brave/components/ai_chat/core/common/features.h" +#include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "components/favicon/content/content_favicon_driver.h" #include "components/grit/brave_components_strings.h" @@ -31,6 +33,8 @@ #include "content/public/browser/navigation_entry.h" #include "content/public/browser/scoped_accessibility_mode.h" #include "content/public/browser/storage_partition.h" +#include "content/public/browser/web_contents.h" +#include "mojo/public/cpp/bindings/self_owned_receiver.h" #include "pdf/buildflags.h" #include "ui/accessibility/ax_mode.h" #include "ui/base/l10n/l10n_util.h" @@ -59,6 +63,30 @@ void AIChatTabHelper::PDFA11yInfoLoadObserver::AccessibilityEventReceived( AIChatTabHelper::PDFA11yInfoLoadObserver::~PDFA11yInfoLoadObserver() = default; +// static +void AIChatTabHelper::BindPageContentExtractorHost( + content::RenderFrameHost* rfh, + mojo::PendingAssociatedReceiver receiver) { + CHECK(rfh); + if (!rfh->IsInPrimaryMainFrame()) { + DVLOG(4) << "Not binding extractor host to non-main frame"; + return; + } + auto* sender = content::WebContents::FromRenderFrameHost(rfh); + if (!sender) { + DVLOG(1) << "Cannot bind extractor host, no valid WebContents"; + return; + } + auto* tab_helper = AIChatTabHelper::FromWebContents(sender); + if (!tab_helper) { + DVLOG(1) << "Cannot bind extractor host, no AIChatTabHelper - " + << sender->GetVisibleURL(); + return; + } + DVLOG(4) << "Binding extractor host to AIChatTabHelper"; + tab_helper->BindPageContentExtractorReceiver(std::move(receiver)); +} + AIChatTabHelper::AIChatTabHelper( content::WebContents* web_contents, AIChatMetrics* ai_chat_metrics, @@ -127,6 +155,7 @@ void AIChatTabHelper::DidFinishNavigation( // and treating it as a "fresh page". is_same_document_navigation_ = navigation_handle->IsSameDocument(); pending_navigation_id_ = navigation_handle->GetNavigationId(); + // Experimentally only call |OnNewPage| for same-page navigations _if_ // it results in a page title change (see |TtileWasSet|). if (!is_same_document_navigation_) { @@ -137,7 +166,7 @@ void AIChatTabHelper::DidFinishNavigation( void AIChatTabHelper::TitleWasSet(content::NavigationEntry* entry) { DVLOG(3) << __func__ << entry->GetTitle(); if (is_same_document_navigation_) { - DVLOG(3) << "Same document navigation detected new \"page\" - calling " + DVLOG(2) << "Same document navigation detected new \"page\" - calling " "OnNewPage()"; // Page title modification after same-document navigation seems as good a // time as any to assume meaningful changes occured to the content. @@ -190,6 +219,22 @@ void AIChatTabHelper::OnFaviconUpdated( OnFaviconImageDataChanged(); } +// mojom::PageContentExtractorHost +void AIChatTabHelper::OnInterceptedPageContentChanged() { + // Maybe mark that the page changed, if we didn't detect it already via title + // change after a same-page navigation. This is the main benefit of this + // function. + if (is_same_document_navigation_) { + DVLOG(2) << "Same document navigation detected new \"page\" - calling " + "OnNewPage()"; + // Page title modification after same-document navigation seems as good a + // time as any to assume meaningful changes occured to the content. + OnNewPage(pending_navigation_id_); + // Don't respond to further TitleWasSet + is_same_document_navigation_ = false; + } +} + // ai_chat::ConversationDriver GURL AIChatTabHelper::GetPageURL() const { @@ -214,6 +259,12 @@ std::u16string AIChatTabHelper::GetPageTitle() const { return web_contents()->GetTitle(); } +void AIChatTabHelper::BindPageContentExtractorReceiver( + mojo::PendingAssociatedReceiver receiver) { + page_content_extractor_receiver_.reset(); + page_content_extractor_receiver_.Bind(std::move(receiver)); +} + WEB_CONTENTS_USER_DATA_KEY_IMPL(AIChatTabHelper); } // namespace ai_chat diff --git a/components/ai_chat/content/browser/ai_chat_tab_helper.h b/components/ai_chat/content/browser/ai_chat_tab_helper.h index 377ba6a7bd265..8fb43c273972e 100644 --- a/components/ai_chat/content/browser/ai_chat_tab_helper.h +++ b/components/ai_chat/content/browser/ai_chat_tab_helper.h @@ -15,11 +15,16 @@ #include "brave/components/ai_chat/core/browser/conversation_driver.h" #include "brave/components/ai_chat/core/browser/engine/engine_consumer.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" +#include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h" #include "components/favicon/core/favicon_driver_observer.h" #include "components/prefs/pref_change_registrar.h" #include "content/public/browser/navigation_handle.h" +#include "content/public/browser/render_frame_host.h" #include "content/public/browser/web_contents_observer.h" #include "content/public/browser/web_contents_user_data.h" +#include "mojo/public/cpp/bindings/associated_receiver.h" +#include "mojo/public/cpp/bindings/pending_associated_receiver.h" +#include "mojo/public/cpp/bindings/receiver_set.h" #include "services/data_decoder/public/cpp/data_decoder.h" class PrefService; @@ -34,15 +39,24 @@ class AIChatMetrics; // Provides context to an AI Chat conversation in the form of the Tab's content class AIChatTabHelper : public content::WebContentsObserver, public content::WebContentsUserData, + public mojom::PageContentExtractorHost, public favicon::FaviconDriverObserver, public ConversationDriver { public: + static void BindPageContentExtractorHost( + content::RenderFrameHost* rfh, + mojo::PendingAssociatedReceiver + receiver); + AIChatTabHelper(const AIChatTabHelper&) = delete; AIChatTabHelper& operator=(const AIChatTabHelper&) = delete; ~AIChatTabHelper() override; void SetOnPDFA11yInfoLoadedCallbackForTesting(base::OnceClosure cb); + // mojom::PageContentExtractorHost + void OnInterceptedPageContentChanged() override; + private: friend class content::WebContentsUserData; @@ -92,6 +106,10 @@ class AIChatTabHelper : public content::WebContentsObserver, std::string_view invalidation_token) override; std::u16string GetPageTitle() const override; + void BindPageContentExtractorReceiver( + mojo::PendingAssociatedReceiver + receiver); + raw_ptr ai_chat_metrics_; bool is_same_document_navigation_ = false; @@ -105,6 +123,9 @@ class AIChatTabHelper : public content::WebContentsObserver, // A scoper only used for PDF viewing. std::unique_ptr scoped_accessibility_mode_; + mojo::AssociatedReceiver + page_content_extractor_receiver_{this}; + base::WeakPtrFactory weak_ptr_factory_{this}; WEB_CONTENTS_USER_DATA_KEY_DECL(); }; diff --git a/components/ai_chat/content/browser/page_content_fetcher.cc b/components/ai_chat/content/browser/page_content_fetcher.cc index 04dbf33c6ab25..5df821df37304 100644 --- a/components/ai_chat/content/browser/page_content_fetcher.cc +++ b/components/ai_chat/content/browser/page_content_fetcher.cc @@ -108,11 +108,13 @@ net::NetworkTrafficAnnotationTag GetGithubNetworkTrafficAnnotationTag() { class PageContentFetcher { public: + explicit PageContentFetcher( + scoped_refptr url_loader_factory) + : url_loader_factory_(url_loader_factory) {} + void Start(mojo::Remote content_extractor, std::string_view invalidation_token, - scoped_refptr url_loader_factory, FetchPageContentCallback callback) { - url_loader_factory_ = url_loader_factory; content_extractor_ = std::move(content_extractor); if (!content_extractor_) { DeleteSelf(); @@ -129,9 +131,7 @@ class PageContentFetcher { void StartGithub( GURL patch_url, - scoped_refptr url_loader_factory, FetchPageContentCallback callback) { - url_loader_factory_ = url_loader_factory; auto request = std::make_unique(); request->url = patch_url; request->load_flags = net::LOAD_DO_NOT_SAVE_COOKIES; @@ -153,17 +153,6 @@ class PageContentFetcher { std::move(on_response), 2 * 1024 * 1024); } - private: - void DeleteSelf() { delete this; } - - void SendResultAndDeleteSelf(FetchPageContentCallback callback, - std::string content = "", - std::string invalidation_token = "", - bool is_video = false) { - std::move(callback).Run(content, is_video, invalidation_token); - delete this; - } - void OnTabContentResult(FetchPageContentCallback callback, std::string_view invalidation_token, mojom::PageContentPtr data) { @@ -226,6 +215,17 @@ class PageContentFetcher { std::move(on_response), 2 * 1024 * 1024); } + private: + void DeleteSelf() { delete this; } + + void SendResultAndDeleteSelf(FetchPageContentCallback callback, + std::string content = "", + std::string invalidation_token = "", + bool is_video = false) { + std::move(callback).Run(content, is_video, invalidation_token); + delete this; + } + void OnYoutubeTranscriptXMLParsed( FetchPageContentCallback callback, std::string invalidation_token, @@ -499,16 +499,16 @@ void FetchPageContent(content::WebContents* web_contents, } } #endif - auto* fetcher = new PageContentFetcher(); auto* loader = url_loader_factory_for_test.get() ? url_loader_factory_for_test.get() : web_contents->GetBrowserContext() ->GetDefaultStoragePartition() ->GetURLLoaderFactoryForBrowserProcess() .get(); + auto* fetcher = new PageContentFetcher(loader); auto patch_url = GetGithubPatchURLForPRURL(url); if (patch_url) { - fetcher->StartGithub(patch_url.value(), loader, std::move(callback)); + fetcher->StartGithub(patch_url.value(), std::move(callback)); return; } @@ -516,8 +516,7 @@ void FetchPageContent(content::WebContents* web_contents, // GetRemoteInterfaces() cannot be null if the render frame is created. primary_rfh->GetRemoteInterfaces()->GetInterface( extractor.BindNewPipeAndPassReceiver()); - fetcher->Start(std::move(extractor), invalidation_token, loader, - std::move(callback)); + fetcher->Start(std::move(extractor), invalidation_token, std::move(callback)); } } // namespace ai_chat diff --git a/components/ai_chat/core/browser/conversation_driver.cc b/components/ai_chat/core/browser/conversation_driver.cc index ae6f8eaf5c908..e99dccef822e7 100644 --- a/components/ai_chat/core/browser/conversation_driver.cc +++ b/components/ai_chat/core/browser/conversation_driver.cc @@ -472,17 +472,42 @@ void ConversationDriver::OnGeneratePageContentComplete( std::string contents_text, bool is_video, std::string invalidation_token) { - VLOG(1) << "OnGeneratePageContentComplete"; - VLOG(4) << "Contents(is_video=" << is_video - << ", invalidation_token=" << invalidation_token - << "): " << contents_text; + DVLOG(1) << "OnGeneratePageContentComplete"; + DVLOG(4) << "Contents(is_video=" << is_video + << ", invalidation_token=" << invalidation_token + << "): " << contents_text; if (navigation_id != current_navigation_id_) { VLOG(1) << __func__ << " for a different navigation. Ignoring."; return; } - is_page_text_fetch_in_progress_ = false; + // Ignore if we received content from observer in the meantime + if (!is_page_text_fetch_in_progress_) { + DVLOG(1) << __func__ + << " but already received contents from observer. Ignoring."; + return; + } + OnPageContentUpdated(contents_text, is_video, invalidation_token); + + std::move(callback).Run(article_text_, is_video_, + content_invalidation_token_); +} + +void ConversationDriver::OnExistingGeneratePageContentComplete( + GetPageContentCallback callback) { + // Don't need to check navigation ID since existing event will be + // deleted when there's a new conversation. + DVLOG(1) << "Existing page content fetch completed, proceeding with " + "the results of that operation."; + std::move(callback).Run(article_text_, is_video_, + content_invalidation_token_); +} + +void ConversationDriver::OnPageContentUpdated(std::string contents_text, + bool is_video, + std::string invalidation_token) { + is_page_text_fetch_in_progress_ = false; // If invalidation token matches existing token, then // content was not re-fetched and we can use our existing cache. if (!invalidation_token.empty() && @@ -500,27 +525,12 @@ void ConversationDriver::OnGeneratePageContentComplete( OnPageHasContentChanged(BuildSiteInfo()); } - on_page_text_fetch_complete_->Signal(); - on_page_text_fetch_complete_ = std::make_unique(); - if (contents_text.empty()) { VLOG(1) << __func__ << ": No data"; } - VLOG(4) << "calling callback with text: " << article_text_; - - std::move(callback).Run(article_text_, is_video_, - content_invalidation_token_); -} - -void ConversationDriver::OnExistingGeneratePageContentComplete( - GetPageContentCallback callback) { - // Don't need to check navigation ID since existing event will be - // deleted when there's a new conversation. - VLOG(1) << "Existing page content fetch completed, proceeding with " - "the results of that operation."; - std::move(callback).Run(article_text_, is_video_, - content_invalidation_token_); + on_page_text_fetch_complete_->Signal(); + on_page_text_fetch_complete_ = std::make_unique(); } void ConversationDriver::OnNewPage(int64_t navigation_id) { diff --git a/components/ai_chat/core/browser/conversation_driver.h b/components/ai_chat/core/browser/conversation_driver.h index c4e0670df5845..10d18ba3175ea 100644 --- a/components/ai_chat/core/browser/conversation_driver.h +++ b/components/ai_chat/core/browser/conversation_driver.h @@ -146,6 +146,15 @@ class ConversationDriver { virtual void OnFaviconImageDataChanged(); + // Implementer should call this when the content is updated in a way that + // will not be detected by the on-demand techniques used by GetPageContent. + // For example for sites where GetPageContent does not read the live DOM but + // reads static JS from HTML that doesn't change for same-page navigation and + // we need to intercept new JS data from subresource loads. + void OnPageContentUpdated(std::string content, + bool is_video, + std::string invalidation_token); + // To be called when a page navigation is detected and a new conversation // is expected. void OnNewPage(int64_t navigation_id); diff --git a/components/ai_chat/core/common/mojom/page_content_extractor.mojom b/components/ai_chat/core/common/mojom/page_content_extractor.mojom index 2aaf6d516d6d1..b52d00e302b3d 100644 --- a/components/ai_chat/core/common/mojom/page_content_extractor.mojom +++ b/components/ai_chat/core/common/mojom/page_content_extractor.mojom @@ -29,3 +29,12 @@ struct PageContent { interface PageContentExtractor { ExtractPageContent() => (PageContent? page_content); }; + +// Allows the renderer to notify the browser process of meaningful changes to +// the content. +interface PageContentExtractorHost { + // We don't send page content here due to an optimization for the majority of + // renderers without active AIChat conversations. We wait until the host + // requests it via PageContentExtractor.ExtractPageContent. + OnInterceptedPageContentChanged(); +}; diff --git a/components/ai_chat/renderer/BUILD.gn b/components/ai_chat/renderer/BUILD.gn index 4c0509cf7fe7d..0e217faea9adb 100644 --- a/components/ai_chat/renderer/BUILD.gn +++ b/components/ai_chat/renderer/BUILD.gn @@ -9,15 +9,23 @@ assert(enable_ai_chat) static_library("renderer") { sources = [ + "ai_chat_resource_sniffer_throttle.cc", + "ai_chat_resource_sniffer_throttle.h", + "ai_chat_resource_sniffer_throttle_delegate.h", + "ai_chat_resource_sniffer_url_loader.cc", + "ai_chat_resource_sniffer_url_loader.h", "page_content_extractor.cc", "page_content_extractor.h", "page_text_distilling.cc", "page_text_distilling.h", + "yt_util.cc", + "yt_util.h", ] deps = [ "//base", "//brave/components/ai_chat/core/common/mojom", + "//brave/components/body_sniffer", "//content/public/renderer", "//gin", "//mojo/public/cpp/bindings", @@ -29,3 +37,26 @@ static_library("renderer") { "//v8", ] } + +if (!is_ios) { + source_set("unit_tests") { + testonly = true + sources = [ + "ai_chat_resource_sniffer_throttle_unittest.cc", + "yt_util_unittest.cc", + ] + + deps = [ + "//base/test:test_support", + "//brave/components/ai_chat/renderer", + "//content/test:test_support", + "//mojo/public/cpp/bindings", + "//mojo/public/cpp/system", + "//services/data_decoder/public/cpp:test_support", + "//services/network:test_support", + "//services/network/public/cpp:cpp", + "//testing/gtest:gtest", + "//third_party/blink/public:blink", + ] + } +} diff --git a/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.cc b/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.cc new file mode 100644 index 0000000000000..8220cd48fc87a --- /dev/null +++ b/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include +#include + +#include "base/containers/contains.h" +#include "base/strings/string_util.h" +#include "brave/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.h" +#include "brave/components/ai_chat/renderer/ai_chat_resource_sniffer_url_loader.h" +#include "brave/components/ai_chat/renderer/yt_util.h" + +namespace ai_chat { + +namespace { +constexpr char kYouTubePlayerAPIPath[] = "/youtubei/v1/player"; +} + +std::unique_ptr +AIChatResourceSnifferThrottle::MaybeCreateThrottleFor( + base::WeakPtr delegate, + const GURL& url, + scoped_refptr task_runner) { + DCHECK(delegate); + // TODO(petemill): Allow some kind of config to be passed in to determine + // which hosts and paths to sniff, and how to parse it to a + // |mojom::PageContent|. + if (url.SchemeIsHTTPOrHTTPS() && base::Contains(kYouTubeHosts, url.host()) && + base::EqualsCaseInsensitiveASCII(url.path(), kYouTubePlayerAPIPath)) { + return std::make_unique(task_runner, + delegate); + } + return nullptr; +} + +AIChatResourceSnifferThrottle::AIChatResourceSnifferThrottle( + scoped_refptr task_runner, + base::WeakPtr delegate) + : task_runner_(task_runner), delegate_(delegate) {} + +AIChatResourceSnifferThrottle::~AIChatResourceSnifferThrottle() = default; + +void AIChatResourceSnifferThrottle::WillProcessResponse( + const GURL& response_url, + network::mojom::URLResponseHead* response_head, + bool* defer) { + mojo::PendingRemote new_remote; + mojo::PendingReceiver new_receiver; + raw_ptr sniffer_loader = nullptr; + + std::tie(new_remote, new_receiver, sniffer_loader) = + AIChatResourceSnifferURLLoader::CreateLoader( + AsWeakPtr(), std::move(delegate_), task_runner_, response_url); + BodySnifferThrottle::InterceptAndStartLoader( + std::move(new_remote), std::move(new_receiver), sniffer_loader); +} + +} // namespace ai_chat diff --git a/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.h b/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.h new file mode 100644 index 0000000000000..88b9eceb98eb2 --- /dev/null +++ b/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.h @@ -0,0 +1,57 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_AI_CHAT_RENDERER_AI_CHAT_RESOURCE_SNIFFER_THROTTLE_H_ +#define BRAVE_COMPONENTS_AI_CHAT_RENDERER_AI_CHAT_RESOURCE_SNIFFER_THROTTLE_H_ + +#include + +#include "base/task/sequenced_task_runner.h" +#include "brave/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_delegate.h" +#include "brave/components/body_sniffer/body_sniffer_throttle.h" + +namespace ai_chat { + +class AIChatResourceSnifferThrottleDelegate; + +// ResourceSnifferThrottle is an interceptor which reads the content of various +// resources and sends it to an AI Chat delegate for content updates. +class AIChatResourceSnifferThrottle : public body_sniffer::BodySnifferThrottle { + public: + explicit AIChatResourceSnifferThrottle( + scoped_refptr task_runner, + base::WeakPtr delegate); + ~AIChatResourceSnifferThrottle() override; + AIChatResourceSnifferThrottle& operator=( + const AIChatResourceSnifferThrottle&) = delete; + + static std::unique_ptr MaybeCreateThrottleFor( + base::WeakPtr delegate, + const GURL& url, + scoped_refptr task_runner); + + protected: + // blink::URLLoaderThrottle via body_sniffer::BodySnifferThrottle + void WillProcessResponse(const GURL& response_url, + network::mojom::URLResponseHead* response_head, + bool* defer) override; + + private: + friend class AIChatResourceSnifferThrottleTest; + FRIEND_TEST_ALL_PREFIXES(AIChatResourceSnifferThrottleTest, NoBody); + FRIEND_TEST_ALL_PREFIXES(AIChatResourceSnifferThrottleTest, Body_NonJson); + FRIEND_TEST_ALL_PREFIXES(AIChatResourceSnifferThrottleTest, Body_InvalidJson); + FRIEND_TEST_ALL_PREFIXES(AIChatResourceSnifferThrottleTest, + Body_ValidNonYTJson); + FRIEND_TEST_ALL_PREFIXES(AIChatResourceSnifferThrottleTest, Body_ValidYTJson); + FRIEND_TEST_ALL_PREFIXES(AIChatResourceSnifferThrottleTest, Abort_NoBodyPipe); + scoped_refptr task_runner_; + + base::WeakPtr delegate_; +}; + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_RENDERER_AI_CHAT_RESOURCE_SNIFFER_THROTTLE_H_ diff --git a/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_delegate.h b/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_delegate.h new file mode 100644 index 0000000000000..19ac160087dd8 --- /dev/null +++ b/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_delegate.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_AI_CHAT_RENDERER_AI_CHAT_RESOURCE_SNIFFER_THROTTLE_DELEGATE_H_ +#define BRAVE_COMPONENTS_AI_CHAT_RENDERER_AI_CHAT_RESOURCE_SNIFFER_THROTTLE_DELEGATE_H_ + +#include +#include + +namespace ai_chat { + +class AIChatResourceSnifferThrottleDelegate { + public: + enum class InterceptedContentType { + kYouTubeMetadataString, + }; + struct InterceptedContent { + InterceptedContentType type; + std::string content; + }; + virtual void OnInterceptedPageContentChanged( + std::unique_ptr content) = 0; + + protected: + virtual ~AIChatResourceSnifferThrottleDelegate() = default; +}; + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_RENDERER_AI_CHAT_RESOURCE_SNIFFER_THROTTLE_DELEGATE_H_ diff --git a/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_unittest.cc b/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_unittest.cc new file mode 100644 index 0000000000000..1920235799785 --- /dev/null +++ b/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_unittest.cc @@ -0,0 +1,318 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include +#include +#include +#include + +#include "base/functional/bind.h" +#include "base/memory/weak_ptr.h" +#include "base/run_loop.h" +#include "base/test/task_environment.h" +#include "brave/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.h" +#include "brave/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_delegate.h" +#include "mojo/public/cpp/bindings/pending_receiver.h" +#include "mojo/public/cpp/bindings/pending_remote.h" +#include "mojo/public/cpp/bindings/remote.h" +#include "mojo/public/cpp/system/data_pipe_utils.h" +#include "services/network/public/mojom/url_response_head.mojom.h" +#include "services/network/test/test_url_loader_client.h" +#include "services/network/test/test_url_loader_factory.h" +#include "testing/gmock/include/gmock/gmock.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "third_party/blink/public/common/loader/url_loader_throttle.h" +#include "third_party/googletest/src/googletest/include/gtest/gtest.h" +#include "url/gurl.h" + +namespace ai_chat { + +namespace { + +class MojoDataPipeSender { + public: + explicit MojoDataPipeSender(mojo::ScopedDataPipeProducerHandle handle) + : handle_(std::move(handle)), + watcher_(FROM_HERE, mojo::SimpleWatcher::ArmingPolicy::AUTOMATIC) {} + + void Start(std::string data, base::OnceClosure done_callback) { + data_ = std::move(data); + done_callback_ = std::move(done_callback); + watcher_.Watch(handle_.get(), + MOJO_HANDLE_SIGNAL_WRITABLE | MOJO_HANDLE_SIGNAL_PEER_CLOSED, + base::BindRepeating(&MojoDataPipeSender::OnWritable, + base::Unretained(this))); + } + + void OnWritable(MojoResult) { + uint32_t sending_bytes = data_.size() - sent_bytes_; + MojoResult result = handle_->WriteData( + data_.c_str() + sent_bytes_, &sending_bytes, MOJO_WRITE_DATA_FLAG_NONE); + switch (result) { + case MOJO_RESULT_OK: + break; + case MOJO_RESULT_FAILED_PRECONDITION: + // Finished unexpectedly. + std::move(done_callback_).Run(); + return; + case MOJO_RESULT_SHOULD_WAIT: + // Just wait until OnWritable() is called by the watcher. + return; + default: + NOTREACHED(); + return; + } + sent_bytes_ += sending_bytes; + if (data_.size() == sent_bytes_) { + std::move(done_callback_).Run(); + } + } + + mojo::ScopedDataPipeProducerHandle ReleaseHandle() { + return std::move(handle_); + } + + bool has_succeeded() const { return data_.size() == sent_bytes_; } + + private: + mojo::ScopedDataPipeProducerHandle handle_; + mojo::SimpleWatcher watcher_; + base::OnceClosure done_callback_; + std::string data_; + uint32_t sent_bytes_ = 0; +}; + +class MockAIChatResourceSnifferThrottleDelegate + : public AIChatResourceSnifferThrottleDelegate { + public: + MOCK_METHOD(void, OnInterceptedPageContentChanged_Data, (std::string)); + + void OnInterceptedPageContentChanged( + std::unique_ptr + content) override { + ASSERT_EQ(content->type, + AIChatResourceSnifferThrottleDelegate::InterceptedContentType:: + kYouTubeMetadataString); + OnInterceptedPageContentChanged_Data(content->content); + } + + base::WeakPtrFactory weak_factory_{ + this}; +}; + +class MockDelegate : public blink::URLLoaderThrottle::Delegate { + public: + // Implements blink::URLLoaderThrottle::Delegate. + void CancelWithError(int error_code, + std::string_view custom_reason) override { + NOTIMPLEMENTED(); + } + void Resume() override { + is_resumed_ = true; + // Resume from OnReceiveResponse() with a customized response header. + destination_loader_client()->OnReceiveResponse( + std::move(updated_response_head_), std::move(body_), absl::nullopt); + } + + void UpdateDeferredResponseHead( + network::mojom::URLResponseHeadPtr new_response_head, + mojo::ScopedDataPipeConsumerHandle body) override { + updated_response_head_ = std::move(new_response_head); + body_ = std::move(body); + } + void PauseReadingBodyFromNet() override { NOTIMPLEMENTED(); } + void ResumeReadingBodyFromNet() override { NOTIMPLEMENTED(); } + void InterceptResponse( + mojo::PendingRemote new_loader, + mojo::PendingReceiver + new_client_receiver, + mojo::PendingRemote* original_loader, + mojo::PendingReceiver* + original_client_receiver, + mojo::ScopedDataPipeConsumerHandle* body) override { + is_intercepted_ = true; + + destination_loader_remote_.Bind(std::move(new_loader)); + ASSERT_TRUE( + mojo::FusePipes(std::move(new_client_receiver), + mojo::PendingRemote( + destination_loader_client_.CreateRemote()))); + pending_receiver_ = original_loader->InitWithNewPipeAndPassReceiver(); + + *original_client_receiver = + source_loader_client_remote_.BindNewPipeAndPassReceiver(); + + if (no_body_) { + return; + } + + DCHECK(!source_body_handle_); + mojo::ScopedDataPipeConsumerHandle consumer; + EXPECT_EQ(MOJO_RESULT_OK, + mojo::CreateDataPipe(nullptr, source_body_handle_, consumer)); + *body = std::move(consumer); + } + + void LoadResponseBody(const std::string& body) { + MojoDataPipeSender sender(std::move(source_body_handle_)); + base::RunLoop loop; + sender.Start(body, loop.QuitClosure()); + loop.Run(); + + EXPECT_TRUE(sender.has_succeeded()); + source_body_handle_ = sender.ReleaseHandle(); + } + + void CompleteResponse() { + source_loader_client_remote()->OnComplete( + network::URLLoaderCompletionStatus()); + source_body_handle_.reset(); + } + + uint32_t ReadResponseBody(uint32_t size) { + std::vector buffer(size); + MojoResult result = destination_loader_client_.response_body().ReadData( + buffer.data(), &size, MOJO_READ_DATA_FLAG_NONE); + switch (result) { + case MOJO_RESULT_OK: + return size; + case MOJO_RESULT_FAILED_PRECONDITION: + return 0; + case MOJO_RESULT_SHOULD_WAIT: + return 0; + default: + NOTREACHED(); + } + return 0; + } + + void ResetProducer() { source_body_handle_.reset(); } + + bool is_intercepted() const { return is_intercepted_; } + bool is_resumed() const { return is_resumed_; } + void set_no_body() { no_body_ = true; } + + network::TestURLLoaderClient* destination_loader_client() { + return &destination_loader_client_; + } + + mojo::Remote& source_loader_client_remote() { + return source_loader_client_remote_; + } + + private: + bool is_intercepted_ = false; + bool is_resumed_ = false; + bool no_body_ = false; + network::mojom::URLResponseHeadPtr updated_response_head_; + mojo::ScopedDataPipeConsumerHandle body_; + + // A pair of a loader and a loader client for destination of the response. + mojo::Remote destination_loader_remote_; + network::TestURLLoaderClient destination_loader_client_; + + // A pair of a receiver and a remote for source of the response. + mojo::PendingReceiver pending_receiver_; + mojo::Remote source_loader_client_remote_; + + mojo::ScopedDataPipeProducerHandle source_body_handle_; +}; + +} // namespace + +class AIChatResourceSnifferThrottleTest : public testing::Test { + public: + std::unique_ptr MaybeCreateThrottleForUrl( + GURL url) { + return AIChatResourceSnifferThrottle::MaybeCreateThrottleFor( + ai_chat_throttle_delegate_.weak_factory_.GetWeakPtr(), url, + task_environment_.GetMainThreadTaskRunner()); + } + + void InterceptBodyRequestFor(const std::string& body) { + GURL url("https://www.youtube.com/youtubei/v1/player"); + auto throttle = MaybeCreateThrottleForUrl(url); + auto delegate = std::make_unique(); + throttle->set_delegate(delegate.get()); + + auto response_head = network::mojom::URLResponseHead::New(); + bool defer = false; + throttle->WillProcessResponse(url, response_head.get(), &defer); + EXPECT_FALSE(defer); + EXPECT_TRUE(delegate->is_intercepted()); + + delegate->LoadResponseBody(body); + delegate->CompleteResponse(); + task_environment_.RunUntilIdle(); + EXPECT_TRUE(delegate->destination_loader_client()->has_received_response()); + } + + protected: + base::test::TaskEnvironment task_environment_; + testing::NiceMock + ai_chat_throttle_delegate_; +}; + +TEST_F(AIChatResourceSnifferThrottleTest, ThrottlesYTPlayerAPI) { + EXPECT_NE(nullptr, MaybeCreateThrottleForUrl(GURL( + "http://www.youtube.com/youtubei/v1/player?example"))); +} + +TEST_F(AIChatResourceSnifferThrottleTest, DoesNotThrottleYTOther) { + EXPECT_EQ(nullptr, + MaybeCreateThrottleForUrl(GURL( + "http://www.youtube.com/youtubei/v1/somethingelse?example"))); +} + +TEST_F(AIChatResourceSnifferThrottleTest, DoesNotThrottleNonYT) { + EXPECT_EQ(nullptr, MaybeCreateThrottleForUrl(GURL( + "http://www.example.com/youtubei/v1/player?example"))); +} + +TEST_F(AIChatResourceSnifferThrottleTest, DoesNotThrottleNonHTTP) { + EXPECT_EQ(nullptr, MaybeCreateThrottleForUrl(GURL( + "wss://www.youtube.com/youtubei/v1/player?example"))); +} + +TEST_F(AIChatResourceSnifferThrottleTest, Body_NonJson) { + // AIChatResourceSnifferThrottle doesn't parse the json as an optimization + // since it might not get used until an AIChat conversation message is about + // to be sent, so any body content should be passed to the delegate, we don't + // need to test for valid JSON + std::string body = "\x89PNG\x0D\x0A\x1A\x0A"; + EXPECT_CALL(ai_chat_throttle_delegate_, + OnInterceptedPageContentChanged_Data(body)) + .Times(1); + InterceptBodyRequestFor(body); +} + +TEST_F(AIChatResourceSnifferThrottleTest, Body_ValidYTJson) { + std::string body = R"({ + "captions": { + "playerCaptionsTracklistRenderer": { + "captionTracks": [ + { + "baseUrl": "https://www.example.com/caption1" + } + ] + } + } + })"; + EXPECT_CALL(ai_chat_throttle_delegate_, + OnInterceptedPageContentChanged_Data(body)) + .Times(1); + InterceptBodyRequestFor(body); +} + +TEST_F(AIChatResourceSnifferThrottleTest, LongBody) { + std::string body = "This should be long enough..."; + body.resize(2048, 'a'); + EXPECT_CALL(ai_chat_throttle_delegate_, + OnInterceptedPageContentChanged_Data(body)) + .Times(1); + InterceptBodyRequestFor(body); +} + +} // namespace ai_chat diff --git a/components/ai_chat/renderer/ai_chat_resource_sniffer_url_loader.cc b/components/ai_chat/renderer/ai_chat_resource_sniffer_url_loader.cc new file mode 100644 index 0000000000000..ba1136d043dce --- /dev/null +++ b/components/ai_chat/renderer/ai_chat_resource_sniffer_url_loader.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/ai_chat/renderer/ai_chat_resource_sniffer_url_loader.h" + +#include +#include +#include + +#include "base/json/json_reader.h" +#include "base/logging.h" +#include "base/types/expected.h" +#include "base/values.h" +#include "brave/components/body_sniffer/body_sniffer_url_loader.h" +#include "mojo/public/cpp/bindings/self_owned_receiver.h" + +namespace ai_chat { + +namespace { + +constexpr uint32_t kReadBufferSize = 37000; // average subresource size + +} // namespace + +// static +std::tuple, + mojo::PendingReceiver, + AIChatResourceSnifferURLLoader*> +AIChatResourceSnifferURLLoader::CreateLoader( + base::WeakPtr throttle, + base::WeakPtr delegate, + scoped_refptr task_runner, + const GURL& response_url) { + mojo::PendingRemote url_loader; + mojo::PendingRemote url_loader_client; + mojo::PendingReceiver + url_loader_client_receiver = + url_loader_client.InitWithNewPipeAndPassReceiver(); + + auto loader = base::WrapUnique(new AIChatResourceSnifferURLLoader( + std::move(throttle), delegate, std::move(url_loader_client), + std::move(task_runner), response_url)); + AIChatResourceSnifferURLLoader* loader_rawptr = loader.get(); + mojo::MakeSelfOwnedReceiver(std::move(loader), + url_loader.InitWithNewPipeAndPassReceiver()); + return std::make_tuple(std::move(url_loader), + std::move(url_loader_client_receiver), loader_rawptr); +} + +AIChatResourceSnifferURLLoader::AIChatResourceSnifferURLLoader( + base::WeakPtr throttle, + base::WeakPtr delegate, + mojo::PendingRemote + destination_url_loader_client, + scoped_refptr task_runner, + const GURL& response_url) + : body_sniffer::BodySnifferURLLoader( + throttle, + response_url, + std::move(destination_url_loader_client), + task_runner), + delegate_(delegate), + response_url_(response_url) {} + +AIChatResourceSnifferURLLoader::~AIChatResourceSnifferURLLoader() = default; + +void AIChatResourceSnifferURLLoader::OnBodyReadable(MojoResult) { + DCHECK_EQ(State::kLoading, state_); + + if (!BodySnifferURLLoader::CheckBufferedBody(kReadBufferSize)) { + return; + } + + body_consumer_watcher_.ArmOrNotify(); +} + +void AIChatResourceSnifferURLLoader::OnBodyWritable(MojoResult r) { + DCHECK_EQ(State::kSending, state_); + if (bytes_remaining_in_buffer_ > 0) { + SendBufferedBodyToClient(); + } else { + CompleteSending(); + } +} + +void AIChatResourceSnifferURLLoader::CompleteLoading(std::string body) { + if (!body.empty()) { + auto content = std::make_unique< + AIChatResourceSnifferThrottleDelegate::InterceptedContent>(); + content->type = AIChatResourceSnifferThrottleDelegate:: + InterceptedContentType::kYouTubeMetadataString; + content->content = body; + delegate_->OnInterceptedPageContentChanged(std::move(content)); + } + body_sniffer::BodySnifferURLLoader::CompleteLoading(std::move(body)); +} + +} // namespace ai_chat diff --git a/components/ai_chat/renderer/ai_chat_resource_sniffer_url_loader.h b/components/ai_chat/renderer/ai_chat_resource_sniffer_url_loader.h new file mode 100644 index 0000000000000..4530b333953ca --- /dev/null +++ b/components/ai_chat/renderer/ai_chat_resource_sniffer_url_loader.h @@ -0,0 +1,62 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_AI_CHAT_RENDERER_AI_CHAT_RESOURCE_SNIFFER_URL_LOADER_H_ +#define BRAVE_COMPONENTS_AI_CHAT_RENDERER_AI_CHAT_RESOURCE_SNIFFER_URL_LOADER_H_ + +#include +#include + +#include "base/memory/scoped_refptr.h" +#include "base/memory/weak_ptr.h" +#include "base/task/sequenced_task_runner.h" +#include "brave/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_delegate.h" +#include "brave/components/body_sniffer/body_sniffer_url_loader.h" + +namespace body_sniffer { +class BodySnifferThrottle; +} // namespace body_sniffer + +namespace ai_chat { + +class AIChatResourceSnifferURLLoader + : public body_sniffer::BodySnifferURLLoader { + public: + ~AIChatResourceSnifferURLLoader() override; + + // mojo::PendingRemote controls the lifetime of the + // loader. + static std::tuple, + mojo::PendingReceiver, + AIChatResourceSnifferURLLoader*> + CreateLoader(base::WeakPtr throttle, + base::WeakPtr delegate, + scoped_refptr task_runner, + const GURL& response_url); + + private: + AIChatResourceSnifferURLLoader( + base::WeakPtr throttle, + base::WeakPtr delegate, + mojo::PendingRemote + destination_url_loader_client, + scoped_refptr task_runner, + const GURL& response_url); + + // body_sniffer::BodySnifferURLLoader + void OnBodyReadable(MojoResult) override; + void OnBodyWritable(MojoResult) override; + + void CompleteLoading(std::string body) override; + base::WeakPtr delegate_; + + GURL response_url_; + + base::WeakPtrFactory weak_ptr_factory_{this}; +}; + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_RENDERER_AI_CHAT_RESOURCE_SNIFFER_URL_LOADER_H_ diff --git a/components/ai_chat/renderer/page_content_extractor.cc b/components/ai_chat/renderer/page_content_extractor.cc index a78de1edbce9e..3c31b14d79fd2 100644 --- a/components/ai_chat/renderer/page_content_extractor.cc +++ b/components/ai_chat/renderer/page_content_extractor.cc @@ -5,6 +5,7 @@ #include "brave/components/ai_chat/renderer/page_content_extractor.h" +#include #include #include #include @@ -14,13 +15,18 @@ #include "base/containers/fixed_flat_set.h" #include "base/containers/span.h" #include "base/functional/bind.h" +#include "base/memory/ptr_util.h" #include "base/values.h" #include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h" #include "brave/components/ai_chat/renderer/page_text_distilling.h" +#include "brave/components/ai_chat/renderer/yt_util.h" #include "content/public/renderer/render_frame.h" #include "content/public/renderer/render_frame_observer.h" +#include "mojo/public/cpp/bindings/associated_remote.h" #include "net/base/url_util.h" +#include "third_party/blink/public/common/associated_interfaces/associated_interface_provider.h" +#include "third_party/blink/public/common/browser_interface_broker_proxy.h" #include "third_party/blink/public/platform/web_string.h" #include "third_party/blink/public/web/web_local_frame.h" #include "third_party/blink/public/web/web_script_source.h" @@ -28,29 +34,14 @@ #include "url/url_constants.h" #include "v8/include/v8-isolate.h" +namespace ai_chat { + namespace { const char16_t kYoutubeTranscriptUrlExtractionScript[] = - // TODO(petemill): Consider user's language uR"JS( (function() { - const tracks = ytplayer?.config?.args?.raw_player_response?.captions?.playerCaptionsTracklistRenderer?.captionTracks - if (!tracks.length) { - return null - } - const langTracks = tracks.filter(track => track.languageCode === "en") - if (langTracks.length) { - const nonAutoGeneratedTrack = langTracks.find(track => track.kind !== 'asr') - if (nonAutoGeneratedTrack) { - return nonAutoGeneratedTrack.baseUrl - } - return langTracks[0].baseUrl - } - const nonAutoGeneratedTrack = tracks.find(track => track.kind !== 'asr') - if (nonAutoGeneratedTrack) { - return nonAutoGeneratedTrack.baseUrl - } - return tracks[0].baseUrl + return ytplayer?.config?.args?.raw_player_response?.captions?.playerCaptionsTracklistRenderer?.captionTracks })() )JS"; @@ -72,13 +63,6 @@ const char16_t kVideoTrackTranscriptUrlExtractionScript[] = })() )JS"; -constexpr auto kYouTubeHosts = - base::MakeFixedFlatSet(base::sorted_unique, - { - "m.youtube.com", - "www.youtube.com", - }); - // TODO(petemill): Use heuristics to determine if page's main focus is // a video, and not a hard-coded list of Url hosts. constexpr auto kVideoTrackHosts = @@ -89,14 +73,13 @@ constexpr auto kVideoTrackHosts = } // namespace -namespace ai_chat { - PageContentExtractor::PageContentExtractor( content::RenderFrame* render_frame, service_manager::BinderRegistry* registry, int32_t global_world_id, int32_t isolated_world_id) : content::RenderFrameObserver(render_frame), + RenderFrameObserverTracker(render_frame), global_world_id_(global_world_id), isolated_world_id_(isolated_world_id), weak_ptr_factory_(this) { @@ -117,9 +100,43 @@ void PageContentExtractor::OnDestruct() { delete this; } +base::WeakPtr PageContentExtractor::GetWeakPtr() { + return weak_ptr_factory_.GetWeakPtr(); +} + void PageContentExtractor::ExtractPageContent( mojom::PageContentExtractor::ExtractPageContentCallback callback) { VLOG(1) << "AI Chat renderer has been asked for page content."; + // When content has been pushed to this class from a throttle via + // OnInterceptedPageContentChanged, use that content instead of fetching it + // from the page. + if (intercepted_content_) { + auto intercepted_content = std::move(intercepted_content_); + intercepted_content_.reset(); + DVLOG(1) << "Using intercepted content."; + DCHECK_EQ(intercepted_content->type, + InterceptedContentType::kYouTubeMetadataString) + << "Unexpected intercepted content type"; + // Parse the YT metadata and extract the most appropriate caption Url + auto maybe_caption_url = + ParseAndChooseCaptionTrackUrl(intercepted_content->content); + if (maybe_caption_url.has_value()) { + GURL caption_url = + render_frame()->GetWebFrame()->GetDocument().CompleteURL( + blink::WebString::FromASCII(maybe_caption_url.value())); + if (caption_url.is_valid()) { + mojom::PageContentPtr content_update = mojom::PageContent::New(); + content_update->type = mojom::PageContentType::VideoTranscriptYouTube; + content_update->content = + mojom::PageContentData::NewContentUrl(caption_url); + std::move(callback).Run(std::move(content_update)); + return; + } + } + std::move(callback).Run({}); + return; + } + blink::WebLocalFrame* main_frame = render_frame()->GetWebFrame(); GURL origin = url::Origin(((const blink::WebFrame*)main_frame)->GetSecurityOrigin()) @@ -185,6 +202,26 @@ void PageContentExtractor::ExtractPageContent( weak_ptr_factory_.GetWeakPtr(), std::move(callback))); } +void PageContentExtractor::OnInterceptedPageContentChanged( + std::unique_ptr + content_update) { + DCHECK_EQ(content_update->type, + AIChatResourceSnifferThrottleDelegate::InterceptedContentType:: + kYouTubeMetadataString) + << " - unexpected content type"; + DCHECK(!content_update->content.empty()); + + // Store the new content for later, when we're asked for it, so that we + // don't have to do any parsing or fetching when there's no active + // conversation. + intercepted_content_ = std::move(content_update); + // Let the host know that new content was received so that it may record a + // "page" change. + mojo::AssociatedRemote host; + render_frame()->GetRemoteAssociatedInterfaces()->GetInterface(&host); + host->OnInterceptedPageContentChanged(); +} + void PageContentExtractor::BindReceiver( mojo::PendingReceiver receiver) { VLOG(1) << "AIChat PageContentExtractor handler bound."; @@ -222,15 +259,31 @@ void PageContentExtractor::OnJSTranscriptUrlResult( DVLOG(2) << "Video transcript Url extraction script completed and took" << (base::TimeTicks::Now() - start_time).InMillisecondsF() << "ms" << "\nResult: " << (value ? value->DebugString() : "[undefined]"); + // Handle no result from script - if (!value.has_value() || !value->is_string()) { + if (!value.has_value()) { + std::move(callback).Run({}); + return; + } + + // Optional parsing + std::string url; + if (type == mojom::PageContentType::VideoTranscriptYouTube) { + auto maybe_url = ChooseCaptionTrackUrl(value->GetIfList()); + if (maybe_url.has_value()) { + url = maybe_url.value(); + } + } else if (value->is_string()) { + url = value->GetString(); + } + if (url.empty()) { std::move(callback).Run({}); return; } // Handle invalid url GURL transcript_url = render_frame()->GetWebFrame()->GetDocument().CompleteURL( - blink::WebString::FromASCII(value->GetString())); + blink::WebString::FromASCII(url)); if (!transcript_url.is_valid() || !transcript_url.SchemeIs(url::kHttpsScheme)) { DVLOG(1) << "Invalid Url for transcript: " << transcript_url.spec(); diff --git a/components/ai_chat/renderer/page_content_extractor.h b/components/ai_chat/renderer/page_content_extractor.h index 6bb71921f4ac8..c2fd55517b6c4 100644 --- a/components/ai_chat/renderer/page_content_extractor.h +++ b/components/ai_chat/renderer/page_content_extractor.h @@ -7,28 +7,38 @@ #define BRAVE_COMPONENTS_AI_CHAT_RENDERER_PAGE_CONTENT_EXTRACTOR_H_ #include +#include #include #include #include "base/values.h" #include "brave/components/ai_chat/core/common/mojom/page_content_extractor.mojom.h" +#include "brave/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle_delegate.h" #include "content/public/renderer/render_frame.h" #include "content/public/renderer/render_frame_observer.h" +#include "content/public/renderer/render_frame_observer_tracker.h" #include "mojo/public/cpp/bindings/receiver.h" +#include "mojo/public/cpp/bindings/remote.h" namespace ai_chat { -class PageContentExtractor : public ai_chat::mojom::PageContentExtractor, - public content::RenderFrameObserver { +class PageContentExtractor + : public ai_chat::mojom::PageContentExtractor, + public content::RenderFrameObserver, + public content::RenderFrameObserverTracker, + public AIChatResourceSnifferThrottleDelegate { public: - explicit PageContentExtractor(content::RenderFrame* render_frame, - service_manager::BinderRegistry* registry, - int32_t global_world_id, - int32_t isolated_world_id); + PageContentExtractor(content::RenderFrame* render_frame, + service_manager::BinderRegistry* registry, + int32_t global_world_id, + int32_t isolated_world_id); + PageContentExtractor(const PageContentExtractor&) = delete; PageContentExtractor& operator=(const PageContentExtractor&) = delete; ~PageContentExtractor() override; + base::WeakPtr GetWeakPtr(); + private: void OnJSTranscriptUrlResult( mojom::PageContentExtractor::ExtractPageContentCallback callback, @@ -47,6 +57,11 @@ class PageContentExtractor : public ai_chat::mojom::PageContentExtractor, mojom::PageContentExtractor::ExtractPageContentCallback callback) override; + // AIChatResourceSnifferThrottleDelegate + void OnInterceptedPageContentChanged( + std::unique_ptr + content_update) override; + void BindReceiver( mojo::PendingReceiver receiver); @@ -55,6 +70,9 @@ class PageContentExtractor : public ai_chat::mojom::PageContentExtractor, int32_t global_world_id_; int32_t isolated_world_id_; + std::unique_ptr + intercepted_content_; + base::WeakPtrFactory weak_ptr_factory_{this}; }; diff --git a/components/ai_chat/renderer/yt_util.cc b/components/ai_chat/renderer/yt_util.cc new file mode 100644 index 0000000000000..7cb941638909a --- /dev/null +++ b/components/ai_chat/renderer/yt_util.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/ai_chat/renderer/yt_util.h" + +#include +#include + +#include "base/json/json_reader.h" +#include "base/logging.h" +#include "base/ranges/algorithm.h" +#include "base/values.h" + +namespace ai_chat { + +std::optional ChooseCaptionTrackUrl( + const base::Value::List* caption_tracks) { + if (!caption_tracks || caption_tracks->empty()) { + return std::nullopt; + } + if (caption_tracks->empty()) { + return std::nullopt; + } + const base::Value::Dict* track; + // When only single track, use that + if (caption_tracks->size() == 1) { + track = caption_tracks->front().GetIfDict(); + } else { + // When multiple tracks, favor english (due to ai_chat models), then first + // english auto-generated track, then settle for anything. + // TODO(petemill): Consider preferring user's language. + auto iter = base::ranges::find_if( + *caption_tracks, [](const base::Value& track_raw) { + const base::Value::Dict* language_track = track_raw.GetIfDict(); + auto* kind = language_track->FindString("kind"); + if (kind && *kind == "asr") { + return false; + } + auto* lang = language_track->FindString("languageCode"); + if (lang && *lang == "en") { + return true; + } + return false; + }); + if (iter == caption_tracks->end()) { + iter = base::ranges::find_if( + *caption_tracks, [](const base::Value& track_raw) { + const base::Value::Dict* language_track = track_raw.GetIfDict(); + auto* lang = language_track->FindString("languageCode"); + if (lang && *lang == "en") { + return true; + } + return false; + }); + } + if (iter == caption_tracks->end()) { + iter = caption_tracks->begin(); + } + track = iter->GetIfDict(); + } + if (!track) { + return std::nullopt; + } + const std::string* caption_url_raw = track->FindString("baseUrl"); + + if (!caption_url_raw) { + return std::nullopt; + } + return *caption_url_raw; +} + +std::optional ParseAndChooseCaptionTrackUrl( + std::string_view body) { + if (!body.size()) { + return std::nullopt; + } + + auto result_value = + base::JSONReader::ReadAndReturnValueWithError(body, base::JSON_PARSE_RFC); + + if (!result_value.has_value() || result_value->is_string()) { + DVLOG(1) << __func__ << ": parsing error: " << result_value.ToString(); + return std::nullopt; + } else if (!result_value->is_dict()) { + DVLOG(1) << __func__ << ": parsing error: not a dict"; + return std::nullopt; + } + + auto* caption_tracks = result_value->GetDict().FindListByDottedPath( + "captions.playerCaptionsTracklistRenderer.captionTracks"); + if (!caption_tracks) { + DVLOG(1) << __func__ << ": no caption tracks found"; + return std::nullopt; + } + + return ChooseCaptionTrackUrl(caption_tracks); +} + +} // namespace ai_chat diff --git a/components/ai_chat/renderer/yt_util.h b/components/ai_chat/renderer/yt_util.h new file mode 100644 index 0000000000000..db0bdfc14ca12 --- /dev/null +++ b/components/ai_chat/renderer/yt_util.h @@ -0,0 +1,36 @@ +// Copyright (c) 2024 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#ifndef BRAVE_COMPONENTS_AI_CHAT_RENDERER_YT_UTIL_H_ +#define BRAVE_COMPONENTS_AI_CHAT_RENDERER_YT_UTIL_H_ + +#include +#include +#include + +#include "base/containers/fixed_flat_set.h" +#include "base/values.h" + +namespace ai_chat { + +inline constexpr auto kYouTubeHosts = + base::MakeFixedFlatSet(base::sorted_unique, + { + "m.youtube.com", + "www.youtube.com", + }); + +// Extract a caption url from an array of YT caption tracks, from the YT page +// API. +std::optional ChooseCaptionTrackUrl( + const base::Value::List* caption_tracks); + +// Parse YT metadata json string and choose the most appropriate caption track +// url. +std::optional ParseAndChooseCaptionTrackUrl(std::string_view body); + +} // namespace ai_chat + +#endif // BRAVE_COMPONENTS_AI_CHAT_RENDERER_YT_UTIL_H_ diff --git a/components/ai_chat/renderer/yt_util_unittest.cc b/components/ai_chat/renderer/yt_util_unittest.cc new file mode 100644 index 0000000000000..a62391b81fd1a --- /dev/null +++ b/components/ai_chat/renderer/yt_util_unittest.cc @@ -0,0 +1,189 @@ +// Copyright (c) 2023 The Brave Authors. All rights reserved. +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "brave/components/ai_chat/renderer/yt_util.h" + +#include +#include + +#include "base/json/json_reader.h" +#include "base/logging.h" +#include "base/values.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace ai_chat { + +TEST(YTCaptionTrackTest, ChoosesENCaptionTrackUrl) { + std::string body = R"([ + { + "kind": "captions", + "languageCode": "de", + "baseUrl": "http://example.com/caption_de.vtt" + }, + { + "kind": "captions", + "languageCode": "en", + "baseUrl": "http://example.com/caption_en.vtt" + }, + { + "kind": "captions", + "languageCode": "es", + "baseUrl": "http://example.com/caption_es.vtt" + } + ])"; + auto result_value = base::JSONReader::Read(body, base::JSON_PARSE_RFC); + ASSERT_TRUE(result_value.has_value()); + ASSERT_TRUE(result_value->is_list()); + + auto result = ChooseCaptionTrackUrl(result_value->GetIfList()); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), "http://example.com/caption_en.vtt"); +} + +TEST(YTCaptionTrackTest, PrefersNonASR) { + std::string body = R"([ + { + "kind": "captions", + "languageCode": "de", + "baseUrl": "http://example.com/caption_de.vtt" + }, + { + "kind": "asr", + "languageCode": "en", + "baseUrl": "http://example.com/caption_en_asr.vtt" + }, + { + "kind": "captions", + "languageCode": "en", + "baseUrl": "http://example.com/caption_en.vtt" + }, + { + "kind": "captions", + "languageCode": "es", + "baseUrl": "http://example.com/caption_es.vtt" + } + ])"; + auto result_value = base::JSONReader::Read(body, base::JSON_PARSE_RFC); + ASSERT_TRUE(result_value.has_value()); + ASSERT_TRUE(result_value->is_list()); + + auto result = ChooseCaptionTrackUrl(result_value->GetIfList()); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), "http://example.com/caption_en.vtt"); +} + +TEST(YTCaptionTrackTest, PrefersEnIfASR) { + std::string body = R"([ + { + "kind": "captions", + "languageCode": "de", + "baseUrl": "http://example.com/caption_de.vtt" + }, + { + "kind": "asr", + "languageCode": "en", + "baseUrl": "http://example.com/caption_en_asr.vtt" + }, + { + "kind": "captions", + "languageCode": "es", + "baseUrl": "http://example.com/caption_es.vtt" + } + ])"; + auto result_value = base::JSONReader::Read(body, base::JSON_PARSE_RFC); + ASSERT_TRUE(result_value.has_value()); + ASSERT_TRUE(result_value->is_list()); + + auto result = ChooseCaptionTrackUrl(result_value->GetIfList()); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), "http://example.com/caption_en_asr.vtt"); +} + +TEST(YTCaptionTrackTest, FallbackToFirst) { + std::string body = R"([ + { + "kind": "captions", + "languageCode": "de", + "baseUrl": "http://example.com/caption_de.vtt" + }, + { + "kind": "captions", + "languageCode": "ja", + "baseUrl": "http://example.com/caption_ja.vtt" + }, + { + "kind": "captions", + "languageCode": "es", + "baseUrl": "http://example.com/caption_es.vtt" + } + ])"; + auto result_value = base::JSONReader::Read(body, base::JSON_PARSE_RFC); + + ASSERT_TRUE(result_value.has_value()); + ASSERT_TRUE(result_value->is_list()); + + auto result = ChooseCaptionTrackUrl(result_value->GetIfList()); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), "http://example.com/caption_de.vtt"); +} + +TEST(YTCaptionTrackTest, ParseAndGetTrackUrl_NonJson) { + std::string body = "\x89PNG\x0D\x0A\x1A\x0A"; + auto result = ParseAndChooseCaptionTrackUrl(body); + EXPECT_FALSE(result.has_value()); +} + +TEST(YTCaptionTrackTest, ParseAndGetTrackUrl_EmptyJson) { + std::string body = "[]"; + auto result = ParseAndChooseCaptionTrackUrl(body); + EXPECT_FALSE(result.has_value()); +} + +TEST(YTCaptionTrackTest, ParseAndGetTrackUrl_InvalidJson) { + std::string body = "{"; + auto result = ParseAndChooseCaptionTrackUrl(body); + EXPECT_FALSE(result.has_value()); +} + +TEST(YTCaptionTrackTest, ParseAndGetTrackUrl_ValidNonYTJson) { + std::string body = R"({ + "captions": [] + })"; + auto result = ParseAndChooseCaptionTrackUrl(body); + EXPECT_FALSE(result.has_value()); +} + +TEST(YTCaptionTrackTest, ParseAndGetTrackUrl_ValidYTJson) { + std::string body = R"({ + "captions": { + "playerCaptionsTracklistRenderer": { + "captionTracks": [ + { + "baseUrl": "https://www.example.com/caption1" + } + ] + } + } + })"; + auto result = ParseAndChooseCaptionTrackUrl(body); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), "https://www.example.com/caption1"); +} + +TEST(YTCaptionTrackTest, ParseAndGetTrackUrl_ValidNoStructure) { + // Not the correct structure + std::string body = R"([ + { + "kind": "captions", + "languageCode": "de", + "baseUrl": "http://example.com/caption_de.vtt" + } + ])"; + auto result = ParseAndChooseCaptionTrackUrl(body); + EXPECT_FALSE(result.has_value()); +} + +} // namespace ai_chat diff --git a/renderer/DEPS b/renderer/DEPS index 0000df873e57d..8baa08c719140 100644 --- a/renderer/DEPS +++ b/renderer/DEPS @@ -5,6 +5,7 @@ include_rules = [ "+content/public/renderer", "+media/base", "+mojo/public/cpp/bindings", + "+services/network/public/cpp", "+third_party/blink/public/common", "+third_party/blink/public/platform", "+third_party/blink/public/public_buildflags.h", diff --git a/renderer/brave_url_loader_throttle_provider_impl.cc b/renderer/brave_url_loader_throttle_provider_impl.cc index 7e5bf927931cb..459ae50eee8c7 100644 --- a/renderer/brave_url_loader_throttle_provider_impl.cc +++ b/renderer/brave_url_loader_throttle_provider_impl.cc @@ -7,8 +7,18 @@ #include +#include "brave/components/ai_chat/core/common/buildflags/buildflags.h" +#include "brave/components/ai_chat/core/common/features.h" +#include "brave/components/ai_chat/renderer/page_content_extractor.h" #include "brave/components/tor/buildflags/buildflags.h" #include "brave/renderer/brave_content_renderer_client.h" +#include "content/public/renderer/render_frame.h" +#include "services/network/public/cpp/resource_request.h" +#include "third_party/blink/public/web/web_local_frame.h" + +#if BUILDFLAG(ENABLE_AI_CHAT) +#include "brave/components/ai_chat/renderer/ai_chat_resource_sniffer_throttle.h" +#endif // ENABLE_AI_CHAT #if BUILDFLAG(ENABLE_TOR) #include "brave/components/tor/renderer/onion_domain_throttle.h" @@ -78,5 +88,25 @@ BraveURLLoaderThrottleProviderImpl::CreateThrottles( throttles.emplace_back(std::move(onion_domain_throttle)); } #endif + // AI Chat +#if BUILDFLAG(ENABLE_AI_CHAT) + if (ai_chat::features::IsAIChatEnabled() && local_frame_token.has_value() && + content::RenderThread::IsMainThread()) { + content::RenderFrame* render_frame = content::RenderFrame::FromWebFrame( + blink::WebLocalFrame::FromFrameToken(local_frame_token.value())); + auto* page_content_delegate = + ai_chat::PageContentExtractor::Get(render_frame); + if (page_content_delegate) { + std::unique_ptr + ai_chat_resource_throttle = + ai_chat::AIChatResourceSnifferThrottle::MaybeCreateThrottleFor( + page_content_delegate->GetWeakPtr(), request.url, + base::SequencedTaskRunner::GetCurrentDefault()); + if (ai_chat_resource_throttle) { + throttles.emplace_back(std::move(ai_chat_resource_throttle)); + } + } + } +#endif // ENABLE_AI_CHAT return throttles; } diff --git a/test/BUILD.gn b/test/BUILD.gn index 06d31790c585c..3881bd7e0b988 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -357,6 +357,7 @@ test("brave_unit_tests") { "//brave/browser/ai_chat:unit_tests", "//brave/components/ai_chat/core/browser:unit_tests", "//brave/components/ai_chat/core/common:unit_tests", + "//brave/components/ai_chat/renderer:unit_tests", ] }