Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to chat streaming #510

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion lib/providers/session.dart
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ class Session extends ChangeNotifier {

final stringStream = model.prompt(messages);

await chat.tail.streamIn(stringStream);
await for (var message in stringStream) {
chat.tail.content += message;
notifyListeners();
}

chat.tail.finalised = true;

notifyListeners();
}
Expand Down
60 changes: 40 additions & 20 deletions lib/ui/mobile/pages/character/character_customization_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,46 @@ class _CharacterCustomizationPageState extends State<CharacterCustomizationPage>
late List<TextEditingController> _greetingControllers;
late List<TextEditingController> _exampleControllers;

@override
void dispose() {
super.dispose();
_nameController.dispose();
_descriptionController.dispose();
_personalityController.dispose();
_scenarioController.dispose();
_systemController.dispose();

for (var controller in _greetingControllers) {
controller.dispose();
}
for (var controller in _exampleControllers) {
controller.dispose();
}

SharedPreferences.getInstance().then((prefs) {
final characterString = prefs.getString("last_character");
final character = Character.fromMap(json.decode(characterString ?? "{}"));

final String charactersJson = prefs.getString("characters") ?? '[]';
final List charactersList = json.decode(charactersJson);

List<Character> characters;
characters = charactersList.map((characterMap) {
return Character.fromMap(characterMap);
}).toList();

characters.removeWhere((listCharacter) {
return character.hash == listCharacter.hash;
});
characters.insert(0, character);

final String newCharactersJson =
json.encode(characters.map((character) => character.toMap()).toList());

prefs.setString("characters", newCharactersJson);
});
}

