diff --git a/README.md b/README.md index 2699b409..3621cad2 100644 --- a/README.md +++ b/README.md @@ -553,6 +553,63 @@ if __name__ == "__main__": asyncio.run(main()) ``` +### Asynchronous Processing with Webhooks + +The Runware SDK supports asynchronous processing via webhooks for long-running operations. When you provide a `webhookURL`, the API immediately returns a task response and sends the final result to your webhook endpoint when processing completes. + +#### How it works + +1. Include `webhookURL` parameter in your request +2. Receive immediate response with `taskType` and `taskUUID` +3. Final result is POSTed to your webhook URL when ready + +Supported operations: +- Image Inference +- Photo Maker +- Image Caption +- Image Background Removal +- Image Upscale +- Prompt Enhance +- Video Inference + +#### Example + +```python +from runware import Runware, IImageInference + +async def main() -> None: + runware = Runware(api_key=RUNWARE_API_KEY) + await runware.connect() + + request_image = IImageInference( + positivePrompt="a beautiful mountain landscape", + model="civitai:36520@76907", + height=512, + width=512, + webhookURL="https://your-server.com/webhook/runware" + ) + + # Returns immediately with task info + response = await runware.imageInference(requestImage=request_image) + print(f"Task Type: {response.taskType}") + print(f"Task UUID: {response.taskUUID}") + # Result will be sent to your webhook URL +``` + +#### Webhook Payload Format +Your webhook endpoint will receive a POST request with the same format as synchronous responses: +```json{ + "data": [ + { + "taskType": "imageInference", + "taskUUID": "a770f077-f413-47de-9dac-be0b26a35da6", + "imageUUID": "77da2d99-a6d3-44d9-b8c0-ae9fb06b6200", + "imageURL": "https://im.runware.ai/image/...", + "cost": 0.0013 + } + ] +} +``` ### Model Upload diff --git a/examples/webhook_image_inference.py b/examples/webhook_image_inference.py new file mode 100644 index 00000000..e69de29b diff --git a/runware/base.py b/runware/base.py index 38d54d2f..dad088b9 100644 --- a/runware/base.py +++ b/runware/base.py @@ -36,6 +36,7 @@ IGoogleProviderSettings, IKlingAIProviderSettings, IFrameImage, + IAsyncTaskResponse, ) from .types import IImage, IError, SdkType, ListenerType from .utils import ( @@ -56,6 +57,7 @@ LISTEN_TO_IMAGES_KEY, isLocalFile, process_image, delay, + createAsyncTaskResponse, ) # Configure logging @@ -196,6 +198,8 @@ async def photoMaker(self, requestPhotoMaker: IPhotoMaker): request_object["includeCost"] = requestPhotoMaker.includeCost if requestPhotoMaker.outputType: request_object["outputType"] = requestPhotoMaker.outputType + if requestPhotoMaker.webhookURL: + request_object["webhookURL"] = requestPhotoMaker.webhookURL await self.send([request_object]) @@ -215,6 +219,10 @@ def check(resolve: callable, reject: callable, *args: Any) -> bool: if made_photo.get("taskType") != "photoMaker": continue + if not made_photo.get("imageUUID"): + del self._globalMessages[task_uuid] + resolve([made_photo]) + return True image_uuid = made_photo.get("imageUUID") if image_uuid not in unique_results: @@ -235,6 +243,9 @@ def check(resolve: callable, reject: callable, *args: Any) -> bool: # This indicates an error response raise RunwareAPIError(response) + if response and len(response) == 1 and not response[0].get("imageUUID"): + return createAsyncTaskResponse(response[0]) + if response: if not isinstance(response, list): response = [response] @@ -250,7 +261,7 @@ def check(resolve: callable, reject: callable, *args: Any) -> bool: async def imageInference( self, requestImage: IImageInference - ) -> Union[List[IImage], None]: + ) -> Union[List[IImage], IAsyncTaskResponse, None]: let_lis: Optional[Any] = None request_object: Optional[Dict[str, Any]] = None task_uuids: List[str] = [] @@ -453,6 +464,8 @@ async def imageInference( request_object["promptWeighting"] = requestImage.promptWeighting if requestImage.maskMargin: request_object["maskMargin"] = requestImage.maskMargin + if requestImage.webhookURL: + request_object["webhookURL"] = requestImage.webhookURL if hasattr(requestImage, "extraArgs"): # if extraArgs is present, and a dictionary, we will add its attributes to the request. # these may contain options used for public beta testing. @@ -491,7 +504,7 @@ async def _requestImages( retry_count: int, number_of_images: int, on_partial_images: Optional[Callable[[List[IImage], Optional[IError]], None]], - ) -> List[IImage]: + ) -> Union[List[IImage], IAsyncTaskResponse]: retry_count += 1 if let_lis: let_lis["destroy"]() @@ -514,31 +527,66 @@ async def _requestImages( } } await self.send(new_request_object) + has_webhook = request_object.get("webhookURL") + if has_webhook: + # For webhook requests, set up a listener and wait for acceptance confirmation + lis = self.globalListener(taskUUID=task_uuid) - let_lis = await self.listenToImages( - onPartialImages=on_partial_images, - taskUUID=task_uuid, - groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES, - ) - images = await self.getSimililarImage( - taskUUID=task_uuids, - numberOfImages=number_of_images, - shouldThrowError=True, - lis=let_lis, - ) + def check_webhook(resolve: callable, reject: callable, *args: Any) -> bool: + response = self._globalMessages.get(task_uuid) + if response: + # Handle list of responses + if isinstance(response, list) and len(response) > 0: + first_response = response[0] + else: + first_response = response - let_lis["destroy"]() - # TODO: NameError("name 'image_path' is not defined"). I think I remove the images when I have onPartialImages - if images: - if "code" in images: - # This indicates an error response - raise RunwareAPIError(images) + if first_response.get("code"): + raise RunwareAPIError(first_response) + + # Check if this is an acceptance response (has taskType and taskUUID but no imageUUID) + if (first_response.get("taskType") == "imageInference" and + first_response.get("taskUUID") == task_uuid and + not first_response.get("imageUUID")): + del self._globalMessages[task_uuid] + resolve(first_response) + return True + return False + + response = await getIntervalWithPromise( + check_webhook, + debugKey=f"imageInference-webhook:{task_uuid}", + timeOutDuration=10000, # Shorter timeout for webhook acceptance + ) + + lis["destroy"]() - return instantiateDataclassList(IImage, images) + # Return async response for webhook + return createAsyncTaskResponse(response) + else: + # Normal synchronous flow + let_lis = await self.listenToImages( + onPartialImages=on_partial_images, + taskUUID=task_uuid, + groupKey=LISTEN_TO_IMAGES_KEY.REQUEST_IMAGES, + ) - # return images + images = await self.getSimililarImage( + taskUUID=task_uuids, + numberOfImages=number_of_images, + shouldThrowError=True, + lis=let_lis, + ) - async def imageCaption(self, requestImageToText: IImageCaption) -> IImageToText: + let_lis["destroy"]() + + if images: + if "code" in images: + raise RunwareAPIError(images) + + return instantiateDataclassList(IImage, images) + + async def imageCaption(self, requestImageToText: IImageCaption) -> Union[IImageToText, IAsyncTaskResponse]: try: await self.ensureConnection() return await asyncRetry( @@ -549,7 +597,7 @@ async def imageCaption(self, requestImageToText: IImageCaption) -> IImageToText: async def _requestImageToText( self, requestImageToText: IImageCaption - ) -> IImageToText: + ) -> Union[IImageToText, IAsyncTaskResponse]: inputImage = requestImageToText.inputImage image_uploaded = await self.uploadImage(inputImage) @@ -570,6 +618,9 @@ async def _requestImageToText( if requestImageToText.includeCost: task_params["includeCost"] = requestImageToText.includeCost + if requestImageToText.webhookURL: + task_params["webhookURL"] = requestImageToText.webhookURL + # Send the task with all applicable parameters await self.send([task_params]) @@ -605,6 +656,9 @@ def check(resolve: callable, reject: callable, *args: Any) -> bool: # This indicates an error response raise RunwareAPIError(response) + if not response.get("text"): + return createAsyncTaskResponse(response) + if response: return createImageToTextFromResponse(response) else: @@ -612,7 +666,7 @@ def check(resolve: callable, reject: callable, *args: Any) -> bool: async def imageBackgroundRemoval( self, removeImageBackgroundPayload: IImageBackgroundRemoval - ) -> List[IImage]: + ) -> Union[List[IImage], IAsyncTaskResponse]: try: await self.ensureConnection() return await asyncRetry( @@ -623,7 +677,7 @@ async def imageBackgroundRemoval( async def _removeImageBackground( self, removeImageBackgroundPayload: IImageBackgroundRemoval - ) -> List[IImage]: + ) -> Union[List[IImage], IAsyncTaskResponse]: inputImage = removeImageBackgroundPayload.inputImage image_uploaded = await self.uploadImage(inputImage) @@ -653,6 +707,8 @@ async def _removeImageBackground( task_params["model"] = removeImageBackgroundPayload.model if removeImageBackgroundPayload.outputQuality: task_params["outputQuality"] = removeImageBackgroundPayload.outputQuality + if removeImageBackgroundPayload.webhookURL: + task_params["webhookURL"] = removeImageBackgroundPayload.webhookURL # Handle settings if provided - convert dataclass to dictionary and add non-None values if removeImageBackgroundPayload.settings: @@ -698,19 +754,22 @@ def check(resolve: callable, reject: callable, *args: Any) -> bool: # This indicates an error response raise RunwareAPIError(response) + if not response.get("imageUUID"): + return createAsyncTaskResponse(response) + image = createImageFromResponse(response) image_list: List[IImage] = [image] return image_list - async def imageUpscale(self, upscaleGanPayload: IImageUpscale) -> List[IImage]: + async def imageUpscale(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]: try: await self.ensureConnection() return await asyncRetry(lambda: self._upscaleGan(upscaleGanPayload)) except Exception as e: raise e - async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> List[IImage]: + async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> Union[List[IImage], IAsyncTaskResponse]: inputImage = upscaleGanPayload.inputImage upscaleFactor = upscaleGanPayload.upscaleFactor @@ -736,6 +795,8 @@ async def _upscaleGan(self, upscaleGanPayload: IImageUpscale) -> List[IImage]: task_params["outputFormat"] = upscaleGanPayload.outputFormat if upscaleGanPayload.includeCost: task_params["includeCost"] = upscaleGanPayload.includeCost + if upscaleGanPayload.webhookURL: + task_params["webhookURL"] = upscaleGanPayload.webhookURL # Send the task with all applicable parameters await self.send([task_params]) @@ -772,6 +833,9 @@ def check(resolve: callable, reject: callable, *args: Any) -> bool: # This indicates an error response raise RunwareAPIError(response) + if not response.get("imageUUID"): + return createAsyncTaskResponse(response) + image = createImageFromResponse(response) # TODO: The respones has an upscaleImageUUID field, should I return it as well? image_list: List[IImage] = [image] @@ -779,7 +843,7 @@ def check(resolve: callable, reject: callable, *args: Any) -> bool: async def promptEnhance( self, promptEnhancer: IPromptEnhance - ) -> List[IEnhancedPrompt]: + ) -> Union[List[IEnhancedPrompt], IAsyncTaskResponse]: """ Enhance the given prompt by generating multiple versions of it. @@ -795,7 +859,7 @@ async def promptEnhance( async def _enhancePrompt( self, promptEnhancer: IPromptEnhance - ) -> List[IEnhancedPrompt]: + ) -> Union[List[IEnhancedPrompt], IAsyncTaskResponse]: """ Internal method to perform the actual prompt enhancement. @@ -822,6 +886,10 @@ async def _enhancePrompt( if promptEnhancer.includeCost: task_params["includeCost"] = promptEnhancer.includeCost + has_webhook = promptEnhancer.webhookURL + if has_webhook: + task_params["webhookURL"] = promptEnhancer.webhookURL + # Send the task with all applicable parameters await self.send([task_params]) @@ -834,11 +902,19 @@ def check(resolve: Any, reject: Any, *args: Any) -> bool: if isinstance(response, dict) and response.get("error"): reject(response) return True - # if response and len(response) >= promptVersions: + if response: - del self._globalMessages[taskUUID] - resolve(response) - return True + if isinstance(response, list) and len(response) > 0: + first_response = response[0] + if has_webhook and first_response.get("taskType") == "promptEnhance" and first_response.get( + "taskUUID") == taskUUID and not first_response.get("text"): + del self._globalMessages[taskUUID] + resolve(first_response) + return True + if first_response.get("text"): + del self._globalMessages[taskUUID] + resolve(response) + return True return False @@ -847,7 +923,8 @@ def check(resolve: Any, reject: Any, *args: Any) -> bool: ) lis["destroy"]() - + if not isinstance(response, list): + return createAsyncTaskResponse(response) if "code" in response[0]: # This indicates an error response raise RunwareAPIError(response[0]) @@ -982,17 +1059,19 @@ def listen_to_images_lis(m: Dict[str, Any]) -> None: ] if images: - self._globalImages.extend(images) - try: - partial_images = instantiateDataclassList(IImage, images) - if onPartialImages: - onPartialImages( - partial_images, None - ) # No error in this case - except Exception as e: - logger.error( - f"Error occurred in user on_partial_images callback function: {e}" - ) + valid_images = [img for img in images if img.get("imageUUID")] + if valid_images: + self._globalImages.extend(valid_images) + try: + partial_images = instantiateDataclassList(IImage, valid_images) + if onPartialImages: + onPartialImages( + partial_images, None + ) # No error in this case + except Exception as e: + logger.error( + f"Error occurred in user on_partial_images callback function: {e}" + ) # Handle error messages elif isinstance(m.get("errors"), list): errors = [ @@ -1381,16 +1460,20 @@ def check(resolve: Callable, reject: Callable, *args: Any) -> bool: raise RunwareAPIError({"message": str(e)}) - async def videoInference(self, requestVideo: IVideoInference) -> List[IVideo]: + async def videoInference(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]: await self.ensureConnection() return await asyncRetry(lambda: self._requestVideo(requestVideo)) - async def _requestVideo(self, requestVideo: IVideoInference) -> List[IVideo]: + async def _requestVideo(self, requestVideo: IVideoInference) -> Union[List[IVideo], IAsyncTaskResponse]: await self._processVideoImages(requestVideo) requestVideo.taskUUID = requestVideo.taskUUID or getUUID() request_object = self._buildVideoRequest(requestVideo) + + if requestVideo.webhookURL: + request_object["webhookURL"] = requestVideo.webhookURL + await self.send([request_object]) - return await self._handleInitialVideoResponse(requestVideo.taskUUID, requestVideo.numberResults) + return await self._handleInitialVideoResponse(requestVideo.taskUUID, requestVideo.numberResults, request_object.get("webhookURL")) async def _processVideoImages(self, requestVideo: IVideoInference) -> None: frame_tasks = [] @@ -1469,7 +1552,7 @@ def _addProviderSettings(self, request_object: Dict[str, Any], requestVideo: IVi if provider_dict: request_object["providerSettings"] = provider_dict - async def _handleInitialVideoResponse(self, task_uuid: str, number_results: int) -> List[IVideo]: + async def _handleInitialVideoResponse(self, task_uuid: str, number_results: int, has_webhook: Optional[str] = None) -> Union[List[IVideo], IAsyncTaskResponse]: lis = self.globalListener(taskUUID=task_uuid) def check_initial_response(resolve: callable, reject: callable, *args: Any) -> bool: @@ -1483,6 +1566,11 @@ def check_initial_response(resolve: callable, reject: callable, *args: Any) -> b if response.get("code"): raise RunwareAPIError(response) + if has_webhook and not response.get("videoUUID"): + del self._globalMessages[task_uuid] + resolve(response) + return True + if response.get("status") == "success": del self._globalMessages[task_uuid] resolve([response]) @@ -1501,6 +1589,9 @@ def check_initial_response(resolve: callable, reject: callable, *args: Any) -> b finally: lis["destroy"]() + if has_webhook and not isinstance(initial_response, list) and not initial_response == "POLL_NEEDED": + return createAsyncTaskResponse(initial_response) + if initial_response == "POLL_NEEDED": return await self._pollVideoResults(task_uuid, number_results) else: @@ -1572,4 +1663,4 @@ def connected(self) -> bool: :return: True if the connection is active and authenticated, False otherwise. """ - return self.isWebsocketReadyState() and self._connectionSessionUUID is not None + return self.isWebsocketReadyState() and self._connectionSessionUUID is not None \ No newline at end of file diff --git a/runware/types.py b/runware/types.py index 973108b2..299ff87f 100644 --- a/runware/types.py +++ b/runware/types.py @@ -111,6 +111,12 @@ class RunwareBaseType: url: Optional[str] = None +@dataclass +class IAsyncTaskResponse: + taskType: str + taskUUID: str + + @dataclass class IImage: taskType: str @@ -301,6 +307,7 @@ class IPhotoMaker: outputFormat: Optional[IOutputFormat] = None includeCost: Optional[bool] = None taskUUID: Optional[str] = None + webhookURL: Optional[str] = None def __post_init__(self): # Validate `inputImages` to ensure it has a maximum of 4 elements @@ -468,6 +475,7 @@ class IImageInference: referenceImages: Optional[List[Union[str, File]]] = field(default_factory=list) acePlusPlus: Optional[IAcePlusPlus] = None providerSettings: Optional[ImageProviderSettings] = None + webhookURL: Optional[str] = None extraArgs: Optional[Dict[str, Any]] = field(default_factory=dict) @@ -475,6 +483,7 @@ class IImageInference: class IImageCaption: inputImage: Optional[Union[File, str]] = None includeCost: bool = False + webhookURL: Optional[str] = None @dataclass @@ -512,6 +521,7 @@ class IPromptEnhance: promptVersions: int prompt: str includeCost: bool = False + webhookURL: Optional[str] = None @dataclass @@ -529,6 +539,7 @@ class IImageUpscale: outputType: Optional[IOutputType] = None outputFormat: Optional[IOutputFormat] = None includeCost: bool = False + webhookURL: Optional[str] = None class ReconnectingWebsocketProps: @@ -745,6 +756,7 @@ class IVideoInference: CFGScale: Optional[float] = None numberResults: Optional[int] = 1 providerSettings: Optional[VideoProviderSettings] = None + webhookURL: Optional[str] = None @dataclass class IVideo: diff --git a/runware/utils.py b/runware/utils.py index 36008e5b..e990e52a 100644 --- a/runware/utils.py +++ b/runware/utils.py @@ -27,6 +27,7 @@ IEnhancedPrompt, IError, UploadImageType, + IAsyncTaskResponse, ) import logging @@ -567,6 +568,15 @@ def accessDeepObject( # return current_value +def createAsyncTaskResponse(response: dict) -> IAsyncTaskResponse: + processed_fields = {} + + for field in fields(IAsyncTaskResponse): + if field.name in response: + processed_fields[field.name] = response[field.name] + + return instantiateDataclass(IAsyncTaskResponse, processed_fields) + def createEnhancedPromptsFromResponse(response: List[dict]) -> List[IEnhancedPrompt]: def process_single_prompt(prompt_data: dict) -> IEnhancedPrompt: processed_fields = {}