In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Jun 15 07:15:42 2023

@author: mbootsma

using 3 channel (hoechst, inclusion, exclusion) input with float32
normalization is log-transform, z-scale, min-max
"""

import sys
sys.path.append('../src/') 
import SEE_TC as ctc

import os
import pandas as pd
import numpy as np
import nd2
from skimage import restoration
from sklearn.model_selection import KFold
import tensorflow as tf

# USER VARS
BIN_ID = "cells" #define the binary layer name to extract
CHANNEL_1 = "350" #Hoechst
N_EPOCH = 1000#define the number of epochs to train for

path = "/mnt/Data02/Bootsma/project_CTC_ML/CTC_ML/"
GPU_ids = ["3"]  # Specify the GPUs you want to use
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
gpu_index = '3'  # Change this to the desired GPU index

In [None]:
N_CHANNELS = 3 #
SIZE_BATCH = 16 # batch size for gpu
SIZE_IMG = 32 # tile size to yield
BACKGROUND_PRP = 0.5 # rate at which to thin background
c=0 # Hoechst is on channel 0

In [None]:
img_path = "../train_data/01_train_UNET_slide_00.nd2"
img_HD_path = "../train_data/01_train_UNET_slide_01.nd2"

In [None]:
# Add a new column 'value_ID' and populate it with elements
df_results = pd.DataFrame({'rownames' : ['val_loss', 'val_accuracy']})


with nd2.ND2File(img_path) as metadata:
    ################### read and pre-process
    img = ctc.extract_image_3Channel(img_path, CHANNEL_1,CHANNEL_1,CHANNEL_1)
    img = img.transpose(1,2,0) # place channel last

    img_i = img.astype(np.float32) # convert for processing
    # Clip each channel prior to FF estimation
    print("Clipping...") # clip outliers
    
    image_clipped = ctc.adaptive_clipping_with_iqr(img_i[..., c]) 
    
    print("Flattening...") # Estimate flat field for each clipped channel
    flat_field_c = ctc.estimate_flat_field(image_clipped)
    flat_field_c = flat_field_c/np.max(flat_field_c) # normalize the correction mask

    image_flat = img_i[..., c] / flat_field_c # APPLY FLAT FIELD CORRECTION HERE

    # Normalize corrected image to the original dtype range (e.g., 0-65535 for uint16)
    image_flat = (image_flat * 65535 / np.max(image_flat)).astype(np.uint16)

    print("Subtracting background...") # Subtract background
    x = image_flat
    background_x = restoration.rolling_ball(x, radius = 21) # radius should be ~2x object size...
    x_rb = x-background_x

    img_2_seg = np.stack((x_rb,x_rb,x_rb), axis = -1) # extract single sample

    img_2_seg[np.where(img_2_seg == 65535)] = 65534
    img_2_seg = img_2_seg+1
    img = np.log10(img_2_seg)                   
    mean = np.mean(img_2_seg)
    std_dev = np.std(np.float64(img_2_seg), dtype=np.float64) # run as float64 to prevent overflow
    std_dev = np.float16(std_dev) # convert back to 16 for calculation
    img_2_seg = (img_2_seg-mean)/std_dev

    ##################
    img_patches = ctc.patchify_image_3Channel(img_2_seg, SIZE_IMG)# make patches of each layer                            
    binary = ctc.extract_binary(img_path, BIN_ID)# load binary layer                           
    binary_patches = ctc.patchify_image(binary, SIZE_IMG)        
    
    img_patches, binary_patches = ctc.thin_background_rankedBackground(img_patches, binary_patches, BACKGROUND_PRP) #reduce the proportion of images with no cells present
    training_img = img_patches
    training_binary = binary_patches
            
######################## TRAIN MODEL ######################
# split data into train and test (last 10% of image)
n_tiles = training_img.shape[0]
split_idx = int(n_tiles*0.9)

test_img = training_img[split_idx+1:n_tiles,:,:,:]
test_binary = training_binary[split_idx+1:n_tiles,:,:,:]

training_img = training_img[0:split_idx,:,:,:]
training_binary = training_binary[0:split_idx,:,:,:]

# define fold indices for k-fold training
n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

# train model
print("input shape (img/bin training/validation):")
print(training_img.shape, training_binary.shape, test_img.shape, test_binary.shape)

for fold, (train_index, val_index) in enumerate(kf.split(training_img, training_binary)):

    print(f"Fold {fold + 1}/{n_splits}")                
    X_train, X_val = training_img[train_index], training_img[val_index]
    y_train, y_val = training_binary[train_index], training_binary[val_index]

    ###################### Load High Density samples after defining folds so we always have 50-50 split on them
    print("inserting high density (HD) samples")

    img_HD = ctc.extract_image_3Channel(img_HD_path, CHANNEL_1,CHANNEL_1,CHANNEL_1)
    img_HD = img_HD.transpose(1,2,0) # place channel last

    ######################
    img_i = img_HD.astype(np.float32) # convert for processing

    print("Clipping...") # clip outliers
    c=0
    image_clipped = ctc.adaptive_clipping_with_iqr(img_i[..., c])     
    print("Flattening...") # Estimate flat field for each clipped channel
    flat_field_c = ctc.estimate_flat_field(image_clipped)
    flat_field_c = flat_field_c/np.max(flat_field_c) # normalize the correction mask
    image_flat = img_i[..., c] / flat_field_c # APPLY FLAT FIELD CORRECTION HERE
    image_flat = (image_flat * 65535 / np.max(image_flat)).astype(np.uint16)

    print("Subtracting background...") # Subtract background
    x = image_flat
    background_x = restoration.rolling_ball(x, radius = 21) # radius should be ~2x object size...
    x_rb = x-background_x

    img_HD = np.stack((x_rb,x_rb,x_rb), axis = -1) # extract single sample

    img_HD[np.where(img_HD == 65535)] = 65534
    img_HD = img_HD+1
    img = np.log10(img_HD)                   
    mean = np.mean(img_HD)
    std_dev = np.std(np.float64(img_HD), dtype=np.float64) # run as float64 to prevent overflow
    std_dev = np.float16(std_dev) # convert back to 16 for calculation
    img_HD = (img_HD-mean)/std_dev
    #############
    img_HD_patches = ctc.patchify_image_3Channel(img_HD, SIZE_IMG)# make patches of each layer
    binary_HD = ctc.extract_binary(img_HD_path, BIN_ID)# load binary layer                   
    binary_HD_patches = ctc.patchify_image(binary_HD, SIZE_IMG)  
    img_HD_patches, binary_HD_patches = ctc.thin_background_rankedBackground(img_HD_patches, binary_HD_patches, BACKGROUND_PRP) #reduce the proportion of images with no cells present
    print("HD patch count: ")
    print(img_HD_patches.shape)
    # split high denisty data in half. use for train/validation
    train_img_HD, train_binary_HD, test_img_HD, test_binary_HD = ctc.split_array_trainTest(img_HD_patches, binary_HD_patches, 0.5) 

    # insert the high density training data evenly into each fold
    X_train = np.concatenate((X_train, train_img_HD), axis=0)
    y_train = np.concatenate((y_train, train_binary_HD), axis=0)

    X_val = np.concatenate((X_val, test_img_HD), axis=0)
    y_val = np.concatenate((y_val, test_binary_HD), axis=0)
    print("input shape (img/bin training/validation):")
    print(X_train.shape, y_train.shape, X_val.shape, y_val.shape)
    ######################
    with tf.device(f'/GPU:{gpu_index}'):
        model = ctc.unet_small_objects_deep(SIZE_IMG, SIZE_IMG, N_CHANNELS)  # Create a new model for each fold
        callbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=200, mode = 'min', restore_best_weights= True)]

        seed = 42
        np.random.seed = seed

        model.fit(X_train, y_train, 
                            validation_data = (X_val, y_val), 
                            batch_size = SIZE_BATCH, epochs = N_EPOCH, verbose = 1, shuffle = True, callbacks = callbacks)    
        
        # Evaluate the model on the current fold's validation data
        scores = model.evaluate(test_img, test_binary, verbose=0)
    print(f"Validation loss: {scores[0]}, Validation accuracy: {scores[1]}")
