diff --git a/lib/classes/generation_options.dart b/lib/classes/generation_options.dart index 916c6045..e1fa8244 100644 --- a/lib/classes/generation_options.dart +++ b/lib/classes/generation_options.dart @@ -1,7 +1,7 @@ import 'dart:math'; import 'package:maid/providers/character.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/providers/session.dart'; import 'package:maid/static/logger.dart'; import 'package:llama_cpp_dart/llama_cpp_dart.dart'; @@ -107,19 +107,20 @@ class GenerationOptions { return map; } - String replaceCaseInsensitive(String original, String from, String replaceWith) { + String replaceCaseInsensitive( + String original, String from, String replaceWith) { // This creates a regular expression that ignores case (case-insensitive) RegExp exp = RegExp(RegExp.escape(from), caseSensitive: false); return original.replaceAll(exp, replaceWith); } GenerationOptions({ - required Model model, + required AiPlatform ai, required Character character, required Session session, }) { try { - Logger.log(model.toMap().toString()); + Logger.log(ai.toMap().toString()); Logger.log(character.toMap().toString()); Logger.log(session.toMap().toString()); @@ -129,53 +130,66 @@ class GenerationOptions { _messages.addAll(session.getMessages()); } - _remoteUrl = model.parameters["remote_url"]; - _promptFormat = model.format; - _apiType = model.apiType; - _apiKey = model.parameters["api_key"]; - _remoteModel = model.parameters["remote_model"]; - _path = model.parameters["path"]; + _remoteUrl = ai.parameters["remote_url"]; + _promptFormat = ai.format; + _apiType = ai.apiType; + _apiKey = ai.parameters["api_key"]; + _remoteModel = ai.parameters["remote_model"]; + _path = ai.parameters["path"]; - _description = replaceCaseInsensitive(character.description, "{{char}}", character.name); - _description = replaceCaseInsensitive(_description, "", character.name); - _description = replaceCaseInsensitive(_description, "{{user}}", session.userName); - _description = replaceCaseInsensitive(_description, "", session.userName); + _description = replaceCaseInsensitive( + character.description, "{{char}}", character.name); + _description = + replaceCaseInsensitive(_description, "", character.name); + _description = + replaceCaseInsensitive(_description, "{{user}}", session.userName); + _description = + replaceCaseInsensitive(_description, "", session.userName); - _personality = replaceCaseInsensitive(character.personality, "{{char}}", character.name); - _personality = replaceCaseInsensitive(_personality, "", character.name); - _personality = replaceCaseInsensitive(_personality, "{{user}}", session.userName); - _personality = replaceCaseInsensitive(_personality, "", session.userName); + _personality = replaceCaseInsensitive( + character.personality, "{{char}}", character.name); + _personality = + replaceCaseInsensitive(_personality, "", character.name); + _personality = + replaceCaseInsensitive(_personality, "{{user}}", session.userName); + _personality = + replaceCaseInsensitive(_personality, "", session.userName); - _scenario = replaceCaseInsensitive(character.scenario, "{{char}}", character.name); + _scenario = replaceCaseInsensitive( + character.scenario, "{{char}}", character.name); _scenario = replaceCaseInsensitive(_scenario, "", character.name); - _scenario = replaceCaseInsensitive(_scenario, "{{user}}", session.userName); + _scenario = + replaceCaseInsensitive(_scenario, "{{user}}", session.userName); _scenario = replaceCaseInsensitive(_scenario, "", session.userName); - _system = replaceCaseInsensitive(character.system, "{{char}}", character.name); + _system = + replaceCaseInsensitive(character.system, "{{char}}", character.name); _system = replaceCaseInsensitive(_system, "", character.name); _system = replaceCaseInsensitive(_system, "{{user}}", session.userName); _system = replaceCaseInsensitive(_system, "", session.userName); - _nKeep = model.parameters["n_keep"]; - _seed = model.parameters["random_seed"] ? Random().nextInt(1000000) : model.parameters["seed"]; - _nPredict = model.parameters["n_predict"]; - _topK = model.parameters["top_k"]; - _topP = model.parameters["top_p"]; - _minP = model.parameters["min_p"]; - _tfsZ = model.parameters["tfs_z"]; - _typicalP = model.parameters["typical_p"]; - _penaltyLastN = model.parameters["penalty_last_n"]; - _temperature = model.parameters["temperature"]; - _penaltyRepeat = model.parameters["penalty_repeat"]; - _penaltyPresent = model.parameters["penalty_present"]; - _penaltyFreq = model.parameters["penalty_freq"]; - _mirostat = model.parameters["mirostat"]; - _mirostatTau = model.parameters["mirostat_tau"]; - _mirostatEta = model.parameters["mirostat_eta"]; - _penalizeNewline = model.parameters["penalize_nl"]; - _nCtx = model.parameters["n_ctx"]; - _nBatch = model.parameters["n_batch"]; - _nThread = model.parameters["n_threads"]; + _nKeep = ai.parameters["n_keep"]; + _seed = ai.parameters["random_seed"] + ? Random().nextInt(1000000) + : ai.parameters["seed"]; + _nPredict = ai.parameters["n_predict"]; + _topK = ai.parameters["top_k"]; + _topP = ai.parameters["top_p"]; + _minP = ai.parameters["min_p"]; + _tfsZ = ai.parameters["tfs_z"]; + _typicalP = ai.parameters["typical_p"]; + _penaltyLastN = ai.parameters["penalty_last_n"]; + _temperature = ai.parameters["temperature"]; + _penaltyRepeat = ai.parameters["penalty_repeat"]; + _penaltyPresent = ai.parameters["penalty_present"]; + _penaltyFreq = ai.parameters["penalty_freq"]; + _mirostat = ai.parameters["mirostat"]; + _mirostatTau = ai.parameters["mirostat_tau"]; + _mirostatEta = ai.parameters["mirostat_eta"]; + _penalizeNewline = ai.parameters["penalize_nl"]; + _nCtx = ai.parameters["n_ctx"]; + _nBatch = ai.parameters["n_batch"]; + _nThread = ai.parameters["n_threads"]; } catch (e) { Logger.log(e.toString()); } diff --git a/lib/main.dart b/lib/main.dart index 20b7da6f..e738f107 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -3,7 +3,7 @@ import 'dart:io'; import 'package:flutter/material.dart'; import 'package:maid/pages/home_page.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/providers/session.dart'; import 'package:maid/providers/character.dart'; import 'package:maid/static/themes.dart'; @@ -24,7 +24,7 @@ void main() { MultiProvider( providers: [ ChangeNotifierProvider(create: (context) => MainProvider()), - ChangeNotifierProvider(create: (context) => Model()), + ChangeNotifierProvider(create: (context) => AiPlatform()), ChangeNotifierProvider(create: (context) => Character()), ChangeNotifierProvider(create: (context) => Session()), ], @@ -69,30 +69,22 @@ class MaidApp extends StatefulWidget { class MaidAppState extends State { @override Widget build(BuildContext context) { - return Consumer4( - builder: ( - context, - mainProvider, - model, - character, - session, - child - ) { + return Consumer4( + builder: (context, mainProvider, ai, character, session, child) { if (!mainProvider.initialised) { mainProvider.init(); - model.init(); + ai.init(); character.init(); session.init(); } return MaterialApp( - debugShowCheckedModeBanner: false, - title: 'Maid', - theme: Themes.lightTheme(), - darkTheme: Themes.darkTheme(), - themeMode: mainProvider.themeMode, - home: const HomePage(title: "Maid") - ); + debugShowCheckedModeBanner: false, + title: 'Maid', + theme: Themes.lightTheme(), + darkTheme: Themes.darkTheme(), + themeMode: mainProvider.themeMode, + home: const HomePage(title: "Maid")); }, ); } diff --git a/lib/pages/character_page.dart b/lib/pages/character_page.dart index 635deb5a..a0bde205 100644 --- a/lib/pages/character_page.dart +++ b/lib/pages/character_page.dart @@ -7,7 +7,6 @@ import 'package:maid/static/logger.dart'; import 'package:maid/widgets/dialogs.dart'; import 'package:maid/widgets/double_button_row.dart'; import 'package:maid/widgets/text_field_list_tile.dart'; -import 'package:maid/widgets/toggleable_text_field_list_tile.dart'; import 'package:provider/provider.dart'; import 'package:shared_preferences/shared_preferences.dart'; @@ -34,7 +33,8 @@ class _CharacterPageState extends State { super.initState(); WidgetsBinding.instance.addPostFrameCallback((_) async { final prefs = await SharedPreferences.getInstance(); - final loadedCharacters = json.decode(prefs.getString("characters") ?? "{}"); + final loadedCharacters = + json.decode(prefs.getString("characters") ?? "{}"); _characters.addAll(loadedCharacters); setState(() {}); }); @@ -298,7 +298,7 @@ class _CharacterPageState extends State { String oldName = character.name; Logger.log( "Updating character $oldName ====> $value"); - character.setName(value); + character.name = value; _characters.remove(oldName); } }, @@ -312,7 +312,7 @@ class _CharacterPageState extends State { labelText: 'Description', controller: _descriptionController, onChanged: (value) { - character.setDescription(value); + character.description = value; }, multiline: true, ), @@ -321,7 +321,7 @@ class _CharacterPageState extends State { labelText: 'Personality', controller: _personalityController, onChanged: (value) { - character.setPersonality(value); + character.personality = value; }, multiline: true, ), @@ -330,7 +330,7 @@ class _CharacterPageState extends State { labelText: 'Scenario', controller: _scenarioController, onChanged: (value) { - character.setScenario(value); + character.scenario = value; }, multiline: true, ), @@ -339,7 +339,7 @@ class _CharacterPageState extends State { labelText: 'System Prompt', controller: _systemController, onChanged: (value) { - character.setSystem(value); + character.system = value; }, multiline: true, ), @@ -352,7 +352,7 @@ class _CharacterPageState extends State { title: const Text('Use Greeting'), value: character.useGreeting, onChanged: (value) { - character.setUseGreeting(value); + character.useGreeting = value; }, ), if (character.useGreeting) ...[ @@ -388,7 +388,7 @@ class _CharacterPageState extends State { title: const Text('Use Examples'), value: character.useExamples, onChanged: (value) { - character.setUseExamples(value); + character.useExamples = value; }, ), if (character.useExamples) ...[ @@ -396,16 +396,20 @@ class _CharacterPageState extends State { leftText: "Add Example", // Adding a new example leftOnPressed: () { - _exampleControllers.add(TextEditingController()); // For the user part - _exampleControllers.add(TextEditingController()); // For the assistant part + _exampleControllers.add( + TextEditingController()); // For the user part + _exampleControllers.add( + TextEditingController()); // For the assistant part character.newExample(); }, rightText: "Remove Example", // Removing the last example rightOnPressed: () { if (_exampleControllers.length >= 2) { - _exampleControllers.removeLast(); // Remove assistant part controller - _exampleControllers.removeLast(); // Remove user part controller + _exampleControllers + .removeLast(); // Remove assistant part controller + _exampleControllers + .removeLast(); // Remove user part controller character.removeLastExample(); } }, @@ -414,7 +418,8 @@ class _CharacterPageState extends State { if (character.examples.isNotEmpty) ...[ for (int i = 0; i < character.examples.length; i++) TextFieldListTile( - headingText:'${character.examples[i]["role"]} content', + headingText: + '${character.examples[i]["role"]} content', labelText: character.examples[i]["role"], controller: _exampleControllers[i], onChanged: (value) { diff --git a/lib/pages/home_page.dart b/lib/pages/home_page.dart index 83d6afe7..6bd49836 100644 --- a/lib/pages/home_page.dart +++ b/lib/pages/home_page.dart @@ -3,7 +3,7 @@ import 'dart:math'; import 'package:flutter/material.dart'; import 'package:maid/pages/about_page.dart'; import 'package:maid/pages/character_page.dart'; -import 'package:maid/pages/model_page.dart'; +import 'package:maid/pages/platform_page.dart'; import 'package:maid/pages/sessions_page.dart'; import 'package:maid/pages/settings_page.dart'; import 'package:maid/providers/character.dart'; @@ -29,74 +29,63 @@ class HomePageState extends State { AppBar _buildAppBar(double aspectRatio) { if (aspectRatio < 0.9) { - return AppBar( - elevation: 0.0, - ); + return AppBar( + elevation: 0.0, + ); } return AppBar( backgroundColor: Theme.of(context).colorScheme.primary, - elevation: 0.0, - titleSpacing: 0, // Remove the default spacing - title: Row( - mainAxisAlignment: MainAxisAlignment.spaceEvenly, // Distribute the space evenly - children: [ - IconButton( - icon: const Icon(Icons.person), - onPressed: () { - Navigator.push( - context, - MaterialPageRoute( - builder: (context) => const CharacterPage() - ) - ); - }, - ), - IconButton( - icon: const Icon(Icons.account_tree_rounded), - onPressed: () { - Navigator.push( - context, - MaterialPageRoute( - builder: (context) => const ModelPage() - ) - ); - }, - ), - IconButton( - icon: const Icon(Icons.chat_rounded), - onPressed: () { - Navigator.push( - context, - MaterialPageRoute( - builder: (context) => const SessionsPage() - ) - ); - }, - ), - IconButton( - icon: const Icon(Icons.settings), - onPressed: () { - Navigator.push( - context, - MaterialPageRoute( - builder: (context) => const SettingsPage() - ) - ); - }, - ), - IconButton( - icon: const Icon(Icons.info), - onPressed: () { - Navigator.push( - context, - MaterialPageRoute( - builder: (context) => const AboutPage() - ) - ); - }, - ), - ], - ), + elevation: 0.0, + titleSpacing: 0, // Remove the default spacing + title: Row( + mainAxisAlignment: + MainAxisAlignment.spaceEvenly, // Distribute the space evenly + children: [ + IconButton( + icon: const Icon(Icons.person), + onPressed: () { + Navigator.push( + context, + MaterialPageRoute( + builder: (context) => const CharacterPage())); + }, + ), + IconButton( + icon: const Icon(Icons.account_tree_rounded), + onPressed: () { + Navigator.push( + context, + MaterialPageRoute( + builder: (context) => const PlatformPage())); + }, + ), + IconButton( + icon: const Icon(Icons.chat_rounded), + onPressed: () { + Navigator.push( + context, + MaterialPageRoute( + builder: (context) => const SessionsPage())); + }, + ), + IconButton( + icon: const Icon(Icons.settings), + onPressed: () { + Navigator.push( + context, + MaterialPageRoute( + builder: (context) => const SettingsPage())); + }, + ), + IconButton( + icon: const Icon(Icons.info), + onPressed: () { + Navigator.push(context, + MaterialPageRoute(builder: (context) => const AboutPage())); + }, + ), + ], + ), ); } @@ -127,17 +116,14 @@ class HomePageState extends State { ListTile( leading: Icon(Icons.person, color: Theme.of(context).colorScheme.onPrimary), - title: Text( - 'Character', - style: Theme.of(context).textTheme.labelLarge), + title: Text('Character', + style: Theme.of(context).textTheme.labelLarge), onTap: () { Navigator.pop(context); // Close the drawer Navigator.push( - context, - MaterialPageRoute( - builder: (context) => const CharacterPage() - ) - ); + context, + MaterialPageRoute( + builder: (context) => const CharacterPage())); }, ), ListTile( @@ -150,11 +136,9 @@ class HomePageState extends State { onTap: () { Navigator.pop(context); // Close the drawer Navigator.push( - context, - MaterialPageRoute( - builder: (context) => const ModelPage() - ) - ); + context, + MaterialPageRoute( + builder: (context) => const PlatformPage())); }, ), ListTile( @@ -167,11 +151,9 @@ class HomePageState extends State { onTap: () { Navigator.pop(context); // Close the drawer Navigator.push( - context, - MaterialPageRoute( - builder: (context) => const SessionsPage() - ) - ); + context, + MaterialPageRoute( + builder: (context) => const SessionsPage())); }, ), ListTile( @@ -182,11 +164,9 @@ class HomePageState extends State { onTap: () { Navigator.pop(context); // Close the drawer Navigator.push( - context, - MaterialPageRoute( - builder: (context) => const SettingsPage() - ) - ); + context, + MaterialPageRoute( + builder: (context) => const SettingsPage())); }, ), ListTile( @@ -196,12 +176,8 @@ class HomePageState extends State { Text('About', style: Theme.of(context).textTheme.labelLarge), onTap: () { Navigator.pop(context); // Close the drawer - Navigator.push( - context, - MaterialPageRoute( - builder: (context) => const AboutPage() - ) - ); + Navigator.push(context, + MaterialPageRoute(builder: (context) => const AboutPage())); }, ), ], @@ -226,7 +202,10 @@ class HomePageState extends State { if (history.isEmpty && character.useGreeting) { final newKey = UniqueKey(); final index = Random().nextInt(character.greetings.length); - session.add(newKey, message: character.greetings[index], userGenerated: false, notify: false); + session.add(newKey, + message: character.greetings[index], + userGenerated: false, + notify: false); history = {newKey: false}; } diff --git a/lib/pages/model_page.dart b/lib/pages/platform_page.dart similarity index 79% rename from lib/pages/model_page.dart rename to lib/pages/platform_page.dart index 38fde965..5a21c836 100644 --- a/lib/pages/model_page.dart +++ b/lib/pages/platform_page.dart @@ -3,7 +3,7 @@ import 'dart:convert'; import 'package:flutter/material.dart'; import 'package:maid/providers/session.dart'; import 'package:maid/static/logger.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/platform_widgets/local_platform.dart'; import 'package:maid/widgets/platform_widgets/mistralai_platform.dart'; import 'package:maid/widgets/platform_widgets/ollama_platform.dart'; @@ -15,15 +15,15 @@ import 'package:maid/widgets/text_field_list_tile.dart'; import 'package:provider/provider.dart'; import 'package:shared_preferences/shared_preferences.dart'; -class ModelPage extends StatefulWidget { - const ModelPage({super.key}); +class PlatformPage extends StatefulWidget { + const PlatformPage({super.key}); @override - State createState() => _ModelPageState(); + State createState() => _PlatformPageState(); } -class _ModelPageState extends State { - late Map _models; +class _PlatformPageState extends State { + late Map _platforms; late TextEditingController _presetController; @override @@ -32,19 +32,19 @@ class _ModelPageState extends State { WidgetsBinding.instance.addPostFrameCallback((_) async { final prefs = await SharedPreferences.getInstance(); final loadedModels = json.decode(prefs.getString("models") ?? "{}"); - _models.addAll(loadedModels); + _platforms.addAll(loadedModels); setState(() {}); }); - final model = context.read(); - _models = {model.preset: model.toMap()}; - _presetController = TextEditingController(text: model.preset); + final ai = context.read(); + _platforms = {ai.preset: ai.toMap()}; + _presetController = TextEditingController(text: ai.preset); } @override void dispose() { SharedPreferences.getInstance().then((prefs) { - prefs.setString("models", json.encode(_models)); + prefs.setString("models", json.encode(_platforms)); }); super.dispose(); @@ -68,11 +68,11 @@ class _ModelPageState extends State { ), title: const Text("Model"), ), - body: Consumer(builder: (context, model, child) { - _models[model.preset] = model.toMap(); + body: Consumer(builder: (context, ai, child) { + _platforms[ai.preset] = ai.toMap(); SharedPreferences.getInstance().then((prefs) { - prefs.setString("last_model", json.encode(model.toMap())); + prefs.setString("last_model", json.encode(ai.toMap())); }); return Stack( @@ -82,7 +82,7 @@ class _ModelPageState extends State { children: [ const SizedBox(height: 10.0), Text( - model.preset, + ai.preset, textAlign: TextAlign.center, style: Theme.of(context).textTheme.titleLarge, ), @@ -92,8 +92,8 @@ class _ModelPageState extends State { showDialog( context: context, builder: (BuildContext context) { - return Consumer( - builder: (context, model, child) { + return Consumer( + builder: (context, ai, child) { return AlertDialog( title: const Text( "Switch Model", @@ -103,11 +103,11 @@ class _ModelPageState extends State { height: 200, width: 200, child: ListView.builder( - itemCount: _models.keys.length, + itemCount: _platforms.keys.length, itemBuilder: (BuildContext context, int index) { final item = - _models.keys.elementAt(index); + _platforms.keys.elementAt(index); return Padding( padding: const EdgeInsets.all(8.0), @@ -117,9 +117,9 @@ class _ModelPageState extends State { Container(color: Colors.red), onDismissed: (direction) { setState(() { - _models.remove(item); - if (model.preset == item) { - model.fromMap(_models.values + _platforms.remove(item); + if (ai.preset == item) { + ai.fromMap(_platforms.values .lastOrNull ?? {}); } @@ -130,7 +130,7 @@ class _ModelPageState extends State { }, child: Container( decoration: BoxDecoration( - color: model.preset == item + color: ai.preset == item ? Theme.of(context) .colorScheme .tertiary @@ -148,10 +148,10 @@ class _ModelPageState extends State { textAlign: TextAlign.center, ), onTap: () { - model - .fromMap(_models[item]); + ai.fromMap( + _platforms[item]); Logger.log( - "Model Set: ${model.preset}"); + "Model Set: ${ai.preset}"); Navigator.of(context).pop(); }, ), @@ -179,10 +179,10 @@ class _ModelPageState extends State { ), FilledButton( onPressed: () { - _models[model.preset] = model.toMap(); - model.newPreset(); - _models[model.preset] = model.toMap(); - model.notify(); + _platforms[ai.preset] = ai.toMap(); + ai.newPreset(); + _platforms[ai.preset] = ai.toMap(); + ai.notify(); }, child: Text( "New Preset", @@ -208,14 +208,14 @@ class _ModelPageState extends State { labelText: "Preset", controller: _presetController, onChanged: (value) { - if (_models.keys.contains(value)) { - model.fromMap(_models[value] ?? {}); - Logger.log("Model Set: ${model.preset}"); + if (_platforms.keys.contains(value)) { + ai.fromMap(_platforms[value] ?? {}); + Logger.log("Model Set: ${ai.preset}"); } else if (value.isNotEmpty) { - String oldName = model.preset; + String oldName = ai.preset; Logger.log("Updating model $oldName ====> $value"); - model.setPreset(value); - _models.remove(oldName); + ai.preset = value; + _platforms.remove(oldName); } }, ), @@ -230,23 +230,23 @@ class _ModelPageState extends State { leftText: "Load Parameters", leftOnPressed: () async { await storageOperationDialog( - context, model.importModelParameters); + context, ai.importModelParameters); setState(() { - _presetController.text = model.preset; + _presetController.text = ai.preset; }); }, rightText: "Save Parameters", rightOnPressed: () async { await storageOperationDialog( - context, model.exportModelParameters); + context, ai.exportModelParameters); }, ), const SizedBox(height: 15.0), FilledButton( onPressed: () { - model.resetAll(); + ai.resetAll(); setState(() { - _presetController.text = model.preset; + _presetController.text = ai.preset; }); }, child: Text( @@ -261,13 +261,13 @@ class _ModelPageState extends State { color: Theme.of(context).colorScheme.primary, ), const ApiDropdown(), - if (model.apiType == ApiType.local) + if (ai.apiType == ApiType.local) const LocalPlatform() - else if (model.apiType == ApiType.ollama) + else if (ai.apiType == ApiType.ollama) const OllamaPlatform() - else if (model.apiType == ApiType.openAI) + else if (ai.apiType == ApiType.openAI) const OpenAiPlatform() - else if (model.apiType == ApiType.mistralAI) + else if (ai.apiType == ApiType.mistralAI) const MistralAiPlatform(), ], ), diff --git a/lib/providers/model.dart b/lib/providers/ai_platform.dart similarity index 94% rename from lib/providers/model.dart rename to lib/providers/ai_platform.dart index 591e891c..9c2c6e95 100644 --- a/lib/providers/model.dart +++ b/lib/providers/ai_platform.dart @@ -10,10 +10,11 @@ import 'package:maid/static/logger.dart'; import 'package:shared_preferences/shared_preferences.dart'; import 'package:llama_cpp_dart/llama_cpp_dart.dart'; -class Model extends ChangeNotifier { +class AiPlatform extends ChangeNotifier { PromptFormatType _format = PromptFormatType.alpaca; ApiType _apiType = ApiType.local; String _preset = "Default"; + String _model = ""; Map _parameters = {}; void newPreset() { @@ -42,22 +43,27 @@ class Model extends ChangeNotifier { } } - void setPreset(String preset) { + set preset(String preset) { _preset = preset; notifyListeners(); } + set model(String model) { + _model = model; + notifyListeners(); + } + void setParameter(String key, dynamic value) { _parameters[key] = value; notifyListeners(); } - void setPromptFormat(PromptFormatType promptFormat) { + set promptFormat(PromptFormatType promptFormat) { _format = promptFormat; notifyListeners(); } - void setApiType(ApiType apiType) { + set apiType(ApiType apiType) { _apiType = apiType; notifyListeners(); } @@ -65,6 +71,7 @@ class Model extends ChangeNotifier { PromptFormatType get format => _format; ApiType get apiType => _apiType; String get preset => _preset; + String get model => _model; Map get parameters => _parameters; Future> getOptions() { diff --git a/lib/providers/character.dart b/lib/providers/character.dart index f0b00022..c40b3bef 100644 --- a/lib/providers/character.dart +++ b/lib/providers/character.dart @@ -92,7 +92,6 @@ class Character extends ChangeNotifier { } } - _system = inputJson["system_prompt"] ?? ""; _useExamples = inputJson["use_examples"] ?? true; @@ -130,27 +129,32 @@ class Character extends ChangeNotifier { return jsonCharacter; } - void setName(String newName) { + set name(String newName) { _name = newName; notifyListeners(); } - void setDescription(String newDescription) { + set description(String newDescription) { _description = newDescription; notifyListeners(); } - void setPersonality(String newPersonality) { + set personality(String newPersonality) { _personality = newPersonality; notifyListeners(); } - void setScenario(String newScenario) { + set scenario(String newScenario) { _scenario = newScenario; notifyListeners(); } - void setUseGreeting(bool useGreeting) { + set system(String newSystem) { + _system = newSystem; + notifyListeners(); + } + + set useGreeting(bool useGreeting) { _useGreeting = useGreeting; notifyListeners(); } @@ -175,12 +179,7 @@ class Character extends ChangeNotifier { notifyListeners(); } - void setSystem(String newSystem) { - _system = newSystem; - notifyListeners(); - } - - void setUseExamples(bool useExamples) { + set useExamples(bool useExamples) { _useExamples = useExamples; notifyListeners(); } @@ -334,16 +333,18 @@ class Character extends ChangeNotifier { _description = image.textData!["description"] ?? ""; _personality = image.textData!["personality"] ?? ""; _scenario = image.textData!["scenario"] ?? ""; - + if (image.textData!["greetings"] != null) { - _greetings = List.from(json.decode(image.textData!["greetings"] ?? "[]")); + _greetings = List.from( + json.decode(image.textData!["greetings"] ?? "[]")); } else { if (image.textData!["first_mes"] != null) { _greetings = [image.textData!["first_mes"] ?? ""]; } - + if (image.textData!["alternate_greetings"] != null) { - _greetings.addAll(List.from(json.decode(image.textData!["alternate_greetings"] ?? "[]"))); + _greetings.addAll(List.from( + json.decode(image.textData!["alternate_greetings"] ?? "[]"))); } } diff --git a/lib/static/generation_manager.dart b/lib/static/generation_manager.dart index d600e776..8cbeb9ff 100644 --- a/lib/static/generation_manager.dart +++ b/lib/static/generation_manager.dart @@ -3,7 +3,7 @@ import 'package:maid/providers/character.dart'; import 'package:maid/providers/session.dart'; import 'package:maid/static/remote_generation.dart'; import 'package:maid/static/local_generation.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/classes/generation_options.dart'; import 'package:provider/provider.dart'; @@ -12,10 +12,10 @@ class GenerationManager { context.read().busy = true; GenerationOptions options = GenerationOptions( - model: context.read(), - character: context.read(), - session: context.read()); - + ai: context.read(), + character: context.read(), + session: context.read()); + if (options.apiType == ApiType.local) { LocalGeneration.prompt(input, options, context.read().stream); } else { diff --git a/lib/static/remote_generation.dart b/lib/static/remote_generation.dart index 943883a5..a1fe7ad0 100644 --- a/lib/static/remote_generation.dart +++ b/lib/static/remote_generation.dart @@ -3,7 +3,7 @@ import 'dart:io'; import 'package:http/http.dart'; import 'package:maid/classes/generation_options.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/static/logger.dart'; import 'package:permission_handler/permission_handler.dart'; import 'package:device_info_plus/device_info_plus.dart'; @@ -158,15 +158,15 @@ class RemoteGeneration { } } - static Future> getOptions(Model model) async { - switch (model.apiType) { + static Future> getOptions(AiPlatform ai) async { + switch (ai.apiType) { case ApiType.ollama: bool permissionGranted = await _requestPermission(); if (!permissionGranted) { return []; } - final url = Uri.parse("${model.parameters["remote_url"]}/api/tags"); + final url = Uri.parse("${ai.parameters["remote_url"]}/api/tags"); final headers = {"Accept": "application/json"}; try { diff --git a/lib/widgets/chat_widgets/chat_field.dart b/lib/widgets/chat_widgets/chat_field.dart index 26954451..fa8934a1 100644 --- a/lib/widgets/chat_widgets/chat_field.dart +++ b/lib/widgets/chat_widgets/chat_field.dart @@ -3,7 +3,7 @@ import 'dart:async'; import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/providers/session.dart'; import 'package:maid/static/generation_manager.dart'; import 'package:maid/static/logger.dart'; @@ -82,7 +82,7 @@ class _ChatFieldState extends State { child: Row( children: [ if (session.isBusy && - context.read().apiType != ApiType.ollama) + context.read().apiType != ApiType.ollama) const IconButton( onPressed: GenerationManager.stop, iconSize: 50, diff --git a/lib/widgets/dropdowns/api_dropdown.dart b/lib/widgets/dropdowns/api_dropdown.dart index fb20375f..182a6067 100644 --- a/lib/widgets/dropdowns/api_dropdown.dart +++ b/lib/widgets/dropdowns/api_dropdown.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:provider/provider.dart'; class ApiDropdown extends StatelessWidget { @@ -7,44 +7,42 @@ class ApiDropdown extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer( - builder: (context, model, child) { - return ListTile( - title: Row( - children: [ - const Expanded( - child: Text("API Type"), - ), - DropdownMenu( - dropdownMenuEntries: const [ - DropdownMenuEntry( - value: ApiType.local, - label: "Local", - ), - DropdownMenuEntry( - value: ApiType.ollama, - label: "Ollama", - ), - DropdownMenuEntry( - value: ApiType.openAI, - label: "OpenAI", - ), - DropdownMenuEntry( - value: ApiType.mistralAI, - label: "MistralAI", - ), - ], - onSelected: (ApiType? value) { - if (value != null) { - model.setApiType(value); - } - }, - initialSelection: model.apiType, - width: 200, - ) - ], - ) - ); + return Consumer(builder: (context, ai, child) { + return ListTile( + title: Row( + children: [ + const Expanded( + child: Text("API Type"), + ), + DropdownMenu( + dropdownMenuEntries: const [ + DropdownMenuEntry( + value: ApiType.local, + label: "Local", + ), + DropdownMenuEntry( + value: ApiType.ollama, + label: "Ollama", + ), + DropdownMenuEntry( + value: ApiType.openAI, + label: "OpenAI", + ), + DropdownMenuEntry( + value: ApiType.mistralAI, + label: "MistralAI", + ), + ], + onSelected: (ApiType? value) { + if (value != null) { + ai.apiType = value; + } + }, + initialSelection: ai.apiType, + width: 200, + ) + ], + )); }); } -} \ No newline at end of file +} diff --git a/lib/widgets/dropdowns/format_dropdown.dart b/lib/widgets/dropdowns/format_dropdown.dart index 80182dd8..f7f046ca 100644 --- a/lib/widgets/dropdowns/format_dropdown.dart +++ b/lib/widgets/dropdowns/format_dropdown.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:llama_cpp_dart/llama_cpp_dart.dart'; import 'package:provider/provider.dart'; @@ -8,7 +8,7 @@ class FormatDropdown extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return ListTile( title: Row( children: [ @@ -32,10 +32,10 @@ class FormatDropdown extends StatelessWidget { ], onSelected: (PromptFormatType? value) { if (value != null) { - model.setPromptFormat(value); + ai.promptFormat = value; } }, - initialSelection: model.format, + initialSelection: ai.format, width: 200, ) ], diff --git a/lib/widgets/dropdowns/model_dropdown.dart b/lib/widgets/dropdowns/model_dropdown.dart index 9b440c6f..de05a48e 100644 --- a/lib/widgets/dropdowns/model_dropdown.dart +++ b/lib/widgets/dropdowns/model_dropdown.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:provider/provider.dart'; class ModelDropdown extends StatelessWidget { @@ -7,43 +7,42 @@ class ModelDropdown extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer( - builder: (context, model, child) { - return ListTile( - title: Row( - children: [ - const Expanded( - child: Text("Remote Model"), - ), - FutureBuilder>( - future: model.getOptions(), - builder: (BuildContext context, AsyncSnapshot> snapshot) { - if (snapshot.data == null) { - return const SizedBox(height: 8.0); - } - - List> dropdownEntries = snapshot.data! + return Consumer(builder: (context, ai, child) { + return ListTile( + title: Row( + children: [ + const Expanded( + child: Text("Remote Model"), + ), + FutureBuilder>( + future: ai.getOptions(), + builder: + (BuildContext context, AsyncSnapshot> snapshot) { + if (snapshot.data == null) { + return const SizedBox(height: 8.0); + } + + List> dropdownEntries = snapshot.data! .map((String modelName) => DropdownMenuEntry( value: modelName, label: modelName, )) .toList(); - - return DropdownMenu( - dropdownMenuEntries: dropdownEntries, - onSelected: (String? value) { - if (value != null) { - model.setParameter("remote_model", value); - } - }, - initialSelection: model.parameters["remote_model"] ?? "", - width: 200, - ); - }, - ), - ], - ) - ); + + return DropdownMenu( + dropdownMenuEntries: dropdownEntries, + onSelected: (String? value) { + if (value != null) { + ai.model = value; + } + }, + initialSelection: ai.model, + width: 200, + ); + }, + ), + ], + )); }); } -} \ No newline at end of file +} diff --git a/lib/widgets/parameter_widgets/boolean_parameter.dart b/lib/widgets/parameter_widgets/boolean_parameter.dart index 1c80eccf..10dba403 100644 --- a/lib/widgets/parameter_widgets/boolean_parameter.dart +++ b/lib/widgets/parameter_widgets/boolean_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:provider/provider.dart'; class BooleanParameter extends StatelessWidget { @@ -11,12 +11,12 @@ class BooleanParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SwitchListTile( title: Text(title), - value: model.parameters[parameter] ?? false, + value: ai.parameters[parameter] ?? false, onChanged: (value) { - model.setParameter(parameter, value); + ai.setParameter(parameter, value); }, ); }); diff --git a/lib/widgets/parameter_widgets/mirostat_eta_parameter.dart b/lib/widgets/parameter_widgets/mirostat_eta_parameter.dart index 9f5d6d50..42c69909 100644 --- a/lib/widgets/parameter_widgets/mirostat_eta_parameter.dart +++ b/lib/widgets/parameter_widgets/mirostat_eta_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class MirostatEtaParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'mirostat_eta', - inputValue: model.parameters["mirostat_eta"] ?? 0.1, + inputValue: ai.parameters["mirostat_eta"] ?? 0.1, sliderMin: 0.0, sliderMax: 1.0, sliderDivisions: 100, onValueChanged: (value) { - model.setParameter("mirostat_eta", value); + ai.setParameter("mirostat_eta", value); }); }); } diff --git a/lib/widgets/parameter_widgets/mirostat_parameter.dart b/lib/widgets/parameter_widgets/mirostat_parameter.dart index 41c79f7f..d38b12b8 100644 --- a/lib/widgets/parameter_widgets/mirostat_parameter.dart +++ b/lib/widgets/parameter_widgets/mirostat_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class MirostatParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'mirostat', - inputValue: model.parameters["mirostat"] ?? 0.0, + inputValue: ai.parameters["mirostat"] ?? 0.0, sliderMin: 0.0, sliderMax: 128.0, sliderDivisions: 127, onValueChanged: (value) { - model.setParameter("mirostat", value.round()); + ai.setParameter("mirostat", value.round()); }); }); } diff --git a/lib/widgets/parameter_widgets/mirostat_tau_parameter.dart b/lib/widgets/parameter_widgets/mirostat_tau_parameter.dart index 1c3a10b1..bdbfbb41 100644 --- a/lib/widgets/parameter_widgets/mirostat_tau_parameter.dart +++ b/lib/widgets/parameter_widgets/mirostat_tau_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class MirostatTauParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'mirostat_tau', - inputValue: model.parameters["mirostat_tau"] ?? 5.0, + inputValue: ai.parameters["mirostat_tau"] ?? 5.0, sliderMin: 0.0, sliderMax: 10.0, sliderDivisions: 100, onValueChanged: (value) { - model.setParameter("mirostat_tau", value); + ai.setParameter("mirostat_tau", value); }); }); } diff --git a/lib/widgets/parameter_widgets/n_batch_parameter.dart b/lib/widgets/parameter_widgets/n_batch_parameter.dart index aa5a9ff6..0aba98bb 100644 --- a/lib/widgets/parameter_widgets/n_batch_parameter.dart +++ b/lib/widgets/parameter_widgets/n_batch_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class NBatchParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'n_batch', - inputValue: model.parameters["n_batch"] ?? 512, + inputValue: ai.parameters["n_batch"] ?? 512, sliderMin: 1.0, sliderMax: 4096.0, sliderDivisions: 4095, onValueChanged: (value) { - model.setParameter("n_batch", value.round()); + ai.setParameter("n_batch", value.round()); }); }); } diff --git a/lib/widgets/parameter_widgets/n_ctx_parameter.dart b/lib/widgets/parameter_widgets/n_ctx_parameter.dart index ac61be85..dfabae07 100644 --- a/lib/widgets/parameter_widgets/n_ctx_parameter.dart +++ b/lib/widgets/parameter_widgets/n_ctx_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class NCtxParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'n_ctx', - inputValue: model.parameters["n_ctx"] ?? 512, + inputValue: ai.parameters["n_ctx"] ?? 512, sliderMin: 1.0, sliderMax: 4096.0, sliderDivisions: 4095, onValueChanged: (value) { - model.setParameter("n_ctx", value.round()); + ai.setParameter("n_ctx", value.round()); }); }); } diff --git a/lib/widgets/parameter_widgets/n_keep_parameter.dart b/lib/widgets/parameter_widgets/n_keep_parameter.dart index 8a1f15bb..594af3af 100644 --- a/lib/widgets/parameter_widgets/n_keep_parameter.dart +++ b/lib/widgets/parameter_widgets/n_keep_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class NKeepParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'n_keep', - inputValue: model.parameters["n_keep"] ?? 48, + inputValue: ai.parameters["n_keep"] ?? 48, sliderMin: 1.0, sliderMax: 1024.0, sliderDivisions: 1023, onValueChanged: (value) { - model.setParameter("n_keep", value.round()); + ai.setParameter("n_keep", value.round()); }); }); } diff --git a/lib/widgets/parameter_widgets/n_predict_parameter.dart b/lib/widgets/parameter_widgets/n_predict_parameter.dart index 42977d3c..0cb648b8 100644 --- a/lib/widgets/parameter_widgets/n_predict_parameter.dart +++ b/lib/widgets/parameter_widgets/n_predict_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class NPredictParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'n_predict', - inputValue: model.parameters["n_predict"] ?? 512, + inputValue: ai.parameters["n_predict"] ?? 512, sliderMin: 1.0, sliderMax: 1024.0, sliderDivisions: 1023, onValueChanged: (value) { - model.setParameter("n_predict", value.round()); + ai.setParameter("n_predict", value.round()); }); }); } diff --git a/lib/widgets/parameter_widgets/n_threads_parameter.dart b/lib/widgets/parameter_widgets/n_threads_parameter.dart index 561e63fd..d6d0d45e 100644 --- a/lib/widgets/parameter_widgets/n_threads_parameter.dart +++ b/lib/widgets/parameter_widgets/n_threads_parameter.dart @@ -1,7 +1,7 @@ import 'dart:io'; import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -10,20 +10,19 @@ class NThreadsParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'n_threads', - inputValue: - model.parameters["n_threads"] ?? Platform.numberOfProcessors, + inputValue: ai.parameters["n_threads"] ?? Platform.numberOfProcessors, sliderMin: 1.0, - sliderMax: model.apiType == ApiType.local + sliderMax: ai.apiType == ApiType.local ? Platform.numberOfProcessors.toDouble() : 128.0, sliderDivisions: 127, onValueChanged: (value) { - model.setParameter("n_threads", value.round()); - if (model.parameters["n_threads"] > Platform.numberOfProcessors) { - model.setParameter("n_threads", Platform.numberOfProcessors); + ai.setParameter("n_threads", value.round()); + if (ai.parameters["n_threads"] > Platform.numberOfProcessors) { + ai.setParameter("n_threads", Platform.numberOfProcessors); } }); }); diff --git a/lib/widgets/parameter_widgets/penalty_frequency_parameter.dart b/lib/widgets/parameter_widgets/penalty_frequency_parameter.dart index f2ad0dec..2f42fcb1 100644 --- a/lib/widgets/parameter_widgets/penalty_frequency_parameter.dart +++ b/lib/widgets/parameter_widgets/penalty_frequency_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class PenaltyFrequencyParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'penalty_freq', - inputValue: model.parameters["penalty_freq"] ?? 0.0, + inputValue: ai.parameters["penalty_freq"] ?? 0.0, sliderMin: 0.0, sliderMax: 1.0, sliderDivisions: 100, onValueChanged: (value) { - model.setParameter("penalty_freq", value); + ai.setParameter("penalty_freq", value); }); }); } diff --git a/lib/widgets/parameter_widgets/penalty_last_n_parameter.dart b/lib/widgets/parameter_widgets/penalty_last_n_parameter.dart index d69a80c5..b13a8fd0 100644 --- a/lib/widgets/parameter_widgets/penalty_last_n_parameter.dart +++ b/lib/widgets/parameter_widgets/penalty_last_n_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class PenaltyLastNParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'penalty_last_n', - inputValue: model.parameters["penalty_last_n"] ?? 64, + inputValue: ai.parameters["penalty_last_n"] ?? 64, sliderMin: 0.0, sliderMax: 128.0, sliderDivisions: 127, onValueChanged: (value) { - model.setParameter("penalty_last_n", value.round()); + ai.setParameter("penalty_last_n", value.round()); }); }); } diff --git a/lib/widgets/parameter_widgets/penalty_present_parameter.dart b/lib/widgets/parameter_widgets/penalty_present_parameter.dart index e156a757..e9ec97ef 100644 --- a/lib/widgets/parameter_widgets/penalty_present_parameter.dart +++ b/lib/widgets/parameter_widgets/penalty_present_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class PenaltyPresentParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'penalty_present', - inputValue: model.parameters["penalty_present"] ?? 0.0, + inputValue: ai.parameters["penalty_present"] ?? 0.0, sliderMin: 0.0, sliderMax: 1.0, sliderDivisions: 100, onValueChanged: (value) { - model.setParameter("penalty_present", value); + ai.setParameter("penalty_present", value); }); }); } diff --git a/lib/widgets/parameter_widgets/penalty_repeat_parameter.dart b/lib/widgets/parameter_widgets/penalty_repeat_parameter.dart index 83d55a18..725c0dd4 100644 --- a/lib/widgets/parameter_widgets/penalty_repeat_parameter.dart +++ b/lib/widgets/parameter_widgets/penalty_repeat_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class PenaltyRepeatParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'penalty_repeat', - inputValue: model.parameters["penalty_repeat"] ?? 1.1, + inputValue: ai.parameters["penalty_repeat"] ?? 1.1, sliderMin: 0.0, sliderMax: 2.0, sliderDivisions: 200, onValueChanged: (value) { - model.setParameter("penalty_repeat", value); + ai.setParameter("penalty_repeat", value); }); }); } diff --git a/lib/widgets/parameter_widgets/seed_parameter.dart b/lib/widgets/parameter_widgets/seed_parameter.dart index 929d81aa..79a2fa91 100644 --- a/lib/widgets/parameter_widgets/seed_parameter.dart +++ b/lib/widgets/parameter_widgets/seed_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:provider/provider.dart'; class SeedParameter extends StatelessWidget { @@ -8,16 +8,16 @@ class SeedParameter extends StatelessWidget { @override Widget build(BuildContext context) { TextEditingController controller = TextEditingController( - text: context.read().parameters["seed"]?.toString() ?? ""); + text: context.read().parameters["seed"]?.toString() ?? ""); - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return Column( children: [ SwitchListTile( title: const Text('random_seed'), - value: model.parameters["random_seed"] ?? true, + value: ai.parameters["random_seed"] ?? true, onChanged: (value) { - model.setParameter("random_seed", value); + ai.setParameter("random_seed", value); }, ), Divider( @@ -26,7 +26,7 @@ class SeedParameter extends StatelessWidget { endIndent: 10, color: Theme.of(context).colorScheme.primary, ), - if (!(model.parameters["random_seed"] ?? true)) + if (!(ai.parameters["random_seed"] ?? true)) ListTile( title: Row( children: [ @@ -41,7 +41,7 @@ class SeedParameter extends StatelessWidget { labelText: 'seed', ), onChanged: (value) { - model.setParameter("seed", int.parse(value)); + ai.setParameter("seed", int.parse(value)); }, ), ), diff --git a/lib/widgets/parameter_widgets/string_parameter.dart b/lib/widgets/parameter_widgets/string_parameter.dart index 198bc393..2b37d7f5 100644 --- a/lib/widgets/parameter_widgets/string_parameter.dart +++ b/lib/widgets/parameter_widgets/string_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/text_field_list_tile.dart'; import 'package:provider/provider.dart'; @@ -12,14 +12,14 @@ class StringParameter extends StatelessWidget { @override Widget build(BuildContext context) { - _controller.text = context.read().parameters[parameter] ?? ""; + _controller.text = context.read().parameters[parameter] ?? ""; return TextFieldListTile( headingText: title, labelText: title, controller: _controller, onChanged: (value) { - context.read().setParameter(parameter, value); + context.read().setParameter(parameter, value); }); } } diff --git a/lib/widgets/parameter_widgets/temperature_parameter.dart b/lib/widgets/parameter_widgets/temperature_parameter.dart index 5e8bac24..78b06cd7 100644 --- a/lib/widgets/parameter_widgets/temperature_parameter.dart +++ b/lib/widgets/parameter_widgets/temperature_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class TemperatureParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'temperature', - inputValue: model.parameters["temperature"] ?? 0.8, + inputValue: ai.parameters["temperature"] ?? 0.8, sliderMin: 0.0, sliderMax: 1.0, sliderDivisions: 100, onValueChanged: (value) { - model.setParameter("temperature", value); + ai.setParameter("temperature", value); }); }); } diff --git a/lib/widgets/parameter_widgets/tfs_z_parameter.dart b/lib/widgets/parameter_widgets/tfs_z_parameter.dart index 825009b0..abf62e6c 100644 --- a/lib/widgets/parameter_widgets/tfs_z_parameter.dart +++ b/lib/widgets/parameter_widgets/tfs_z_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class TfsZParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'tfs_z', - inputValue: model.parameters["tfs_z"] ?? 1.0, + inputValue: ai.parameters["tfs_z"] ?? 1.0, sliderMin: 0.0, sliderMax: 1.0, sliderDivisions: 100, onValueChanged: (value) { - model.setParameter("tfs_z", value); + ai.setParameter("tfs_z", value); }); }); } diff --git a/lib/widgets/parameter_widgets/top_k_parameter.dart b/lib/widgets/parameter_widgets/top_k_parameter.dart index a68f75e4..0ee68bda 100644 --- a/lib/widgets/parameter_widgets/top_k_parameter.dart +++ b/lib/widgets/parameter_widgets/top_k_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class TopKParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'top_k', - inputValue: model.parameters["top_k"] ?? 40, + inputValue: ai.parameters["top_k"] ?? 40, sliderMin: 1.0, sliderMax: 128.0, sliderDivisions: 127, onValueChanged: (value) { - model.setParameter("top_k", value.round()); + ai.setParameter("top_k", value.round()); }); }); } diff --git a/lib/widgets/parameter_widgets/top_p_parameter.dart b/lib/widgets/parameter_widgets/top_p_parameter.dart index 0cee326d..0945b287 100644 --- a/lib/widgets/parameter_widgets/top_p_parameter.dart +++ b/lib/widgets/parameter_widgets/top_p_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class TopPParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'top_p', - inputValue: model.parameters["top_p"] ?? 0.95, + inputValue: ai.parameters["top_p"] ?? 0.95, sliderMin: 0.0, sliderMax: 1.0, sliderDivisions: 100, onValueChanged: (value) { - model.setParameter("top_p", value); + ai.setParameter("top_p", value); }); }); } diff --git a/lib/widgets/parameter_widgets/typical_p_parameter.dart b/lib/widgets/parameter_widgets/typical_p_parameter.dart index 69bbdaca..44284636 100644 --- a/lib/widgets/parameter_widgets/typical_p_parameter.dart +++ b/lib/widgets/parameter_widgets/typical_p_parameter.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/slider_list_tile.dart'; import 'package:provider/provider.dart'; @@ -8,15 +8,15 @@ class TypicalPParameter extends StatelessWidget { @override Widget build(BuildContext context) { - return Consumer(builder: (context, model, child) { + return Consumer(builder: (context, ai, child) { return SliderListTile( labelText: 'typical_p', - inputValue: model.parameters["typical_p"] ?? 1.0, + inputValue: ai.parameters["typical_p"] ?? 1.0, sliderMin: 0.0, sliderMax: 1.0, sliderDivisions: 100, onValueChanged: (value) { - model.setParameter("typical_p", value); + ai.setParameter("typical_p", value); }); }); } diff --git a/lib/widgets/platform_widgets/local_platform.dart b/lib/widgets/platform_widgets/local_platform.dart index 9759ea25..22545613 100644 --- a/lib/widgets/platform_widgets/local_platform.dart +++ b/lib/widgets/platform_widgets/local_platform.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:maid/providers/model.dart'; +import 'package:maid/providers/ai_platform.dart'; import 'package:maid/widgets/dialogs.dart'; import 'package:maid/widgets/parameter_widgets/boolean_parameter.dart'; import 'package:maid/widgets/parameter_widgets/mirostat_eta_parameter.dart'; @@ -43,7 +43,7 @@ class LocalPlatform extends StatelessWidget { Expanded( flex: 2, child: Text( - context.watch().parameters["path"] ?? "None", + context.watch().parameters["path"] ?? "None", textAlign: TextAlign.end, ), ), @@ -55,11 +55,11 @@ class LocalPlatform extends StatelessWidget { leftText: "Load GGUF", leftOnPressed: () { storageOperationDialog( - context, context.read().loadModelFile); + context, context.read().loadModelFile); }, rightText: "Unload GGUF", rightOnPressed: () { - context.read().setParameter("path", ""); + context.read().setParameter("path", ""); }), Divider( height: 20, diff --git a/pubspec.lock b/pubspec.lock index 7776ead1..5ce73369 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -440,30 +440,6 @@ packages: url: "https://pub.dev" source: hosted version: "0.3.1" - leak_tracker: - dependency: transitive - description: - name: leak_tracker - sha256: "78eb209deea09858f5269f5a5b02be4049535f568c07b275096836f01ea323fa" - url: "https://pub.dev" - source: hosted - version: "10.0.0" - leak_tracker_flutter_testing: - dependency: transitive - description: - name: leak_tracker_flutter_testing - sha256: b46c5e37c19120a8a01918cfaf293547f47269f7cb4b0058f21531c2465d6ef0 - url: "https://pub.dev" - source: hosted - version: "2.0.1" - leak_tracker_testing: - dependency: transitive - description: - name: leak_tracker_testing - sha256: a597f72a664dbd293f3bfc51f9ba69816f84dcd403cdac7066cb3f6003f3ab47 - url: "https://pub.dev" - source: hosted - version: "2.0.1" linkify: dependency: "direct main" description: @@ -499,18 +475,18 @@ packages: dependency: transitive description: name: matcher - sha256: d2323aa2060500f906aa31a895b4030b6da3ebdcc5619d14ce1aada65cd161cb + sha256: "1803e76e6653768d64ed8ff2e1e67bea3ad4b923eb5c56a295c3e634bad5960e" url: "https://pub.dev" source: hosted - version: "0.12.16+1" + version: "0.12.16" material_color_utilities: dependency: transitive description: name: material_color_utilities - sha256: "0e0a020085b65b6083975e499759762399b4475f766c21668c4ecca34ea74e5a" + sha256: "9528f2f296073ff54cb9fee677df673ace1218163c3bc7628093e7eed5203d41" url: "https://pub.dev" source: hosted - version: "0.8.0" + version: "0.5.0" math_expressions: dependency: transitive description: @@ -531,10 +507,10 @@ packages: dependency: transitive description: name: meta - sha256: d584fa6707a52763a52446f02cc621b077888fb63b93bbcb1143a7be5a0c0c04 + sha256: a6e590c838b18133bb482a2745ad77c5bb7715fb0451209e1a7567d416678b8e url: "https://pub.dev" source: hosted - version: "1.11.0" + version: "1.10.0" mistralai_dart: dependency: transitive description: @@ -587,10 +563,10 @@ packages: dependency: "direct main" description: name: path - sha256: "087ce49c3f0dc39180befefc60fdb4acd8f8620e5682fe2476afd0b3688bb4af" + sha256: "8829d8a55c13fc0e37127c29fedf290c102f4e40ae94ada574091fe0ff96c917" url: "https://pub.dev" source: hosted - version: "1.9.0" + version: "1.8.3" path_provider: dependency: "direct main" description: @@ -988,14 +964,6 @@ packages: url: "https://pub.dev" source: hosted version: "0.0.7+2" - vm_service: - dependency: transitive - description: - name: vm_service - sha256: b3d56ff4341b8f182b96aceb2fa20e3dcb336b9f867bc0eafc0de10f1048e957 - url: "https://pub.dev" - source: hosted - version: "13.0.0" web: dependency: transitive description: