diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index e8809a11..d2b31569 100644 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -1,5 +1,4 @@ - - - + - + + - + + @@ -105,4 +111,4 @@ android:value="subject_segmentation" /> - \ No newline at end of file + diff --git a/core/network/src/main/java/com/android/developers/androidify/RemoteConfigDataSource.kt b/core/network/src/main/java/com/android/developers/androidify/RemoteConfigDataSource.kt index 1ffae662..d2f18bd2 100644 --- a/core/network/src/main/java/com/android/developers/androidify/RemoteConfigDataSource.kt +++ b/core/network/src/main/java/com/android/developers/androidify/RemoteConfigDataSource.kt @@ -27,8 +27,11 @@ interface RemoteConfigDataSource { fun isBackgroundVibesFeatureEnabled(): Boolean fun promptTextVerify(): String fun promptImageValidation(): String + fun promptImageValidationNano(): String fun promptImageDescription(): String + fun promptImageDescriptionNano(): String fun useGeminiNano(): Boolean + fun enabledGeminiNanoModelVersions(): String fun generateBotPrompt(): String fun promptImageGenerationWithSkinTone(): String @@ -77,14 +80,26 @@ class RemoteConfigDataSourceImpl @Inject constructor() : RemoteConfigDataSource return remoteConfig.getString("prompt_image_validation") } + override fun promptImageValidationNano(): String { + return remoteConfig.getString("prompt_image_validation_nano") + } + override fun promptImageDescription(): String { return remoteConfig.getString("prompt_image_description") } + override fun promptImageDescriptionNano(): String { + return remoteConfig.getString("prompt_image_description_nano") + } + override fun useGeminiNano(): Boolean { return remoteConfig.getBoolean("use_gemini_nano") } + override fun enabledGeminiNanoModelVersions(): String { + return remoteConfig.getString("enabled_gemini_nano_model_versions") + } + override fun generateBotPrompt(): String { return remoteConfig.getString("generate_bot_prompt") } diff --git a/core/network/src/main/res/xml/remote_config_defaults.xml b/core/network/src/main/res/xml/remote_config_defaults.xml index eb8daa7e..986e7a3f 100644 --- a/core/network/src/main/res/xml/remote_config_defaults.xml +++ b/core/network/src/main/res/xml/remote_config_defaults.xml @@ -62,10 +62,36 @@ - it cannot contain hate speech or other offensive language -it cannot contain blood or gore or violence. + + prompt_image_validation_nano + [TASK] + You are a Validator. Analyze the attached image and determine its validity based on the + rules. + [RULES] + VALID if AND ONLY if: + 1. PRIMARY subject is a person showing their head and shoulder + 2. The image MUST NOT contain: Nudity, Explicit content, Illegal weapons, Violent + references, Drugs, Illicit substances, Hate speech, Offensive language, Blood, Gore, or + Violence. + [OUTPUT] + Return ONLY one string. + Check sequentially. Output the first failure code that applies: + 1. Is the PRIMARY subject NOT a person (e.g., animal, object, landscape)? -> + "not_a_person" + 2. Is the person present but missing face/head/shoulders or too blurry? -> + "not_enough_detail" + 3. Does the image violate any negative policy (Rule 2)? -> "policy_violation" + 4. If all rules are passed: -> null + + use_gemini_nano false + + enabled_gemini_nano_model_versions + nano-v3 + dancing_droid_gif_link https://services.google.com/fh/files/misc/android_dancing.gif @@ -148,6 +174,45 @@ * Do not say rendered, rendering, or digital. * Only respond with new image description as a paragraph. + + prompt_image_description_nano + ## Role + You are an expert image analyst specializing in generating detailed, objective descriptions of people. + + ## Task + Your task is to describe the person in the provided image in vivid detail, following the guidelines and examples below. + + ## Guidelines + - Start with the overall mood or impression of the person (e.g., serene, joyful, pensive). + - Describe the person's physical appearance, focusing on hair (color, style, length) and any visible facial features. + - Detail the clothing, including the type of garments, style, color, and material. + - Mention any accessories, such as glasses, hats, or jewelry. + - Describe the immediate surroundings, including any objects, animals, or other people interacting with the subject. + + ## Constraints + - The output must be a single, coherent paragraph. + - If no person is visible in the image, state that clearly and do not describe anything else. + - Provide only the description. Do not add any introductory or concluding remarks. + + ## Examples + + ### Example 1: Standard Case + Input: [Image of a person on a picnic blanket with a dog] + Output: A highly detailed and realistic portrayal of a person with a serene and pleasant mood. The figure has short, chin-length, straight dark black hair. No facial hair is present. Blue mirrored sunglasses are resting on top of its head. The figure is wearing a loose-fitting, light gray kimono-like top with a V-neckline and wide, elbow-length sleeves. This top features intricate, colorful embroidery in muted red, green, and yellow floral patterns on the front and sleeves. On its bottom, the figure wears loose-fitting, light gray wide-leg pants made of a soft, flowing material. No footwear is visible. The figure is seated on a red and white checkered picnic blanket. Next to it on the blanket is a clear plastic bottle. It is interacting with a black and white Pomeranian-like dog, which has black fur with distinct white markings on its chest, legs, and face, and a leash attached to its collar. The overall depiction aims for a clear and life-like appearance. + + ### Example 2: Corner Case (No Person) + Input: [Image of an empty park bench] + Output: No person is visible in the image. + + ## Input + {{image}} + + ## Output Reminder + Take a deep breath, read the instructions again, read the inputs again. Each instruction is crucial and must be executed with utmost care and attention to detail. + + Description: + + promo_video_link https://services.google.com/fh/files/misc/androidfy_storyboard_b_v07.mp4 diff --git a/core/testing/src/main/AndroidManifest.xml b/core/testing/src/main/AndroidManifest.xml index 22cc6771..3cb17dfb 100644 --- a/core/testing/src/main/AndroidManifest.xml +++ b/core/testing/src/main/AndroidManifest.xml @@ -17,5 +17,4 @@ - \ No newline at end of file diff --git a/core/testing/src/main/java/com/android/developers/testing/data/TestGeminiNanoDownloader.kt b/core/testing/src/main/java/com/android/developers/testing/data/TestGeminiNanoDownloader.kt new file mode 100644 index 00000000..60bedb0a --- /dev/null +++ b/core/testing/src/main/java/com/android/developers/testing/data/TestGeminiNanoDownloader.kt @@ -0,0 +1,9 @@ +package com.android.developers.testing.data + +import com.android.developers.androidify.data.GeminiNanoDownloader + +class TestGeminiNanoDownloader(val modelDownloaded: Boolean) : GeminiNanoDownloader { + override fun isModelDownloaded(): Boolean { + return modelDownloaded + } +} \ No newline at end of file diff --git a/core/testing/src/main/java/com/android/developers/testing/data/TestGeminiNanoGenerationDataSource.kt b/core/testing/src/main/java/com/android/developers/testing/data/TestGeminiNanoGenerationDataSource.kt index 1336a217..44a8c19b 100644 --- a/core/testing/src/main/java/com/android/developers/testing/data/TestGeminiNanoGenerationDataSource.kt +++ b/core/testing/src/main/java/com/android/developers/testing/data/TestGeminiNanoGenerationDataSource.kt @@ -15,13 +15,27 @@ */ package com.android.developers.testing.data +import android.graphics.Bitmap +import com.android.developers.androidify.data.GeminiNanoDownloader import com.android.developers.androidify.data.GeminiNanoGenerationDataSource +import com.android.developers.androidify.model.ValidatedDescription +import com.android.developers.androidify.model.ValidatedImage -class TestGeminiNanoGenerationDataSource(val promptOutput: String?) : GeminiNanoGenerationDataSource { - override suspend fun initialize() { - } +class TestGeminiNanoGenerationDataSource( + val promptOutput: String?, + val geminiNanoDownloader: GeminiNanoDownloader +) : GeminiNanoGenerationDataSource { override suspend fun generatePrompt(prompt: String): String? { return promptOutput } + + override suspend fun validateImageHasEnoughInformation(image: Bitmap): ValidatedImage? { + return ValidatedImage(true, null) + } + + override suspend fun generateDescriptivePromptFromImage(image: Bitmap): ValidatedDescription? { + if (!geminiNanoDownloader.isModelDownloaded()) return null + return ValidatedDescription(true, "Nano description") + } } diff --git a/core/testing/src/main/java/com/android/developers/testing/network/TestFirebaseAiDataSource.kt b/core/testing/src/main/java/com/android/developers/testing/network/TestFirebaseAiDataSource.kt index a12a4357..e6b2ac4f 100644 --- a/core/testing/src/main/java/com/android/developers/testing/network/TestFirebaseAiDataSource.kt +++ b/core/testing/src/main/java/com/android/developers/testing/network/TestFirebaseAiDataSource.kt @@ -24,7 +24,7 @@ import com.android.developers.androidify.vertexai.FirebaseAiDataSource class TestFirebaseAiDataSource(val promptOutput: List) : FirebaseAiDataSource { override suspend fun validatePromptHasEnoughInformation(inputPrompt: String): ValidatedDescription { - return ValidatedDescription(true, "User description") + return ValidatedDescription(true, "Firebase description") } override suspend fun validateImageHasEnoughInformation(image: Bitmap): ValidatedImage { @@ -32,7 +32,7 @@ class TestFirebaseAiDataSource(val promptOutput: List) : FirebaseAiDataS } override suspend fun generateDescriptivePromptFromImage(image: Bitmap): ValidatedDescription { - return ValidatedDescription(true, "User description") + return ValidatedDescription(true, "Firebase description") } override suspend fun generateImageFromPromptAndSkinTone( diff --git a/core/testing/src/main/java/com/android/developers/testing/network/TestLocalSegmentationDataSource.kt b/core/testing/src/main/java/com/android/developers/testing/network/TestLocalSegmentationDataSource.kt new file mode 100644 index 00000000..b33b4f78 --- /dev/null +++ b/core/testing/src/main/java/com/android/developers/testing/network/TestLocalSegmentationDataSource.kt @@ -0,0 +1,12 @@ +package com.android.developers.testing.network + +import android.graphics.Bitmap +import com.android.developers.androidify.ondevice.LocalSegmentationDataSource +import androidx.core.graphics.createBitmap + +class TestLocalSegmentationDataSource() : LocalSegmentationDataSource { + + override suspend fun removeBackground(bitmap: Bitmap): Bitmap { + return createBitmap(100, 100) + } +} \ No newline at end of file diff --git a/core/testing/src/main/java/com/android/developers/testing/network/TestRemoteConfigDataSource.kt b/core/testing/src/main/java/com/android/developers/testing/network/TestRemoteConfigDataSource.kt index 5cfea0b8..5d826d68 100644 --- a/core/testing/src/main/java/com/android/developers/testing/network/TestRemoteConfigDataSource.kt +++ b/core/testing/src/main/java/com/android/developers/testing/network/TestRemoteConfigDataSource.kt @@ -42,14 +42,26 @@ class TestRemoteConfigDataSource(private val useGeminiNano: Boolean) : RemoteCon TODO("Not yet implemented") } + override fun promptImageValidationNano(): String { + TODO("Not yet implemented") + } + override fun promptImageDescription(): String { TODO("Not yet implemented") } + override fun promptImageDescriptionNano(): String { + TODO("Not yet implemented") + } + override fun useGeminiNano(): Boolean { return useGeminiNano } + override fun enabledGeminiNanoModelVersions(): String { + TODO("Not yet implemented") + } + override fun generateBotPrompt(): String { return "generateBotPrompt" } diff --git a/core/testing/src/main/java/com/android/developers/testing/repository/FakeImageGenerationRepository.kt b/core/testing/src/main/java/com/android/developers/testing/repository/FakeImageGenerationRepository.kt index 47427a70..02c6e45e 100644 --- a/core/testing/src/main/java/com/android/developers/testing/repository/FakeImageGenerationRepository.kt +++ b/core/testing/src/main/java/com/android/developers/testing/repository/FakeImageGenerationRepository.kt @@ -20,11 +20,10 @@ import android.net.Uri import androidx.core.graphics.createBitmap import androidx.core.net.toUri import com.android.developers.androidify.data.ImageGenerationRepository +import com.android.developers.androidify.model.ValidatedDescription import java.io.File class FakeImageGenerationRepository : ImageGenerationRepository { - override suspend fun initialize() { - } var exceptionToThrow: Exception? = null override suspend fun generateFromDescription( @@ -35,6 +34,10 @@ class FakeImageGenerationRepository : ImageGenerationRepository { return createBitmap(1, 1) } + override suspend fun getDescriptionFromImage(file: File): ValidatedDescription { + return ValidatedDescription(true, "") + } + override suspend fun generateFromImage( file: File, skinTone: String, diff --git a/core/testing/src/main/java/com/android/developers/testing/repository/TestTextGenerationRepository.kt b/core/testing/src/main/java/com/android/developers/testing/repository/TestTextGenerationRepository.kt index 0fa76a1a..8b78553a 100644 --- a/core/testing/src/main/java/com/android/developers/testing/repository/TestTextGenerationRepository.kt +++ b/core/testing/src/main/java/com/android/developers/testing/repository/TestTextGenerationRepository.kt @@ -18,9 +18,6 @@ package com.android.developers.testing.repository import com.android.developers.androidify.data.TextGenerationRepository class TestTextGenerationRepository : TextGenerationRepository { - override suspend fun initialize() { - } - override suspend fun getNextGeneratedBotPrompt(): String? { return "Test prompt" } diff --git a/data/build.gradle.kts b/data/build.gradle.kts index 1677d959..aef0bce3 100644 --- a/data/build.gradle.kts +++ b/data/build.gradle.kts @@ -47,6 +47,7 @@ dependencies { implementation(projects.core.network) implementation(projects.core.util) + implementation(libs.androidx.app.startup) implementation(libs.kotlinx.serialization.json) implementation(libs.retrofit) implementation(libs.timber) @@ -56,9 +57,14 @@ dependencies { implementation(libs.androidx.hilt.navigation.compose) implementation(libs.okhttp) implementation(libs.retrofit.kotlin.serialization) - implementation(libs.ai.edge) { - exclude(group = "com.google.guava") - } + implementation(libs.genai.prompt) ksp(libs.hilt.compiler) + + testImplementation(libs.junit) + testImplementation(libs.kotlinx.coroutines.test) + testImplementation(libs.hilt.android.testing) + testImplementation(libs.robolectric) + testImplementation(projects.core.testing) + testImplementation(kotlin("test")) } diff --git a/data/src/main/AndroidManifest.xml b/data/src/main/AndroidManifest.xml index c8073096..7d6d8140 100644 --- a/data/src/main/AndroidManifest.xml +++ b/data/src/main/AndroidManifest.xml @@ -18,5 +18,4 @@ xmlns:tools="http://schemas.android.com/tools"> - \ No newline at end of file diff --git a/data/src/main/java/com/android/developers/androidify/data/GeminiNanoDownloader.kt b/data/src/main/java/com/android/developers/androidify/data/GeminiNanoDownloader.kt deleted file mode 100644 index 809a3653..00000000 --- a/data/src/main/java/com/android/developers/androidify/data/GeminiNanoDownloader.kt +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright 2025 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.android.developers.androidify.data - -import android.app.Application -import com.google.ai.edge.aicore.DownloadCallback -import com.google.ai.edge.aicore.DownloadConfig -import com.google.ai.edge.aicore.GenerativeAIException -import com.google.ai.edge.aicore.GenerativeModel -import com.google.ai.edge.aicore.generationConfig -import timber.log.Timber -import javax.inject.Inject -import javax.inject.Singleton - -@Singleton -class GeminiNanoDownloader @Inject constructor(private val application: Application) { - var generativeModel: GenerativeModel? = null - private set - - private var modelDownloaded = false - - fun isModelDownloaded() = modelDownloaded - - suspend fun downloadModel() { - Timber.d("downloadModel") - try { - setup() - generativeModel?.prepareInferenceEngine() - } catch (e: Exception) { - Timber.e(e, "Error preparing inference engine") - } - Timber.d("prepare inference engine") - } - - private fun setup() { - val downloadCallback = object : DownloadCallback { - override fun onDownloadStarted(bytesToDownload: Long) { - super.onDownloadStarted(bytesToDownload) - Timber.i("onDownloadStarted for Gemini Nano $bytesToDownload") - } - - override fun onDownloadCompleted() { - super.onDownloadCompleted() - modelDownloaded = true - Timber.i("onDownloadCompleted for Gemini Nano") - } - - override fun onDownloadFailed(failureStatus: String, e: GenerativeAIException) { - super.onDownloadFailed(failureStatus, e) - // downloading the model has failed so make the model null as we can't use it - generativeModel = null - Timber.i("onDownloadFailed for Gemini Nano") - } - } - - val downloadConfig = DownloadConfig(downloadCallback) - - val generationConfig = generationConfig { - context = application - temperature = 0.2f - topK = 16 - maxOutputTokens = 256 - } - - generativeModel = GenerativeModel( - generationConfig = generationConfig, - downloadConfig = downloadConfig, - ) - } -} diff --git a/data/src/main/java/com/android/developers/androidify/data/GeminiNanoDownloaderImpl.kt b/data/src/main/java/com/android/developers/androidify/data/GeminiNanoDownloaderImpl.kt new file mode 100644 index 00000000..08607baa --- /dev/null +++ b/data/src/main/java/com/android/developers/androidify/data/GeminiNanoDownloaderImpl.kt @@ -0,0 +1,87 @@ +/* + * Copyright 2025 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.android.developers.androidify.data + +import com.android.developers.androidify.RemoteConfigDataSource +import com.google.mlkit.genai.common.DownloadStatus +import com.google.mlkit.genai.common.FeatureStatus +import com.google.mlkit.genai.prompt.Generation +import timber.log.Timber +import javax.inject.Inject +import javax.inject.Singleton + +interface GeminiNanoDownloader { + fun isModelDownloaded(): Boolean +} + +@Singleton +class GeminiNanoDownloaderImpl @Inject constructor( + private val remoteConfigDataSource: RemoteConfigDataSource, +) : GeminiNanoDownloader { + var generativeModel = Generation.getClient() + + private var modelDownloaded = false + + override fun isModelDownloaded() = modelDownloaded + + suspend fun downloadModel() { + Timber.d("downloadModel") + try { + setup() + generativeModel.warmup() + } catch (e: Exception) { + Timber.e(e, "Error preparing inference engine") + } + Timber.d("prepare inference engine") + } + + private suspend fun setup() { + val nanoStatus = generativeModel.checkStatus() + + if (nanoStatus == FeatureStatus.UNAVAILABLE) { + Timber.d("Nano not available on device") + return + } + + if (nanoStatus == FeatureStatus.DOWNLOADABLE && + remoteConfigDataSource.enabledGeminiNanoModelVersions() + .contains(generativeModel.getBaseModelName()) + ) { + generativeModel.download().collect { status -> + when (status) { + is DownloadStatus.DownloadStarted -> + Timber.d("starting download for Gemini Nano") + + is DownloadStatus.DownloadProgress -> + Timber.d("Nano ${status.totalBytesDownloaded} bytes downloaded") + + DownloadStatus.DownloadCompleted -> { + Timber.d("Gemini Nano download complete") + modelDownloaded = true + } + + is DownloadStatus.DownloadFailed -> { + Timber.e("Nano download failed ${status.e.message}") + } + } + } + } + + if (nanoStatus == FeatureStatus.AVAILABLE) { + modelDownloaded = true + } + } +} diff --git a/data/src/main/java/com/android/developers/androidify/data/GeminiNanoGenerationDataSource.kt b/data/src/main/java/com/android/developers/androidify/data/GeminiNanoGenerationDataSource.kt index cc5e880c..06702afc 100644 --- a/data/src/main/java/com/android/developers/androidify/data/GeminiNanoGenerationDataSource.kt +++ b/data/src/main/java/com/android/developers/androidify/data/GeminiNanoGenerationDataSource.kt @@ -15,36 +15,88 @@ */ package com.android.developers.androidify.data +import android.graphics.Bitmap +import androidx.annotation.VisibleForTesting import com.android.developers.androidify.RemoteConfigDataSource +import com.android.developers.androidify.model.ImageValidationError +import com.android.developers.androidify.model.ValidatedDescription +import com.android.developers.androidify.model.ValidatedImage +import com.google.mlkit.genai.prompt.ImagePart +import com.google.mlkit.genai.prompt.TextPart +import com.google.mlkit.genai.prompt.generateContentRequest import timber.log.Timber import javax.inject.Inject import javax.inject.Singleton interface GeminiNanoGenerationDataSource { - suspend fun initialize() suspend fun generatePrompt(prompt: String): String? + suspend fun validateImageHasEnoughInformation(image: Bitmap): ValidatedImage? + suspend fun generateDescriptivePromptFromImage(image: Bitmap): ValidatedDescription? } @Singleton internal class GeminiNanoGenerationDataSourceImpl @Inject constructor( private val remoteConfigDataSource: RemoteConfigDataSource, - private val downloader: GeminiNanoDownloader) : + private val downloader: GeminiNanoDownloaderImpl, +) : GeminiNanoGenerationDataSource { - override suspend fun initialize() { - if (remoteConfigDataSource.useGeminiNano()) { - downloader.downloadModel() - } - } - /** * Generate a prompt to create an Android bot using Gemini Nano. * If Gemini Nano is not available, return null. */ override suspend fun generatePrompt(prompt: String): String? { if (!downloader.isModelDownloaded()) return null - val response = downloader.generativeModel?.generateContent(prompt) - Timber.d("generatePrompt: ${response?.text}") - return response?.text + val response = downloader.generativeModel.generateContent( + generateContentRequest(TextPart(prompt)) + { + temperature = 0.2f + topK = 16 + candidateCount = 1 + maxOutputTokens = 256 + }, + ) + Timber.d("generatePrompt: ${response.candidates[0].text}") + return response.candidates[0].text + } + + override suspend fun validateImageHasEnoughInformation(image: Bitmap): ValidatedImage? { + if (!downloader.isModelDownloaded()) return null + + val response = downloader.generativeModel.generateContent( + generateContentRequest( + ImagePart(image), + TextPart(remoteConfigDataSource.promptImageValidationNano()), + ) { + temperature = 0.0f + maxOutputTokens = 20 + }, + ).candidates[0].text + + // If the model returns null as the validation error, there was no error found when + // validating the image. + val successValue = response == "null" + return ValidatedImage( + successValue, + ImageValidationError.entries.find { it.description == response }, + ) + } + + override suspend fun generateDescriptivePromptFromImage(image: Bitmap): ValidatedDescription? { + if (!downloader.isModelDownloaded()) return null + + val generatedImageDescription = downloader.generativeModel.generateContent( + generateContentRequest( + ImagePart(image), + TextPart(remoteConfigDataSource.promptImageDescriptionNano()), + ) { + temperature = 0.2f + }, + ) + + return ValidatedDescription( + true, + generatedImageDescription.candidates[0].text, + ) } } diff --git a/data/src/main/java/com/android/developers/androidify/data/ImageGenerationRepository.kt b/data/src/main/java/com/android/developers/androidify/data/ImageGenerationRepository.kt index 7a40a0bc..8365d93f 100644 --- a/data/src/main/java/com/android/developers/androidify/data/ImageGenerationRepository.kt +++ b/data/src/main/java/com/android/developers/androidify/data/ImageGenerationRepository.kt @@ -31,8 +31,8 @@ import javax.inject.Inject import javax.inject.Singleton interface ImageGenerationRepository { - suspend fun initialize() suspend fun generateFromDescription(description: String, skinTone: String): Bitmap + suspend fun getDescriptionFromImage(file: File): ValidatedDescription suspend fun generateFromImage(file: File, skinTone: String): Bitmap suspend fun saveImage(imageBitmap: Bitmap): Uri suspend fun saveImageToExternalStorage(imageBitmap: Bitmap): Uri @@ -52,20 +52,21 @@ internal class ImageGenerationRepositoryImpl @Inject constructor( private val localSegmentationDataSource: LocalSegmentationDataSource, ) : ImageGenerationRepository { - override suspend fun initialize() { - Timber.d("Initializing") - geminiNanoDataSource.initialize() - } - private suspend fun validatePromptHasEnoughInformation(inputPrompt: String): ValidatedDescription = firebaseAiDataSource.validatePromptHasEnoughInformation(inputPrompt) - private suspend fun validateImageIsFullPerson(file: File): ValidatedImage = - firebaseAiDataSource.validateImageHasEnoughInformation( - BitmapFactory.decodeFile( - file.absolutePath, - ), - ) + private suspend fun validateImageIsFullPerson(file: File): ValidatedImage { + val bitmap = BitmapFactory.decodeFile(file.absolutePath) + val validateImageResult = if (remoteConfigDataSource.useGeminiNano()) { + geminiNanoDataSource.validateImageHasEnoughInformation(bitmap) + } else { + null + } + + // If validating image with Nano is not successful, fallback to using Firebase AI + return validateImageResult + ?: firebaseAiDataSource.validateImageHasEnoughInformation(bitmap) + } @Throws(InsufficientInformationException::class) override suspend fun generateFromDescription( @@ -84,19 +85,40 @@ internal class ImageGenerationRepositoryImpl @Inject constructor( ) } - override suspend fun generateFromImage( - file: File, - skinTone: String, - ): Bitmap { + override suspend fun getDescriptionFromImage(file: File): ValidatedDescription { checkInternetConnection() val validatedImage = validateImageIsFullPerson(file) if (!validatedImage.success) { throw ImageValidationException(validatedImage.errorMessage?.toImageValidationError()) } - val imageDescription = firebaseAiDataSource.generateDescriptivePromptFromImage( - BitmapFactory.decodeFile(file.absolutePath), - ) + var imageDescription = if (remoteConfigDataSource.useGeminiNano()) { + geminiNanoDataSource.generateDescriptivePromptFromImage( + BitmapFactory.decodeFile(file.absolutePath), + ) + } else { + null + } + + Timber.d("nano generated image desc ${imageDescription?.userDescription}") + + // If we're not getting a valid result from Nano, try with Firebase AI Logic + if (imageDescription?.success != true) { + Timber.d("generating image description with Firebase AI Logic") + imageDescription = firebaseAiDataSource.generateDescriptivePromptFromImage( + BitmapFactory.decodeFile(file.absolutePath), + ) + } + + return imageDescription + } + + override suspend fun generateFromImage( + file: File, + skinTone: String, + ): Bitmap { + val imageDescription = getDescriptionFromImage(file) + if (!imageDescription.success || imageDescription.userDescription == null) { throw ImageDescriptionFailedGenerationException() } @@ -113,7 +135,8 @@ internal class ImageGenerationRepositoryImpl @Inject constructor( } override suspend fun saveImageToExternalStorage(imageBitmap: Bitmap): Uri { - val cacheFile = localFileProvider.createCacheFile("androidify_image_result_${UUID.randomUUID()}.png") + val cacheFile = + localFileProvider.createCacheFile("androidify_image_result_${UUID.randomUUID()}.png") localFileProvider.saveBitmapToFile(imageBitmap, cacheFile) return localFileProvider.saveToSharedStorage(cacheFile, cacheFile.name, "image/png") } @@ -134,7 +157,7 @@ internal class ImageGenerationRepositoryImpl @Inject constructor( override suspend fun addBackgroundToBot(image: Bitmap, backgroundPrompt: String): Bitmap { val backgroundBotInstructions = remoteConfigDataSource.getBotBackgroundInstructionPrompt() + - "\"" + backgroundPrompt + "\"" + "\"" + backgroundPrompt + "\"" return firebaseAiDataSource.generateImageWithEdit(image, backgroundBotInstructions) } diff --git a/data/src/main/java/com/android/developers/androidify/data/TextGenerationRepository.kt b/data/src/main/java/com/android/developers/androidify/data/TextGenerationRepository.kt index dfca3db0..3f2eb0a4 100644 --- a/data/src/main/java/com/android/developers/androidify/data/TextGenerationRepository.kt +++ b/data/src/main/java/com/android/developers/androidify/data/TextGenerationRepository.kt @@ -21,7 +21,6 @@ import javax.inject.Inject import javax.inject.Singleton interface TextGenerationRepository { - suspend fun initialize() suspend fun getNextGeneratedBotPrompt(): String? } @@ -35,10 +34,6 @@ class TextGenerationRepositoryImpl @Inject constructor( private var currentPrompts: List? = null private var currentPromptIndex = 0 - override suspend fun initialize() { - geminiNanoDataSource.initialize() - } - override suspend fun getNextGeneratedBotPrompt(): String? { val prompts = currentPrompts if (prompts.isNullOrEmpty() || currentPromptIndex >= prompts.size) { diff --git a/data/src/main/java/com/android/developers/androidify/startup/GeminiNanoDownloaderInitializer.kt b/data/src/main/java/com/android/developers/androidify/startup/GeminiNanoDownloaderInitializer.kt new file mode 100644 index 00000000..a686f722 --- /dev/null +++ b/data/src/main/java/com/android/developers/androidify/startup/GeminiNanoDownloaderInitializer.kt @@ -0,0 +1,58 @@ +/* + * Copyright 2025 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.android.developers.androidify.startup + +import android.annotation.SuppressLint +import android.content.Context +import androidx.startup.Initializer +import com.android.developers.androidify.RemoteConfigDataSource +import com.android.developers.androidify.data.GeminiNanoDownloaderImpl +import dagger.hilt.EntryPoint +import dagger.hilt.InstallIn +import dagger.hilt.android.EntryPointAccessors +import dagger.hilt.components.SingletonComponent +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.launch + +@SuppressLint("EnsureInitializerMetadata") // Registered in :app module +class GeminiNanoDownloaderInitializer : Initializer { + + @EntryPoint + @InstallIn(SingletonComponent::class) + interface GeminiNanoDownloaderInitializerEntryPoint { + fun geminiNanoDownloader(): GeminiNanoDownloaderImpl + fun remoteConfigDataSource(): RemoteConfigDataSource + } + + override fun create(context: Context) { + val hiltEntryPoint = EntryPointAccessors.fromApplication( + context, + GeminiNanoDownloaderInitializerEntryPoint::class.java + ) + val remoteConfigDataSource = hiltEntryPoint.remoteConfigDataSource() + if (remoteConfigDataSource.useGeminiNano()) { + val geminiNanoDownloader = hiltEntryPoint.geminiNanoDownloader() + CoroutineScope(Dispatchers.IO).launch { + geminiNanoDownloader.downloadModel() + } + } + } + + override fun dependencies(): List>> { + return listOf(FirebaseRemoteConfigInitializer::class.java) + } +} diff --git a/data/src/test/java/com/android/developers/androidify/data/ImageGenerationRepositoryTest.kt b/data/src/test/java/com/android/developers/androidify/data/ImageGenerationRepositoryTest.kt new file mode 100644 index 00000000..e3a8cba1 --- /dev/null +++ b/data/src/test/java/com/android/developers/androidify/data/ImageGenerationRepositoryTest.kt @@ -0,0 +1,65 @@ +package com.android.developers.androidify.data + +import com.android.developers.testing.data.TestFileProvider +import com.android.developers.testing.data.TestGeminiNanoDownloader +import com.android.developers.testing.data.TestGeminiNanoGenerationDataSource +import com.android.developers.testing.data.TestInternetConnectivityManager +import com.android.developers.testing.network.TestFirebaseAiDataSource +import com.android.developers.testing.network.TestLocalSegmentationDataSource +import com.android.developers.testing.network.TestRemoteConfigDataSource +import kotlinx.coroutines.test.runTest +import org.junit.Test +import org.junit.runner.RunWith +import org.robolectric.RobolectricTestRunner +import java.io.File +import kotlin.test.assertEquals + +@RunWith(RobolectricTestRunner::class) +class ImageGenerationRepositoryTest { + private lateinit var repository: ImageGenerationRepository + + @Test + fun getDescriptionFromImage_NanoDisabled() = runTest { + repository = ImageGenerationRepositoryImpl( + TestFileProvider(), + TestInternetConnectivityManager(true), + TestGeminiNanoGenerationDataSource("", TestGeminiNanoDownloader(false)), + TestFirebaseAiDataSource(listOf("")), + TestRemoteConfigDataSource(false), + TestLocalSegmentationDataSource() + ) + + val description = repository.getDescriptionFromImage(File("")) + assertEquals(description.userDescription, "Firebase description") + } + + @Test + fun getDescriptionFromImage_NanoEnabledAndDownloaded() = runTest { + repository = ImageGenerationRepositoryImpl( + TestFileProvider(), + TestInternetConnectivityManager(true), + TestGeminiNanoGenerationDataSource("", TestGeminiNanoDownloader(true)), + TestFirebaseAiDataSource(listOf("")), + TestRemoteConfigDataSource(true), + TestLocalSegmentationDataSource() + ) + + val description = repository.getDescriptionFromImage(File("")) + assertEquals(description.userDescription, "Nano description") + } + + @Test + fun getDescriptionFromImage_NanoEnabledButNotDownloaded() = runTest { + repository = ImageGenerationRepositoryImpl( + TestFileProvider(), + TestInternetConnectivityManager(true), + TestGeminiNanoGenerationDataSource("", TestGeminiNanoDownloader(false)), + TestFirebaseAiDataSource(listOf("")), + TestRemoteConfigDataSource(true), + TestLocalSegmentationDataSource() + ) + + val description = repository.getDescriptionFromImage(File("")) + assertEquals(description.userDescription, "Firebase description") + } +} diff --git a/feature/creation/src/main/AndroidManifest.xml b/feature/creation/src/main/AndroidManifest.xml index 22cc6771..3cb17dfb 100644 --- a/feature/creation/src/main/AndroidManifest.xml +++ b/feature/creation/src/main/AndroidManifest.xml @@ -17,5 +17,4 @@ - \ No newline at end of file diff --git a/feature/creation/src/main/java/com/android/developers/androidify/creation/CreationViewModel.kt b/feature/creation/src/main/java/com/android/developers/androidify/creation/CreationViewModel.kt index 537ad96a..5f509934 100644 --- a/feature/creation/src/main/java/com/android/developers/androidify/creation/CreationViewModel.kt +++ b/feature/creation/src/main/java/com/android/developers/androidify/creation/CreationViewModel.kt @@ -78,10 +78,6 @@ class CreationViewModel @AssistedInject constructor( init { onImageSelected(originalImageUrl) - viewModelScope.launch { - imageGenerationRepository.initialize() - textGenerationRepository.initialize() - } } fun onImageSelected(uri: Uri?) { diff --git a/feature/home/build.gradle.kts b/feature/home/build.gradle.kts index 68410264..bdfaa1fe 100644 --- a/feature/home/build.gradle.kts +++ b/feature/home/build.gradle.kts @@ -70,10 +70,6 @@ dependencies { } ksp(libs.hilt.compiler) - implementation(libs.ai.edge) { - exclude(group = "com.google.guava") - } - implementation(libs.androidx.xr.compose) implementation(projects.core.xr) diff --git a/feature/home/src/main/AndroidManifest.xml b/feature/home/src/main/AndroidManifest.xml index 22cc6771..3cb17dfb 100644 --- a/feature/home/src/main/AndroidManifest.xml +++ b/feature/home/src/main/AndroidManifest.xml @@ -17,5 +17,4 @@ - \ No newline at end of file diff --git a/feature/results/src/test/kotlin/com/android/developers/androidify/data/TextGenerationRepositoryImplTest.kt b/feature/results/src/test/kotlin/com/android/developers/androidify/data/TextGenerationRepositoryImplTest.kt index 6921d43c..d7f4e4c2 100644 --- a/feature/results/src/test/kotlin/com/android/developers/androidify/data/TextGenerationRepositoryImplTest.kt +++ b/feature/results/src/test/kotlin/com/android/developers/androidify/data/TextGenerationRepositoryImplTest.kt @@ -15,6 +15,7 @@ */ package com.android.developers.androidify.data +import com.android.developers.testing.data.TestGeminiNanoDownloader import com.android.developers.testing.data.TestGeminiNanoGenerationDataSource import com.android.developers.testing.network.TestFirebaseAiDataSource import com.android.developers.testing.network.TestRemoteConfigDataSource @@ -29,7 +30,8 @@ class TextGenerationRepositoryImplTest { fun `Initial prompt generation`() = runTest { val output = "prompt" val remoteConfigDataSource = TestRemoteConfigDataSource(true) - val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(output) + val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(output, + TestGeminiNanoDownloader(false)) val firebaseAiDataSource = TestFirebaseAiDataSource(emptyList()) val repository = TextGenerationRepositoryImpl( @@ -48,7 +50,8 @@ class TextGenerationRepositoryImplTest { val output = "prompt" val prompts = listOf("prompt1", "prompt2") val remoteConfigDataSource = TestRemoteConfigDataSource(false) - val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(output) + val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(output, + TestGeminiNanoDownloader(false)) val firebaseAiDataSource = TestFirebaseAiDataSource(prompts) val repository = TextGenerationRepositoryImpl( @@ -69,7 +72,8 @@ class TextGenerationRepositoryImplTest { // list is returned. val prompts = listOf("prompt1", "prompt2") val remoteConfigDataSource = TestRemoteConfigDataSource(true) - val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(null) + val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(null, + TestGeminiNanoDownloader(false)) val firebaseAiDataSource = TestFirebaseAiDataSource(prompts) val repository = TextGenerationRepositoryImpl( @@ -92,7 +96,8 @@ class TextGenerationRepositoryImplTest { // prompt in the list val prompts = listOf("prompt1", "prompt2", "prompt3") val remoteConfigDataSource = TestRemoteConfigDataSource(true) - val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(null) + val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(null, + TestGeminiNanoDownloader(false)) val firebaseAiDataSource = TestFirebaseAiDataSource(prompts) val repository = TextGenerationRepositoryImpl( @@ -115,7 +120,8 @@ class TextGenerationRepositoryImplTest { // result, the function falls back to `firebaseAiDataSource.generatePrompt()`. val prompts = listOf("prompt1", "prompt2", "prompt3") val remoteConfigDataSource = TestRemoteConfigDataSource(true) - val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(null) + val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(null, + TestGeminiNanoDownloader(true)) val firebaseAiDataSource = TestFirebaseAiDataSource(prompts) val repository = TextGenerationRepositoryImpl( @@ -134,7 +140,8 @@ class TextGenerationRepositoryImplTest { // `firebaseAiDataSource` return empty or null results, // `generatePrompts()` returns null. val remoteConfigDataSource = TestRemoteConfigDataSource(true) - val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(null) + val geminiNanoDataSource = TestGeminiNanoGenerationDataSource(null, + TestGeminiNanoDownloader(false)) val firebaseAiDataSource = TestFirebaseAiDataSource(emptyList()) val repository = TextGenerationRepositoryImpl( diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 3bb97b38..d917c88d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -35,6 +35,7 @@ crashlytics = "3.0.6" datastore = "1.1.7" espressoCore = "3.7.0" firebaseBom = "34.4.0" +genaiPrompt = "1.0.0-alpha1" googleServices = "4.4.4" googleOss = "17.3.0" googleOssPlugin = "0.10.9" @@ -78,7 +79,6 @@ wearCompose = "1.5.0" wearComposeTooling = "1.4.1" wearRemoteInteractions = "1.1.0" window = "1.5.0" -aiEdge = "0.0.1-exp02" lifecycleProcess = "2.9.4" mlkitCommon = "18.11.0" mlkitSegmentation = "16.0.0-beta1" @@ -148,6 +148,7 @@ firebase-bom = { module = "com.google.firebase:firebase-bom", version.ref = "fir firebase-config = { module = "com.google.firebase:firebase-config" } firebase-crashlytics = { module = "com.google.firebase:firebase-crashlytics" } firebase-ai = { module = "com.google.firebase:firebase-ai" } +genai-prompt = { module = "com.google.mlkit:genai-prompt", version.ref = "genaiPrompt" } guava = { module = "com.google.guava:guava", version.ref = "guava" } hilt-android = { group = "com.google.dagger", name = "hilt-android", version.ref = "hiltAndroid" } hilt-android-testing = { group = "com.google.dagger", name = "hilt-android-testing", version.ref = "hiltAndroid" } @@ -173,7 +174,6 @@ ui-tooling = { group = "androidx.compose.ui", name = "ui-tooling", version.ref = androidx-uiautomator = { group = "androidx.test.uiautomator", name = "uiautomator", version.ref = "uiautomator" } androidx-benchmark-macro-junit4 = { group = "androidx.benchmark", name = "benchmark-macro-junit4", version.ref = "benchmarkMacroJunit4" } androidx-profileinstaller = { group = "androidx.profileinstaller", name = "profileinstaller", version.ref = "profileinstaller" } -ai-edge = { group = "com.google.ai.edge.aicore", name = "aicore", version.ref = "aiEdge" } google-oss-licenses = { group = "com.google.android.gms", name = "play-services-oss-licenses", version.ref = "googleOss" } google-oss-licenses-plugin = { group = "com.google.android.gms", name = "oss-licenses-plugin", version.ref = "googleOssPlugin" } androidx-lifecycle-process = { group = "androidx.lifecycle", name = "lifecycle-process", version.ref = "lifecycleProcess" }