Skip to content

Commit

Permalink
Merge pull request #510 from danemadsen/main
Browse files Browse the repository at this point in the history
Fixes to chat streaming
  • Loading branch information
danemadsen committed Apr 24, 2024
2 parents fa9835c + 679df8c commit ee95041
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 56 deletions.
7 changes: 6 additions & 1 deletion lib/providers/session.dart
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ class Session extends ChangeNotifier {

final stringStream = model.prompt(messages);

await chat.tail.streamIn(stringStream);
await for (var message in stringStream) {
chat.tail.content += message;
notifyListeners();
}

chat.tail.finalised = true;

notifyListeners();
}
Expand Down
60 changes: 40 additions & 20 deletions lib/ui/mobile/pages/character/character_customization_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,46 @@ class _CharacterCustomizationPageState extends State<CharacterCustomizationPage>
late List<TextEditingController> _greetingControllers;
late List<TextEditingController> _exampleControllers;

@override
void dispose() {
super.dispose();
_nameController.dispose();
_descriptionController.dispose();
_personalityController.dispose();
_scenarioController.dispose();
_systemController.dispose();

for (var controller in _greetingControllers) {
controller.dispose();
}
for (var controller in _exampleControllers) {
controller.dispose();
}

SharedPreferences.getInstance().then((prefs) {
final characterString = prefs.getString("last_character");
final character = Character.fromMap(json.decode(characterString ?? "{}"));

final String charactersJson = prefs.getString("characters") ?? '[]';
final List charactersList = json.decode(charactersJson);

List<Character> characters;
characters = charactersList.map((characterMap) {
return Character.fromMap(characterMap);
}).toList();

characters.removeWhere((listCharacter) {
return character.hash == listCharacter.hash;
});
characters.insert(0, character);

final String newCharactersJson =
json.encode(characters.map((character) => character.toMap()).toList());

prefs.setString("characters", newCharactersJson);
});
}

