Skip to content

Commit

Permalink
Merge pull request #507 from danemadsen/main
Browse files Browse the repository at this point in the history
work on simplifying chat node logic
  • Loading branch information
danemadsen committed Apr 23, 2024
2 parents 86c82cf + 6680e8c commit fa9835c
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 115 deletions.
33 changes: 1 addition & 32 deletions lib/providers/session.dart
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,16 @@ import 'package:shared_preferences/shared_preferences.dart';

class Session extends ChangeNotifier {
Key _key = UniqueKey();
bool _busy = false;
LargeLanguageModel model = LlamaCppModel();
ChatNodeTree chat = ChatNodeTree();

String _name = "";

bool get isBusy => _busy;

String get name => _name;

Key get key => _key;

set busy(bool value) {
_busy = value;
notifyListeners();
}

Expand Down Expand Up @@ -114,9 +110,6 @@ class Session extends ChangeNotifier {
}

void prompt(BuildContext context) async {
_busy = true;
notifyListeners();

final user = context.read<User>();
final character = context.read<Character>();

Expand All @@ -136,12 +129,8 @@ class Session extends ChangeNotifier {

final stringStream = model.prompt(messages);

await for (var message in stringStream) {
stream(message);
}
await chat.tail.streamIn(stringStream);

_busy = false;
finalise();
notifyListeners();
}

Expand Down Expand Up @@ -171,31 +160,11 @@ class Session extends ChangeNotifier {

void stop() {
(model as LlamaCppModel).stop();
_busy = false;
Logger.log('Local generation stopped');
finalise();
notifyListeners();
}

void stream(String? message) async {
if (message == null) {
finalise();
} else {
chat.buffer += message;

if (!(chat.tail.messageController.isClosed)) {
chat.tail.messageController.add(chat.buffer);
chat.buffer = "";
}

chat.tail.content += message;
}
}

void finalise() {
_busy = false;

chat.tail.messageController.close();

SharedPreferences.getInstance().then((prefs) {
prefs.setString("last_session", json.encode(toMap()));
Expand Down
26 changes: 17 additions & 9 deletions lib/ui/mobile/pages/home_page.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import 'dart:convert';
import 'dart:math';

import 'package:flutter/material.dart';
Expand All @@ -11,6 +12,7 @@ import 'package:maid/ui/mobile/widgets/chat_widgets/chat_field.dart';
import 'package:maid/ui/mobile/widgets/appbars/home_app_bar.dart';
import 'package:maid/ui/mobile/widgets/home_drawer.dart';
import 'package:provider/provider.dart';
import 'package:shared_preferences/shared_preferences.dart';

class HomePage extends StatefulWidget {
final String title;
Expand All @@ -36,23 +38,29 @@ class HomePageState extends State<HomePage> {
Widget _buildBody() {
return Consumer3<Session, User, Character>(
builder: (context, session, user, character, child) {
Map<Key, ChatRole> history = session.chat.getHistory();
if (history.isEmpty && character.useGreeting) {
SharedPreferences.getInstance().then((prefs) {
prefs.setString("last_session", json.encode(session.toMap()));
});

List<ChatNode> chat = session.chat.getChat();
if (chat.isEmpty && character.useGreeting) {
final newKey = UniqueKey();
final index = Random().nextInt(character.greetings.length);
session.chat.add(
newKey,

final message = ChatNode(
key: newKey,
role: ChatRole.assistant,
content: Utilities.formatPlaceholders(character.greetings[index], user.name, character.name),
role: ChatRole.assistant
);
history = {newKey: ChatRole.assistant};

session.chat.addNode(message);
chat = [message];
}

chatWidgets.clear();
for (var key in history.keys) {
for (final message in chat) {
chatWidgets.add(ChatMessage(
key: key,
role: history[key] ?? ChatRole.assistant,
node: message,
));
}

Expand Down
6 changes: 3 additions & 3 deletions lib/ui/mobile/widgets/chat_widgets/chat_field.dart
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class _ChatFieldState extends State<ChatField> {
padding: const EdgeInsets.all(8.0),
child: Row(
children: [
if (session.isBusy &&
if (!session.chat.tail.finalised &&
session.model.type != LargeLanguageModelType.ollama)
IconButton(
onPressed: session.stop,
Expand All @@ -109,14 +109,14 @@ class _ChatFieldState extends State<ChatField> {
),
IconButton(
onPressed: () {
if (!session.isBusy) {
if (session.chat.tail.finalised ) {
send();
}
},
iconSize: 50,
icon: Icon(
Icons.arrow_circle_right,
color: session.isBusy
color: !session.chat.tail.finalised
? Theme.of(context).colorScheme.onPrimary
: Theme.of(context).colorScheme.secondary,
)),
Expand Down
90 changes: 24 additions & 66 deletions lib/ui/mobile/widgets/chat_widgets/chat_message.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import 'package:flutter/material.dart';
import 'package:maid_llm/src/chat_node.dart';
import 'package:maid_llm/maid_llm.dart';
import 'package:maid/providers/character.dart';
import 'package:maid/providers/session.dart';
import 'package:maid/providers/user.dart';
Expand All @@ -8,54 +8,20 @@ import 'package:maid_ui/maid_ui.dart';
import 'package:provider/provider.dart';

class ChatMessage extends StatefulWidget {
final ChatRole role;
final ChatNode node;

const ChatMessage({
required super.key,
this.role = ChatRole.assistant,
super.key,
required this.node,
});

@override
ChatMessageState createState() => ChatMessageState();
}

class ChatMessageState extends State<ChatMessage>
with SingleTickerProviderStateMixin {
late Session session;
final TextEditingController _messageController = TextEditingController();
String _message = "";
bool _finalised = false;
class ChatMessageState extends State<ChatMessage> with SingleTickerProviderStateMixin {
bool _editing = false;

@override
void initState() {
super.initState();
session = context.read<Session>();

if (session.chat.messageOf(widget.key!).isNotEmpty) {
_message = session.chat.messageOf(widget.key!);
_finalised = true;
} else {
session.chat.getMessageStream(widget.key!).stream.listen((textChunk) {
setState(() {
_message += textChunk;
});
}).onDone(() {
_message = _message.trim();

session.chat.add(
widget.key!,
content: _message,
role: widget.role
);

session.notify();

_finalised = true;
});
}
}

Widget _messageBuilder(String message) {
List<Widget> widgets = [];
List<String> parts = message.split('```');
Expand Down Expand Up @@ -91,17 +57,16 @@ class ChatMessageState extends State<ChatMessage>
Widget build(BuildContext context) {
return Consumer3<Session, User, Character>(
builder: (context, session, user, character, child) {
int currentIndex = session.chat.indexOf(widget.key!);
int siblingCount = session.chat.siblingCountOf(widget.key!);
bool busy = session.isBusy;
int currentIndex = session.chat.indexOf(widget.node.key);
int siblingCount = session.chat.siblingCountOf(widget.node.key);

return Column(crossAxisAlignment: CrossAxisAlignment.start, children: [
Row(
mainAxisAlignment: MainAxisAlignment.start,
children: [
const SizedBox(width: 10.0),
FutureAvatar(
image: widget.role == ChatRole.user ? user.profile : character.profile,
image: widget.node.role == ChatRole.user ? user.profile : character.profile,
radius: 16,
),
const SizedBox(width: 10.0),
Expand All @@ -118,7 +83,7 @@ class ChatMessageState extends State<ChatMessage>
blendMode: BlendMode
.srcIn, // This blend mode applies the shader to the text color.
child: Text(
widget.role == ChatRole.user ? user.name : character.name,
widget.node.role == ChatRole.user ? user.name : character.name,
style: const TextStyle(
fontWeight: FontWeight.normal,
color: Colors
Expand All @@ -128,16 +93,16 @@ class ChatMessageState extends State<ChatMessage>
),
),
const Expanded(child: SizedBox()), // Spacer
if (_finalised) ..._messageOptions(),
if (widget.node.finalised) ..._messageOptions(),
Row(
mainAxisSize: MainAxisSize.max,
mainAxisAlignment: MainAxisAlignment.spaceEvenly,
children: <Widget>[
IconButton(
padding: const EdgeInsets.all(0),
onPressed: () {
if (busy) return;
session.chat.last(widget.key!);
if (!session.chat.tail.finalised) return;
session.chat.last(widget.node.key);
session.notify();
},
icon: Icon(Icons.arrow_left,
Expand All @@ -147,8 +112,8 @@ class ChatMessageState extends State<ChatMessage>
IconButton(
padding: const EdgeInsets.all(0),
onPressed: () {
if (busy) return;
session.chat.next(widget.key!);
if (!session.chat.tail.finalised) return;
session.chat.next(widget.node.key);
session.notify();
},
icon: Icon(Icons.arrow_right,
Expand All @@ -171,18 +136,16 @@ class ChatMessageState extends State<ChatMessage>
}

List<Widget> _messageOptions() {
return widget.role == ChatRole.user ? _userOptions() : _assistantOptions();
return widget.node.role == ChatRole.user ? _userOptions() : _assistantOptions();
}

List<Widget> _userOptions() {
return [
IconButton(
onPressed: () {
if (session.isBusy) return;
if (!context.read<Session>().chat.tail.finalised) return;
setState(() {
_messageController.text = _message;
_editing = true;
_finalised = false;
});
},
icon: const Icon(Icons.edit),
Expand All @@ -194,8 +157,8 @@ class ChatMessageState extends State<ChatMessage>
return [
IconButton(
onPressed: () {
if (session.isBusy) return;
session.regenerate(widget.key!, context);
if (!context.read<Session>().chat.tail.finalised) return;
context.read<Session>().regenerate(widget.node.key, context);
setState(() {});
},
icon: const Icon(Icons.refresh),
Expand All @@ -204,11 +167,11 @@ class ChatMessageState extends State<ChatMessage>
}

List<Widget> _editingColumn() {
final busy = context.watch<Session>().isBusy;
final messageController = TextEditingController(text: widget.node.content);

return [
TextField(
controller: _messageController,
controller: messageController,
autofocus: true,
cursorColor: Theme.of(context).colorScheme.secondary,
style: Theme.of(context).textTheme.bodyMedium,
Expand All @@ -224,23 +187,18 @@ class ChatMessageState extends State<ChatMessage>
IconButton(
padding: const EdgeInsets.all(0),
onPressed: () {
if (busy) return;
final inputMessage = _messageController.text;
if (!context.watch<Session>().chat.tail.finalised) return;
setState(() {
_messageController.text = _message;
_editing = false;
_finalised = true;
});
session.edit(widget.key!, inputMessage, context);
context.read<Session>().edit(widget.node.key, messageController.text, context);
},
icon: const Icon(Icons.done)),
IconButton(
padding: const EdgeInsets.all(0),
onPressed: () {
setState(() {
_messageController.text = _message;
_editing = false;
_finalised = true;
});
},
icon: const Icon(Icons.close))
Expand All @@ -250,10 +208,10 @@ class ChatMessageState extends State<ChatMessage>

List<Widget> _standardColumn() {
return [
if (!_finalised && _message.isEmpty)
if (!widget.node.finalised && widget.node.content.isEmpty)
const TypingIndicator() // Assuming TypingIndicator is a custom widget you've defined.
else
_messageBuilder(_message),
_messageBuilder(widget.node.content),
];
}
}
4 changes: 2 additions & 2 deletions lib/ui/mobile/widgets/home_drawer.dart
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class _HomeDrawerState extends State<HomeDrawer> {
),
FilledButton(
onPressed: () {
if (session.isBusy) return;
if (!session.chat.tail.finalised) return;
setState(() {
final newSession = Session();
sessions.add(newSession);
Expand All @@ -144,7 +144,7 @@ class _HomeDrawerState extends State<HomeDrawer> {
return SessionTile(
session: sessions[index],
onDelete: () {
if (session.isBusy) return;
if (!session.chat.tail.finalised) return;
setState(() {
if (sessions[index].key == session.key) {
session.from(sessions.firstOrNull ?? Session());
Expand Down
2 changes: 1 addition & 1 deletion lib/ui/mobile/widgets/session_busy_overlay.dart
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class _SessionBusyOverlayState extends State<SessionBusyOverlay> {
return Stack(
children: [
widget.child,
if (context.watch<Session>().isBusy)
if (!context.watch<Session>().chat.tail.finalised)
Positioned.fill(
child: Container(
color: Colors.black.withOpacity(0.4),
Expand Down
Loading

0 comments on commit fa9835c

Please sign in to comment.