#### load a sample image --> preprocess(chunk) --> prediction(on chunks) --> re-construct the full image 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tifffile
import mrcfile
from pathlib import Path
import tensorflow as tf
from tensorflow.keras.models import load_model
from models import DFCAN  # Assuming this defines your model
from loss_functions import mse_ssim
import os
from csbdeep.data.generate import norm_percentiles, sample_percentiles

# Set up paths
root_dir = '../F-actin'
model_dir = Path(root_dir) / 'SRModel_1400_ready'

mrc_file = f'/share/klab/argha/Microtubules/Test/Cell_019/RawSIMData_level_05.mrc'
output_dir = Path.cwd() / 'SR_Model_plots_and_results'
Path(output_dir).mkdir( exist_ok=True)

: 

##### Note:  
- library missmatch on typoing extension. reinstall it for prediction.
- revert it back for training.


In [None]:

# Load the model
if len(os.listdir(model_dir)) > 0:
    print(f"Loading model from {model_dir}")
    with tf.keras.utils.custom_object_scope({'mse_ssim': mse_ssim}):
        trained_model = load_model(model_dir)
else:
    raise ValueError("Model directory is empty. Please check the model path.")

# Load the .mrc image
with mrcfile.open(mrc_file, mode='r') as mrc:
    full_image = mrc.data

# Transform the image to shape [502, 502, 9]
full_image = np.transpose(full_image, (1, 2, 0))
## have to do the image processing. whats done for creating the train image

# Print the full image shape for debugging
print(f"Transformed full image shape: {full_image.shape} :: {full_image.dtype} :: {type(full_image)} : min value: {np.min(full_image)} :: max value: {np.max(full_image)}")
# Transformed full image shape: (502, 502, 9) :: uint16 :: <class 'numpy.ndarray'>

## expected train data shape: (TensorShape([495, 128, 128, 9]),

# Normalize the image to [0, 1]
# full_image = norm_percentiles(full_image, 99.5)
# print(f"Transformed full image shape: {full_image.shape} :: {full_image.dtype} :: {type(full_image)}")

#################  TO DO ###########################
# ''' Need to percentile normalize the data according to the dat set creation rules'''

# ' PERCENTILE NORMALIZATION FUNCTION'
def prctile_norm(x, min_prc=1, max_prc=99.9):
    y = (x-np.percentile(x, min_prc))/(np.percentile(x, max_prc)-np.percentile(x, min_prc)+1e-7)
    return y

full_image_percentile = prctile_norm(full_image, 0, 100)
print(f" full_image_percentile  Transformed full image shape: {full_image_percentile.shape} :: {full_image_percentile .dtype} :: {type(full_image_percentile )} : min value: {np.min(full_image_percentile )} :: max value: {np.max(full_image_percentile )}")



In [None]:
# Example usage
normalized_image = full_image_percentile   # Replace this with your actual image array
# normalized_image = percentile_normalize(image)
# print(f'normalized_image: {normalized_image.shape} :: {normalized_image.dtype} :: {type(normalized_image)}  min {np.min(normalized_image)} :: max  {np.max(normalized_image)}')

## visualioze the image
plt.figure(figsize=(10, 10))
plt.imshow(normalized_image[...,1])
plt.title(f'Normalized Image shape: {normalized_image.shape} \n :: dtype: {normalized_image.dtype} ::\n min_value: {np.min(normalized_image)} :: \nmax_value: {np.max(normalized_image)}')

plt.axis('off')
plt.savefig(f'{output_dir }/sr_normalized_image.png')
plt.show()
image=normalized_image

#### step 1: chunk the images as 128 x 128 x 9 

In [None]:


def chunk_image(image, chunk_size):

    chunks = []
    chunk_coords = []
    image_height, image_width = image.shape[:2]

    # Iterate over the image with steps of chunk_size
    for y in range(0, image_height, chunk_size):
        for x in range(0, image_width, chunk_size):
            #print(y , x)
            # Calculate end coordinates
            y_end = min(y + chunk_size, image_height)
            x_end = min(x + chunk_size, image_width)
            #print(y_end, x_end)
            if y== 384:
              y = 502-128
            if x == 384:
              x = 502-128


            # Extract chunk
            chunk = image[y:y_end, x:x_end]
            # chunk = prctile_norm(chunk)
            #print(chunk.shape)
            chunks.append(chunk)
            chunk_coords.append((x, y))

    return chunks, chunk_coords

resized_image = normalized_image

chunk_size = 128
chunks, chunk_coords = chunk_image(resized_image, chunk_size)
print(f'after chunkinh: {len(chunks)} :: {len(chunk_coords)} :: {type(chunks)} {type(chunk_coords)}')
chunks= np.array(chunks).astype(np.float32)

print(f'chunks: {chunks.shape} :: {chunks.dtype} :: {type(chunks)}')


In [None]:

# Visualize the chunks in a grid layout
num_chunks = chunks.shape[0]
ncols = int(np.ceil(np.sqrt(num_chunks)))
nrows = int(np.ceil(num_chunks / ncols))
plt.figure(figsize=(10, 10))
plt.suptitle('Chunks')
print(f'chunks: {chunks.shape} {type(chunks)} : chunk_coords: {len(chunk_coords)}{type(chunk_coords)} ')
for i, (chunk, (x_start, y_start)) in enumerate(zip(chunks, chunk_coords)):
    print(f'chunk.shape: {chunk.shape} :: {type(chunk)} , {x_start}, {y_start}')
    print(f'min value: {np.min(chunk)} :: max value: {np.max(chunk)} dtype: {chunk.dtype}')
    plt.subplot(nrows, ncols, i + 1)
    plt.imshow(chunk[...,1])
    plt.title(f"({x_start}, {y_start}):S {chunk.shape}", fontsize=8)
    plt.axis('off')
    plt.savefig(f'{output_dir }/SR_chunk_.png')
    # plt.tight_layout()
