In [None]:
import torch
from torchvision.transforms import v2

import numpy as np
import matplotlib.pyplot as plt
import importlib

import time

# adjust PyTorch parameter to enable more efficient use of GPU memory
import os 
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:native, garbage_collection_threshold:0.6, max_split_size_mb:64"

In [None]:
import Modules.Models.UNets as UNets
import Modules.Data.DICHeLaDataset as DICHeLaSegDataset 
import Modules.Data.ImageStackTransform as ImageStackTransform  
import Modules.TrainAndValidate.LossFunctions as LossFunctions
import Modules.Utils.OverlapTile as OverlapTile

In [None]:
## source dataset and model file path configuation
test_data_file_path_globs = [
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Test\DIC-C2DH-HeLa\DIC-C2DH-HeLa\01\t*.tif",
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Test\DIC-C2DH-HeLa\DIC-C2DH-HeLa\02\t*.tif",
]

# src_model_path = r".\Results\model_2024-07-07-18-50-42.pt"
src_model_path = r".\Results\model_2024-07-07-11-16-47.pt"

In [None]:
## create necesary data transform
data_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float,scale = False),
])

In [None]:
## create data set
importlib.reload(DICHeLaSegDataset)

color_categories = False

test_dataset = DICHeLaSegDataset.DICHeLaWeightedSegDataset(
    data_image_path_globs = test_data_file_path_globs,
    seg_image_path_globs = None,
    data_transform = data_transform,
    target_transform = None,
    common_transform = None,
    color_categories = color_categories,
)

In [None]:
## create test dataloader for loading data
test_batch_size = 1

test_dataloader = torch.utils.data.DataLoader(test_dataset, 
                                               batch_size = test_batch_size, 
                                               shuffle = False)

In [None]:
## load model

model = torch.load(src_model_path)

print(model)

In [None]:
## use parallel computing if possible
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

In [None]:
## Use overlap and tile strategy to process large image
importlib.reload(OverlapTile)

check_idx = 0
check_bath_idx = 0

sub_image_size = (128, 128)
stride = (64, 64)

model.to(device)
with torch.no_grad():
    for i_batch in range(check_bath_idx + 1):
        check_features, check_labels, check_weights = next(iter(test_dataloader))
    check_feature = check_features[check_idx,...]

    ## overlap and tile
    sub_images, sub_image_locs = OverlapTile.split(check_feature, sub_image_size, stride)

    sub_images = sub_images.to(device)
    sub_images = model(sub_images)
    sub_images = sub_images.detach().cpu()
    
    sub_images = torch.argmax(sub_images, dim = 1, keepdims = True)
    check_pred = OverlapTile.merge(sub_images, sub_image_locs, check_feature.size()[-2:])

check_feature = check_feature.detach().cpu().numpy()
check_pred = check_pred.detach().cpu().numpy()

check_feature = np.rollaxis(check_feature,0,3)
check_pred = np.rollaxis(check_pred,0,3)

plt.figure(figsize = (9,4))
plt.subplot(1,2,1)
plt.imshow(check_feature)
plt.subplot(1,2,2)
plt.imshow(check_pred)
plt.colorbar()
plt.tight_layout()
plt.show()

In [None]:
## create destination dir path
dst_plot_subdir_path = ".\Results\Plots"

if not os.path.isdir(dst_plot_subdir_path):
    os.makedirs(dst_plot_subdir_path)

In [None]:
## make nicer picture

plot_bkg_color = 0

plot_image = check_feature
plot_segmentation = check_pred

plot_dst_png_file_name = "OverlapTile.png"

plot_dst_png_file_path = os.path.join(dst_plot_subdir_path, plot_dst_png_file_name)

fig = plt.figure(figsize = (13,4))

plt.subplot(1,3,1)
plt.imshow(plot_image, cmap = "gray")
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.title("Input image")

plt.subplot(1,3,2)
plt.imshow(plot_segmentation, cmap = "Set3")
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.title("Segmentation")


plt.subplot(1,3,3)
plt.imshow(plot_image, cmap = "gray")
plot_labels = np.unique(plot_segmentation[plot_segmentation != plot_bkg_color] )
plot_canvas = np.full(plot_segmentation.shape, np.nan)
for cur_label in plot_labels:
    plot_canvas[plot_segmentation == cur_label] = cur_label
plt.imshow(plot_canvas, cmap = "Set3", alpha = 0.5)
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.title("Segmentation over input image")

plt.tight_layout()

fig.savefig(plot_dst_png_file_path, bbox_inches='tight', dpi = 300)

plt.show()

print(plot_dst_png_file_path)