@override
Widget build(BuildContext context) {
return Scaffold(
Expand Down Expand Up @@ -59,26 +99,6 @@ class _CharacterCustomizationPageState extends State<CharacterCustomizationPage>

SharedPreferences.getInstance().then((prefs) {
prefs.setString("last_character", json.encode(character.toMap()));

final String charactersJson = prefs.getString("characters") ?? '[]';
final List charactersList = json.decode(charactersJson);

List<Character> characters;
characters = charactersList.map((characterMap) {
return Character.fromMap(characterMap);
}).toList();

characters.removeWhere((listCharacter) {
print("Character Hash: ${character.hash}");
print("List Character Hash: ${listCharacter.hash}");
return character.hash == listCharacter.hash;
});
characters.insert(0, character);

final String newCharactersJson =
json.encode(characters.map((character) => character.toMap()).toList());

prefs.setString("characters", newCharactersJson);
});

return SessionBusyOverlay(
Expand Down
7 changes: 3 additions & 4 deletions lib/ui/mobile/pages/home_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import 'dart:convert';
import 'dart:math';

import 'package:flutter/material.dart';
import 'package:maid_llm/src/chat_node.dart';
import 'package:maid_llm/maid_llm.dart';
import 'package:maid/providers/user.dart';
import 'package:maid/providers/character.dart';
import 'package:maid/providers/session.dart';
Expand Down Expand Up @@ -44,11 +44,10 @@ class HomePageState extends State<HomePage> {

List<ChatNode> chat = session.chat.getChat();
if (chat.isEmpty && character.useGreeting) {
final newKey = UniqueKey();
final index = Random().nextInt(character.greetings.length);

final message = ChatNode(
key: newKey,
key: UniqueKey(),
role: ChatRole.assistant,
content: Utilities.formatPlaceholders(character.greetings[index], user.name, character.name),
);
Expand All @@ -60,7 +59,7 @@ class HomePageState extends State<HomePage> {
chatWidgets.clear();
for (final message in chat) {
chatWidgets.add(ChatMessage(
node: message,
key: message.key,
));
}

Expand Down
60 changes: 30 additions & 30 deletions lib/ui/mobile/widgets/chat_widgets/chat_message.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@ import 'package:maid_ui/maid_ui.dart';
import 'package:provider/provider.dart';

class ChatMessage extends StatefulWidget {
final ChatNode node;

const ChatMessage({
super.key,
required this.node,
required super.key,
});

@override
ChatMessageState createState() => ChatMessageState();
State<ChatMessage> createState() => _ChatMessageState();
}

class ChatMessageState extends State<ChatMessage> with SingleTickerProviderStateMixin {
bool _editing = false;
class _ChatMessageState extends State<ChatMessage> with SingleTickerProviderStateMixin {
late ChatNode node;
bool editing = false;

Widget _messageBuilder(String message) {
Widget messageBuilder(String message) {
List<Widget> widgets = [];
List<String> parts = message.split('```');

Expand Down Expand Up @@ -57,16 +55,18 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
Widget build(BuildContext context) {
return Consumer3<Session, User, Character>(
builder: (context, session, user, character, child) {
int currentIndex = session.chat.indexOf(widget.node.key);
int siblingCount = session.chat.siblingCountOf(widget.node.key);
node = session.chat.find(widget.key!)!;

int currentIndex = session.chat.indexOf(widget.key!);
int siblingCount = session.chat.siblingCountOf(widget.key!);

return Column(crossAxisAlignment: CrossAxisAlignment.start, children: [
Row(
mainAxisAlignment: MainAxisAlignment.start,
children: [
const SizedBox(width: 10.0),
FutureAvatar(
image: widget.node.role == ChatRole.user ? user.profile : character.profile,
image: node.role == ChatRole.user ? user.profile : character.profile,
radius: 16,
),
const SizedBox(width: 10.0),
Expand All @@ -83,7 +83,7 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
blendMode: BlendMode
.srcIn, // This blend mode applies the shader to the text color.
child: Text(
widget.node.role == ChatRole.user ? user.name : character.name,
node.role == ChatRole.user ? user.name : character.name,
style: const TextStyle(
fontWeight: FontWeight.normal,
color: Colors
Expand All @@ -93,7 +93,7 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
),
),
const Expanded(child: SizedBox()), // Spacer
if (widget.node.finalised) ..._messageOptions(),
if (node.finalised) ...messageOptions(),
Row(
mainAxisSize: MainAxisSize.max,
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
Expand All @@ -102,7 +102,7 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
padding: const EdgeInsets.all(0),
onPressed: () {
if (!session.chat.tail.finalised) return;
session.chat.last(widget.node.key);
session.chat.last(node.key);
session.notify();
},
icon: Icon(Icons.arrow_left,
Expand All @@ -113,7 +113,7 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
padding: const EdgeInsets.all(0),
onPressed: () {
if (!session.chat.tail.finalised) return;
session.chat.next(widget.node.key);
session.chat.next(node.key);
session.notify();
},
icon: Icon(Icons.arrow_right,
Expand All @@ -128,46 +128,46 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
padding: const EdgeInsets.fromLTRB(20, 10, 20, 10),
child: Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: _editing ? _editingColumn() : _standardColumn(),
children: editing ? editingColumn() : standardColumn(),
))
]);
},
);
}

List<Widget> _messageOptions() {
return widget.node.role == ChatRole.user ? _userOptions() : _assistantOptions();
List<Widget> messageOptions() {
return node.role == ChatRole.user ? userOptions() : assistantOptions();
}

List<Widget> _userOptions() {
List<Widget> userOptions() {
return [
IconButton(
onPressed: () {
if (!context.read<Session>().chat.tail.finalised) return;
setState(() {
_editing = true;
editing = true;
});
},
icon: const Icon(Icons.edit),
),
];
}

List<Widget> _assistantOptions() {
List<Widget> assistantOptions() {
return [
IconButton(
onPressed: () {
if (!context.read<Session>().chat.tail.finalised) return;
context.read<Session>().regenerate(widget.node.key, context);
context.read<Session>().regenerate(node.key, context);
setState(() {});
},
icon: const Icon(Icons.refresh),
),
];
}

List<Widget> _editingColumn() {
final messageController = TextEditingController(text: widget.node.content);
List<Widget> editingColumn() {
final messageController = TextEditingController(text: node.content);

return [
TextField(
Expand All @@ -189,29 +189,29 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
onPressed: () {
if (!context.watch<Session>().chat.tail.finalised) return;
setState(() {
_editing = false;
editing = false;
});
context.read<Session>().edit(widget.node.key, messageController.text, context);
context.read<Session>().edit(node.key, messageController.text, context);
},
icon: const Icon(Icons.done)),
IconButton(
padding: const EdgeInsets.all(0),
onPressed: () {
setState(() {
_editing = false;
editing = false;
});
},
icon: const Icon(Icons.close))
])
];
}

List<Widget> _standardColumn() {
List<Widget> standardColumn() {
return [
if (!widget.node.finalised && widget.node.content.isEmpty)
if (!node.finalised && node.content.isEmpty)
const TypingIndicator() // Assuming TypingIndicator is a custom widget you've defined.
else
_messageBuilder(widget.node.content),
messageBuilder(node.content),
];
}
}
2 changes: 1 addition & 1 deletion packages/maid_llm