In [None]:
import os, sys, time, math, shutil
from natsort import natsorted

import numpy as np
import cv2
import matplotlib.pyplot as plt
import cma

# from pymoo.core.problem import Problem
# from pymoo.algorithms.moo.nsga2 import NSGA2
# from pymoo.optimize import minimize

import glob
from tifffile import imwrite
import json 

from gvxrPython3 import gvxr
from gvxrPython3 import json2gvxr

import utils
from utils import average_images, flatField, getReference, displayResult, fitnessMSE, fitnessRMSE, fitnessMAE, fitnessSSIM, fitnessZNCC, getXrayImage
from utils import bbox, plot_directory, applyTransformation, inverseX, rescaleX

from utils import compareMAE, compareMSE, compareZNCC, compareSSIM, compareRMSE
from utils import fitnessHOGwithMAE, fitnessHOGwithMSE, fitnessHOGwithZNCC, fitnessHOGwithRMSE

In [None]:
utils.use_padding = False
utils.pad_width = 50
angular_step_in_deg = 3.6

downsample = True

In [None]:
# data_path = "C:/Users/user/phd/Dataoff/"
# data_path = "C:/Users/user/phd/26SepImages/"
# data_path = "C:/Users/snn23kfl/project/"
#data_path = "4thOCtober_image/"
#data_path = "4thCotober_imageAngle/"
data_path = "25OctoberData/"
#data_path = "15NovemberData/"
#data_path = "24NovemberData/"

In [None]:
current_folder = str(globals()['_dh'][0])
print(current_folder)

json_file = current_folder + "/simulation2.json"

In [None]:
# dirs = [
#     current_folder + "/" + data_path + "/plot1",
#     current_folder + "/" + data_path + "/plot2"
# ]

# for directory in dirs:
#     if os.path.isdir(directory):
#         shutil.rmtree(directory)

#     if not os.path.exists(directory):
#         os.mkdir(directory)

In [None]:
#path for my recently acquired images stored in PhD file



dark_field_paths = glob.glob(data_path + '/darkfd/darkfd_*.tiff')
dark_field_paths = natsorted(dark_field_paths, key=lambda y: y.lower())

white_field_paths = glob.glob(data_path + '/whitefd/whitefd_*.tiff')
white_field_paths = natsorted(white_field_paths, key=lambda y: y.lower())

raw_image_paths = glob.glob(data_path + '/raw_images/raw_image_*.tiff')
raw_image_paths = natsorted(raw_image_paths, key=lambda y: y.lower())

if len(raw_image_paths) == 0:
    raw_image_paths = glob.glob(data_path + '/rawimages/raw_images_*.jpg')
    raw_image_paths = natsorted(raw_image_paths, key=lambda y: y.lower())

In [None]:
#Average the dark field and white field images
I_high_res_dark = average_images(dark_field_paths, False)
I_high_res_white = average_images(white_field_paths, False)

if downsample:
    I_low_res_dark  = average_images(dark_field_paths, True)
    I_low_res_white = average_images(white_field_paths, True)

In [None]:
I_high_res_raw = []
I_low_res_raw = []
angles_in_deg = []
for i, fname in enumerate(raw_image_paths):

    angle = angular_step_in_deg * i

    if angle < 360.000001:
        angles_in_deg.append(angular_step_in_deg * i)
        I_high_res_raw.append(cv2.imread(fname, cv2.IMREAD_GRAYSCALE))
    
        if utils.use_padding:
            median_value = np.median(I_high_res_raw[-1])
            I_high_res_raw[-1] = np.pad(I_high_res_raw[-1], (pad_width, pad_width), 'constant', constant_values=(median_value, median_value))

        if downsample:
            I_low_res_raw.append(cv2.pyrDown(I_high_res_raw[-1]))
            I_low_res_raw[-1] = cv2.pyrDown(I_low_res_raw[-1])


I_high_res_raw = np.array(I_high_res_raw, dtype=np.single)

I_high_res_flat1 = flatField(I_high_res_raw, I_high_res_white, I_high_res_dark)
I_high_res_flat2 = flatField(I_high_res_raw, I_high_res_white, np.zeros(I_high_res_dark.shape))

