Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,40 @@ def test_macenko_torch():
# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_numpy.flatten(), result_torch.flatten(), decimal=2, verbose=True)

def test_multitarget_macenko_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
target = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/target.png")), cv2.COLOR_BGR2RGB), (size, size))
to_transform = cv2.resize(cv2.cvtColor(cv2.imread(os.path.join(curr_file_path, "../data/source.png")), cv2.COLOR_BGR2RGB), (size, size))

# setup preprocessing and preprocess image to be normalized
T = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 255)
])
target = T(target)
t_to_transform = T(to_transform)

# initialize normalizers for each backend and fit to target image
single_normalizer = torchstain.normalizers.MacenkoNormalizer(backend="torch")
single_normalizer.fit(target)

multi_normalizer = torchstain.normalizers.MultiMacenkoNormalizer(backend="torch", norm_mode="avg-post")
multi_normalizer.fit([target, target, target])


# transform
result_single, _, _ = single_normalizer.normalize(I=t_to_transform, stains=True)
result_multi, _, _ = multi_normalizer.normalize(I=t_to_transform, stains=True)

# convert to numpy and set dtype
result_single = result_single.numpy().astype("float32") / 255.
result_multi = result_multi.numpy().astype("float32") / 255.

# assess whether the normalized images are identical across backends
np.testing.assert_almost_equal(result_single.flatten(), result_multi.flatten(), decimal=2, verbose=True)


def test_reinhard_torch():
size = 1024
curr_file_path = os.path.dirname(os.path.realpath(__file__))
Expand Down
1 change: 1 addition & 0 deletions torchstain/base/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .he_normalizer import HENormalizer
from .macenko import MacenkoNormalizer
from .multitarget import MultiMacenkoNormalizer
from .reinhard import ReinhardNormalizer
14 changes: 9 additions & 5 deletions torchstain/base/normalizers/multitarget.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
def MultiMacenkoNormalizer(backend='torch', **kwargs):
if backend == 'torch':
from torchstain.torch.normalizers.multitarget import MultiMacenkoNormalizer
return MultiMacenkoNormalizer(**kwargs)
def MultiMacenkoNormalizer(backend="torch", **kwargs):
if backend == "numpy":
raise NotImplementedError("MultiMacenkoNormalizer is not implemented for NumPy backend")
elif backend == "torch":
from torchstain.torch.normalizers import TorchMultiMacenkoNormalizer
return TorchMultiMacenkoNormalizer(**kwargs)
elif backend == "tensorflow":
raise NotImplementedError("MultiMacenkoNormalizer is not implemented for TensorFlow backend")
else:
raise Exception(f'Unsupported backend {backend}')
raise Exception(f"Unsupported backend {backend}")
2 changes: 1 addition & 1 deletion torchstain/torch/normalizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from torchstain.torch.normalizers.macenko import TorchMacenkoNormalizer
from torchstain.torch.normalizers.multitarget import MultiMacenkoNormalizer
from torchstain.torch.normalizers.multitarget import TorchMultiMacenkoNormalizer
from torchstain.torch.normalizers.reinhard import TorchReinhardNormalizer
20 changes: 11 additions & 9 deletions torchstain/torch/normalizers/multitarget.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import torch
from torchstain.torch.utils import cov, percentile

"""
Implementation of the multi-target normalizer from the paper: https://arxiv.org/pdf/2406.02077
"""
class MultiMacenkoNormalizer:
def __init__(self, norm_mode='avg-post'):
class TorchMultiMacenkoNormalizer:
def __init__(self, norm_mode="avg-post"):
self.norm_mode = norm_mode
self.HERef = torch.tensor([[0.5626, 0.2159],
[0.7201, 0.8012],
[0.4062, 0.5581]])
self.maxCRef = torch.tensor([1.9705, 1.0308])
self.updated_lstsq = hasattr(torch.linalg, 'lstsq')
self.updated_lstsq = hasattr(torch.linalg, "lstsq")

def __convert_rgb2od(self, I, Io, beta):
I = I.permute(1, 2, 0)
Expand Down Expand Up @@ -48,7 +49,8 @@ def __find_concentration(self, OD, HE):
def __compute_matrices_single(self, I, Io, alpha, beta):
OD, ODhat = self.__convert_rgb2od(I, Io=Io, beta=beta)

_, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)
# _, eigvecs = torch.symeig(cov(ODhat.T), eigenvectors=True)
_, eigvecs = torch.linalg.eigh(cov(ODhat.T), UPLO='U')
eigvecs = eigvecs[:, [1, 2]]

HE = self.__find_HE(ODhat, eigvecs, alpha)
Expand All @@ -59,15 +61,15 @@ def __compute_matrices_single(self, I, Io, alpha, beta):
return HE, C, maxC

def fit(self, Is, Io=240, alpha=1, beta=0.15):
if self.norm_mode == 'avg-post':
if self.norm_mode == "avg-post":
HEs, _, maxCs = zip(*(
self.__compute_matrices_single(I, Io, alpha, beta)
for I in Is
))

self.HERef = torch.stack(HEs).mean(dim=0)
self.maxCRef = torch.stack(maxCs).mean(dim=0)
elif self.norm_mode == 'concat':
elif self.norm_mode == "concat":
ODs, ODhats = zip(*(
self.__convert_rgb2od(I, Io, beta)
for I in Is
Expand All @@ -83,7 +85,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
maxCs = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
self.HERef = HE
self.maxCRef = maxCs
elif self.norm_mode == 'avg-pre':
elif self.norm_mode == "avg-pre":
ODs, ODhats = zip(*(
self.__convert_rgb2od(I, Io, beta)
for I in Is
Expand All @@ -100,7 +102,7 @@ def fit(self, Is, Io=240, alpha=1, beta=0.15):
maxCs = torch.stack([percentile(C[0, :], 99), percentile(C[1, :], 99)])
self.HERef = HE
self.maxCRef = maxCs
elif self.norm_mode == 'fixed-single' or self.norm_mode == 'stochastic-single':
elif self.norm_mode == "fixed-single" or self.norm_mode == "stochastic-single":
# single img
self.HERef, _, self.maxCRef = self.__compute_matrices_single(Is[0], Io, alpha, beta)
else:
Expand All @@ -127,4 +129,4 @@ def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
E[E > 255] = 255
E = E.T.reshape(h, w, c).int()

return Inorm, H, E
return Inorm, H, E