In [None]:
import torch
from torchvision.transforms import v2

import numpy as np
import matplotlib.pyplot as plt
import importlib

import time

# adjust PyTorch parameter to enable more efficient use of GPU memory
import os 
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:native, garbage_collection_threshold:0.6, max_split_size_mb:64"

In [None]:
import Modules.Models.UNets as UNets
import Modules.Data.DICHeLaDataset as DICHeLaSegDataset 
import Modules.Data.ImageStackTransform as ImageStackTransform  
import Modules.TrainAndValidate.TrainAndValidate as TrainAndValidate
import Modules.TrainAndValidate.LossFunctions as LossFunctions

In [None]:
trainvalidate_data_file_path_globs = [
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Train\DIC-C2DH-HeLa\DIC-C2DH-HeLa\01\t*.tif",
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Train\DIC-C2DH-HeLa\DIC-C2DH-HeLa\02\t*.tif"
]

trainvalidate_seg_file_path_globs = [
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Train\DIC-C2DH-HeLa\DIC-C2DH-HeLa\01_ST\SEG_ERODE\man_seg*.tif",
    r"E:\Python\DataSet\TorchDataSet\DIC-C2DH-HeLa\Train\DIC-C2DH-HeLa\DIC-C2DH-HeLa\02_ST\SEG_ERODE\man_seg*.tif"
]

In [None]:
## define data transforms
importlib.reload(ImageStackTransform)

# create common transforms
common_transform = v2.Compose([
    ImageStackTransform.ElasticTransform(fills = ["mean", "min", "min"], alpha = 50, sigma = 5),
    ImageStackTransform.RandomRotation(fills = ["mean", "min", "min"], degrees = [-45, 45]),
    ImageStackTransform.RandomCrop(
        size = (256,256), 
        pad_if_needed = True, 
        padding_mode = "reflect",
        
    ),
    ImageStackTransform.RandomHorizontalFlip(p = 0.5),
    ImageStackTransform.RandomVerticalFlip(p = 0.5),
    
])

## NOTE: scaling and normalization is not always helpful. Depending on the dataset, sometimes the will shift the distribution of the data and causing problem in inference 
## NOTE: if source data's gray scale is well controled, no need to scale and normalize

data_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float, scale = False),
    v2.Resize(size = 512, antialias=True,),
])

target_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.long, scale = False),
    v2.Resize(size = 512, antialias=True,),
    v2.Lambda(lambda x: torch.squeeze(x, dim = 0)),
])

In [None]:
## create data set
importlib.reload(DICHeLaSegDataset)

color_categories = False

trainvalidate_dataset = DICHeLaSegDataset.DICHeLaSegDataset(
    data_image_path_globs = trainvalidate_data_file_path_globs,
    seg_image_path_globs = trainvalidate_seg_file_path_globs,
    data_transform = data_transform,
    target_transform = target_transform,
    common_transform = common_transform,
    color_categories = color_categories,
)

print(f"Tot data size = {len(trainvalidate_dataset)}")

In [None]:
## split train and validate dataset 
data_split_rand_genenrator = torch.Generator().manual_seed(0)
data_split_ratios = [0.8, 0.2]

train_dataset, validate_dataset = torch.utils.data.random_split(
    trainvalidate_dataset, 
    data_split_ratios, 
    generator = data_split_rand_genenrator)

print(f"Train data size = {len(train_dataset)}")
print(f"Validate data size = {len(validate_dataset)}")

In [None]:
## check data and label
check_idx = 6

check_dataset = train_dataset

check_data, check_label = check_dataset[check_idx]
print(check_data.size())
print(check_label.size())

check_data = check_data.numpy()
check_label = check_label.numpy()

plt.figure(figsize = (7,2))

plt.subplot(1,2,1)
plt.imshow(np.rollaxis(check_data,0,3))
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.title("Data")

plt.subplot(1,2,2)
plt.imshow(check_label, cmap = "Set3")
plt.colorbar()
plt.xticks([])
plt.yticks([])
plt.title("Target")

plt.tight_layout()
plt.show()

In [None]:
## Create data loader to training and validation dataset

# NOTE: Use a very small batch size here to fit the data into my small GPU memory 
train_bath_size = 8
validate_batch_size = 8

train_dataloader = torch.utils.data.DataLoader(train_dataset, 
                                               batch_size = train_bath_size, 
                                               shuffle = True)
validate_dataloader = torch.utils.data.DataLoader(validate_dataset, 
                                               batch_size = validate_batch_size, 
                                               shuffle = True)

In [None]:
## Load model
importlib.reload(UNets)

in_channels = 1 # input image number of channels
out_channels = 2 # output segmentation number of classes
layer_nof_channels = [32, 64, 128, 256, 512]

