Skip to content

Commit

Permalink
Merge pull request #357 from danemadsen/main
Browse files Browse the repository at this point in the history
refactor
  • Loading branch information
danemadsen committed Feb 22, 2024
2 parents a67df3f + 2979069 commit 96885ee
Show file tree
Hide file tree
Showing 35 changed files with 416 additions and 454 deletions.
96 changes: 55 additions & 41 deletions lib/classes/generation_options.dart
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import 'dart:math';

import 'package:maid/providers/character.dart';
import 'package:maid/providers/model.dart';
import 'package:maid/providers/ai_platform.dart';
import 'package:maid/providers/session.dart';
import 'package:maid/static/logger.dart';
import 'package:llama_cpp_dart/llama_cpp_dart.dart';
Expand Down Expand Up @@ -107,19 +107,20 @@ class GenerationOptions {
return map;
}

String replaceCaseInsensitive(String original, String from, String replaceWith) {
String replaceCaseInsensitive(
String original, String from, String replaceWith) {
// This creates a regular expression that ignores case (case-insensitive)
RegExp exp = RegExp(RegExp.escape(from), caseSensitive: false);
return original.replaceAll(exp, replaceWith);
}

GenerationOptions({
required Model model,
required AiPlatform ai,
required Character character,
required Session session,
}) {
try {
Logger.log(model.toMap().toString());
Logger.log(ai.toMap().toString());
Logger.log(character.toMap().toString());
Logger.log(session.toMap().toString());

Expand All @@ -129,53 +130,66 @@ class GenerationOptions {
_messages.addAll(session.getMessages());
}

_remoteUrl = model.parameters["remote_url"];
_promptFormat = model.format;
_apiType = model.apiType;
_apiKey = model.parameters["api_key"];
_remoteModel = model.parameters["remote_model"];
_path = model.parameters["path"];
_remoteUrl = ai.parameters["remote_url"];
_promptFormat = ai.format;
_apiType = ai.apiType;
_apiKey = ai.parameters["api_key"];
_remoteModel = ai.parameters["remote_model"];
_path = ai.parameters["path"];

_description = replaceCaseInsensitive(character.description, "{{char}}", character.name);
_description = replaceCaseInsensitive(_description, "<BOT>", character.name);
_description = replaceCaseInsensitive(_description, "{{user}}", session.userName);
_description = replaceCaseInsensitive(_description, "<USER>", session.userName);
_description = replaceCaseInsensitive(
character.description, "{{char}}", character.name);
_description =
replaceCaseInsensitive(_description, "<BOT>", character.name);
_description =
replaceCaseInsensitive(_description, "{{user}}", session.userName);
_description =
replaceCaseInsensitive(_description, "<USER>", session.userName);

_personality = replaceCaseInsensitive(character.personality, "{{char}}", character.name);
_personality = replaceCaseInsensitive(_personality, "<BOT>", character.name);
_personality = replaceCaseInsensitive(_personality, "{{user}}", session.userName);
_personality = replaceCaseInsensitive(_personality, "<USER>", session.userName);
_personality = replaceCaseInsensitive(
character.personality, "{{char}}", character.name);
_personality =
replaceCaseInsensitive(_personality, "<BOT>", character.name);
_personality =
replaceCaseInsensitive(_personality, "{{user}}", session.userName);
_personality =
replaceCaseInsensitive(_personality, "<USER>", session.userName);

_scenario = replaceCaseInsensitive(character.scenario, "{{char}}", character.name);
_scenario = replaceCaseInsensitive(
character.scenario, "{{char}}", character.name);
_scenario = replaceCaseInsensitive(_scenario, "<BOT>", character.name);
_scenario = replaceCaseInsensitive(_scenario, "{{user}}", session.userName);
_scenario =
replaceCaseInsensitive(_scenario, "{{user}}", session.userName);
_scenario = replaceCaseInsensitive(_scenario, "<USER>", session.userName);

_system = replaceCaseInsensitive(character.system, "{{char}}", character.name);
_system =
replaceCaseInsensitive(character.system, "{{char}}", character.name);
_system = replaceCaseInsensitive(_system, "<BOT>", character.name);
_system = replaceCaseInsensitive(_system, "{{user}}", session.userName);
_system = replaceCaseInsensitive(_system, "<USER>", session.userName);

_nKeep = model.parameters["n_keep"];
_seed = model.parameters["random_seed"] ? Random().nextInt(1000000) : model.parameters["seed"];
_nPredict = model.parameters["n_predict"];
_topK = model.parameters["top_k"];
_topP = model.parameters["top_p"];
_minP = model.parameters["min_p"];
_tfsZ = model.parameters["tfs_z"];
_typicalP = model.parameters["typical_p"];
_penaltyLastN = model.parameters["penalty_last_n"];
_temperature = model.parameters["temperature"];
_penaltyRepeat = model.parameters["penalty_repeat"];
_penaltyPresent = model.parameters["penalty_present"];
_penaltyFreq = model.parameters["penalty_freq"];
_mirostat = model.parameters["mirostat"];
_mirostatTau = model.parameters["mirostat_tau"];
_mirostatEta = model.parameters["mirostat_eta"];
_penalizeNewline = model.parameters["penalize_nl"];
_nCtx = model.parameters["n_ctx"];
_nBatch = model.parameters["n_batch"];
_nThread = model.parameters["n_threads"];
_nKeep = ai.parameters["n_keep"];
_seed = ai.parameters["random_seed"]
? Random().nextInt(1000000)
: ai.parameters["seed"];
_nPredict = ai.parameters["n_predict"];
_topK = ai.parameters["top_k"];
_topP = ai.parameters["top_p"];
_minP = ai.parameters["min_p"];
_tfsZ = ai.parameters["tfs_z"];
_typicalP = ai.parameters["typical_p"];
_penaltyLastN = ai.parameters["penalty_last_n"];
_temperature = ai.parameters["temperature"];
_penaltyRepeat = ai.parameters["penalty_repeat"];
_penaltyPresent = ai.parameters["penalty_present"];
_penaltyFreq = ai.parameters["penalty_freq"];
_mirostat = ai.parameters["mirostat"];
_mirostatTau = ai.parameters["mirostat_tau"];
_mirostatEta = ai.parameters["mirostat_eta"];
_penalizeNewline = ai.parameters["penalize_nl"];
_nCtx = ai.parameters["n_ctx"];
_nBatch = ai.parameters["n_batch"];
_nThread = ai.parameters["n_threads"];
} catch (e) {
Logger.log(e.toString());
}
Expand Down
30 changes: 11 additions & 19 deletions lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import 'dart:io';

import 'package:flutter/material.dart';
import 'package:maid/pages/home_page.dart';
import 'package:maid/providers/model.dart';
import 'package:maid/providers/ai_platform.dart';
import 'package:maid/providers/session.dart';
import 'package:maid/providers/character.dart';
import 'package:maid/static/themes.dart';
Expand All @@ -24,7 +24,7 @@ void main() {
MultiProvider(
providers: [
ChangeNotifierProvider(create: (context) => MainProvider()),
ChangeNotifierProvider(create: (context) => Model()),
ChangeNotifierProvider(create: (context) => AiPlatform()),
ChangeNotifierProvider(create: (context) => Character()),
ChangeNotifierProvider(create: (context) => Session()),
],
Expand Down Expand Up @@ -69,30 +69,22 @@ class MaidApp extends StatefulWidget {
class MaidAppState extends State<MaidApp> {
@override
Widget build(BuildContext context) {
return Consumer4<MainProvider, Model, Character, Session>(
builder: (
context,
mainProvider,
model,
character,
session,
child
) {
return Consumer4<MainProvider, AiPlatform, Character, Session>(
builder: (context, mainProvider, ai, character, session, child) {
if (!mainProvider.initialised) {
mainProvider.init();
model.init();
ai.init();
character.init();
session.init();
}

return MaterialApp(
debugShowCheckedModeBanner: false,
title: 'Maid',
theme: Themes.lightTheme(),
darkTheme: Themes.darkTheme(),
themeMode: mainProvider.themeMode,
home: const HomePage(title: "Maid")
);
debugShowCheckedModeBanner: false,
title: 'Maid',
theme: Themes.lightTheme(),
darkTheme: Themes.darkTheme(),
themeMode: mainProvider.themeMode,
home: const HomePage(title: "Maid"));
},
);
}
Expand Down
33 changes: 19 additions & 14 deletions lib/pages/character_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import 'package:maid/static/logger.dart';
import 'package:maid/widgets/dialogs.dart';
import 'package:maid/widgets/double_button_row.dart';
import 'package:maid/widgets/text_field_list_tile.dart';
import 'package:maid/widgets/toggleable_text_field_list_tile.dart';
import 'package:provider/provider.dart';
import 'package:shared_preferences/shared_preferences.dart';

Expand All @@ -34,7 +33,8 @@ class _CharacterPageState extends State<CharacterPage> {
super.initState();
WidgetsBinding.instance.addPostFrameCallback((_) async {
final prefs = await SharedPreferences.getInstance();
final loadedCharacters = json.decode(prefs.getString("characters") ?? "{}");
final loadedCharacters =
json.decode(prefs.getString("characters") ?? "{}");
_characters.addAll(loadedCharacters);
setState(() {});
});
Expand Down Expand Up @@ -298,7 +298,7 @@ class _CharacterPageState extends State<CharacterPage> {
String oldName = character.name;
Logger.log(
"Updating character $oldName ====> $value");
character.setName(value);
character.name = value;
_characters.remove(oldName);
}
},
Expand All @@ -312,7 +312,7 @@ class _CharacterPageState extends State<CharacterPage> {
labelText: 'Description',
controller: _descriptionController,
onChanged: (value) {
character.setDescription(value);
character.description = value;
},
multiline: true,
),
Expand All @@ -321,7 +321,7 @@ class _CharacterPageState extends State<CharacterPage> {
labelText: 'Personality',
controller: _personalityController,
onChanged: (value) {
character.setPersonality(value);
character.personality = value;
},
multiline: true,
),
Expand All @@ -330,7 +330,7 @@ class _CharacterPageState extends State<CharacterPage> {
labelText: 'Scenario',
controller: _scenarioController,
onChanged: (value) {
character.setScenario(value);
character.scenario = value;
},
multiline: true,
),
Expand All @@ -339,7 +339,7 @@ class _CharacterPageState extends State<CharacterPage> {
labelText: 'System Prompt',
controller: _systemController,
onChanged: (value) {
character.setSystem(value);
character.system = value;
},
multiline: true,
),
Expand All @@ -352,7 +352,7 @@ class _CharacterPageState extends State<CharacterPage> {
title: const Text('Use Greeting'),
value: character.useGreeting,
onChanged: (value) {
character.setUseGreeting(value);
character.useGreeting = value;
},
),
if (character.useGreeting) ...[
Expand Down Expand Up @@ -388,24 +388,28 @@ class _CharacterPageState extends State<CharacterPage> {
title: const Text('Use Examples'),
value: character.useExamples,
onChanged: (value) {
character.setUseExamples(value);
character.useExamples = value;
},
),
if (character.useExamples) ...[
DoubleButtonRow(
leftText: "Add Example",
// Adding a new example
leftOnPressed: () {
_exampleControllers.add(TextEditingController()); // For the user part
_exampleControllers.add(TextEditingController()); // For the assistant part
_exampleControllers.add(
TextEditingController()); // For the user part
_exampleControllers.add(
TextEditingController()); // For the assistant part
character.newExample();
},
rightText: "Remove Example",
// Removing the last example
rightOnPressed: () {
if (_exampleControllers.length >= 2) {
_exampleControllers.removeLast(); // Remove assistant part controller
_exampleControllers.removeLast(); // Remove user part controller
_exampleControllers
.removeLast(); // Remove assistant part controller
_exampleControllers
.removeLast(); // Remove user part controller
character.removeLastExample();
}
},
Expand All @@ -414,7 +418,8 @@ class _CharacterPageState extends State<CharacterPage> {
if (character.examples.isNotEmpty) ...[
for (int i = 0; i < character.examples.length; i++)
TextFieldListTile(
headingText:'${character.examples[i]["role"]} content',
headingText:
'${character.examples[i]["role"]} content',
labelText: character.examples[i]["role"],
controller: _exampleControllers[i],
onChanged: (value) {
Expand Down

0 comments on commit 96885ee

Please sign in to comment.