Skip to content

Commit

Permalink
Commit chatbot implementation with Gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
tezan committed Jun 18, 2024
1 parent 5bffebb commit 2a2884b
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 35 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ jobs:
- name: Setup Gradle
uses: gradle/gradle-build-action@v3

- name: Create dummy local.properties file
run: echo "apiKey=CHANGE_ME" > local.properties

- name: Check spotless
run: ./gradlew spotlessCheck

Expand Down
4 changes: 4 additions & 0 deletions app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ plugins {
alias(libs.plugins.hilt)
alias(libs.plugins.kotlinAndroid)
alias(libs.plugins.ksp)
alias(libs.plugins.secrets)
}

kotlin {
Expand Down Expand Up @@ -144,4 +145,7 @@ dependencies {

implementation(libs.coil)
implementation(libs.coil.compose)

implementation(libs.generativeai)
implementation(libs.datastore)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
package com.google.android.samples.socialite

import android.content.Intent
import android.os.Build
import android.os.Bundle
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.activity.enableEdgeToEdge
import androidx.core.content.pm.ShortcutManagerCompat
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
import androidx.glance.appwidget.updateAll
Expand All @@ -33,6 +35,10 @@ import kotlinx.coroutines.runBlocking
class MainActivity : ComponentActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
installSplashScreen()
enableEdgeToEdge()
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
window.isNavigationBarContrastEnforced = false
}
super.onCreate(savedInstanceState)
runBlocking { SociaLiteAppWidget().updateAll(this@MainActivity) }
setContent {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,5 @@ fun SupportSQLiteDatabase.populateInitialData() {
put("timestamp", now + chatIds[index])
},
)

// Add second message
insert(
table = "Message",
conflictAlgorithm = SQLiteDatabase.CONFLICT_NONE,
values = ContentValues().apply {
put("id", (index * 2).toLong() + 1L)
put("chatId", chatIds[index])
put("senderId", contact.id)
put("text", "I will reply in 5 seconds")
put("timestamp", now + chatIds[index])
},
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,34 @@

package com.google.android.samples.socialite.repository

import android.content.Context
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.booleanPreferencesKey
import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.preferencesDataStore
import com.google.ai.client.generativeai.GenerativeModel
import com.google.ai.client.generativeai.type.Content
import com.google.ai.client.generativeai.type.GenerateContentResponse
import com.google.ai.client.generativeai.type.content
import com.google.android.samples.socialite.BuildConfig
import com.google.android.samples.socialite.data.ChatDao
import com.google.android.samples.socialite.data.ContactDao
import com.google.android.samples.socialite.data.MessageDao
import com.google.android.samples.socialite.di.AppCoroutineScope
import com.google.android.samples.socialite.model.ChatDetail
import com.google.android.samples.socialite.model.Message
import com.google.android.samples.socialite.widget.model.WidgetModelRepository
import dagger.hilt.android.qualifiers.ApplicationContext
import javax.inject.Inject
import javax.inject.Singleton
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filterNotNull
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch

@Singleton
Expand All @@ -40,7 +55,15 @@ class ChatRepository @Inject internal constructor(
private val widgetModelRepository: WidgetModelRepository,
@AppCoroutineScope
private val coroutineScope: CoroutineScope,
@ApplicationContext private val appContext: Context,
) {
private val Context.dataStore: DataStore<Preferences> by preferencesDataStore(name = "settings")
private val enableChatbotKey = booleanPreferencesKey("enable_chatbot")
val isBotEnabled = appContext.dataStore.data.map {
preference ->
preference[enableChatbotKey] ?: false
}

private var currentChat: Long = 0L

init {
Expand All @@ -66,40 +89,123 @@ class ChatRepository @Inject internal constructor(
mediaMimeType: String?,
) {
val detail = chatDao.loadDetailById(chatId) ?: return
// Save the message to the database
saveMessageAndNotify(chatId, text, 0L, mediaUri, mediaMimeType, detail, PushReason.OutgoingMessage)

// Create a generative AI Model to interact with the Gemini API.
val generativeModel = GenerativeModel(
modelName = "gemini-1.5-pro-latest",
// Set your Gemini API in as `apiKey` in the local.properties file
apiKey = BuildConfig.apiKey,
// Set a system instruction to set the behavior of the model.
systemInstruction = content {
text("Please respond to this chat conversation like a friendly ${detail.firstContact.replyModel}.")
},
)

coroutineScope.launch {
if (isBotEnabled.firstOrNull() == true) {
// Get the previous messages and them generative model chat
val pastMessages = getMessageHistory(chatId)
val chat = generativeModel.startChat(
history = pastMessages,
)

// Send a message prompt to the model to generate a response
var generateContentResult = try {
chat.sendMessage(text)
} catch (e: Exception) {
e.printStackTrace()
null
}
val response = generateContentResult?.text ?: "GenAI failed :(".trim()

// Save the generated response to the database
saveMessageAndNotify(chatId, response, detail.firstContact.id, null, null, detail, PushReason.IncomingMessage)
} else {
// Simulate a response from the peer.
// The code here is just for demonstration purpose in this sample.
// Real apps will use their server backend and Firebase Cloud Messaging to deliver messages.

// The person is typing...
delay(5000L)
// Receive a reply.
val message = detail.firstContact.reply(text).apply { this.chatId = chatId }.build()
saveMessageAndNotify(message.chatId, message.text, detail.firstContact.id, message.mediaUri, message.mediaMimeType, detail, PushReason.IncomingMessage)
}

// Show notification if the chat is not on the foreground.
if (chatId != currentChat) {
notificationHelper.showNotification(
detail.firstContact,
messageDao.loadAll(chatId),
false,
)
}

widgetModelRepository.updateUnreadMessagesForContact(contactId = detail.firstContact.id, unread = true)
}
}

private suspend fun saveMessageAndNotify(
chatId: Long,
text: String,
senderId: Long,
mediaUri: String?,
mediaMimeType: String?,
detail: ChatDetail,
pushReason: PushReason,
) {
messageDao.insert(
Message(
id = 0L,
chatId = chatId,
// User
senderId = 0L,
senderId = senderId,
text = text,
mediaUri = mediaUri,
mediaMimeType = mediaMimeType,
timestamp = System.currentTimeMillis(),
),
)
notificationHelper.pushShortcut(detail.firstContact, PushReason.OutgoingMessage)
// Simulate a response from the peer.
// The code here is just for demonstration purpose in this sample.
// Real apps will use their server backend and Firebase Cloud Messaging to deliver messages.
coroutineScope.launch {
// The person is typing...
delay(5000L)
// Receive a reply.
messageDao.insert(
detail.firstContact.reply(text).apply { this.chatId = chatId }.build(),
)
notificationHelper.pushShortcut(detail.firstContact, PushReason.IncomingMessage)
// Show notification if the chat is not on the foreground.
if (chatId != currentChat) {
notificationHelper.showNotification(
detail.firstContact,
messageDao.loadAll(chatId),
false,
)
}

private suspend fun getMessageHistory(chatId: Long): List<Content> {
val pastMessages = findMessages(chatId).first().filter { message ->
message.text.isNotEmpty()
}.sortedBy { message ->
message.timestamp
}.fold(initial = mutableListOf<Message>()) { acc, message ->
if (acc.isEmpty()) {
acc.add(message)
} else {
if (acc.last().isIncoming == message.isIncoming) {
val lastMessage = acc.removeLast()
val combinedMessage = Message(
id = lastMessage.id,
chatId = chatId,
// User
senderId = lastMessage.senderId,
text = lastMessage.text + " " + message.text,
mediaUri = null,
mediaMimeType = null,
timestamp = System.currentTimeMillis(),
)
acc.add(combinedMessage)
} else {
acc.add(message)
}
}
widgetModelRepository.updateUnreadMessagesForContact(contactId = detail.firstContact.id, unread = true)
return@fold acc
}

val lastUserMessage = pastMessages.removeLast()

val pastContents = pastMessages.mapNotNull { message: Message ->
val role = if (message.isIncoming) "model" else "user"
return@mapNotNull content(role = role) { text(message.text) }
}
return pastContents
}

suspend fun clearMessages() {
Expand Down Expand Up @@ -143,4 +249,12 @@ class ChatRepository @Inject internal constructor(
val detail = chatDao.loadDetailById(chatId) ?: return false
return notificationHelper.canBubble(detail.firstContact)
}

fun toggleChatbotSetting() {
coroutineScope.launch {
appContext.dataStore.edit { preferences ->
preferences[enableChatbotKey] = (preferences[enableChatbotKey]?.not()) ?: false
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ import androidx.compose.material3.ButtonDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.collectAsState
import androidx.compose.ui.Modifier
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.viewmodel.compose.viewModel
import com.google.android.samples.socialite.R
import kotlinx.coroutines.flow.map

@Composable
fun Settings(
Expand All @@ -59,6 +60,29 @@ fun Settings(
Text(text = stringResource(R.string.clear_message_history))
}
}

val chatbotStatusResource = viewModel.isBotEnabledFlow.map {
if (it) {
R.string.ai_chatbot_setting_enabled
} else {
R.string.ai_chatbot_setting_disabled
}
}.collectAsState(initial = R.string.ai_chatbot_setting_enabled).value

Box(modifier = Modifier.padding(32.dp)) {
Button(
onClick = { viewModel.toggleChatbot() },
modifier = Modifier
.fillMaxWidth()
.heightIn(min = 56.dp),
colors = ButtonDefaults.buttonColors(
containerColor = MaterialTheme.colorScheme.primaryContainer,
contentColor = MaterialTheme.colorScheme.onPrimaryContainer,
),
) {
Text(text = "${stringResource(id = R.string.ai_chatbot_setting)}: ${stringResource(chatbotStatusResource)}")
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,12 @@ class SettingsViewModel @Inject constructor(
).show()
}
}

val isBotEnabledFlow = repository.isBotEnabled

fun toggleChatbot() {
viewModelScope.launch {
repository.toggleChatbotSetting()
}
}
}
3 changes: 3 additions & 0 deletions app/src/main/res/values/strings.xml
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,8 @@
<string name="ff_title">Fast forward</string>
<string name="rw_title">Rewind</string>
<string name="favorite_contact_widget_name">Favorite Contact</string>
<string name="ai_chatbot_setting">AI Chatbot</string>
<string name="ai_chatbot_setting_enabled">enabled</string>
<string name="ai_chatbot_setting_disabled">disabled</string>

</resources>
7 changes: 7 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ uiautomator = "2.3.0"
window = "1.2.0"
material3-adaptive-navigation-suite = "1.0.0-alpha05"
glance = "1.1.0-alpha01"
secrets = "2.0.1"
generativeai = "0.7.0"
datastore = "1.0.0"


[libraries]
accompanist-painter = { group = "com.google.accompanist", name = "accompanist-drawablepainter", version.ref = "accompanist" }
Expand Down Expand Up @@ -105,6 +109,8 @@ turbine = { group = "app.cash.turbine", name = "turbine", version.ref = "turbine
uiautomator = { group = "androidx.test.uiautomator", name = "uiautomator", version.ref = "uiautomator" }
window = { group = "androidx.window", name = "window", version.ref = "window" }
ktlint = "com.pinterest.ktlint:ktlint-cli:1.1.1" # Used in build.gradle.kts
generativeai = { group = "com.google.ai.client.generativeai", name = "generativeai", version.ref = "generativeai"}
datastore = { group = "androidx.datastore", name = "datastore-preferences", version.ref = "datastore"}

[plugins]
androidApplication = { id = "com.android.application", version.ref = "agp" }
Expand All @@ -114,3 +120,4 @@ hilt = { id = "com.google.dagger.hilt.android", version.ref = "hilt" }
kotlinAndroid = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" }
spotless = { id = "com.diffplug.spotless", version.ref = "spotless" }
secrets = { id = "com.google.android.libraries.mapsplatform.secrets-gradle-plugin", version.ref = "secrets" }

0 comments on commit 2a2884b

Please sign in to comment.