Skip to content

Commit

Permalink
Merge pull request #358 from danemadsen/main
Browse files Browse the repository at this point in the history
refactor
  • Loading branch information
danemadsen committed Feb 22, 2024
2 parents 96885ee + 3598385 commit 8a44caf
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 29 deletions.
12 changes: 4 additions & 8 deletions lib/classes/generation_options.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class GenerationOptions {
late PromptFormatType _promptFormat;
late ApiType _apiType;
late String? _apiKey;
late String? _remoteModel;
late String? _path;
late String _model;
late String _description;
late String _personality;
late String _scenario;
Expand Down Expand Up @@ -44,8 +43,7 @@ class GenerationOptions {
PromptFormatType get promptFormat => _promptFormat;
ApiType get apiType => _apiType;
String? get apiKey => _apiKey;
String? get remoteModel => _remoteModel;
String? get path => _path;
String get model => _model;
String get description => _description;
String get personality => _personality;
String get scenario => _scenario;
Expand Down Expand Up @@ -78,8 +76,7 @@ class GenerationOptions {
map["prompt_format"] = _promptFormat.index;
map["api_type"] = _apiType.index;
map["api_key"] = _apiKey;
map["remote_model"] = _remoteModel;
map["path"] = _path;
map["model"] = _model;
map["description"] = _description;
map["personality"] = _personality;
map["scenario"] = _scenario;
Expand Down Expand Up @@ -134,8 +131,7 @@ class GenerationOptions {
_promptFormat = ai.format;
_apiType = ai.apiType;
_apiKey = ai.parameters["api_key"];
_remoteModel = ai.parameters["remote_model"];
_path = ai.parameters["path"];
_model = ai.model;

_description = replaceCaseInsensitive(
character.description, "{{char}}", character.name);
Expand Down
2 changes: 1 addition & 1 deletion lib/providers/ai_platform.dart
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class AiPlatform extends ChangeNotifier {

Logger.log("Loading model from $file");

_parameters["path"] = file.path;
_model = file.path;
} catch (e) {
return "Error: $e";
}
Expand Down
20 changes: 5 additions & 15 deletions lib/static/local_generation.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@ class LocalGeneration {
static Timer? _timer;
static DateTime? _startTime;

static void prompt(
String input,
GenerationOptions options,
void Function(String?) callback
) async {
static void prompt(String input, GenerationOptions options,
void Function(String?) callback) async {
_timer = null;
_startTime = null;
_completer = Completer();
Expand All @@ -40,13 +37,9 @@ class LocalGeneration {
samplingParams.mirostatTau = options.mirostatTau;
samplingParams.mirostatEta = options.mirostatEta;
samplingParams.penalizeNl = options.penalizeNewline;

_llamaProcessor = LlamaProcessor(
options.path!,
modelParams,
contextParams,
samplingParams
);
options.model, modelParams, contextParams, samplingParams);

List<Map<String, dynamic>> messages = [
{
Expand Down Expand Up @@ -75,10 +68,7 @@ class LocalGeneration {
break;
}

messages.add({
'role': 'system',
'content': options.system
});
messages.add({'role': 'system', 'content': options.system});
}

_llamaProcessor!.messages = options.messages;
Expand Down
6 changes: 3 additions & 3 deletions lib/static/remote_generation.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class RemoteGeneration {
final chat = ChatOllama(
baseUrl: '${options.remoteUrl}/api',
defaultOptions: ChatOllamaOptions(
model: options.remoteModel ?? 'llama2',
model: options.model,
numKeep: options.nKeep,
seed: options.seed,
numPredict: options.nPredict,
Expand Down Expand Up @@ -58,7 +58,7 @@ class RemoteGeneration {
baseUrl: options.remoteUrl,
apiKey: options.apiKey,
defaultOptions: ChatOpenAIOptions(
model: options.remoteModel ?? 'gpt-3.5-turbo',
model: options.model,
temperature: options.temperature,
frequencyPenalty: options.penaltyFreq,
presencePenalty: options.penaltyPresent,
Expand Down Expand Up @@ -86,7 +86,7 @@ class RemoteGeneration {
baseUrl: '${options.remoteUrl}/v1',
apiKey: options.apiKey,
defaultOptions: ChatMistralAIOptions(
model: options.remoteModel ?? 'mistral-small',
model: options.model,
topP: options.topP,
temperature: options.temperature,
),
Expand Down
4 changes: 2 additions & 2 deletions lib/widgets/platform_widgets/local_platform.dart
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class LocalPlatform extends StatelessWidget {
Expanded(
flex: 2,
child: Text(
context.watch<AiPlatform>().parameters["path"] ?? "None",
context.watch<AiPlatform>().model,
textAlign: TextAlign.end,
),
),
Expand All @@ -59,7 +59,7 @@ class LocalPlatform extends StatelessWidget {
},
rightText: "Unload GGUF",
rightOnPressed: () {
context.read<AiPlatform>().setParameter("path", "");
context.read<AiPlatform>().model = "";
}),
Divider(
height: 20,
Expand Down

0 comments on commit 8a44caf

Please sign in to comment.