In [None]:
import torch
from torchvision.transforms import v2

import numpy as np
import matplotlib.pyplot as plt
import importlib

import time

# adjust PyTorch parameter to enable more efficient use of GPU memory
import os 
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:native, garbage_collection_threshold:0.6, max_split_size_mb:64"

In [None]:
import Modules.Models.UNets as UNets
import Modules.Data.DICHeLaDataset as DICHeLaSegDataset 
import Modules.Data.ImageStackTransform as ImageStackTransform  
import Modules.Utils.Evaluations as Evaluations

In [None]:
trainvalidate_data_file_path_globs = [
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Train\DIC-C2DH-HeLa\DIC-C2DH-HeLa\01\t*.tif",
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Train\DIC-C2DH-HeLa\DIC-C2DH-HeLa\02\t*.tif"
]

trainvalidate_seg_file_path_globs = [
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Train\DIC-C2DH-HeLa\DIC-C2DH-HeLa\01_ST\SEG_ERODE\man_seg*.tif",
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Train\DIC-C2DH-HeLa\DIC-C2DH-HeLa\02_ST\SEG_ERODE\man_seg*.tif"
]

src_model_path = r".\Results\model_2024-07-07-18-50-42.pt"
# src_model_path = r".\Results\model_2024-07-07-11-16-47.pt"

In [None]:
## define data transforms
importlib.reload(ImageStackTransform)

# create common transforms
common_transform = None # no need to perform common transforms

data_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float, scale = False),
    v2.Resize(size = 512, antialias=True,),
])

target_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.long, scale = False),
    v2.Resize(size = 512, antialias=True,),
    v2.Lambda(lambda x: torch.squeeze(x, dim = 0)),
])

In [None]:
## create data set
importlib.reload(DICHeLaSegDataset)

color_categories = False

trainvalidate_dataset = DICHeLaSegDataset.DICHeLaSegDataset(
    data_image_path_globs = trainvalidate_data_file_path_globs,
    seg_image_path_globs = trainvalidate_seg_file_path_globs,
    data_transform = data_transform,
    target_transform = target_transform,
    common_transform = common_transform,
    color_categories = color_categories,
)

print(f"Tot data size = {len(trainvalidate_dataset)}")

In [None]:
## split train and validate dataset 
data_split_rand_genenrator = torch.Generator().manual_seed(0)
data_split_ratios = [0.8, 0.2]

train_dataset, validate_dataset = torch.utils.data.random_split(
    trainvalidate_dataset, 
    data_split_ratios, 
    generator = data_split_rand_genenrator)

print(f"Train data size = {len(train_dataset)}")
print(f"Validate data size = {len(validate_dataset)}")

In [None]:
## check data and label
check_idx = 0

check_dataset = validate_dataset

check_data, check_label = check_dataset[check_idx]
print(check_data.size())
print(check_label.size())

check_data = check_data.numpy()
check_label = check_label.numpy()

plt.figure(figsize = (7,3))

plt.subplot(1,2,1)
plt.imshow(np.rollaxis(check_data,0,3))
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.title("Data")

plt.subplot(1,2,2)
plt.imshow(check_label, cmap = "Set3")
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.title("Target")

plt.tight_layout()
plt.show()

In [None]:
## Create data loader to training and validation dataset

# NOTE: Use a very small batch size here to fit the data into my small GPU memory 
train_bath_size = 16
validate_batch_size = 16

train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                               batch_size = train_bath_size, 
                                               shuffle = False)
validate_dataloader = torch.utils.data.DataLoader(validate_dataset, 
                                               batch_size = validate_batch_size, 
                                               shuffle = False)

In [None]:
## load model

model = torch.load(src_model_path)

print(model)

In [None]:
## use parallel computing if possible
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

In [None]:
## interate through dataset and evalution model in terms of IOU
importlib.reload(Evaluations)

src_dataloader = validate_dataloader
bkg_val = 0

mean_ious = Evaluations.mean_iou_over_dataset(
    model = model,
    src_dataloader = src_dataloader,
    device = device,
    bkg_val = bkg_val,
)

print("IOUs{class:val} :")
print(mean_ious)