diff --git a/lib/classes/generation_options.dart b/lib/classes/generation_options.dart index e1fa8244..b5f6ddd3 100644 --- a/lib/classes/generation_options.dart +++ b/lib/classes/generation_options.dart @@ -12,8 +12,7 @@ class GenerationOptions { late PromptFormatType _promptFormat; late ApiType _apiType; late String? _apiKey; - late String? _remoteModel; - late String? _path; + late String _model; late String _description; late String _personality; late String _scenario; @@ -44,8 +43,7 @@ class GenerationOptions { PromptFormatType get promptFormat => _promptFormat; ApiType get apiType => _apiType; String? get apiKey => _apiKey; - String? get remoteModel => _remoteModel; - String? get path => _path; + String get model => _model; String get description => _description; String get personality => _personality; String get scenario => _scenario; @@ -78,8 +76,7 @@ class GenerationOptions { map["prompt_format"] = _promptFormat.index; map["api_type"] = _apiType.index; map["api_key"] = _apiKey; - map["remote_model"] = _remoteModel; - map["path"] = _path; + map["model"] = _model; map["description"] = _description; map["personality"] = _personality; map["scenario"] = _scenario; @@ -134,8 +131,7 @@ class GenerationOptions { _promptFormat = ai.format; _apiType = ai.apiType; _apiKey = ai.parameters["api_key"]; - _remoteModel = ai.parameters["remote_model"]; - _path = ai.parameters["path"]; + _model = ai.model; _description = replaceCaseInsensitive( character.description, "{{char}}", character.name); diff --git a/lib/providers/ai_platform.dart b/lib/providers/ai_platform.dart index 9c2c6e95..08f8a672 100644 --- a/lib/providers/ai_platform.dart +++ b/lib/providers/ai_platform.dart @@ -180,7 +180,7 @@ class AiPlatform extends ChangeNotifier { Logger.log("Loading model from $file"); - _parameters["path"] = file.path; + _model = file.path; } catch (e) { return "Error: $e"; } diff --git a/lib/static/local_generation.dart b/lib/static/local_generation.dart index e2ab60c5..9da993dc 100644 --- a/lib/static/local_generation.dart +++ b/lib/static/local_generation.dart @@ -9,11 +9,8 @@ class LocalGeneration { static Timer? _timer; static DateTime? _startTime; - static void prompt( - String input, - GenerationOptions options, - void Function(String?) callback - ) async { + static void prompt(String input, GenerationOptions options, + void Function(String?) callback) async { _timer = null; _startTime = null; _completer = Completer(); @@ -40,13 +37,9 @@ class LocalGeneration { samplingParams.mirostatTau = options.mirostatTau; samplingParams.mirostatEta = options.mirostatEta; samplingParams.penalizeNl = options.penalizeNewline; - + _llamaProcessor = LlamaProcessor( - options.path!, - modelParams, - contextParams, - samplingParams - ); + options.model, modelParams, contextParams, samplingParams); List> messages = [ { @@ -75,10 +68,7 @@ class LocalGeneration { break; } - messages.add({ - 'role': 'system', - 'content': options.system - }); + messages.add({'role': 'system', 'content': options.system}); } _llamaProcessor!.messages = options.messages; diff --git a/lib/static/remote_generation.dart b/lib/static/remote_generation.dart index a1fe7ad0..89e61f53 100644 --- a/lib/static/remote_generation.dart +++ b/lib/static/remote_generation.dart @@ -19,7 +19,7 @@ class RemoteGeneration { final chat = ChatOllama( baseUrl: '${options.remoteUrl}/api', defaultOptions: ChatOllamaOptions( - model: options.remoteModel ?? 'llama2', + model: options.model, numKeep: options.nKeep, seed: options.seed, numPredict: options.nPredict, @@ -58,7 +58,7 @@ class RemoteGeneration { baseUrl: options.remoteUrl, apiKey: options.apiKey, defaultOptions: ChatOpenAIOptions( - model: options.remoteModel ?? 'gpt-3.5-turbo', + model: options.model, temperature: options.temperature, frequencyPenalty: options.penaltyFreq, presencePenalty: options.penaltyPresent, @@ -86,7 +86,7 @@ class RemoteGeneration { baseUrl: '${options.remoteUrl}/v1', apiKey: options.apiKey, defaultOptions: ChatMistralAIOptions( - model: options.remoteModel ?? 'mistral-small', + model: options.model, topP: options.topP, temperature: options.temperature, ), diff --git a/lib/widgets/platform_widgets/local_platform.dart b/lib/widgets/platform_widgets/local_platform.dart index 22545613..b075ee70 100644 --- a/lib/widgets/platform_widgets/local_platform.dart +++ b/lib/widgets/platform_widgets/local_platform.dart @@ -43,7 +43,7 @@ class LocalPlatform extends StatelessWidget { Expanded( flex: 2, child: Text( - context.watch().parameters["path"] ?? "None", + context.watch().model, textAlign: TextAlign.end, ), ), @@ -59,7 +59,7 @@ class LocalPlatform extends StatelessWidget { }, rightText: "Unload GGUF", rightOnPressed: () { - context.read().setParameter("path", ""); + context.read().model = ""; }), Divider( height: 20,