diff --git a/.env b/.env index c6573ff..14c4431 100644 --- a/.env +++ b/.env @@ -46,6 +46,7 @@ MATRIX_INSTANCE= # for example lemmings.world, can be left empty if it's same as #### api keys SLACK_BOT_TOKEN= MATRIX_API_TOKEN= +AI_HORDE_API_KEY= #### other settings USE_LEMMYVERSE_LINK_SLACK=0 diff --git a/config/services.yaml b/config/services.yaml index e07df18..9097610 100644 --- a/config/services.yaml +++ b/config/services.yaml @@ -27,6 +27,7 @@ parameters: app.fediseer.api: '%env(FEDISEER_API_URL)%' app.fediseer.key: '%env(FEDISEER_API_KEY)%' + app.ai_horde.api_key: '%env(AI_HORDE_API_KEY)%' app.image_check.regex: '%env(IMAGE_CHECK_REGEX)%' @@ -88,7 +89,10 @@ services: arguments: $removalLogValidity: '@app.log_validity' - App\Service\ExpressionLanguage: + App\Service\Expression\ExpressionLanguage: calls: - - registerProvider: ['@App\Service\ExpressionLanguageFunctions'] - Symfony\Component\ExpressionLanguage\ExpressionLanguage: '@App\Service\ExpressionLanguage' + - registerProvider: ['@App\Service\Expression\ExpressionLanguageFunctions'] + - registerProvider: ['@App\Service\Expression\ExpressionLanguageAiFunctions'] + - registerProvider: ['@App\Service\Expression\ExpressionLanguageStringFunctions'] + - registerProvider: ['@App\Service\Expression\ExpressionLanguageLemmyFunctions'] + Symfony\Component\ExpressionLanguage\ExpressionLanguage: '@App\Service\Expression\ExpressionLanguage' diff --git a/src/Automod/ModAction/ComplexRuleAction.php b/src/Automod/ModAction/ComplexRuleAction.php index ff55c9b..9610d64 100644 --- a/src/Automod/ModAction/ComplexRuleAction.php +++ b/src/Automod/ModAction/ComplexRuleAction.php @@ -9,7 +9,7 @@ use App\Enum\FurtherAction; use App\Enum\RunConfiguration; use App\Repository\ComplexRuleRepository; -use App\Service\ExpressionLanguage; +use App\Service\Expression\ExpressionLanguage; use LogicException; use Rikudou\LemmyApi\LemmyApi; use Rikudou\LemmyApi\Response\Model\Person; diff --git a/src/Enum/AiActor.php b/src/Enum/AiActor.php new file mode 100644 index 0000000..5532ce4 --- /dev/null +++ b/src/Enum/AiActor.php @@ -0,0 +1,10 @@ +expressionLanguage->evaluate($message->expression, $message->context); + } +} diff --git a/src/Service/AiHorde/AiHorde.php b/src/Service/AiHorde/AiHorde.php new file mode 100644 index 0000000..157780a --- /dev/null +++ b/src/Service/AiHorde/AiHorde.php @@ -0,0 +1,134 @@ +apiKey) { + throw new LogicException('There is no api key set, cannot use AI actions'); + } + + $models = $this->findModels($model); + if (!count($models)) { + throw new LogicException('There was an error while looking for available models - no model able to handle your message seems to be online. Please try again later.'); + } + $formatter = $this->findFormatter($model) ?? throw new LogicException("Could not find formatter for {$model->value}"); + [$maxLength, $maxContextLength] = $this->getMaxLength($model); + + $response = $this->httpClient->request(Request::METHOD_POST, 'https://aihorde.net/api/v2/generate/text/async', [ + 'json' => [ + 'prompt' => $formatter->getPrompt(new MessageHistory( + ...[...$history, new Message(role: AiActor::User, content: $message)], + )), + 'params' => [ + 'max_length' => $maxLength, + 'max_context_length' => $maxContextLength, + ], + 'models' => $models, + ], + 'headers' => [ + 'apikey' => $this->apiKey, + ], + ]); + $json = json_decode($response->getContent(), true, flags: JSON_THROW_ON_ERROR); + $jobId = $json['id']; + + do { + $response = $this->httpClient->request(Request::METHOD_GET, "https://aihorde.net/api/v2/generate/text/status/{$jobId}", [ + 'headers' => [ + 'apikey' => $this->apiKey, + ], + ]); + $json = json_decode($response->getContent(), true, flags: JSON_THROW_ON_ERROR); + if (!$json['done']) { + sleep(1); + } + } while (!$json['done']); + + if (!isset($json['generations'][0])) { + throw new LogicException('Missing generations output'); + } + + $output = $formatter->formatOutput($json['generations'][0]['text']); + + return $output->content; + } + + /** + * @return array + */ + public function findModels(AiModel $model): array + { + $response = $this->httpClient->request(Request::METHOD_GET, 'https://aihorde.net/api/v2/status/models?type=text'); + $json = json_decode($response->getContent(), true, flags: JSON_THROW_ON_ERROR); + + return array_values(array_map( + fn (array $modelData) => $modelData['name'], + array_filter($json, fn (array $modelData) => fnmatch("*/{$model->value}", $modelData['name'])), + )); + } + + private function findFormatter(AiModel $model): ?MessageFormatter + { + foreach ($this->formatters as $formatter) { + if ($formatter->supports($model)) { + return $formatter; + } + } + + return null; + } + + private function getMaxLength(AiModel $model): array + { + $response = $this->httpClient->request(Request::METHOD_GET, 'https://aihorde.net/api/v2/workers?type=text'); + $json = json_decode($response->getContent(), true, flags: JSON_THROW_ON_ERROR); + $workers = array_filter( + $json, + fn (array $worker) => count(array_filter( + $worker['models'], + fn (string $modelName) => fnmatch("*/{$model->value}", $modelName), + )) > 0, + ); + $targetLength = 1024; + $targetContext = 2048; + + if (!count(array_filter($workers, fn(array $worker) => $worker['max_length'] >= $targetLength))) { + $targetLength = max(array_map(fn (array $worker) => $worker['max_length'], $workers)); + } + if (!count(array_filter($workers, fn(array $worker) => $worker['max_context_length'] >= $targetContext))) { + $targetContext = max(array_map(fn (array $worker) => $worker['max_context_length'], $workers)); + } + + if ($targetLength > $targetContext / 2) { + $targetLength = $targetContext / 2; + } + + return [$targetLength, $targetContext]; + } +} diff --git a/src/Service/AiHorde/Message/Message.php b/src/Service/AiHorde/Message/Message.php new file mode 100644 index 0000000..2fa1ca3 --- /dev/null +++ b/src/Service/AiHorde/Message/Message.php @@ -0,0 +1,25 @@ + $this->role->value, + 'content' => $this->content, + ]; + } +} diff --git a/src/Service/AiHorde/Message/MessageHistory.php b/src/Service/AiHorde/Message/MessageHistory.php new file mode 100644 index 0000000..01b564a --- /dev/null +++ b/src/Service/AiHorde/Message/MessageHistory.php @@ -0,0 +1,73 @@ + + * @implements ArrayAccess + */ +final class MessageHistory implements IteratorAggregate, ArrayAccess, Countable, JsonSerializable +{ + /** + * @var array + */ + private array $messages; + + public function __construct(Message ...$messages) + { + $this->messages = $messages; + } + + public function getIterator(): Traversable + { + return new ArrayIterator($this->messages); + } + + public function offsetExists(mixed $offset): bool + { + return isset($this->messages[$offset]); + } + + public function offsetGet(mixed $offset): Message + { + return $this->messages[$offset]; + } + + public function offsetSet(mixed $offset, mixed $value): void + { + if (!$value instanceof Message) { + throw new InvalidArgumentException('Only instances of ' . Message::class . ' are supported'); + } + if ($offset !== null) { + $this->messages[$offset] = $value; + } else { + $this->messages[] = $value; + } + } + + public function offsetUnset(mixed $offset): void + { + unset($this->messages[$offset]); + } + + public function count(): int + { + return count($this->messages); + } + + /** + * @return array + */ + public function jsonSerialize(): array + { + return array_map(fn (Message $message) => $message->jsonSerialize(), $this->messages); + } +} diff --git a/src/Service/AiHorde/MessageFormatter/ChatMLPromptFormat.php b/src/Service/AiHorde/MessageFormatter/ChatMLPromptFormat.php new file mode 100644 index 0000000..7a19466 --- /dev/null +++ b/src/Service/AiHorde/MessageFormatter/ChatMLPromptFormat.php @@ -0,0 +1,43 @@ +{$message->role->value}\n{$message->content}<|im_end|>"; + }, [...$messages]))); + } + + public function formatOutput(string $message): Message + { + $role = 'assistant'; + $message = trim($message); + + if (str_starts_with($message, '<|im_start|>')) { + $message = substr($message, strlen('<|im_start|>')); + $parts = explode("\n", $message, 2); + $message = $parts[1]; + $role = $parts[0]; + } + if (str_ends_with($message, '<|im_end|>')) { + $message = substr($message, 0, -strlen('<|im_end|>')); + } + + $role = AiActor::tryFrom($role) ?? AiActor::Assistant; + + return new Message(role: $role, content: $message); + } + + public function supports(AiModel $model): bool + { + return in_array($model, [AiModel::Mistral7BOpenHermes], true); + } +} diff --git a/src/Service/AiHorde/MessageFormatter/MessageFormatter.php b/src/Service/AiHorde/MessageFormatter/MessageFormatter.php new file mode 100644 index 0000000..3da55bd --- /dev/null +++ b/src/Service/AiHorde/MessageFormatter/MessageFormatter.php @@ -0,0 +1,17 @@ + throw new LogicException('This function cannot be compiled'); + } +} diff --git a/src/Service/ExpressionLanguage.php b/src/Service/Expression/ExpressionLanguage.php similarity index 88% rename from src/Service/ExpressionLanguage.php rename to src/Service/Expression/ExpressionLanguage.php index 428b893..3577fdb 100644 --- a/src/Service/ExpressionLanguage.php +++ b/src/Service/Expression/ExpressionLanguage.php @@ -1,6 +1,6 @@ uncompilableFunction(), + $this->aiAnalyzeFunction(...), + ), + ]; + } + + private function aiAnalyzeFunction(array $context, string $message, ?string $systemPrompt = null): string + { + $history = new MessageHistory(); + if ($systemPrompt !== null) { + $history[] = new Message(role: AiActor::System, content: $systemPrompt); + } + + return $this->aiHorde->getResponse($message, AiModel::Mistral7BOpenHermes, $history); + } +} diff --git a/src/Service/ExpressionLanguageFunctions.php b/src/Service/Expression/ExpressionLanguageFunctions.php similarity index 77% rename from src/Service/ExpressionLanguageFunctions.php rename to src/Service/Expression/ExpressionLanguageFunctions.php index b07537b..d588785 100644 --- a/src/Service/ExpressionLanguageFunctions.php +++ b/src/Service/Expression/ExpressionLanguageFunctions.php @@ -1,15 +1,13 @@ anyFunction(...), ), new ExpressionFunction( - 'and', + '_and_', $this->uncompilableFunction(), $this->andFunction(...), ), new ExpressionFunction( - 'or', + '_or_', $this->uncompilableFunction(), $this->orFunction(...), ), + new ExpressionFunction( + 'catchError', + $this->uncompilableFunction(), + $this->catchErrorFunction(...), + ), ]; } - private function uncompilableFunction(): Closure - { - return fn () => throw new LogicException('This function cannot be compiled'); - } - private function asyncFunction(array $context, string $expression): bool { $this->messageBus->dispatch(new RunExpressionAsyncMessage($context, $expression)); @@ -106,4 +104,14 @@ private function orFunction(array $context, string ...$expressions): bool return $result; } + + private function catchErrorFunction(array $context, string $expression, string $onErrorExpression): bool + { + try { + return $this->expressionLanguage->evaluate($expression, $context); + } catch (Throwable $exception) { + $context['exception'] = $exception; + return $this->expressionLanguage->evaluate($onErrorExpression, $context); + } + } } diff --git a/src/Service/Expression/ExpressionLanguageLemmyFunctions.php b/src/Service/Expression/ExpressionLanguageLemmyFunctions.php new file mode 100644 index 0000000..6b3827c --- /dev/null +++ b/src/Service/Expression/ExpressionLanguageLemmyFunctions.php @@ -0,0 +1,38 @@ +uncompilableFunction(), + $this->removeComment(...), + ), + ]; + } + + public function removeComment(array $context, int|CommentView|Comment $comment, ?string $reason = null): bool + { + $id = match (true) { + $comment instanceof CommentView => $comment->comment->id, + $comment instanceof Comment => $comment->id, + is_int($comment) => $comment, + }; + + return $this->api->moderator()->removeComment(comment: $id, reason: $reason); + } +} diff --git a/src/Service/Expression/ExpressionLanguageStringFunctions.php b/src/Service/Expression/ExpressionLanguageStringFunctions.php new file mode 100644 index 0000000..dc0b78c --- /dev/null +++ b/src/Service/Expression/ExpressionLanguageStringFunctions.php @@ -0,0 +1,24 @@ +uncompilableFunction(), + $this->toLower(...), + ) + ]; + } + + private function toLower(array $context, string $string): string + { + return mb_strtolower($string); + } +}