diff --git a/ai_diffusion/document.py b/ai_diffusion/document.py index 461ad04f..a1d12db4 100644 --- a/ai_diffusion/document.py +++ b/ai_diffusion/document.py @@ -81,7 +81,7 @@ def insert_mask_layer( def set_layer_content(self, layer: krita.Node, img: Image, bounds: Bounds, make_visible=True): raise NotImplementedError - def create_group_layer(self, name: str): + def create_group_layer(self, name: str, parent: krita.Node | None = None) -> krita.Node: raise NotImplementedError def hide_layer(self, layer: krita.Node): @@ -324,16 +324,27 @@ def set_layer_content(self, layer: krita.Node, img: Image, bounds: Bounds, make_ self.refresh(layer) return layer - def create_group_layer(self, name: str): + def create_group_layer(self, name: str, parent: krita.Node | None = None): + create_paint_layer = parent is None group = self._doc.createGroupLayer(name) - active = self._doc.activeNode() - parent = active.parentNode() - if active.type() != "grouplayer" and parent.parentNode() is not None: - active = parent - parent = parent.parentNode() + if parent is None: + active = self._doc.activeNode() + parent = active.parentNode() + if active.type() != "grouplayer" and parent.parentNode() is not None: + active = parent + parent = parent.parentNode() + else: + children = parent.childNodes() + active = None if not children else children[0] + for child in children: + if child.type() == "grouplayer": + break + active = child parent.addChildNode(group, active) - paint = self._doc.createNode("Paint Layer", "paintlayer") - group.addChildNode(paint, None) + if create_paint_layer: + paint = self._doc.createNode("Paint Layer", "paintlayer") + group.addChildNode(paint, None) + return group def hide_layer(self, layer: krita.Node): layer.setVisible(False) @@ -443,10 +454,11 @@ def _traverse_layers(node: krita.Node, type_filter=None): def _find_layer_above(doc: krita.Document, layer_below: krita.Node | None): if layer_below: - nodes = doc.rootNode().childNodes() + nodes = layer_below.parentNode().childNodes() index = nodes.index(layer_below) if index >= 1: return nodes[index - 1] + return layer_below return None @@ -553,7 +565,12 @@ def update(self): self.changed.emit() def find(self, id: QUuid): - return next((l.node for l in self._layers if l.id == id), None) + if self._doc is None: + return None + root = self._doc.rootNode() + if root.uniqueId() == id: + return root + return next((l for l in _traverse_layers(root) if l.uniqueId() == id), None) def updated(self): self.update() diff --git a/ai_diffusion/image.py b/ai_diffusion/image.py index 87f00ef3..ed02a91e 100644 --- a/ai_diffusion/image.py +++ b/ai_diffusion/image.py @@ -183,6 +183,18 @@ def minimum_size(bounds: "Bounds", min_size: int, max_extent: Extent): ) return Bounds.clamp(result, max_extent) + @staticmethod + def intersection(a: "Bounds", b: "Bounds"): + x = max(a.x, b.x) + y = max(a.y, b.y) + width = min(a.x + a.width, b.x + b.width) - x + height = min(a.y + a.height, b.y + b.height) - y + return Bounds(x, y, max(0, width), max(0, height)) + + @property + def area(self): + return self.width * self.height + def relative_to(self, reference: "Bounds"): """Return bounds relative to another bounds.""" return Bounds(self.x - reference.x, self.y - reference.y, self.width, self.height) @@ -360,6 +372,12 @@ def make_opaque(self, background=Qt.GlobalColor.white): def invert(self): self._qimage.invertPixels() + def average(self): + assert self.is_mask + avg = Image.scale(self, Extent(1, 1)).pixel(0, 0) + avg = avg[0] if isinstance(avg, tuple) else avg + return avg / 255 + @property def data(self): self.to_krita_format() diff --git a/ai_diffusion/jobs.py b/ai_diffusion/jobs.py index 86d1af25..c306f440 100644 --- a/ai_diffusion/jobs.py +++ b/ai_diffusion/jobs.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, fields, field from datetime import datetime from enum import Enum, Flag -from typing import Deque, NamedTuple +from typing import Any, Deque, NamedTuple from PyQt5.QtCore import QObject, pyqtSignal from .image import Bounds, ImageCollection @@ -32,6 +32,7 @@ class JobKind(Enum): class JobRegion: layer_id: str prompt: str + is_background: bool = False @dataclass @@ -45,6 +46,12 @@ class JobParams: frame: tuple[int, int, int] = (0, 0, 0) animation_id: str = "" + @staticmethod + def from_dict(data: dict[str, Any]): + data["bounds"] = Bounds(*data["bounds"]) + data["regions"] = [JobRegion(**r) for r in data.get("regions", [])] + return JobParams(**data) + @classmethod def equal_ignore_seed(cls, a: JobParams | None, b: JobParams | None): if a is None or b is None: diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 67f11cee..0bc54ac8 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -151,47 +151,88 @@ def create_region(self): doc = self._model.document doc.create_group_layer(f"Region {len(self)}") - def to_api(self, parent_layer_id: QUuid | None, bounds: Bounds | None = None): + def to_api(self, parent_layer_id: QUuid | None, bounds: Bounds): + result = ConditioningInput( + positive=self.root.prompt, + negative=self.root.negative_prompt, + control=[c.to_api(bounds) for c in self.root.control], + ) + if len(self._regions) == 0: + return result, [] + # Assemble all regions by finding group layers which are direct children of the parent layer. - # Ignore regions with no prompt or control layers. + # Filter out regions with: + # * no content (empty mask) + # * no prompt or control layers + # * less than 10% overlap (esimate based on bounding box) layers = self._model.layers - parent_layer = layers.find(parent_layer_id) if parent_layer_id else layers.root + parent_layer = ensure(layers.find(parent_layer_id)) if parent_layer_id else layers.root api_regions: list[RegionInput] = [] job_regions: list[JobRegion] = [] for layer in layers: if layer.type() == "grouplayer" and layer.parentNode() == parent_layer: + layer_bounds = _region_layer_bounds(layer) + if layer_bounds.area == 0: + print(f"Skipping empty region {layer.name()}") + continue + region = self._lookup_region(layer.uniqueId()) - if region.prompt != "" or len(region.control) > 0: - api_regions.append(region.to_api(bounds)) - job_regions.append(JobRegion(region.layer_id, region.prompt)) + if region.prompt == "" and len(region.control) == 0: + continue + + overlap_rough = Bounds.intersection(bounds, layer_bounds).area / bounds.area + if overlap_rough < 0.1: + print(f"Skipping region {region.prompt[:10]}: overlap is {overlap_rough}") + continue + + api_regions.append(region.to_api(bounds)) + job_regions.append(JobRegion(region.layer_id, region.prompt)) # Remove from each region mask any overlapping areas from regions above it. accumulated_mask = None - for region in reversed(api_regions): + for i in range(len(api_regions) - 1, -1, -1): + region = api_regions[i] + mask = region.mask if accumulated_mask is None: - accumulated_mask = region.mask + accumulated_mask = Image.copy(region.mask) else: - current = region.mask - region.mask = Image.mask_subtract(region.mask, accumulated_mask) - accumulated_mask = Image.mask_add(accumulated_mask, current) - - # If the regions don't cover the entire image, add a final region for the remaining area. - if accumulated_mask is not None: - average = Image.scale(accumulated_mask, Extent(1, 1)).pixel(0, 0) - fully_covered = isinstance(average, tuple) and average[0] >= 254 - if not fully_covered: - accumulated_mask.invert() - api_regions.append(RegionInput(accumulated_mask, self.root.prompt)) - - return ( - ConditioningInput( - positive=self.root.prompt, - negative=self.root.negative_prompt, - control=[c.to_api(bounds) for c in self.root.control], - regions=api_regions, - ), - job_regions, - ) + mask = Image.mask_subtract(mask, accumulated_mask) + + coverage = mask.average() + if coverage > 0.9: + # Single region covers (almost) entire image, don't use regional conditioning. + print(f"Using single region {region.positive[:10]}: coverage is {coverage}") + result.positive = workflow.merge_prompt(region.positive, result.positive) + result.control += region.control + return result, [job_regions[i]] + elif coverage < 0.1: + # Region has less than 10% coverage, remove it. + print(f"Skipping region {region.positive[:10]}: coverage is {coverage}") + api_regions.pop(i) + job_regions.pop(i) + else: + # Accumulate mask for next region, and store modified mask. + accumulated_mask = Image.mask_add(accumulated_mask, region.mask) + region.mask = mask + + # If there are no regions left, don't use regional conditioning. + if len(api_regions) == 0: + result.positive = workflow.merge_prompt(self.root.prompt, "") + return result, [] + + # If the region(s) don't cover the entire image, add a final region for the remaining area. + assert accumulated_mask is not None, "Expecting at least one region mask" + total_coverage = accumulated_mask.average() + if total_coverage < 1: + print(f"Adding background region: total coverage is {total_coverage}") + accumulated_mask.invert() + api_regions.append(RegionInput(accumulated_mask, "background")) + job_regions.append( + JobRegion(parent_layer.uniqueId().toString(), "background", is_background=True) + ) + + result.regions = api_regions + return result, job_regions def siblings(self, region: Region): def get_regions(layers: list[krita.Node]): @@ -264,6 +305,21 @@ def _layer_id_str(a: QUuid | str | None): return a +def _region_layer_bounds(layer: krita.Node): + layer_bounds = Bounds.from_qrect(layer.bounds()) + for child in layer.childNodes(): + if child.type() == "transparencymask": + mask_sel = krita.Selection() + data = child.pixelData(*layer_bounds) + mask_sel.setPixelData(data, *layer_bounds) + mask_sel_bounds = Bounds( + mask_sel.x(), mask_sel.y(), mask_sel.width(), mask_sel.height() + ) + return mask_sel_bounds + + return layer_bounds + + class Model(QObject, ObservableProperties): """Represents diffusion workflows for a specific Krita document. Stores all inputs related to image generation. Launches generation jobs. Listens to server messages and keeps a @@ -604,17 +660,29 @@ def apply_result(self, job_id: str, index: int): else: img = job.results[index] for region in job.params.regions: - if region_layer := self.layers.find(QUuid(region.layer_id)): - if not any(l.parentNode() == region_layer for l in self.layers.masks): - mask = self._doc.get_layer_mask(region_layer, job.params.bounds) - self._doc.insert_mask_layer( - "Transparency Mask", mask, job.params.bounds, region_layer - ) - - self._doc.insert_layer( - f"[Generated] {region.prompt}", img, job.params.bounds, parent=region_layer + region_layer = self.layers.find(QUuid(region.layer_id)) or self.layers.root + has_layers = len(region_layer.childNodes()) > 0 + has_mask = any(l.parentNode() == region_layer for l in self.layers.masks) + if has_layers and not has_mask: + mask = self._doc.get_layer_mask(region_layer, job.params.bounds) + self._doc.insert_mask_layer( + "Transparency Mask", mask, job.params.bounds, region_layer ) + below = None + if region.is_background: + for node in region_layer.childNodes(): + if node.type() == "grouplayer": + below = node + break + self._doc.insert_layer( + f"[Generated] {region.prompt}", + img, + job.params.bounds, + below=below, + parent=region_layer, + ) + if self._layer: self._layer.remove() self._layer = None diff --git a/ai_diffusion/persistence.py b/ai_diffusion/persistence.py index d98af553..a6f50b73 100644 --- a/ai_diffusion/persistence.py +++ b/ai_diffusion/persistence.py @@ -28,8 +28,7 @@ class _HistoryResult: @staticmethod def from_dict(data: dict[str, Any]): - data["params"]["bounds"] = Bounds(*data["params"]["bounds"]) - data["params"] = JobParams(**data["params"]) + data["params"] = JobParams.from_dict(data["params"]) return _HistoryResult(**data)