Skip to content

Commit

Permalink
vector database rag context logic + new prompts (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
josancamon19 committed May 8, 2024
2 parents 05174fc + 8f46859 commit b5bb804
Show file tree
Hide file tree
Showing 16 changed files with 409 additions and 36 deletions.
22 changes: 10 additions & 12 deletions apps/AppWithWearable/Tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,25 @@
structuring but as context, not as part of the structure, so that if there's some reference to a
person, and then you use a pronoun, the LLM understands what you are referring to.
- [ ] Migrate MemoryRecord from SharedPreferences to sqlite
- [ ] Implement [similarity search](https://www.pinecone.io/learn/vector-similarity/) locally
- [ ] Use from the AppStandalone `_ragContext` function as a baseline for creating the query
- [X] Implement [similarity search](https://www.pinecone.io/learn/vector-similarity/) locally
- [X] Use from the AppStandalone `_ragContext` function as a baseline for creating the query
embedding.
- [ ] When a memory is created, compute the vector embedding and store it locally.
- [ ] When the user sends a question in the chat, extract from the AppStandalone
- [X] When a memory is created, compute the vector embedding and store it locally.
- [X] When the user sends a question in the chat, extract from the AppStandalone
the `function_calling` that determines if the message requires context, if that's the case,
retrieve the top 10 most similar vectors ~~ For an initial version we can read all memories
from sqlite or SharedPreferences, and compute the formula between the query and each vector.
- [ ] Use that as context, and ask to the LLM. Retrieve the prompt from the AppStandalone.
- [X] -----
- [ ] Another option is to use one of the vector db libraries available for
dart https://github.com/FastCodeAI/DVDB or https://pub.dev/packages/chromadb
- [X] Use that as context, and ask to the LLM. Retrieve the prompt from the AppStandalone.
- [ ] Improve function call way of parsing the text sent to the RAG, GPT should format the input
better for RAG to retrieve better context.
- [ ] Settings Deepgram + openAI key are forced to be set
- [ ] In case an API key fails, either Deepgram WebSocket connection fails, or GPT requests, let
the
user know the error message, either has no more credits, api key is invalid, etc.
the user know the error message, either has no more credits, api key is invalid, etc.
- [ ] Improve connected device page UI, including transcription text, and when memory creates
after
30 seconds, let the user know
- [ ] Structure the memory asking JSON output `{"title", "summary"}`, in that way we can have
better
parsed data.
better parsed data.
- [x] Test/Implement [speaker diarization](https://developers.deepgram.com/docs/diarization) to
recognize multiple speakers in transcription, use that for better context when creating the
structured memory.
Expand All @@ -44,6 +41,7 @@
conversation, also, remove Speaker $i in transcript.
- [ ] Allow users who don't have a GCP bucket to store their recordings locally.
- [ ] Improve recordings player.

---

- [x] Multilanguage option, implement settings selector, and use that for the deepgram websocket
Expand Down
6 changes: 5 additions & 1 deletion apps/AppWithWearable/lib/actions/actions.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import 'package:flutter/material.dart';
import 'package:friend_private/backend/storage/vector_db.dart';
import 'package:friend_private/backend/storage/memories.dart';
import 'package:uuid/uuid.dart';
import '/backend/api_requests/api_calls.dart';
Expand Down Expand Up @@ -43,8 +44,11 @@ void changeAppStateMemoryCreating() {

// Finalize memory record after processing feedback
Future<void> finalizeMemoryRecord(String rawMemory, String structuredMemory, String? audioFilePath) async {
await createMemoryRecord(rawMemory, structuredMemory, audioFilePath);
MemoryRecord createdMemory = await createMemoryRecord(rawMemory, structuredMemory, audioFilePath);
changeAppStateMemoryCreating();
List<double> vector = await getEmbeddingsFromInput(structuredMemory);
storeMemoryVector(createdMemory, vector);
// storeMemoryVector
}

// Create memory record
Expand Down
72 changes: 72 additions & 0 deletions apps/AppWithWearable/lib/backend/api_requests/api_calls.dart
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,75 @@ String qaStreamedFullMemories(List<MemoryRecord> memories, List<dynamic> chatHis
});
return body;
}

// ------

Future<String?> determineRequiresContext(String lastMessage, List<dynamic> chatHistory) async {
var tools = [
{
"type": "function",
"function": {
"name": "retrieve_rag_context",
"description": "Retrieve pieces of user memories as context.",
"parameters": {
"type": "object",
"properties": {
"question": {
"type": "string",
"description": '''
Based on the current conversation, determine if the message is a question and if there's
context that needs to be retrieved from the user recorded audio memories in order to answer that question.
If that's the case, return the question better parsed so that retrieved pieces of context are better.
''',
},
},
},
},
}
];
String message = '''
Conversation:
${chatHistory.map((e) => '${e['role'].toString().toUpperCase()}: ${e['content']}').join('\n')}\n
USER:$lastMessage
'''
.replaceAll(' ', '');
debugPrint('determineRequiresContext message: $message');
var response = await gptApiCall(
model: 'gpt-4-turbo',
messages: [
{"role": "user", "content": message}
],
tools: tools);
if (response.toString().contains('retrieve_rag_context')) {
var args = jsonDecode(response[0]['function']['arguments']);
return args['question'];
}
return null;
}

String qaStreamedBody(String context, List<dynamic> chatHistory) {
var prompt = '''
You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
If the message doesn't require context, it will be empty, so answer the question casually.
Conversation History:
${chatHistory.map((e) => '${e['role'].toString().toUpperCase()}: ${e['content']}').join('\n')}
Context:
```
$context
```
Answer:
'''
.replaceAll(' ', '');
debugPrint(prompt);
var body = jsonEncode({
"model": "gpt-4-turbo",
"messages": [
{"role": "system", "content": prompt}
],
"stream": true,
});
return body;
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ Future streamApiResponse(
'Authorization': 'Bearer $apiKey',
};

// String body = qaStreamedBody(context, retrieveMostRecentMessages(FFAppState().chatHistory));
String body = qaStreamedFullMemories(FFAppState().memories, retrieveMostRecentMessages(FFAppState().chatHistory));
String body = qaStreamedBody(context, retrieveMostRecentMessages(FFAppState().chatHistory));
// String body = qaStreamedFullMemories(FFAppState().memories, retrieveMostRecentMessages(FFAppState().chatHistory));
var request = http.Request("POST", Uri.parse(url))
..headers.addAll(headers)
..body = body;
Expand Down
1 change: 1 addition & 0 deletions apps/AppWithWearable/lib/backend/storage/dvdb/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
https://github.com/FastCodeAI/DVDB
107 changes: 107 additions & 0 deletions apps/AppWithWearable/lib/backend/storage/dvdb/collection.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import 'dart:convert';
import 'dart:core';
import 'dart:io';
import 'dart:typed_data';

import 'package:friend_private/backend/storage/dvdb/document.dart';
import 'package:friend_private/backend/storage/dvdb/math.dart';
import 'package:friend_private/backend/storage/dvdb/search_result.dart';
import 'package:path/path.dart';
import 'package:path_provider/path_provider.dart';
import 'package:uuid/uuid.dart';

class Collection {
Collection(this.name);

final String name;
final Map<String, Document> documents = {};

void addDocument(String? id, String text, Float64List embedding, {Map<String, String>? metadata}) {
var uuid = Uuid();
final Document document = Document(
id: id ?? uuid.v1(),
text: text,
embedding: embedding,
metadata: metadata,
);

documents[document.id] = document;
_writeDocument(document);
}

void addDocuments(List<Document> docs) {
for (final Document doc in docs) {
documents[doc.id] = doc;
_writeDocument(doc);
}
}

void removeDocument(String id) {
if (documents.containsKey(id)) {
documents.remove(id);
_saveAllDocuments(); // Re-saving all documents after removal
}
}

List<SearchResult> search(Float64List query, {int numResults = 10, double? threshold}) {
final List<SearchResult> similarities = <SearchResult>[];
for (var document in documents.values) {
final double similarity = MathFunctions().cosineSimilarity(query, document.embedding);

if (threshold != null && similarity < threshold) {
continue;
}

similarities.add(SearchResult(id: document.id, text: document.text, score: similarity));
}

similarities.sort((SearchResult a, SearchResult b) => b.score.compareTo(a.score));
return similarities.take(numResults).toList();
}

Future<void> _writeDocument(Document document) async {
Directory documentsDirectory = await getApplicationDocumentsDirectory();
String path = join(documentsDirectory.path, '$name.json');
final File file = File(path);

var encodedDocument = json.encode(document.toJson());
List<int> bytes = utf8.encode('$encodedDocument\n');

file.writeAsBytesSync(bytes, mode: FileMode.append);
}

Future<void> _saveAllDocuments() async {
Directory documentsDirectory = await getApplicationDocumentsDirectory();
String path = join(documentsDirectory.path, '$name.json');
final File file = File(path);

file.writeAsStringSync(''); // Clearing the file
for (var document in documents.values) {
_writeDocument(document);
}
}

Future<void> load() async {
Directory documentsDirectory = await getApplicationDocumentsDirectory();
String path = join(documentsDirectory.path, '$name.json');
final File file = File(path);

if (!file.existsSync()) {
documents.clear();
return;
}

final lines = file.readAsLinesSync();

for (var line in lines) {
var decodedDocument = json.decode(line) as Map<String, dynamic>;
var document = Document.fromJson(decodedDocument);
documents[document.id] = document;
}
}

void clear() {
documents.clear();
_saveAllDocuments();
}
}
43 changes: 43 additions & 0 deletions apps/AppWithWearable/lib/backend/storage/dvdb/document.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import 'dart:math';
import 'dart:typed_data';
import 'package:uuid/uuid.dart';

class Document {
Document({String? id, required this.text, required this.embedding, Map<String, String>? metadata})
: id = id ?? _generateUuid(),
magnitude = _calculateMagnitude(embedding),
metadata = metadata ?? Map<String, String>();

final String id;
final String text;
final Float64List embedding;
final double magnitude;
final Map<String, String> metadata;

static String _generateUuid() {
return Uuid().v1();
}

static double _calculateMagnitude(Float64List embedding) {
return sqrt(embedding.fold(0, (num sum, double element) => sum + element * element));
}

Map<String, dynamic> toJson() {
return {
'id': id,
'text': text,
'embedding': embedding,
'magnitude': magnitude,
'metadata': metadata,
};
}

factory Document.fromJson(Map<String, dynamic> json) {
return Document(
id: json['id'],
text: json['text'],
embedding: Float64List.fromList(json['embedding'].cast<double>()),
metadata: Map<String, String>.from(json['metadata'])
);
}
}
39 changes: 39 additions & 0 deletions apps/AppWithWearable/lib/backend/storage/dvdb/dvdb_helper.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import 'package:friend_private/backend/storage/dvdb/collection.dart';

class DVDB {
DVDB._internal();

static final DVDB _shared = DVDB._internal();

factory DVDB() {
return _shared;
}

final Map<String, Collection> _collections = {};

Collection collection(String name) {
if (_collections.containsKey(name)) {
return _collections[name]!;
}

final Collection collection = Collection(name);
_collections[name] = collection;
collection.load();
return collection;
}

Collection? getCollection(String name) {
return _collections[name];
}

void releaseCollection(String name) {
_collections.remove(name);
}

void reset() {
for (final Collection collection in _collections.values) {
collection.clear();
}
_collections.clear();
}
}
22 changes: 22 additions & 0 deletions apps/AppWithWearable/lib/backend/storage/dvdb/errors.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
enum VectorDBError {
collectionAlreadyExists,
}

class CollectionError implements Exception {
CollectionError._(this.message);

final String message;

factory CollectionError.fileNotFound() {
return CollectionError._("File not found.");
}

factory CollectionError.loadFailed(String errorMessage) {
return CollectionError._("Load failed: $errorMessage");
}

@override
String toString() {
return message;
}
}

0 comments on commit b5bb804

Please sign in to comment.