In [1]:
import timeit                                   # Used in timer_function to test code execution times
import numpy as np                              # Numerical processing library
import matplotlib.pyplot as plt                 # Matplotlib plotting library

###Helper functions 
def drag_race(number, repeats, functions, args_lists, sig_figs=3, share_args=False):
    """
    Runs timeit.repeat on each function in the the input list 'functions' a specified number of times and 
    prints the minimum runtime for each func
    
    # Arguments:
    number:     Number of times to run the functions per repeat.
    repeats:    Number of times to time the function (each time function is timed it is run 'number' times).
    functions:  The functions to be timed, in format [function_name_1, function_name_2].
    args_lists: Arguments to pass to the functions using format [[F1-arg1, F1-arg2], [F2-arg1, F2-arg2, F2-arg3]] Unless all 
                functions take same arguments in which case pass [[shared_arg1, shared_arg2]] and then also set share_args=True.
    sig_figs:   Sets the number of significant figures for the printed results readout [Default=3].
    share_args: If all functions share the same argumnets then passing share_args=True allows user to only input them once and they are used for all fucntions [Default=False].
    
    # Returns:
    No values are returned instead function automatically prints statment with function names and min runtimes.
    """
    
    if share_args == True:
        args_lists = args_lists * len(functions)  # If share args is used the single set of arguments is copied for the numebr of function requiring them
        
    for i, function in enumerate(functions):
        
        run_times = timeit.repeat(lambda: function(*args_lists[i]), number=number, repeat=repeats)
        min_time = min(run_times)/number

        print("\nFunction: {}\nRuntime: {} ms (minimum result over {} runs)".format(function.__name__, round(min_time*1000, sig_figs), repeats))

In [2]:
#%% - Dependencies
# External Libraries
import numpy as np 
import matplotlib.pyplot as plt
import torch
import os
import importlib.util
import re

# Enable for dynamic plots
#%matplotlib notebook
# Enable for still plots
%matplotlib inline 

#%% - functions
# Import and prepare Autoencoder model
def setup_encoder_decoder(latent_dim,pretrained_model_path, AE_file_folder_path):
    #from DC3D_Autoencoder_V1 import Encoder, Decoder   # make this programatic from the model folder as it also contains the backup AE file

    def import_encoder_decoder(folder_path):
        module_name_pattern = r"DC3D_Autoencoder_V\w+\.py"
        module_path_pattern = os.path.join(folder_path, module_name_pattern)

        matching_files = [file for file in os.listdir(folder_path) if re.match(module_name_pattern, file)]
        if not matching_files:
            raise ImportError(f"No DC3D_Autoencoder module found in {folder_path}\n")

        module_name = matching_files[0][:-3]
        module_path = os.path.join(folder_path, f"{module_name}.py")
        print(f"Loaded {module_name}\n")
        spec = importlib.util.spec_from_file_location(module_name, module_path)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)

        return module.Encoder, module.Decoder
    
    Encoder, Decoder = import_encoder_decoder(AE_file_folder_path)
    encoder = Encoder(encoded_space_dim=latent_dim, fc2_input_dim=128, encoder_debug=False, record_activity=False)
    decoder = Decoder(encoded_space_dim=latent_dim, fc2_input_dim=128, decoder_debug=False, record_activity=False)
    encoder.double()   
    decoder.double()

    # load the full state dictionary into memory
    full_state_dict = torch.load(pretrained_model_path)

    # load the state dictionaries into the models
    encoder.load_state_dict(full_state_dict['encoder_state_dict'])
    decoder.load_state_dict(full_state_dict['decoder_state_dict'])

    encoder.eval()                                   #.eval() is a  switch for some specific layers/parts of the model that behave differently during training and inference (evaluating) time. For example, Dropouts Layers, BatchNorm Layers etc. You need to turn off them during model evaluation, and .eval() will do it for you. In addition, the common practice for evaluating/validation is using torch.no_grad() in pair with model.eval() to turn off gradients computation
    decoder.eval()    
    return encoder, decoder


# New function to add n noise points to the 2d numpy array
def add_noise_points_nonorm(input_image, noise_points=100, time_dimension=100):
    image = input_image.copy()
    
    if noise_points > 0:
        #Find dimensions of input image 
        x_dim = image.shape[0]
        y_dim = image.shape[1]

        #Create a list of random x and y coordinates
        x_coords = np.random.randint(0, x_dim, noise_points)
        y_coords = np.random.randint(0, y_dim, noise_points)

        # Iterate through noise_points number of random pixels to noise
        for i in range(noise_points):

            # Add a random number between 0 and time_dimension to the pixel 
            image[x_coords[i], y_coords[i]] = np.random.uniform(0, time_dimension)
    return image

