Skip to content

Commit

Permalink
Merge pull request #376 from danemadsen/main
Browse files Browse the repository at this point in the history
Add option to use default ollama options
  • Loading branch information
danemadsen committed Feb 29, 2024
2 parents 39ebb87 + b8f2836 commit 34917bc
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 31 deletions.
19 changes: 13 additions & 6 deletions lib/providers/ai_platform.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class AiPlatform extends ChangeNotifier {
String _model = "";

bool _randomSeed = true;
bool _useDefault = true;

int _nKeep = 48;
int _seed = 0;
Expand Down Expand Up @@ -72,15 +73,15 @@ class AiPlatform extends ChangeNotifier {
case AiPlatformType.ollama:
_url = await GenerationManager.getOllamaUrl();
case AiPlatformType.openAI:
_url = "https://api.openai.com/v1/";
_url = "https://api.openai.com/v1/";
case AiPlatformType.mistralAI:
_url = "https://api.mistral.ai/v1/";
_url = "https://api.mistral.ai/v1/";
default:
_url = "";
_url = "";
}

_apiType = apiType;

notifyListeners();
}

Expand Down Expand Up @@ -114,6 +115,11 @@ class AiPlatform extends ChangeNotifier {
notifyListeners();
}

set useDefault(bool useDefault) {
_useDefault = useDefault;
notifyListeners();
}

set nKeep(int nKeep) {
_nKeep = nKeep;
notifyListeners();
Expand Down Expand Up @@ -221,6 +227,7 @@ class AiPlatform extends ChangeNotifier {
String get url => _url;
String get model => _model;
bool get randomSeed => _randomSeed;
bool get useDefault => _useDefault;
int get nKeep => _nKeep;
int get seed => _seed;
int get nPredict => _nPredict;
Expand Down Expand Up @@ -250,8 +257,8 @@ class AiPlatform extends ChangeNotifier {
if (inputJson.isEmpty) {
resetAll();
} else {
_promptFormat = PromptFormatType
.values[inputJson["prompt_promptFormat"] ?? PromptFormatType.alpaca.index];
_promptFormat = PromptFormatType.values[
inputJson["prompt_promptFormat"] ?? PromptFormatType.alpaca.index];
_apiType = AiPlatformType
.values[inputJson["api_type"] ?? AiPlatformType.local.index];
_preset = inputJson["preset"] ?? "Default";
Expand Down
63 changes: 38 additions & 25 deletions lib/static/generation_manager.dart
Original file line number Diff line number Diff line change
Expand Up @@ -176,28 +176,38 @@ class GenerationManager {
static void ollamaRequest(List<ChatMessage> chatMessages, AiPlatform ai,
void Function(String?) callback) async {
try {
final chat = ChatOllama(
baseUrl: '${ai.url}/api',
defaultOptions: ChatOllamaOptions(
model: ai.model,
numKeep: ai.nKeep,
seed: ai.seed,
numPredict: ai.nPredict,
topK: ai.topK,
topP: ai.topP,
typicalP: ai.typicalP,
temperature: ai.temperature,
repeatPenalty: ai.penaltyRepeat,
frequencyPenalty: ai.penaltyFreq,
presencePenalty: ai.penaltyPresent,
mirostat: ai.mirostat,
mirostatTau: ai.mirostatTau,
mirostatEta: ai.mirostatEta,
numCtx: ai.nCtx,
numBatch: ai.nBatch,
numThread: ai.nThread,
),
);
ChatOllama chat;
if (ai.useDefault) {
chat = ChatOllama(
baseUrl: '${ai.url}/api',
defaultOptions: ChatOllamaOptions(
model: ai.model,
),
);
} else {
chat = ChatOllama(
baseUrl: '${ai.url}/api',
defaultOptions: ChatOllamaOptions(
model: ai.model,
numKeep: ai.nKeep,
seed: ai.seed,
numPredict: ai.nPredict,
topK: ai.topK,
topP: ai.topP,
typicalP: ai.typicalP,
temperature: ai.temperature,
repeatPenalty: ai.penaltyRepeat,
frequencyPenalty: ai.penaltyFreq,
presencePenalty: ai.penaltyPresent,
mirostat: ai.mirostat,
mirostatTau: ai.mirostatTau,
mirostatEta: ai.mirostatEta,
numCtx: ai.nCtx,
numBatch: ai.nBatch,
numThread: ai.nThread,
),
);
}

final stream = chat.stream(PromptValue.chat(chatMessages));

Expand Down Expand Up @@ -361,9 +371,11 @@ class GenerationManager {
if (!Platform.isAndroid && !Platform.isIOS) {
return true;
}

// Get sdk version
final sdk = await DeviceInfoPlugin().androidInfo.then((value) => value.version.sdkInt);
final sdk = await DeviceInfoPlugin()
.androidInfo
.then((value) => value.version.sdkInt);
var permissions = <Permission>[]; // List of permissions to request

if (sdk <= 32) {
Expand All @@ -376,7 +388,8 @@ class GenerationManager {

// Request permissions and check if all are granted
Map<Permission, PermissionStatus> statuses = await permissions.request();
bool allPermissionsGranted = statuses.values.every((status) => status.isGranted);
bool allPermissionsGranted =
statuses.values.every((status) => status.isGranted);

if (allPermissionsGranted) {
Logger.log("Nearby Devices - permission granted");
Expand Down
20 changes: 20 additions & 0 deletions lib/ui/mobile/widgets/parameter_widgets/use_default.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import 'package:flutter/material.dart';
import 'package:maid/providers/ai_platform.dart';
import 'package:provider/provider.dart';

class UseDefaultParameter extends StatelessWidget {
const UseDefaultParameter({super.key});

@override
Widget build(BuildContext context) {
return Consumer<AiPlatform>(builder: (context, ai, child) {
return SwitchListTile(
title: const Text('Use Default Parameters'),
value: ai.useDefault,
onChanged: (value) {
ai.useDefault = value;
},
);
});
}
}
2 changes: 2 additions & 0 deletions lib/ui/mobile/widgets/platform_widgets/ollama_platform.dart
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import 'package:maid/ui/mobile/widgets/parameter_widgets/top_p_parameter.dart';
import 'package:maid/ui/mobile/widgets/parameter_widgets/typical_p_parameter.dart';
import 'package:maid/ui/mobile/widgets/dropdowns/model_dropdown.dart';
import 'package:maid/ui/mobile/widgets/parameter_widgets/url_parameter.dart';
import 'package:maid/ui/mobile/widgets/parameter_widgets/use_default.dart';

class OllamaPlatform extends StatelessWidget {
const OllamaPlatform({super.key});
Expand All @@ -45,6 +46,7 @@ class OllamaPlatform extends StatelessWidget {
endIndent: 10,
color: Theme.of(context).colorScheme.primary,
),
const UseDefaultParameter(),
const PenalizeNlParameter(),
const SeedParameter(),
const NThreadsParameter(),
Expand Down

0 comments on commit 34917bc

Please sign in to comment.