@override
Widget build(BuildContext context) {
return Scaffold(
Expand Down Expand Up @@ -59,26 +99,6 @@ class _CharacterCustomizationPageState extends State<CharacterCustomizationPage>

SharedPreferences.getInstance().then((prefs) {
prefs.setString("last_character", json.encode(character.toMap()));

final String charactersJson = prefs.getString("characters") ?? '[]';
final List charactersList = json.decode(charactersJson);

List<Character> characters;
characters = charactersList.map((characterMap) {
return Character.fromMap(characterMap);
}).toList();

characters.removeWhere((listCharacter) {
print("Character Hash: ${character.hash}");
print("List Character Hash: ${listCharacter.hash}");
return character.hash == listCharacter.hash;
});
characters.insert(0, character);

final String newCharactersJson =
json.encode(characters.map((character) => character.toMap()).toList());

prefs.setString("characters", newCharactersJson);
});

return SessionBusyOverlay(
Expand Down
7 changes: 3 additions & 4 deletions lib/ui/mobile/pages/home_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import 'dart:convert';
import 'dart:math';

import 'package:flutter/material.dart';
import 'package:maid_llm/src/chat_node.dart';
import 'package:maid_llm/maid_llm.dart';
import 'package:maid/providers/user.dart';
import 'package:maid/providers/character.dart';
import 'package:maid/providers/session.dart';
Expand Down Expand Up @@ -44,11 +44,10 @@ class HomePageState extends State<HomePage> {

List<ChatNode> chat = session.chat.getChat();
if (chat.isEmpty && character.useGreeting) {
final newKey = UniqueKey();
final index = Random().nextInt(character.greetings.length);

final message = ChatNode(
key: newKey,
key: UniqueKey(),
role: ChatRole.assistant,
content: Utilities.formatPlaceholders(character.greetings[index], user.name, character.name),
);
Expand All @@ -60,7 +59,7 @@ class HomePageState extends State<HomePage> {
chatWidgets.clear();
for (final message in chat) {
chatWidgets.add(ChatMessage(
node: message,
key: message.key,
));
}

Expand Down
60 changes: 30 additions & 30 deletions lib/ui/mobile/widgets/chat_widgets/chat_message.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@ import 'package:maid_ui/maid_ui.dart';
import 'package:provider/provider.dart';

class ChatMessage extends StatefulWidget {
final ChatNode node;

const ChatMessage({
super.key,
required this.node,
required super.key,
});

@override
ChatMessageState createState() => ChatMessageState();
State<ChatMessage> createState() => _ChatMessageState();
}

class ChatMessageState extends State<ChatMessage> with SingleTickerProviderStateMixin {
bool _editing = false;
class _ChatMessageState extends State<ChatMessage> with SingleTickerProviderStateMixin {
late ChatNode node;
bool editing = false;

Widget _messageBuilder(String message) {
Widget messageBuilder(String message) {
List<Widget> widgets = [];
List<String> parts = message.split('```');

Expand Down Expand Up @@ -57,16 +55,18 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
Widget build(BuildContext context) {
return Consumer3<Session, User, Character>(
builder: (context, session, user, character, child) {
int currentIndex = session.chat.indexOf(widget.node.key);
int siblingCount = session.chat.siblingCountOf(widget.node.key);
node = session.chat.find(widget.key!)!;

int currentIndex = session.chat.indexOf(widget.key!);
int siblingCount = session.chat.siblingCountOf(widget.key!);

return Column(crossAxisAlignment: CrossAxisAlignment.start, children: [
Row(
mainAxisAlignment: MainAxisAlignment.start,
children: [
const SizedBox(width: 10.0),
FutureAvatar(
image: widget.node.role == ChatRole.user ? user.profile : character.profile,
image: node.role == ChatRole.user ? user.profile : character.profile,
radius: 16,
),
const SizedBox(width: 10.0),
Expand All @@ -83,7 +83,7 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
blendMode: BlendMode
.srcIn, // This blend mode applies the shader to the text color.
child: Text(
widget.node.role == ChatRole.user ? user.name : character.name,
node.role == ChatRole.user ? user.name : character.name,
style: const TextStyle(
fontWeight: FontWeight.normal,
color: Colors
Expand All @@ -93,7 +93,7 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
),
),
const Expanded(child: SizedBox()), // Spacer
if (widget.node.finalised) ..._messageOptions(),
if (node.finalised) ...messageOptions(),
Row(
mainAxisSize: MainAxisSize.max,
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
Expand All @@ -102,7 +102,7 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
padding: const EdgeInsets.all(0),
onPressed: () {
if (!session.chat.tail.finalised) return;
session.chat.last(widget.node.key);
session.chat.last(node.key);
session.notify();
},
icon: Icon(Icons.arrow_left,
Expand All @@ -113,7 +113,7 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
padding: const EdgeInsets.all(0),
onPressed: () {
if (!session.chat.tail.finalised) return;
session.chat.next(widget.node.key);
session.chat.next(node.key);
session.notify();
},
icon: Icon(Icons.arrow_right,
Expand All @@ -128,46 +128,46 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
padding: const EdgeInsets.fromLTRB(20, 10, 20, 10),
child: Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: _editing ? _editingColumn() : _standardColumn(),
children: editing ? editingColumn() : standardColumn(),
))
]);
},
);
}

List<Widget> _messageOptions() {
return widget.node.role == ChatRole.user ? _userOptions() : _assistantOptions();
List<Widget> messageOptions() {
return node.role == ChatRole.user ? userOptions() : assistantOptions();
}

List<Widget> _userOptions() {
List<Widget> userOptions() {
return [
IconButton(
onPressed: () {
if (!context.read<Session>().chat.tail.finalised) return;
setState(() {
_editing = true;
editing = true;
});
},
icon: const Icon(Icons.edit),
),
];
}

List<Widget> _assistantOptions() {
List<Widget> assistantOptions() {
return [
IconButton(
onPressed: () {
if (!context.read<Session>().chat.tail.finalised) return;
context.read<Session>().regenerate(widget.node.key, context);
context.read<Session>().regenerate(node.key, context);
setState(() {});
},
icon: const Icon(Icons.refresh),
),
];
}

List<Widget> _editingColumn() {
final messageController = TextEditingController(text: widget.node.content);
List<Widget> editingColumn() {
final messageController = TextEditingController(text: node.content);

return [
TextField(
Expand All @@ -189,29 +189,29 @@ class ChatMessageState extends State<ChatMessage> with SingleTickerProviderState
onPressed: () {
if (!context.watch<Session>().chat.tail.finalised) return;
setState(() {
_editing = false;
editing = false;
});
context.read<Session>().edit(widget.node.key, messageController.text, context);
context.read<Session>().edit(node.key, messageController.text, context);
},
icon: const Icon(Icons.done)),
IconButton(
padding: const EdgeInsets.all(0),
onPressed: () {
setState(() {
_editing = false;
editing = false;
});
},
icon: const Icon(Icons.close))
])
];
}

List<Widget> _standardColumn() {
List<Widget> standardColumn() {
return [
if (!widget.node.finalised && widget.node.content.isEmpty)
if (!node.finalised && node.content.isEmpty)
const TypingIndicator() // Assuming TypingIndicator is a custom widget you've defined.
else
_messageBuilder(widget.node.content),
messageBuilder(node.content),
];
}
}
2 changes: 1 addition & 1 deletion packages/maid_llm

0 comments on commit ee95041

Please sign in to comment.