def image_loader(input_image_path, noise_points, time_dimension=100):
    ### Load image from path 
    input_image = np.load(input_image_path)
    
    # Add noise if noise_points is greater than 1
    noisy_image = add_noise_points_nonorm(input_image, noise_points, time_dimension)
    
    # Turn input image into tensor and add two extra dimesnions to start of array so shape goes from (x,y) to (1,1,x,y) to represent batch and channel dims
    noisy_image_tensor = torch.tensor(noisy_image)
    noisy_image_tensor.double()  
    
    return input_image, noisy_image_tensor

def custom_normalisation(data, reconstruction_threshold, time_dimension=100):
    data = ((data / time_dimension) / (1/(1-reconstruction_threshold))) + reconstruction_threshold
    for row in data:   ###REPLACE USING NP.WHERE
        for i, ipt in enumerate(row):
            if ipt == reconstruction_threshold:
                row[i] = 0
    return data

def custom_renormalisation(data, reconstruction_threshold, time_dimension=100):
    data = np.where(data > reconstruction_threshold, ((data - reconstruction_threshold)*(1/(1-reconstruction_threshold)))*(time_dimension), 0)
    return data

def build_3d(data, time_dimension=100):
    # Apply the processing functions to the data
    shape = data.shape
    processed_data = np.zeros((shape[0], shape[1], time_dimension))

    i, j = np.nonzero(data)           # Compute the indices for the non-zero elements of data in the third dimension of array_3D
    k = data[i, j].astype(int)        # Convert the values to integers
    processed_data[i, j, k-1] = 1     # array_3D is now a 3D numpy array of size n by m by time_dimension_max, with the non-zero values from the original 2D array set to 1 in the appropriate location 
    return(processed_data)

def multi_3d_plotting(input_image, noised_image, rec_image, masked_rec_image, time_dimension=100, show_2d_projection=True):
    # Create a 3D plot with two subplots
    fig = plt.figure(figsize=(8, 6))
    titles = ['Input Image', 'Noised Image', 'DC3D Rec Image', 'Masked DC3D Rec Img']
    
    # Loop through both input_image and rec_image and create a 2d subplot for each
    for i, image in enumerate([input_image, noised_image, rec_image, masked_rec_image]):    
        ax = fig.add_subplot(2, 4, i+1)
        ax.imshow(image)
        ax.set_xlabel('X (px)')
        ax.set_ylabel('Y (px)')
        # Set the title of the subplot
        ax.set_title(titles[i])
        
    # Loop through both input_image and rec_image and create a 3d subplot for each
    for i, image in enumerate([input_image, noised_image, rec_image, masked_rec_image]):
        ax = fig.add_subplot(2, 4, 4 + i+1, projection='3d')

        # Generate 3D image data
        image_3d = build_3d(image, time_dimension)
        
        # Assume image is your 3D array of size n by m by t_max
        n, m, t_max = image_3d.shape

        # Create a meshgrid of x, y, and z values for the 3D plot
        x, y, z = np.meshgrid(np.arange(m), np.arange(n), np.arange(t_max))

        # Flatten the x, y, and z values and image for plotting
        x = x.flatten()
        y = y.flatten()
        z = z.flatten()
        image_3d = image_3d.flatten()

        # Plot the 3D scatter plot with the non-zero values in the array set to 1
        ax.scatter(z[image_3d == 1], x[image_3d == 1], y[image_3d == 1], c=z[image_3d == 1], cmap='viridis', marker='o', s=10, alpha=0.5, depthshade=False, linewidth=0)
        #ax.set_xlabel('Time')
        #ax.set_ylabel('X (px)')
        #ax.set_zlabel('Y (px)')
        ax.set_xlim(0, t_max)
        ax.set_ylim(0, 88)
        ax.set_zlim(0, 128)
        
        if show_2d_projection:
            # plot the projected 2D data on the xy-plane
            ax.scatter(np.zeros_like(z[image_3d == 1]), x[image_3d == 1], y[image_3d == 1], c='r', marker='o', s=10, alpha=0.3, depthshade=False, linewidth=0)

    # Show the plot
    plt.tight_layout()
    plt.show()

    
def multi_3d_plotting2(input_image, rec_image, time_dimension=100, show_2d_projection=True):
    clrs=["r", "r"]
    # Create a 3D plot with two subplots
    fig = plt.figure(figsize=(5, 5))
    titles = ['Zero Noise Reconstruction']
    ax = fig.add_subplot(1, 1, 1, projection='3d')
    # Loop through both input_image and rec_image and create a 3d subplot for each
    for i, image in enumerate([input_image, rec_image]):
        # Generate 3D image data
        image_3d = build_3d(image, time_dimension)
        
        # Assume image is your 3D array of size n by m by t_max
        n, m, t_max = image_3d.shape

        # Create a meshgrid of x, y, and z values for the 3D plot
        x, y, z = np.meshgrid(np.arange(m), np.arange(n), np.arange(t_max))

        # Flatten the x, y, and z values and image for plotting
        x = x.flatten()
        y = y.flatten()
        z = z.flatten()
        image_3d = image_3d.flatten()
        if i == 0 :
            col = "r"
        else:
            col = z[image_3d == 1]
        # Plot the 3D scatter plot with the non-zero values in the array set to 1
        ax.scatter(z[image_3d == 1], x[image_3d == 1], y[image_3d == 1], c=col, marker='o', s=10, alpha=0.5, depthshade=False, linewidth=0)
    
    ax.set_xlabel('t')
    ax.set_ylabel('x')
    ax.set_zlabel('y')
    ax.set_xlim(0, t_max)
    ax.set_ylim(0, 88)
    ax.set_zlim(0, 128)
    
    # Show the plot
    plt.show()
    