model = UNets.Simple3LayerUNet(
    in_channels = in_channels,
    out_channels = out_channels,
    layer_nof_channels = layer_nof_channels,
)

print(model)

In [None]:
## use parallel computing if possible
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

print(f"Using {device} device")

In [None]:
## quickly check if model can run

model.to("cpu")
with torch.no_grad():
    check_features, check_labels = next(iter(train_dataloader))
    check_features = check_features.to("cpu")
    model.eval()
    print(model(check_features).size())

In [None]:
## training configuration
importlib.reload(LossFunctions)

loss_func = LossFunctions.CrossEntropyLoss(reduction = "mean")
# loss_func = torch.nn.CrossEntropyLoss()

learning_rate = 2E-5
nof_epochs = 400

train_parameters = model.parameters()
optimizer = torch.optim.Adam(train_parameters, lr = learning_rate)

# scheduler = torch.optim.lr_scheduler.StepLR(
#     optimizer = optimizer,
#     step_size = 80,
#     gamma = 0.1,
# )

stop_lr = 1E-10

In [None]:
## training loop
importlib.reload(TrainAndValidate)

model = model.to(device)

learning_rates = torch.zeros((nof_epochs,))
train_losses = torch.zeros((nof_epochs,))
validate_losses = torch.zeros((nof_epochs,))

end_nof_epochs = 0

for i_epoch in range(nof_epochs):
    print(f" ------ Epoch {i_epoch} ------ ")

    end_nof_epochs = i_epoch
    
    cur_lr = optimizer.param_groups[0]['lr'];

    if cur_lr < stop_lr:
        break
    
    print(f"current lr = {cur_lr}")
    learning_rates[i_epoch] = cur_lr

    cur_train_loss = TrainAndValidate.train_one_epoch(
        model = model,
        train_loader = train_dataloader,
        loss_func = loss_func,
        optimizer = optimizer,
        device = device,
    )

    cur_validate_loss = TrainAndValidate.validate_one_epoch(
        model = model,
        validate_loader = validate_dataloader,
        loss_func = loss_func,
        device = device,
    )

    # scheduler.step()
    
    train_losses[i_epoch] = cur_train_loss
    validate_losses[i_epoch] = cur_validate_loss
    
    print("\n")

In [None]:
# plot training and validation metrics
plt.figure()
plt.subplot(3,1,1)
plt.plot(train_losses, label = "train loss")
plt.yscale("log")
plt.legend()
plt.subplot(3,1,2)
plt.plot(validate_losses, label = "validation rate")
plt.yscale("log")
plt.legend()
plt.subplot(3,1,3)
plt.plot(learning_rates, label = "learning rate")
plt.yscale("log")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
## check learning result
check_idx =0
check_batch_idx = 0
check_dataloader = validate_dataloader

model.to(device)
with torch.no_grad():
    check_features = None
    check_labels = None
    
    for i_batch in range(check_batch_idx + 1):
        check_features, check_labels = next(iter(check_dataloader))
    
    check_features = check_features.to(device)
    check_preds = model(check_features)

check_features = check_features.detach().cpu()
check_preds = check_preds.detach().cpu()
check_labels = check_labels.detach().cpu()

check_preds = torch.argmax(check_preds, dim = 1)
# check_preds = check_preds[1,...]


check_feature = check_features[check_idx,...].numpy()
check_pred = check_preds[check_idx,...].numpy()
check_label = check_labels[check_idx,...].numpy()

check_feature = np.rollaxis(check_feature,0,3)

plt.figure(figsize = (7,2))

plt.subplot(1,3,1)
plt.imshow(check_feature)
plt.title("input")

plt.subplot(1,3,2)
plt.imshow(check_pred)
plt.colorbar()
plt.title("prediction")

plt.subplot(1,3,3)
plt.imshow(check_label)
plt.colorbar()
plt.title("ground truth")

plt.tight_layout()
plt.show()

In [None]:
## save model and model parameters

dst_dir_path = r".\Results"
if not os.path.isdir(dst_dir_path):
    os.makedirs(dst_dir_path)

dst_model_name = "model_" + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
dst_model_file_name = dst_model_name + ".pt"
dst_modelstate_file_name = dst_model_name + "_state.pt"

dst_model_file_path = os.path.join(dst_dir_path, dst_model_file_name)
torch.save(model, dst_model_file_path)
print("model saved to: " + dst_model_file_path)

dst_modelstate_file_path = os.path.join(dst_dir_path, dst_modelstate_file_name)
torch.save(model.state_dict(), dst_modelstate_file_path)
print("model state saved to: " + dst_modelstate_file_path)