# 3d adaptation

## 1. generate data:

In [1]:
import os
from pathlib import Path
import numpy as np
import shutil
import pandas as pd
import quilt3
from aicsimageio import AICSImage
from aicsimageio.writers import OmeTiffWriter
import random

In [2]:
seed_value = 2023
np.random.seed(seed_value)
random.seed(seed_value)
os.environ['PYTHONHASHSEED'] = str(seed_value)

In [3]:
# turn off pandas parser warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
# tunr off ome_types parser warning
warnings.filterwarnings("ignore", category=FutureWarning)

In [32]:
cline = "FBL"
num_samples_per_cell_line = 10 # choose what you need, with roughly 80/20 training/validation split 
# set up path
parent_path = Path("/mnt/eternus/users/Yu/project/data_compression/data/labelfree_3d")
parent_path.mkdir(exist_ok=True)
raw_path = parent_path / Path("download")
raw_path.mkdir(exist_ok=True)
train_path = parent_path / Path("train")
train_path.mkdir(exist_ok=True)
holdout_path = parent_path / Path("holdout")
holdout_path.mkdir(exist_ok=True)

In [33]:
# connect to quilt and load meta table
pkg = quilt3.Package.browse(
    "aics/hipsc_single_cell_image_dataset", registry="s3://allencell"
)
meta_df_obj = pkg["metadata.csv"]
meta_df_obj.fetch(parent_path / "meta.csv")
meta_df = pd.read_csv(parent_path / "meta.csv")
# fetch the data of the specific cell line
meta_df_line = meta_df.query("structure_name==@cline")
# collapse the data table based on FOVId
meta_df_line.drop_duplicates(subset="FOVId", inplace=True)
# reset index
meta_df_line.reset_index(drop=True, inplace=True)

Loading manifest: 100%|██████████| 484465/484465 [00:14<00:00, 34.4k/s]
100%|██████████| 1.69G/1.69G [00:40<00:00, 41.8MB/s] 
Columns (33) have mixed types. Specify dtype option on import or set low_memory=False.

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [None]:
# download the images and re-slice into input (BF) and ground truth (fluorescent) images
for row in meta_df_line.itertuples():
    if row.Index >= num_samples_per_cell_line:
        break
    
    # fetch the raw image (multi-channel)
    subdir_name = row.fov_path.split("/")[0]
    file_name = row.fov_path.split("/")[1]

    local_fn = raw_path / f"{row.FOVId}_original.tiff"
    pkg[subdir_name][file_name].fetch(local_fn)

    # extract the bf and structures channel
    reader = AICSImage(local_fn)
    bf_img = reader.get_image_data(
        "ZYX", C=row.ChannelNumberBrightfield, S=0, T=0
    )
    str_img = reader.get_image_data(
        "ZYX", C=row.ChannelNumberStruct, S=0, T=0
    )
    if random.random() < 0.2:
        data_path = holdout_path
    else:
        data_path = train_path
        
    im_fn = data_path / f"{row.FOVId}_IM.tiff"
    gt_fn = data_path / f"{row.FOVId}_GT.tiff"
    OmeTiffWriter.save(bf_img, im_fn, dim_order="ZYX")
    OmeTiffWriter.save(str_img, gt_fn, dim_order="ZYX")

In [35]:
# you may remove the download folder now.
from shutil import rmtree
import os
rmtree(raw_path)
os.remove(parent_path / "meta.csv")

## 2. train:

In [None]:
!python3 ../train.py -d /mnt/data/ISAS.DE/yu.zhou/Yu/project/data_compression/experiment/3d_adaptation/data \
                    --train_split train \
                    --test_split test \
                    --aux-learning-rate 1e-3 \
                    --lambda 0.18 \
                    --epochs 50 \
                    -lr 1e-4 \
                    --batch-size 2 \
                    --model bmshj2018-factorized_3d \
                    --use_3D \
                    --quality 8 \
                    --metric {metric} \
                    --cuda \
                    --save_path /mnt/eternus/users/Yu/project/data_compression/experiment/3d_adaptation/model/fine_tune_v13.pth.tar \
                    --seed 2023

### Try with 2D:

Since 3D priminary result is not good, we again come back to 2d to see the intermidiate result. The objective is to jusitify the correctness of the network.

In [7]:
metric = "mse"

In [None]:
!python3 ../train.py -d /mnt/data/ISAS.DE/yu.zhou/Yu/project/data_compression/data/labelfree_2d \
                    --train_split train \
                    --test_split test \
                    --aux-learning-rate 1e-3 \
                    --lambda 0.18 \
                    --epochs 50 \
                    -lr 1e-4 \
                    --batch-size 2 \
                    --model bmshj2018-factorized \
                    --quality 8 \
                    --metric {metric} \
                    --cuda \
                    --save_path /mnt/eternus/users/Yu/project/data_compression/experiment/3d_adaptation/model/fine_tune_v15.pth.tar \
                    --seed 2023

### 3D fine-tuning:

seems like the bad 3d result is due to the lack of enough data, so we add more data to the training set.

In [10]:
training_path = "/mnt/eternus/users/Jianxu/projects/im2im_experiments_v1/data/labelfree3D/FBL/"
metric = "mse"

In [None]:
!python3 ../train.py -d {training_path} \
                    --train_split train \
                    --test_split holdout \
                    --aux-learning-rate 1e-3 \
                    --lambda 0.18 \
                    --epochs 50 \
                    -lr 1e-4 \
                    --batch-size 2 \
                    --use_3D \
                    --model bmshj2018-factorized_3d \
                    --quality 8 \
                    --metric {metric} \
                    --cuda \
                    --save_path /mnt/eternus/users/Yu/project/data_compression/experiment/3d_adaptation/model/3d_adaptation_v1.pth.tar \
                    --seed 2023

