Skip to content

Commit

Permalink
Extend region filter logic:
Browse files Browse the repository at this point in the history
- do a rough bounding box check and remove regions with <% overlap
- do a per-pixel check and remove regions <% coverage
- use regions with >% coverage  exclusive (disable regional prompt)
- add background region for insuficcient coverage
  • Loading branch information
Acly committed May 9, 2024
1 parent 7eb7b72 commit 89206c9
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 52 deletions.
39 changes: 28 additions & 11 deletions ai_diffusion/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def insert_mask_layer(
def set_layer_content(self, layer: krita.Node, img: Image, bounds: Bounds, make_visible=True):
raise NotImplementedError

def create_group_layer(self, name: str):
def create_group_layer(self, name: str, parent: krita.Node | None = None) -> krita.Node:
raise NotImplementedError

def hide_layer(self, layer: krita.Node):
Expand Down Expand Up @@ -324,16 +324,27 @@ def set_layer_content(self, layer: krita.Node, img: Image, bounds: Bounds, make_
self.refresh(layer)
return layer

def create_group_layer(self, name: str):
def create_group_layer(self, name: str, parent: krita.Node | None = None):
create_paint_layer = parent is None
group = self._doc.createGroupLayer(name)
active = self._doc.activeNode()
parent = active.parentNode()
if active.type() != "grouplayer" and parent.parentNode() is not None:
active = parent
parent = parent.parentNode()
if parent is None:
active = self._doc.activeNode()
parent = active.parentNode()
if active.type() != "grouplayer" and parent.parentNode() is not None:
active = parent
parent = parent.parentNode()
else:
children = parent.childNodes()
active = None if not children else children[0]
for child in children:
if child.type() == "grouplayer":
break
active = child
parent.addChildNode(group, active)
paint = self._doc.createNode("Paint Layer", "paintlayer")
group.addChildNode(paint, None)
if create_paint_layer:
paint = self._doc.createNode("Paint Layer", "paintlayer")
group.addChildNode(paint, None)
return group

def hide_layer(self, layer: krita.Node):
layer.setVisible(False)
Expand Down Expand Up @@ -443,10 +454,11 @@ def _traverse_layers(node: krita.Node, type_filter=None):

def _find_layer_above(doc: krita.Document, layer_below: krita.Node | None):
if layer_below:
nodes = doc.rootNode().childNodes()
nodes = layer_below.parentNode().childNodes()
index = nodes.index(layer_below)
if index >= 1:
return nodes[index - 1]
return layer_below
return None


Expand Down Expand Up @@ -553,7 +565,12 @@ def update(self):
self.changed.emit()

def find(self, id: QUuid):
return next((l.node for l in self._layers if l.id == id), None)
if self._doc is None:
return None
root = self._doc.rootNode()
if root.uniqueId() == id:
return root
return next((l for l in _traverse_layers(root) if l.uniqueId() == id), None)

def updated(self):
self.update()
Expand Down
18 changes: 18 additions & 0 deletions ai_diffusion/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ def minimum_size(bounds: "Bounds", min_size: int, max_extent: Extent):
)
return Bounds.clamp(result, max_extent)

@staticmethod
def intersection(a: "Bounds", b: "Bounds"):
x = max(a.x, b.x)
y = max(a.y, b.y)
width = min(a.x + a.width, b.x + b.width) - x
height = min(a.y + a.height, b.y + b.height) - y
return Bounds(x, y, max(0, width), max(0, height))

@property
def area(self):
return self.width * self.height

def relative_to(self, reference: "Bounds"):
"""Return bounds relative to another bounds."""
return Bounds(self.x - reference.x, self.y - reference.y, self.width, self.height)
Expand Down Expand Up @@ -360,6 +372,12 @@ def make_opaque(self, background=Qt.GlobalColor.white):
def invert(self):
self._qimage.invertPixels()

def average(self):
assert self.is_mask
avg = Image.scale(self, Extent(1, 1)).pixel(0, 0)
avg = avg[0] if isinstance(avg, tuple) else avg
return avg / 255

@property
def data(self):
self.to_krita_format()
Expand Down
9 changes: 8 additions & 1 deletion ai_diffusion/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, fields, field
from datetime import datetime
from enum import Enum, Flag
from typing import Deque, NamedTuple
from typing import Any, Deque, NamedTuple
from PyQt5.QtCore import QObject, pyqtSignal

from .image import Bounds, ImageCollection
Expand Down Expand Up @@ -32,6 +32,7 @@ class JobKind(Enum):
class JobRegion:
layer_id: str
prompt: str
is_background: bool = False


@dataclass
Expand All @@ -45,6 +46,12 @@ class JobParams:
frame: tuple[int, int, int] = (0, 0, 0)
animation_id: str = ""

@staticmethod
def from_dict(data: dict[str, Any]):
data["bounds"] = Bounds(*data["bounds"])
data["regions"] = [JobRegion(**r) for r in data.get("regions", [])]
return JobParams(**data)

@classmethod
def equal_ignore_seed(cls, a: JobParams | None, b: JobParams | None):
if a is None or b is None:
Expand Down
144 changes: 106 additions & 38 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,47 +151,88 @@ def create_region(self):
doc = self._model.document
doc.create_group_layer(f"Region {len(self)}")

def to_api(self, parent_layer_id: QUuid | None, bounds: Bounds | None = None):
def to_api(self, parent_layer_id: QUuid | None, bounds: Bounds):
result = ConditioningInput(
positive=self.root.prompt,
negative=self.root.negative_prompt,
control=[c.to_api(bounds) for c in self.root.control],
)
if len(self._regions) == 0:
return result, []

# Assemble all regions by finding group layers which are direct children of the parent layer.
# Ignore regions with no prompt or control layers.
# Filter out regions with:
# * no content (empty mask)
# * no prompt or control layers
# * less than 10% overlap (esimate based on bounding box)
layers = self._model.layers
parent_layer = layers.find(parent_layer_id) if parent_layer_id else layers.root
parent_layer = ensure(layers.find(parent_layer_id)) if parent_layer_id else layers.root
api_regions: list[RegionInput] = []
job_regions: list[JobRegion] = []
for layer in layers:
if layer.type() == "grouplayer" and layer.parentNode() == parent_layer:
layer_bounds = _region_layer_bounds(layer)
if layer_bounds.area == 0:
print(f"Skipping empty region {layer.name()}")
continue

region = self._lookup_region(layer.uniqueId())
if region.prompt != "" or len(region.control) > 0:
api_regions.append(region.to_api(bounds))
job_regions.append(JobRegion(region.layer_id, region.prompt))
if region.prompt == "" and len(region.control) == 0:
continue

overlap_rough = Bounds.intersection(bounds, layer_bounds).area / bounds.area
if overlap_rough < 0.1:
print(f"Skipping region {region.prompt[:10]}: overlap is {overlap_rough}")
continue

api_regions.append(region.to_api(bounds))
job_regions.append(JobRegion(region.layer_id, region.prompt))

# Remove from each region mask any overlapping areas from regions above it.
accumulated_mask = None
for region in reversed(api_regions):
for i in range(len(api_regions) - 1, -1, -1):
region = api_regions[i]
mask = region.mask
if accumulated_mask is None:
accumulated_mask = region.mask
accumulated_mask = Image.copy(region.mask)
else:
current = region.mask
region.mask = Image.mask_subtract(region.mask, accumulated_mask)
accumulated_mask = Image.mask_add(accumulated_mask, current)

# If the regions don't cover the entire image, add a final region for the remaining area.
if accumulated_mask is not None:
average = Image.scale(accumulated_mask, Extent(1, 1)).pixel(0, 0)
fully_covered = isinstance(average, tuple) and average[0] >= 254
if not fully_covered:
accumulated_mask.invert()
api_regions.append(RegionInput(accumulated_mask, self.root.prompt))

return (
ConditioningInput(
positive=self.root.prompt,
negative=self.root.negative_prompt,
control=[c.to_api(bounds) for c in self.root.control],
regions=api_regions,
),
job_regions,
)
mask = Image.mask_subtract(mask, accumulated_mask)

coverage = mask.average()
if coverage > 0.9:
# Single region covers (almost) entire image, don't use regional conditioning.
print(f"Using single region {region.positive[:10]}: coverage is {coverage}")
result.positive = workflow.merge_prompt(region.positive, result.positive)
result.control += region.control
return result, [job_regions[i]]
elif coverage < 0.1:
# Region has less than 10% coverage, remove it.
print(f"Skipping region {region.positive[:10]}: coverage is {coverage}")
api_regions.pop(i)
job_regions.pop(i)
else:
# Accumulate mask for next region, and store modified mask.
accumulated_mask = Image.mask_add(accumulated_mask, region.mask)
region.mask = mask

# If there are no regions left, don't use regional conditioning.
if len(api_regions) == 0:
result.positive = workflow.merge_prompt(self.root.prompt, "")
return result, []

# If the region(s) don't cover the entire image, add a final region for the remaining area.
assert accumulated_mask is not None, "Expecting at least one region mask"
total_coverage = accumulated_mask.average()
if total_coverage < 1:
print(f"Adding background region: total coverage is {total_coverage}")
accumulated_mask.invert()
api_regions.append(RegionInput(accumulated_mask, "background"))
job_regions.append(
JobRegion(parent_layer.uniqueId().toString(), "background", is_background=True)
)

result.regions = api_regions
return result, job_regions

def siblings(self, region: Region):
def get_regions(layers: list[krita.Node]):
Expand Down Expand Up @@ -264,6 +305,21 @@ def _layer_id_str(a: QUuid | str | None):
return a


def _region_layer_bounds(layer: krita.Node):
layer_bounds = Bounds.from_qrect(layer.bounds())
for child in layer.childNodes():
if child.type() == "transparencymask":
mask_sel = krita.Selection()
data = child.pixelData(*layer_bounds)
mask_sel.setPixelData(data, *layer_bounds)
mask_sel_bounds = Bounds(
mask_sel.x(), mask_sel.y(), mask_sel.width(), mask_sel.height()
)
return mask_sel_bounds

return layer_bounds


class Model(QObject, ObservableProperties):
"""Represents diffusion workflows for a specific Krita document. Stores all inputs related to
image generation. Launches generation jobs. Listens to server messages and keeps a
Expand Down Expand Up @@ -604,17 +660,29 @@ def apply_result(self, job_id: str, index: int):
else:
img = job.results[index]
for region in job.params.regions:
if region_layer := self.layers.find(QUuid(region.layer_id)):
if not any(l.parentNode() == region_layer for l in self.layers.masks):
mask = self._doc.get_layer_mask(region_layer, job.params.bounds)
self._doc.insert_mask_layer(
"Transparency Mask", mask, job.params.bounds, region_layer
)

self._doc.insert_layer(
f"[Generated] {region.prompt}", img, job.params.bounds, parent=region_layer
region_layer = self.layers.find(QUuid(region.layer_id)) or self.layers.root
has_layers = len(region_layer.childNodes()) > 0
has_mask = any(l.parentNode() == region_layer for l in self.layers.masks)
if has_layers and not has_mask:
mask = self._doc.get_layer_mask(region_layer, job.params.bounds)
self._doc.insert_mask_layer(
"Transparency Mask", mask, job.params.bounds, region_layer
)

below = None
if region.is_background:
for node in region_layer.childNodes():
if node.type() == "grouplayer":
below = node
break
self._doc.insert_layer(
f"[Generated] {region.prompt}",
img,
job.params.bounds,
below=below,
parent=region_layer,
)

if self._layer:
self._layer.remove()
self._layer = None
Expand Down
3 changes: 1 addition & 2 deletions ai_diffusion/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class _HistoryResult:

@staticmethod
def from_dict(data: dict[str, Any]):
data["params"]["bounds"] = Bounds(*data["params"]["bounds"])
data["params"] = JobParams(**data["params"])
data["params"] = JobParams.from_dict(data["params"])
return _HistoryResult(**data)


Expand Down

0 comments on commit 89206c9

Please sign in to comment.