plt.show()

#### step 2: Upscale each chunk (prediction from SR) 


In [None]:
def upscale_chunks(chunks, chunk_coords, upscale_factor=2):


    upscaled_chunks = []
    upscaled_chunk_coords = []
    # will be feed into prediction from SR
    chunks = np.array(chunks)
    print(chunks.shape)
    predictions = trained_model.predict(chunks)


    print(f'predictions: {predictions.shape} {type(predictions)} {predictions.dtype} min_value: {np.min(predictions)} max_value: {np.max(predictions)}')
    predictions=predictions
    
    for i, pred in enumerate(predictions):
        # Convert the image back to the original shape if needed
        pred = tf.squeeze(pred, axis=-1)  # Remove the last channel if it's 1
        output_path = f'{output_dir }/' + f'predicted_image_numpy{i+1}.tif'
        tifffile.imwrite(output_path, pred.numpy())

    for i, chunk in enumerate(predictions):
        print(f'output of teh prediction each chunk: {chunk.shape}')
        pred = tf.squeeze(chunk, axis=-1)
        print(f' chunk_pred_shape : {pred.shape}')
        
        upscaled_chunks.append(pred)

        x_start, y_start = chunk_coords[i]
        upscaled_chunk_coords.append((x_start * upscale_factor, y_start * upscale_factor))

    return upscaled_chunks, upscaled_chunk_coords


#### Step 3: Upscale each chunk

In [None]:


upscaled_chunks, upscaled_chunk_coords = upscale_chunks(chunks, chunk_coords)
print(f'see teh shapes and data :  {len(upscaled_chunks)} :: {len(upscaled_chunk_coords)} :: {type(upscaled_chunks)} :: {type(upscaled_chunk_coords)}')


upscaled_chunks= np.array(upscaled_chunks)
# upscaled_chunks_viz = upscaled_chunks.
print(f'upscaled_chunks: {upscaled_chunks.shape} {type(upscaled_chunks)} : {upscaled_chunks.dtype} : min_value: {np.min(upscaled_chunks)} max_value : {np.max(upscaled_chunks)}') 

# Visualize the upscaled chunks in a grid layout
num_upscaled_chunks = upscaled_chunks.shape[0]
ncols_upscaled = int(np.ceil(np.sqrt(num_upscaled_chunks)))
nrows_upscaled = int(np.ceil(num_upscaled_chunks / ncols_upscaled))

plt.figure(figsize=(10, 10))
plt.suptitle('256x256 Upscaled Chunks')
for i, (upscaled_chunk, (x_start, y_start)) in enumerate(zip(upscaled_chunks, upscaled_chunk_coords)):
    print(f'min value: {np.min(upscaled_chunk)} :: max value: {np.max(upscaled_chunk)} dtype:  {upscaled_chunk.dtype}')
    plt.subplot(nrows_upscaled, ncols_upscaled, i + 1)
    plt.imshow(upscaled_chunk)
    plt.title(f'({x_start}, {y_start}):S {upscaled_chunk.shape}', fontsize=8)  # Display coordinates as title
    plt.axis('off')
    plt.savefig(f'{output_dir }/SR_upscaled_chunk.png')
    #tifffile.imwrite(f'{output_dir }/prediction _ upscale_chunk _ {i} _ normalized.tif', upscaled_chunk)
    # plt.tight_layout()
plt.show()

#### Step 4: Reassemble the upscaled chunks into a 1004x1004 grid

In [None]:
# Assuming 'resized_image' has the shape (height, width, channels)
target_size = 1004
channels = 1  # Number of channels in the image

# Initialize the final image with zeros
final_image = np.zeros((target_size, target_size, channels))

for i, (upscaled_chunk, (x_start, y_start)) in enumerate(zip(upscaled_chunks, upscaled_chunk_coords)):
    # Ensure the chunk has a third dimension (channels)
    if len(upscaled_chunk.shape) == 2:  # Shape is (256, 256)
        upscaled_chunk = np.expand_dims(upscaled_chunk, axis=-1)  # Shape becomes (256, 256, 1)

    # Now, safely check the number of channels
    if upscaled_chunk.shape[2] != channels:
        raise ValueError(f"Chunk has {upscaled_chunk.shape[2]} channels, expected {channels} channels.")
    
    x_end = min(x_start + upscaled_chunk.shape[1], target_size)
    y_end = min(y_start + upscaled_chunk.shape[0], target_size)

    # Ensure that x_end and y_end are valid indices
    if x_end > x_start and y_end > y_start:
        final_image[y_start:y_end, x_start:x_end, :] = np.maximum(
            final_image[y_start:y_end, x_start:x_end, :],
            upscaled_chunk[:y_end-y_start, :x_end-x_start, :]
        )

tifffile.imwrite(f'{output_dir }/prediction _ fullimage.tif', final_image.transpose(2,0,1))

# Visualize the final image
plt.figure(figsize=(10, 10))
plt.title(f'Reassembled Image  :: shape :: {final_image.shape} \n final_imahge:dtype :: {final_image.dtype} :: \n min_value: {np.min(final_image)} ::\n max_value: {np.max(final_image)}')
plt.imshow(final_image)

plt.axis('off')
plt.savefig(f'{output_dir }/SR_final_image.png')
plt.show()


