In [None]:
# import modules:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

from pathlib import Path
import os
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import mean_squared_error as mse
import monai
import random

In [2]:
import torch
import numpy as np
from monai.transforms import (
    LoadImage,
    Compose,
    AddChannel,
    Transform,
    Transpose,
    ScaleIntensity,
)
from monai.inferers import sliding_window_inference
from compressai.zoo import image_models, models
from compressai.zoo.pretrained import load_pretrained
from aicsimageio import AICSImage
from aicsimageio.writers import OmeTiffWriter
from tqdm.contrib import tenumerate

In [None]:
# set parameters

# which cell line to download: in the paper, we tested on four nuclear structures:
# - fibrillarin (cline = "FBL")
# - nucleophosmin (cline = "NPM1")
# - lamin b1 (cline = "LMNB1")
# - histon H2B (cline = "HIST1H2BJ")
cline = "FBL"

# set up path 3D
parent_path_3d = Path("../../../../data/labelfree3D") / f"{cline}"
train_path_3d = parent_path_3d / Path("train")
holdout_path_3d = parent_path_3d / Path("holdout")

In [3]:
# define funcs and hyper-parameters:

_SEED = 2023
np.random.seed(_SEED)
random.seed(_SEED)
os.environ["PYTHONHASHSEED"] = str(_SEED)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def normalizeItensity(image):
    # Convert the image data to a floating-point data type
    img = image.astype(np.float32)
    # Convert the image to a toch Tensor
    img = torch.from_numpy(img)
    # Normalize the intensity of the image using the MONAI NormalizeIntensity transform
    normalize_intensity = monai.transforms.NormalizeIntensity()
    img_normalized = normalize_intensity(img)
    # Convert the normalized image back to a numpy array
    return img_normalized.numpy()


def transform_img(image):
    img = image.astype(np.float32) / 65535
    return img


def compare_images(path1, path2, gt=True):
    # Load the two images
    image1 = AICSImage(path1).get_image_data("ZYX")
    image2 = AICSImage(path2).get_image_data("ZYX")
    # scale to 0-1
    scaler = ScaleIntensity()
    image1 = scaler(image1).cpu().numpy()
    image2 = scaler(image2).cpu().numpy()
    mse_value = mse(image1, image2)
    ssim_value = ssim(image1, image2, data_range=1)
    psnr_value = psnr(image1, image2, data_range=1)
    corr = np.corrcoef(image1.ravel(), image2.ravel())[0, 1]
    # psnr = 10 * np.log10(1 / (mse + 0.000001))
    return mse, ssim_value, psnr_value, corr


class Normalize(Transform):
    def __init__(self):
        super().__init__()

    def __call__(self, img):
        # Rescale unint16 values to [0,1]
        result = img / 65535.0
        return result


def torch2img(x: torch.Tensor):
    # Convert  tensor to numpy array and rescale to uint16
    np_array = x.clamp_(0, 1).squeeze().cpu().detach().numpy()
    return np_array
    # return (np_array * (2**16 - 1)).astype(np.uint16)

## Model Training

Since we don't have the pre-trained model in 3D cases, we need to train from scratch. Here we use `bmshj2018-factorized_3d` model.

- pretrain on mse loss for 50 epochs.

In [None]:
!python3 ../../../../train.py -d {parent_path_3d} \
                    --train_split train \
                    --test_split holdout \
                    --aux-learning-rate 1e-3 \
                    --lambda 0.18 \
                    --epochs 50 \
                    -lr 1e-4 \
                    --batch-size 2 \
                    --model bmshj2018-factorized_3d \
                    --use_3D \
                    --quality 8 \
                    --metric mse \
                    --cuda \
                    --save_path ./pretrain.pth.tar \
                    --seed 2023

- fine-tune with another 50 epochs using ms-ssim loss.

In [None]:

!python3 ../train.py -d {parent_path_3d} \
                    --train_split train \
                    --test_split holdout \
                    --aux-learning-rate 1e-4 \
                    --lambda 220.0 \
                    --epochs 50 \
                    -lr 5e-5 \
                    --batch-size 2 \
                    --use_3D \
                    --model bmshj2018-factorized_3d \
                    --checkpoint ./pretrain.pth.tar \
                    --quality 8 \
                    --metric ms-ssim \
                    --cuda \
                    --save_path ./fine-tune.pth.tar \
                    --seed 2023

## Inference

instead of using `codec.py`, we try to directly forward the network to get the prediction. We will use sliding window inference to avoid memory overhead.

In [4]:
model = "bmshj2018-factorized_3d"
device = torch.device("cpu")
metric = "ms-ssim"
quality = 8
model_info = models[model]
checkpoint = "./fine-tune.pth.tar"
transform = Compose(
    [
        LoadImage(image_only=True),
        AddChannel(),
        Transpose(indices=(0, 3, 2, 1)),
        Normalize(),
    ]
)
state_dict = torch.load(checkpoint, map_location=device)["state_dict"]
state_dict = load_pretrained(state_dict)
net = (
    model_info(quality=quality, metric=metric, pretrained=False)
    .from_state_dict(state_dict)
    .to(device)
    .eval()
)
net.update(force=True)

# Global variable to store the call count
call_count = 0


# Custom decorator for counting function calls
def counter(func):
    def wrapper(*args, **kwargs):
        global call_count  # Access the global variable
        call_count += 1  # Increment the call count
        result = func(*args, **kwargs)
        return result

    return wrapper


@counter
def infer(img):
    """
    img: (tensor) N x C x Z x H x W
    """
    with torch.no_grad():
        out = net(img)["x_hat"]

    return out

In [5]:
input_dir = sorted(holdout_path_3d.glob("*IM.tiff"))
output_dir = holdout_path_3d / "bmshj2018-factorized_3d_ms-ssim_8"
output_dir.mkdir(parents=True, exist_ok=True)
compress_dir = output_dir

In [27]:
for i, input in tenumerate(input_dir):
    img = transform(input)[0].unsqueeze(0).unsqueeze(0).to(device)
    # [img]->img, add batch channel, to device.
    pred = sliding_window_inference(
        inputs=img,
        predictor=infer,
        device=torch.device("cpu"),
        roi_size=[32, 256, 256],
        sw_batch_size=4,
        overlap=0.1,
        mode="gaussian",
    )
    pred = torch2img(pred)
    OmeTiffWriter.save(
        pred,
        (output_dir / f"{input.stem}_decoded{input.suffix}"),
        dim_order="ZYX",
    )

100%|██████████| 105/105 [1:14:06<00:00, 42.35s/it]


## Evaluation

In [6]:
decoded_paths = sorted((compress_dir).glob("*.tiff"))
mse_value = AverageMeter()
ssim_value = AverageMeter()
psnr_value = AverageMeter()
corr_value = AverageMeter()
for i, (decode_path, raw_path) in tenumerate(zip(decoded_paths, input_dir)):
    tmp_mse, tmp_ssim, tmp_psnr, tmp_corr = compare_images(decode_path, raw_path)
    mse_value.update(tmp_mse)
    ssim_value.update(tmp_ssim)
    psnr_value.update(tmp_psnr)
    corr_value.update(tmp_corr)
print(
    "MSE:",
    mse_value.avg,
    "SSIM:",
    ssim_value.avg,
    "PSNR:",
    psnr_value.avg,
    "CORR:",
    corr_value.avg,
)

105it [07:22,  4.22s/it]

MSE: 0.33440911084474484 SSIM: 0.9220190446920535 PSNR: 28.136645543964296 CORR: 0.9483616418217692