# Clamp pixel values
I_high_res_flat1[I_high_res_flat1<0] = 0
I_high_res_flat2[I_high_res_flat2<0] = 0

I_high_res_flat1[I_high_res_flat1>1] = 1
I_high_res_flat2[I_high_res_flat2>1] = 1

if downsample:
    I_low_res_raw = np.array(I_low_res_raw, dtype=np.single)

    I_low_res_flat1 = flatField(I_low_res_raw, I_low_res_white, I_low_res_dark)
    I_low_res_flat2 = flatField(I_low_res_raw, I_low_res_white, np.zeros(I_low_res_dark.shape))

    # Clamp pixel values
    I_low_res_flat1[I_low_res_flat1<0] = 0
    I_low_res_flat2[I_low_res_flat2<0] = 0

    I_low_res_flat1[I_low_res_flat1>1] = 1
    I_low_res_flat2[I_low_res_flat2>1] = 1


In [None]:
plt.figure(figsize=(17,5))
plt.subplot(2, 5, 1)
plt.title("Image with \nlight off")
plt.imshow(I_high_res_dark, cmap="gray", vmin=0, vmax=255)
plt.colorbar()

plt.subplot(2, 5, 2)
plt.title("Image with \nlight on")
plt.imshow(I_high_res_white, cmap="gray", vmin=0, vmax=255)
plt.colorbar()

if len(I_high_res_raw.shape) == 2:

    plt.subplot(2, 5, 3)
    plt.imshow(I_high_res_raw, cmap="gray", vmin=0, vmax=255)
    plt.colorbar()

    plt.subplot(2, 5, 4)
    plt.imshow(I_high_res_flat1, cmap="gray", vmin=0, vmax=1)
    plt.colorbar()
    
    plt.subplot(2, 5, 5)
    plt.imshow(I_high_res_flat2, cmap="gray", vmin=0, vmax=1)
    plt.colorbar()
