Skip to content

Commit

Permalink
Bug fix for windowed processing (#109)
Browse files Browse the repository at this point in the history
* Mussel bg areas -> 0

* Fix bug with windowed processing
  • Loading branch information
tayden committed May 7, 2024
1 parent be69948 commit 6f60378
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 106 deletions.
8 changes: 4 additions & 4 deletions kelp_o_matic/geotiff_io/geotiff_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(
self.profile = src.profile
self.block_shapes = src.block_shapes

self._y0s = list(range(0, self.height, self.stride))
self._x0s = list(range(0, self.width, self.stride))
self._y0s = list(range(0, self.height - self.stride + 1, self.stride))
self._x0s = list(range(0, self.width - self.stride + 1, self.stride))
self.y0x0 = list(itertools.product(self._y0s, self._x0s))

def __len__(self) -> int:
Expand All @@ -62,13 +62,13 @@ def is_top_window(self, window: Window):
return window.row_off == 0

def is_bottom_window(self, window: Window):
return window.row_off == self._y0s[-1]
return window.row_off + window.height >= self.height

def is_left_window(self, window: Window):
return window.col_off == 0

def is_right_window(self, window: Window):
return window.col_off == self._x0s[-1]
return window.col_off + window.width >= self.width

def __getitem__(self, idx: int) -> ("np.ndarray", Window):
window = self.get_window(idx)
Expand Down
149 changes: 107 additions & 42 deletions kelp_o_matic/hann.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import math
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Union
from typing import Type, Annotated

import numpy as np
import rasterio
import torch
from rasterio.windows import Window

Expand Down Expand Up @@ -103,24 +101,21 @@ def _init_wi(size: int, device: torch.device.type) -> torch.Tensor:
class TorchMemoryRegister(object):
def __init__(
self,
image_path: Union[str, Path],
reg_depth: int,
window_size: int,
image_width: Annotated[int, "Width of the image in pixels"],
register_depth: Annotated[int, "Generally equal to the number of classes"],
window_size: Annotated[int, "Moving window size"],
kernel: Type[Kernel],
device: torch.device.type,
):
super().__init__()
self.image_path = Path(image_path)
self.n = reg_depth
self.n = register_depth
self.ws = window_size
self.hws = window_size // 2
self.kernel = kernel(size=window_size, device=device)
self.device = device

# Copy metadata from img
with rasterio.open(str(image_path), "r") as src:
src_width = src.width

self.height = self.ws
self.width = (math.ceil(src_width / self.ws) * self.ws) + self.hws
self.width = (math.ceil(image_width / self.ws) * self.ws) + self.hws
self.register = torch.zeros(
(self.n, self.height, self.width), device=self.device
)
Expand All @@ -131,40 +126,110 @@ def _zero_chip(self):
(self.n, self.hws, self.hws), dtype=torch.float, device=self.device
)

def step(self, new_logits: torch.Tensor, img_window: Window):
# 1. Read data from the registry to update with the new logits
def step(
self,
new_logits: torch.Tensor,
img_window: Window,
*,
top: bool,
bottom: bool,
left: bool,
right: bool,
):
# Read data from the registry to update with the new logits
# |a|b| |
# |c|d| |
with torch.no_grad():
logits_abcd = self.register[
:, :, img_window.col_off : img_window.col_off + self.ws
].clone()
logits_abcd += new_logits

# Update the registry and pop information-complete data
# |c|b| | + pop a
# |0|d| |
logits_a = logits_abcd[:, : self.hws, : self.hws]
logits_c = logits_abcd[:, self.hws :, : self.hws]
logits_c0 = torch.concat([logits_c, self._zero_chip], dim=1)
logits_bd = logits_abcd[:, :, self.hws :]

# write c0
self.register[:, :, img_window.col_off : img_window.col_off + self.hws] = (
logits_c0
)

# write bd
col_off_bd = img_window.col_off + self.hws
self.register[:, :, col_off_bd : col_off_bd + (self.ws - self.hws)] = logits_bd

# Return the information-complete predictions
logits_win = Window(
col_off=img_window.col_off,
row_off=img_window.row_off,
height=min(self.hws, img_window.height),
width=min(self.hws, img_window.width),
)
logits = logits_a[:, : img_window.height, : img_window.width]
logits_abcd += self.kernel(
new_logits, top=top, bottom=bottom, left=left, right=right
)

if right and bottom:
# Need to return entire window
logits_win = img_window
logits = logits_abcd[:, : img_window.height, : img_window.width]

elif right:
# Need to return a and b sections

# Update the registry and pop information-complete data
# |c|d| | + pop a+b
# |0|0| |
logits_ab = logits_abcd[:, : self.hws, :]
logits_cd = logits_abcd[:, self.hws :, :]
logits_00 = torch.concat([self._zero_chip, self._zero_chip], dim=2)

# write cd and 00
self.register[
:, : self.hws, img_window.col_off : img_window.col_off + self.ws
] = logits_cd
self.register[
:, self.hws :, img_window.col_off : img_window.col_off + self.ws
] = logits_00

logits_win = Window(
col_off=img_window.col_off,
row_off=img_window.row_off,
height=min(self.hws, img_window.height),
width=min(self.ws, img_window.width),
)
logits = logits_ab[:, : logits_win.height, : logits_win.width]
elif bottom:
# Need to return a and c sections only

# Update the registry and pop information-complete data
# |0|b| | + pop a+c
# |0|d| |
logits_ac = logits_abcd[:, :, : self.hws]
logits_00 = torch.concat([self._zero_chip, self._zero_chip], dim=1)
logits_bd = logits_abcd[:, :, self.hws :]

# write 00 and bd
self.register[:, :, img_window.col_off : img_window.col_off + self.hws] = (
logits_00 # Not really necessary since this is the last row
)
self.register[
:, :, img_window.col_off + self.hws : img_window.col_off + self.ws
] = logits_bd

logits_win = Window(
col_off=img_window.col_off,
row_off=img_window.row_off,
height=min(self.ws, img_window.height),
width=min(self.hws, img_window.width),
)
logits = logits_ac[:, : img_window.height, : img_window.width]
else:
# Need to return "a" section only

# Update the registry and pop information-complete data
# |c|b| | + pop a
# |0|d| |
logits_a = logits_abcd[:, : self.hws, : self.hws]
logits_c = logits_abcd[:, self.hws :, : self.hws]
logits_c0 = torch.concat([logits_c, self._zero_chip], dim=1)
logits_bd = logits_abcd[:, :, self.hws :]

# write c0
self.register[:, :, img_window.col_off : img_window.col_off + self.hws] = (
logits_c0
)

# write bd
col_off_bd = img_window.col_off + self.hws
self.register[:, :, col_off_bd : col_off_bd + (self.ws - self.hws)] = (
logits_bd
)

logits_win = Window(
col_off=img_window.col_off,
row_off=img_window.row_off,
height=min(self.hws, img_window.height),
width=min(self.hws, img_window.width),
)
logits = logits_a[:, : img_window.height, : img_window.width]

return logits, logits_win
11 changes: 7 additions & 4 deletions kelp_o_matic/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ def __init__(
dtype="uint8",
nodata=0,
)
self.kernel = BartlettHannKernel(crop_size, self.model.device)
self.register = TorchMemoryRegister(
self.input_path, self.model.register_depth, crop_size, self.model.device
image_width=self.reader.width,
register_depth=self.model.register_depth,
window_size=crop_size,
kernel=BartlettHannKernel,
device=self.model.device,
)

def __call__(self):
Expand Down Expand Up @@ -107,14 +110,14 @@ def __call__(self):
else:
logits = self.model(crop.unsqueeze(0))[0]

logits = self.kernel(
write_logits, write_window = self.register.step(
logits,
read_window,
top=self.reader.is_top_window(read_window),
bottom=self.reader.is_bottom_window(read_window),
left=self.reader.is_left_window(read_window),
right=self.reader.is_right_window(read_window),
)
write_logits, write_window = self.register.step(logits, read_window)
labels = self.model.post_process(write_logits)

# Write outputs
Expand Down
2 changes: 1 addition & 1 deletion kelp_o_matic/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class KelpRGBSpeciesSegmentationModel(_SpeciesSegmentationModel):

class MusselRGBPresenceSegmentationModel(_Model):
register_depth = 1
all_black_val = -1
all_black_val = 0

torchscript_path = (
"UNetPlusPlus_EfficientNetB4_mussel_presence_rgb_jit_dice=0.9269.pt"
Expand Down
25 changes: 12 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 6f60378

Please sign in to comment.