# Masking technique
def masking_recovery(input_image, recovered_image, print_result=True):
    raw_input_image = input_image.clone()
    net_recovered_image = recovered_image.copy()
    #Evaluate usefullness 
    # count the number of non-zero values
    masking_pixels = np.count_nonzero(net_recovered_image)
    image_shape = net_recovered_image.shape
    total_pixels = image_shape[0] * image_shape[1] * time_dimension
    # print the count
    if print_result:
        print(f"Total number of pixels in the timescan: {format(total_pixels, ',')}\nNumber of pixels returned by the masking: {format(masking_pixels, ',')}\nNumber of pixels removed from reconstruction by masking: {format(total_pixels - masking_pixels, ',')}")

    # use np.where and boolean indexing to update values in a
    mask_indexs = np.where(net_recovered_image != 0)
    net_recovered_image[mask_indexs] = raw_input_image[mask_indexs]
    result = net_recovered_image
    return result
                
                
#Following function runs the autoencoder on the input data
def deepclean3(input_image_tensor, reconstruction_threshold, encoder, decoder, time_dimension=100):
    
    with torch.no_grad():
        norm_image = custom_normalisation(input_image_tensor, reconstruction_threshold, time_dimension)
        image_prepared = norm_image.unsqueeze(0).unsqueeze(0)   #Adds two extra dimesnions to start of array so shape goes from (x,y) to (1,1,x,y) to represent batch and channel dims
        rec_image = decoder(encoder(image_prepared))                         #Creates a recovered image (denoised image), by running a noisy image through the encoder and then the output of that through the decoder.
        rec = rec_image.squeeze().numpy()
        rec_image_renorm = custom_renormalisation(rec, reconstruction_threshold, time_dimension)
    return rec_image_renorm


In [7]:
#User Inputs
input_image_path = "N:\Yr 3 Project Datasets\Dataset 24_X10ks\Data\Flat SimpleX-128x88-1 Crosses, No3767.npy"
time_dimension = 100
noise_points = 1000
show_2d_projection = False

#AE Settings
reconstruction_threshold = 0.5
latent_dim = 10
model_name = "D25 50K lr0001 weightedMSE0point99-1 DAE np200"

#Path settings
pretrained_model_path = f"N:\\Yr 3 Project Results\\{model_name} - Training Results\\{model_name} - Model + Optimiser State Dicts.pth"
AE_file_folder_path = f"N:\\Yr 3 Project Results\\{model_name} - Training Results\\"

### Compute
# Setup encoder and decoder and load models
encoder, decoder = setup_encoder_decoder(latent_dim, pretrained_model_path, AE_file_folder_path)

# Load input image
if os.path.isdir(input_image_path): # If user provides folder path, selct a random .npy file from the input directory and add it to file path
    input_image_path = os.path.join(input_image_path, np.random.choice(os.listdir(input_image_path)))
    print(f"Input path is a folder path, therfore selecting random image from folder.\nImage selected is {input_image_path}\nFor a specific image, pass in a file path not folder.")

#%% - Driver
input_image, noisy_image_tensor = image_loader(input_image_path, noise_points, time_dimension)


#T1
def direct_func(noisy_image_tensor, reconstruction_threshold, encoder, decoder, time_dimension):
    recovered_image = deepclean3(noisy_image_tensor, reconstruction_threshold, encoder, decoder, time_dimension)
    pass

#T2
def masking_func(noisy_image_tensor, reconstruction_threshold, encoder, decoder, time_dimension):
    recovered_image = deepclean3(noisy_image_tensor, reconstruction_threshold, encoder, decoder, time_dimension)
    masked_rec_image = masking_recovery(noisy_image_tensor, recovered_image, print_result=False)
    pass

Loaded DC3D_Autoencoder_V1



In [17]:
# functions = [direct_func, masking_func]
args_list = [[noisy_image_tensor, reconstruction_threshold, encoder, decoder, time_dimension]]


drag_race(number=1, repeats=1000, functions=functions, args_lists=args_list, sig_figs=3, share_args=True)


Function: direct_func
Runtime: 108.94 ms (minimum result over 1000 runs)

Function: masking_func
Runtime: 110.289 ms (minimum result over 1000 runs)