else:
    plt.subplot(2, 5, 3)
    plt.imshow(I_high_res_raw[I_high_res_raw.shape[0] //2], cmap="gray", vmin=0, vmax=255)
    plt.colorbar()

    plt.subplot(2, 5, 4)
    plt.title("Using image with \nlight off as dark field")
    plt.imshow(I_high_res_flat1[I_high_res_raw.shape[0] //2], cmap="gray", vmin=0, vmax=1)
    plt.colorbar()

    plt.subplot(2, 5, 5)
    plt.title("Using np.zeros as\n dark field")
    plt.imshow(I_high_res_flat2[I_high_res_raw.shape[0] //2], cmap="gray", vmin=0, vmax=1)
    plt.colorbar()

In [None]:
I_high_res_flat = I_high_res_flat2

if downsample:
    I_low_res_flat = I_low_res_flat2

In [None]:
if not os.path.exists(data_path + "/flat_images"):
    os.mkdir(data_path + "/flat_images")

for i, img in enumerate(I_high_res_flat):
    imwrite(data_path + "/flat_images/projection_" + str(i).zfill(4) + ".tif", img)

In [None]:
I_high_res_binary = []
I_low_res_binary = []

for img in I_high_res_flat:
    ret, binary_image = cv2.threshold((255 * img).astype(np.uint8), 127, 255, cv2.THRESH_OTSU)
    I_high_res_binary.append(binary_image)

I_high_res_binary = np.array(I_high_res_binary, dtype=np.uint8)   


if downsample:
    for img in I_low_res_flat:
        ret, binary_image = cv2.threshold((255 * img).astype(np.uint8), 127, 255, cv2.THRESH_OTSU)
        I_low_res_binary.append(binary_image)

    I_low_res_binary = np.array(I_low_res_binary, dtype=np.uint8)   

In [None]:
#sod = 43
#sod =47
#sod = 48
#sdd = 61
#for rabbit
#sod=41
#sdd=51
sod=104
sdd=130
#sod = 25
#sdd=61

In [None]:
utils.data_range = [[-50, -500, -50, # source position
                   -50,    0, -50,  # detector position
                   -50, -50],  # object position on turntable
                   # -180, -180], # object orientation
                  
                  [ 50,    0, 50,   # source position
                    50, 500, 50,   # detector position
                    50,  50,]]   # object position on turntable
                    # 180,  180]]; # object orientation

In [None]:
Ti90Al10 = 5.68
x_src = 0
y_src = -sod
z_src = 0

x_det = 0
y_det = (sdd - sod)
z_det = 0

x_rot = 0
y_rot = 0
z_rot = 0

x_obj = 0
y_obj = 0
z_obj = 0

#for "25OctoberData/"
alpha_x =  90 
alpha_y = -90
alpha_z =   0

#for "24NovemberData/"
alpha_x =  -90 
alpha_y = -0
alpha_z =   0

x_default = [
    x_src,
    y_src,
    z_src,

    x_det,
    y_det,
    z_det,

    x_obj,
    y_obj,
    
#    alpha_x,
#    alpha_y,
]

x_default = inverseX(x_default)

In [None]:
# Initialise gVXR using our JSON file
json2gvxr.initGVXR(json_file, "OPENGL", 3, 2)

In [None]:
# Load our source properties
json2gvxr.initSourceGeometry()
json2gvxr.initSpectrum(verbose=0)

gvxr.setSourcePosition(x_src, y_src, z_src, "mm")

In [None]:
# Load our detector
json2gvxr.initDetector()
    # "NumberOfPixels": [2880, 2880],

gvxr.setDetectorPosition(x_det, y_det, z_det, "mm")

if downsample:
    detector_size = np.array(gvxr.getDetectorSize("mm"));

    old_number_of_pixels = np.array(gvxr.getDetectorNumberOfPixels());
    old_spacing = detector_size / old_number_of_pixels

    gvxr.setDetectorNumberOfPixels(old_number_of_pixels[0] // 4, old_number_of_pixels[1] // 4);    
    new_number_of_pixels = np.array(gvxr.getDetectorNumberOfPixels());
    new_spacing = detector_size / new_number_of_pixels

    gvxr.setDetectorPixelSize(new_spacing[0], new_spacing[1], "mm");
    
    print("\tDetector number of pixels:", new_number_of_pixels)
    print("\tPixel spacing:", new_spacing, "mm")

if utils.use_padding:
    number_of_pixels = gvxr.getDetectorNumberOfPixels();
    gvxr.setDetectorNumberOfPixels(number_of_pixels[0] + 2 * pad_width, number_of_pixels[1] + 2 * pad_width);

In [None]:
# Load our samples
json2gvxr.initSamples(json_file, verbose=0)

gvxr.setDensity("cuboid", Ti90Al10, "g/cm3")

In [None]:
utils.default_up_vector    = gvxr.getDetectorUpVector();
utils.default_right_vector = gvxr.getDetectorRightVector();

In [None]:
gvxr.computeXRayImage();
gvxr.displayScene()
gvxr.setZoom(400)
gvxr.setSceneRotationMatrix([0.6925417184829712, 0.010556249879300594, -0.7213009595870972, 0.0, -0.7204560041427612, 0.060654886066913605, -0.6908417344093323, 0.0, 0.036457497626543045, 0.998101532459259, 0.04961010813713074, 0.0, 0.0, 0.0, 0.0, 1.0])

In [None]:
# Let's get an x-ray image
fig = plt.figure()
plt.imshow(gvxr.computeXRayImage(), cmap="gray")
plt.colorbar()
plt.show()

In [None]:
utils.figsize = (15, 15)

if downsample:
    utils.ref_image, utils.selected_angles, utils.indices = getReference(I_low_res_binary, angles_in_deg, 5)
else:
    utils.ref_image, utils.selected_angles, utils.indices = getReference(I_high_res_binary, angles_in_deg, 5)

In [None]:
displayResult(x_default, figsize=utils.figsize)

In [None]:
opts = cma.CMAOptions()
opts.set('tolfun', 1e-4);
opts['tolx'] = 1e-4;
opts['bounds'] = [8*[-1], 8*[1]];

x_fname       = current_folder + "/" + data_path + "/x_best1.dat"
fitness_fname = current_folder + "/" + data_path + "/fitness_set1.npz"
utils.plot_directory = current_folder + "/" + data_path + "/plot1"

if not os.path.exists(x_fname) or not os.path.exists(fitness_fname):
    
    if os.path.isdir(utils.plot_directory):
        shutil.rmtree(utils.plot_directory)

    if not os.path.exists(utils.plot_directory):
        os.mkdir(utils.plot_directory)    

    utils.best_fitness = sys.float_info.max
    utils.fitness_set = []
    utils.counter = 1

    start_time = time.time();
    es = cma.CMAEvolutionStrategy(x_default, 0.65, opts);
    es.optimize(fitnessZNCC);
    elapsed_time = time.time() - start_time

    x_best1 = es.result.xbest
    np.savetxt(x_fname, es.result.xbest)

    utils.fitness_set = np.array(utils.fitness_set)
    np.savez(fitness_fname, utils.fitness_set)

    
else:
    x_best1 = np.loadtxt(x_fname)
    data = np.load(fitness_fname)
    lst = data.files
    utils.fitness_set = data[lst[0]]

x_current = rescaleX(x_best1)
utils.x_best = x_current

In [None]:
plt.plot(utils.fitness_set[:,0], 100 * 1.0 / utils.fitness_set[:,1])
plt.xlabel("Number of fitness evaluations")
plt.ylabel("Fitness value (ZNCC in %)")

In [None]:
print(rescaleX(x_best1))

In [None]:
displayResult(x_best1, figsize=utils.figsize)

In [None]:
# applyTransformation(x_best1)

In [None]:
opts.set('tolfun', 1e-4);
opts['tolx'] = 1e-4;
opts['bounds'] = [2*[-1], 2*[1]];

x_fname       = current_folder + "/" + data_path + "/x_best2.dat"
fitness_fname = current_folder + "/" + data_path + "/fitness_set2.npz"
utils.plot_directory = current_folder + "/" + data_path + "/plot2"


utils.data_range = [[-180, -180], [180, 180]]                  

if not os.path.exists(x_fname) or not os.path.exists(fitness_fname):
    
    if os.path.isdir(utils.plot_directory):
        shutil.rmtree(utils.plot_directory)

    if not os.path.exists(utils.plot_directory):
        os.mkdir(utils.plot_directory)    
    
    utils.best_fitness = sys.float_info.max
    utils.fitness_set = []
    utils.counter = 1
        
    start_time = time.time();
    es = cma.CMAEvolutionStrategy(2*[0.0], 0.65, opts);
    es.optimize(fitnessSSIM);
    elapsed_time = time.time() - start_time

    x_best2 = es.result.xbest
    np.savetxt(x_fname, es.result.xbest)

    utils.fitness_set = np.array(utils.fitness_set)
    np.savez(fitness_fname, utils.fitness_set)
    
else:
    x_best2 = np.loadtxt(x_fname)
    data = np.load(fitness_fname)
    lst = data.files
    utils.fitness_set = data[lst[0]]
    
x_current = np.hstack((x_current, rescaleX(x_best2)))

In [None]:
plt.plot(utils.fitness_set[:,0], utils.fitness_set[:,1])
plt.xlabel("Number of fitness evaluations")
plt.ylabel("Fitness value (RMSE)")

In [None]:
print(x_current)

In [None]:
displayResult(x_best2, figsize=utils.figsize)

In [None]:
opts.set('tolfun', 1e-5);
opts['tolx'] = 1e-5;
opts['bounds'] = [10*[-1], 10*[1]];

x_fname       = current_folder + "/" + data_path + "/x_best3.dat"
fitness_fname = current_folder + "/" + data_path + "/fitness_set3.npz"
utils.plot_directory = current_folder + "/" + data_path + "/plot3"


utils.data_range = [[x_current[ 0] - 50, x_current[ 1] - 50, x_current[ 2] - 50,  # source position
                   x_current[ 3] - 50, x_current[ 4] - 50, x_current[ 5] - 50,  # detector position
                   x_current[ 6] - 50, x_current[ 7] - 50,   # object position on turntable
                   x_current[ 8] - 90, x_current[ 9] - 90],                      # object orientation
                   # -10, -10],                                                   # detector orientation
                  
                  [x_current[ 0] + 50, x_current[ 1] + 50, x_current[ 2] + 50,  # source position
                   x_current[ 3] + 50, x_current[ 4] + 50, x_current[ 5] + 50,  # detector position
                   x_current[ 6] + 50, x_current[ 7] + 50,   # object position on turntable
                   x_current[ 8] + 90, x_current[ 9] + 90,]]                      # object orientation
                   # 10, 10]]                                                     # detector orientation                  


if downsample:

    gvxr.setDetectorNumberOfPixels(old_number_of_pixels[0], old_number_of_pixels[1]);    
    gvxr.setDetectorPixelSize(old_spacing[0], old_spacing[1], "mm");
    utils.ref_image, utils.selected_angles, utils.indices = getReference(I_high_res_binary, angles_in_deg, 5)
    
if not os.path.exists(x_fname) or not os.path.exists(fitness_fname):
    
        
        
    if os.path.isdir(utils.plot_directory):
        shutil.rmtree(utils.plot_directory)

    if not os.path.exists(utils.plot_directory):
        os.mkdir(utils.plot_directory)    
    
    utils.best_fitness = sys.float_info.max
    utils.fitness_set = []
    utils.counter = 1
        
    start_time = time.time();
    es = cma.CMAEvolutionStrategy(10*[0.0], 0.2, opts);
    es.optimize(fitnessRMSE);
    elapsed_time = time.time() - start_time

    x_best3 = es.result.xbest
    np.savetxt(x_fname, es.result.xbest)

    utils.fitness_set = np.array(utils.fitness_set)
    np.savez(fitness_fname, utils.fitness_set)
    
else:
    x_best3 = np.loadtxt(x_fname)
    data = np.load(fitness_fname)
    lst = data.files
    utils.fitness_set = data[lst[0]]
    
x_current = rescaleX(x_best3)

In [None]:
plt.plot(utils.fitness_set[:,0], utils.fitness_set[:,1])
plt.xlabel("Number of fitness evaluations")
plt.ylabel("Fitness value (RMSE)")

In [None]:
print(x_current)

In [None]:
utils.ref_image, utils.selected_angles, utils.indices = getReference(I_high_res_flat, angles_in_deg, 5)

utils.ref_image -= np.min(utils.ref_image)
utils.ref_image /= np.max(utils.ref_image)

In [None]:
displayResult(x_best3, figsize=utils.figsize)

In [None]:
def getCentreOfRotationPosition(x):
   
    x_rot_axis_pos = 0
    y_rot_axis_pos = 0
    z_rot_axis_pos = 0

    return np.array([x_rot_axis_pos, y_rot_axis_pos, z_rot_axis_pos])

    # Global coordinates
    return np.array([(utils.bbox[0] + utils.bbox[3]) / 2,
                    (utils.bbox[1] + utils.bbox[4]) / 2,
                    (utils.bbox[2] + utils.bbox[5]) / 2])

def getSourcePosition(x):
    return np.array([x[0], x[1], x[2]])
    
def getDetectorPosition(x):
    return np.array([x[3], x[4], x[5]])
    
def getRotationAxisPosition(x):
    x_rot_axis_pos = x[6]
    y_rot_axis_pos = x[7]
    z_rot_axis_pos = x[8]

def printX(x):



    x_obj = x[6]
    y_obj = x[7]
    z_obj = 0
    
    alpha_x = x[8]
    alpha_y = x[9]
    # alpha_z = x[14]

    print("Source position:", getSourcePosition(x), "mm")
    print("Detector position:", getDetectorPosition(x), "mm")
    print("Source-Detector Distance (SDD):", math.sqrt(math.pow(x_src - x_det, 2) + math.pow(y_src - y_det, 2) + math.pow(z_src - z_det, 2)), "mm")
    print("Object position:", x_obj, y_obj, z_obj, "mm")
    print("Centre of rotation position:", getCentreOfRotationPosition(x), "mm")
    
    if len(x) == 12:
        print("Tilt around", utils.default_up_vector, ":", x[10], "degrees")
        print("Tilt around", utils.default_right_vector, ":", x[11], "degrees")

printX(x_current)

In [None]:
#applyTransformation(x_best2)

In [None]:
def saveJSON(x, fname, image_path):

    x_rot_axis_pos = 0
    y_rot_axis_pos = 0
    z_rot_axis_pos = 0

    x_obj = x[6]
    y_obj = x[7]
    z_obj = 0

    alpha_x = x[8]
    alpha_y = x[9]
    # alpha_z = x[14]

    dictionary = {}

    dictionary["WindowSize"] = [800, 600]
    
    dictionary["Detector"] = {}
    dictionary["Detector"]["Position"] = [
        getDetectorPosition(x)[0],
        getDetectorPosition(x)[1],
        getDetectorPosition(x)[2],
        "mm"]
    dictionary["Detector"]["UpVector"] = gvxr.getDetectorUpVector()
    dictionary["Detector"]["RightVector"] = gvxr.getDetectorRightVector()
    dictionary["Detector"]["NumberOfPixels"] = gvxr.getDetectorNumberOfPixels()
    dictionary["Detector"]["Size"] = [
        gvxr.getDetectorSize("mm")[0],
        gvxr.getDetectorSize("mm")[1],
        "mm"]

    dictionary["Source"] = {}
    dictionary["Source"]["Position"] = [
        getSourcePosition(x)[0],
        getSourcePosition(x)[1],
        getSourcePosition(x)[2],
        "mm"]
    dictionary["Source"]["Shape"] = "PointSource"
    dictionary["Source"]["Beam"] = []
    
    for energy, count in zip(gvxr.getEnergyBins("keV"), gvxr.getPhotonCountEnergyBins()):
        dictionary["Source"]["Beam"].append({})
        dictionary["Source"]["Beam"][-1]["Energy"] = energy
        dictionary["Source"]["Beam"][-1]["Unit"] = "keV"
        dictionary["Source"]["Beam"][-1]["PhotonCount"] = count

    dictionary["Samples"] = []
    
    for mesh in json2gvxr.params["Samples"]:
        dictionary["Samples"].append({})
        dictionary["Samples"][-1]["Label"] = mesh["Label"]
        dictionary["Samples"][-1]["Path"] = "../" + mesh["Path"]
        dictionary["Samples"][-1]["Unit"] = mesh["Unit"]
        dictionary["Samples"][-1]["Material"] = mesh["Material"]
        if "Density" in mesh:
            dictionary["Samples"][-1]["Density"] = mesh["Density"]
    
        dictionary["Samples"][-1]["Transform"] = []
        dictionary["Samples"][-1]["Transform"].append(["Translation", x_obj, y_obj, z_obj, "mm"])
        dictionary["Samples"][-1]["Transform"].append(["Rotation", alpha_x, 1, 0, 0])
        dictionary["Samples"][-1]["Transform"].append(["Rotation", alpha_y, 0, 1, 0])
        dictionary["Samples"][-1]["Transform"].append(["Rotation", alpha_z, 0, 0, 1])
    
    dictionary["Scan"] = {}
    rot_centre = getCentreOfRotationPosition(x)
    # rot_centre = [x_rot_axis_pos, y_rot_axis_pos, z_rot_axis_pos] # local
    dictionary["Scan"]["CenterOfRotation"] = [
        x_rot_axis_pos,
        y_rot_axis_pos,
        z_rot_axis_pos,
        "mm"]
    
    dictionary["Scan"]["FinalAngle"] = (I_high_res_binary.shape[0] - 1) * 3.6
    dictionary["Scan"]["IncludeFinalAngle"] = True
    dictionary["Scan"]["NumberOfProjections"] = I_high_res_binary.shape[0]
    dictionary["Scan"]["GifPath"] = "preview.gif"
    dictionary["Scan"]["OutFolder"] = image_path
    dictionary["Scan"]["Flat-Field Correction"] = True
    
    
    # Convert and write JSON object to file
    with open(fname, "w") as outfile: 
        json.dump(dictionary, outfile, indent = 4)


visible_light_CT_json_file = current_folder + "/" + data_path + "/visible_light.json"
simulated_CT_json_file = current_folder + "/" + data_path + "/simulation.json"

saveJSON(x_current, visible_light_CT_json_file, "flat_images/")
saveJSON(x_current, simulated_CT_json_file,     "simulation/")