now we get a reasonable pretrained model to start. The next thing is to use ms-ssim loss to fine tune it. Below is the code.

In [None]:

!python3 ../train.py -d /mnt/eternus/users/Jianxu/projects/im2im_experiments_v1/data/labelfree3D/FBL/ \
                    --train_split train \
                    --test_split holdout \
                    --aux-learning-rate 1e-4 \
                    --lambda 220.0 \
                    --epochs 50 \
                    -lr 5e-5 \
                    --batch-size 2 \
                    --use_3D \
                    --model bmshj2018-factorized_3d \
                    --checkpoint /mnt/data/ISAS.DE/yu.zhou/Yu/project/data_compression/experiment/3d_adaptation/model/3d_adaptation_v1_best.pth.tar \
                    --quality 8 \
                    --metric ms-ssim \
                    --cuda \
                    --save_path /mnt/eternus/users/Yu/project/data_compression/experiment/3d_adaptation/model/3d_adaptation_v2.pth.tar \
                    --seed 2023

## 3. Inference:

instead of using `codec.py`, we try to directly forward the network to get the prediction. We will use sliding window inference to avoid memory overhead.

In [None]:
import torch
import numpy as np
from monai.transforms import (
    RandSpatialCropSamples,
    LoadImage,
    SaveImage,
    Compose,
    AddChannel,
    RepeatChannel,
    ToTensor,
    Transform,
    Transpose,
    CastToType,
    EnsureType,
    ScaleIntensityRangePercentiles,
)
from monai.inferers import sliding_window_inference
from compressai.zoo import image_models, models
from compressai.zoo.pretrained import load_pretrained
from aicsimageio import AICSImage
from aicsimageio.writers import  OmeTiffWriter

In [3]:
class Normalize(Transform):
    def __init__(self):
        super().__init__()

    def __call__(self, img):
        # Rescale unint16 values to [0,1]
        result = img / 65535.0
        return result
    
def torch2img(x: torch.Tensor): 
    # Convert  tensor to numpy array and rescale to uint16
    np_array = x.clamp_(0, 1).squeeze().cpu().detach().numpy()
    return np_array
    # return (np_array * (2**16 - 1)).astype(np.uint16)

In [4]:
model = "bmshj2018-factorized_3d"
device = torch.device('cpu')
metric = "ms-ssim"
quality = 8
model_info = models[model]
checkpoint = "/mnt/eternus/users/Yu/project/data_compression/experiment/3d_adaptation/model/3d_adaptation_v2_best.pth.tar"
# transform = Compose([LoadImage(image_only=True),AddChannel(), Transpose(indices = (0,3,1,2)), Normalize(), RandSpatialCropSamples(roi_size = (64,256,256), num_samples = 1, random_size = False, random_center = False)])
transform = Compose([LoadImage(image_only=True),AddChannel(), Transpose(indices = (0,3,2,1)), Normalize()])
state_dict = torch.load(checkpoint, map_location = device)['state_dict']
state_dict = load_pretrained(state_dict)
net = model_info(quality=quality, metric=metric, pretrained=False).from_state_dict(state_dict).to(device).eval()

def infer(img):
    """
    img: (tensor) N x C x Z x H x W
    """
    out = net(img)["x_hat"]
    return out

<class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. It will be removed in version 1.3. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead with `channel_dim='no_channel'`.


In [5]:
input = "/mnt/data/ISAS.DE/yu.zhou/Yu/project/data_compression/experiment/3d_adaptation/data/test/7632_IM.tiff"
output_dir = "/mnt/data/ISAS.DE/yu.zhou/Yu/project/data_compression/experiment/3d_adaptation/data/pred/7632_test.tiff"
img = transform(input)[0].unsqueeze(0).unsqueeze(0).to(device) #[img]->img, add batch channel, to device.
pred = sliding_window_inference(inputs=img,
                                predictor=infer,
                                device=torch.device("cpu"),
                                roi_size = [32, 256, 256],
                                sw_batch_size = 4,
                                overlap = 0.1,
                                mode = 'gaussian')

- save the img:

In [6]:
img = torch2img(img)
pred = torch2img(pred)
OmeTiffWriter.save(img, output_dir, dim_order="ZYX")
OmeTiffWriter.save(pred, output_dir.replace('test','pred'), dim_order="ZYX")



## 4. Evaluate:

test using mse, ssim, psnr and pearson corr:

In [7]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

In [8]:
def compare_image(img, pred, dimension = '3d'):
    """
    Calculate metrics: mse, ssim, psnr, corr
    """
    assert img.shape == pred.shape, "shape should be the same!"
    if dimension.lower() == '3d':
        num_pixel = img.shape[-1] * img.shape[-2] * img.shape[-3]
    elif dimension.lower() == '2d':
        num_pixel = img.shape[-1] * img.shape[-2]
    else:
        raise ValueError("Invalid dimension input. Expected '2d' or '3d'.")
    mse = np.sum((img - pred) ** 2)/(num_pixel)
    ssim_value = ssim(img, pred, data_range = 1)
    psnr_value = psnr(img, pred, data_range = 1)
    corr = np.corrcoef(img.ravel(), pred.ravel())[0, 1]
    return mse, ssim_value, psnr_value, corr


In [9]:
mse, ssim_value, psnr_value, corr = compare_image(img, pred)
print(f"""Metrics:
- MSE  : {mse:.4f}
- SSIM : {ssim_value:.4f}
- PSNR : {psnr_value:.4f}
- CORR : {corr:.4f}""")

Metrics:
- MSE  : 0.0003
- SSIM : 0.8894
- PSNR : 35.6661
- CORR : 0.9240
