In [1]:
# %%
import hyperiax
import jax
import jax.numpy as jnp
import numpy as np
import os
import pandas as pd
import sys


from hyperiax.execution import DependencyTreeExecutor
from hyperiax.models import UpLambda
from hyperiax.models.functional import pass_up


from help_functions.DICAROS import *
from help_functions.image_manipulation import *
from help_functions.align_shapes import *



from matplotlib import pyplot as plt


# Load data, for into tree, to do reconstruction

In [2]:
# Load all files
corrected_dataset = "./Papilnodae_dataset/"


#butterfly_tree = read.newick(main_path+"./paplionade_tree.txt",format=1)
fline=open(corrected_dataset+"/tree.tree").readline().rstrip()
tree = hyperiax.tree.builders.tree_from_newick(fline)

tree.root.data["edge_length"] = 0

# variables
landmarks = pd.read_csv(corrected_dataset+"/mean_shapes.csv",)
metadata = pd.read_csv(corrected_dataset+"/metadata.csv")



landmarks["species"] = landmarks["species"].str.replace(" ", "_")
landmarks = landmarks.set_index('species')
landmarks = landmarks.drop(columns='count')

In [3]:
n,m = jnp.shape(landmarks)
d = 2

# Rearrange the landmarks to fit the tree
node_names = [node.name for node in tree.iter_leaves()]
landmarks = landmarks.loc[node_names]


d = 2
# For each row in landmarks, convert to array 
gpa_landmarks = [row.reshape(-1,d) for row in landmarks.values]
aligned_shapes, final_mean_shape = align_shapes(gpa_landmarks, max_iterations=100)

for i, leaf in enumerate(tree.iter_leaves()):
    leaf.data["value"] = jnp.array(aligned_shapes[i].flatten())
tree.root.data["edge_length"] = 0


# Prepare to do reconstrction

In [4]:

gpu_selector = int(3)

available_devices = jax.devices()
selected_gpu = available_devices[-1]
print(f"Using device: {selected_gpu}")

# Compile the function for the second GPU
fuse_lddmm = jax.jit(fuse_lddmm, device=selected_gpu)


up_momentum= pass_up('value','sigma','edge_length')
upmodel_momentum = UpLambda(up_momentum, fuse_lddmm)
root_exe_momentum = DependencyTreeExecutor(upmodel_momentum, batch_size=5)


# Compile the function for the second GPU
fuse_edgelength = jax.jit(fuse_edgelength, device=selected_gpu)
up_correct_edge= pass_up('edge_length')
upmodel_edge = UpLambda(up_correct_edge, fuse_edgelength)
root_exe_edgelength = DependencyTreeExecutor(upmodel_edge, batch_size=5)


# Correct edge length to adjust for uncertiniaty 
tree = root_exe_edgelength.up(tree)

# Correct sigma
sigma = np.mean([find_sigma(node["value"]) for node in tree.iter_leaves()])
for leaf in tree.iter_bfs():
    leaf['sigma']=jnp.array([sigma]*d)

Using device: cuda:3


In [5]:
# Do reconstruction 
tree = root_exe_momentum.up(tree)


# Code for reconstruction on the images 

Here we lift the image along the branches


# First prepare the data 

- First find the image
- match the image, to the shape data (for the speciment)
- replace the mean shape in the tip, with the speciment data 
- repeat reconstruction and apply the diffemorphic transformation to the image

In [89]:
from email.mime import image


image_path = "./Papilnodae_dataset/example_images"

individual_landmarks = pd.read_csv(corrected_dataset+"/speciment_shapes.csv")

n_steps = 100
sigma_factor = 1
depth = 0
n_images = 50 


image_idx_list = [46,90,61,16,31,65,13,89] # Images used for the paper
# list full path of images
image = os.listdir(image_path)[image_idx_list[3]]


In [90]:

# Extract gbifID and get landmarks
local_gbifid_int = int(image.split("_", 6)[5])
leaf_landmarks = individual_landmarks[individual_landmarks["gbifID"] == local_gbifid_int]

    
# Get species name
leaf_species = leaf_landmarks.iloc[0]["species"]
leaf_species = leaf_species.replace(" ", "_") if leaf_species else None

# Get landmark coordinates and remove specific points
leaf_landmarks = leaf_landmarks.iloc[0, 3:].values[:]
leaf_landmarks = np.delete(leaf_landmarks, [181,180,59,58])

# Find matching species in tree
leaf_node = None
for leaf in tree.iter_leaves():
    if leaf.name == leaf_species:
        leaf_node = leaf
        break
        
if leaf_node is None:
    print(f"No node found for species {leaf_species}")
    sys.exit(1)



print(f"Found node for species {leaf_species}: {leaf_node.name}")

image_path = os.path.join(image_path, image)

Found node for species Euryades_corethrus: Euryades_corethrus


In [91]:

outputdir_gen_butterfly = f"Output_folder/manipulated_images_butterfly/"
os.makedirs(outputdir_gen_butterfly, exist_ok=True)

outputdir_butterfly = outputdir_gen_butterfly+leaf_species
os.makedirs(outputdir_butterfly, exist_ok=True)

outputdir_gen_tree = f"Output_folder/manipulated_images_tree/"
os.makedirs(outputdir_gen_tree, exist_ok=True)

outputdir_tree = outputdir_gen_tree+leaf_species
os.makedirs(outputdir_tree, exist_ok=True)




# First we plot the  tree

We first plot the point in the tree, because we need the length of each branch, before they are adjusted from the independent contrast.

In [92]:
fline=open(corrected_dataset+"./tree.tree").readline().rstrip()
tree_plot = hyperiax.tree.builders.tree_from_newick(fline)
tree_plot.root.data["edge_length"] = 0



# The tip, which we wanna trail from - the name has to match with the tip name 
add_trail = leaf_species.replace("_", " ")
add_trail

# Adjusting name for plotting pourposes
for leaf_plot in tree_plot.iter_bfs():
    if len(leaf_plot.children) !=0:
        leaf_plot.name = ""
    else: 
        leaf_plot.name = leaf_plot.name+"  "
        leaf_plot.name = leaf_plot.name.replace('_', ' ')


In [93]:
# Find the correct leaf and start there
print(f"Find leaf for: {add_trail}")
flag = False
for leaf_plot in tree_plot.iter_leaves_dfs():    
    print(leaf_plot.name)
    if leaf_plot.name[:-2] == add_trail:
        flag = True
        print("Match found")
        break
if not flag:
    print("No trail found")
    sys.exit()

figure_counter = 0

Find leaf for: Euryades corethrus
Papilio phestus  
Papilio ambrax  
Papilio polytes  
Papilio protenor  
Papilio memnon  
Papilio deiphobus  
Papilio antimachus  
Papilio gigon  
Papilio polyxenes  
Papilio zelicaon  
Papilio xuthus  
Papilio glaucus  
Papilio troilus  
Papilio slateri  
Papilio thoas  
Papilio cresphontes  
Papilio aristodemus  
Parides photinus  
Parides eurimedes  
Parides agavus  
Euryades corethrus  
Match found


In [94]:

from hyperiax.tree.plot_utils import *
fig, axs = plt.subplots(figsize=(8, 6))
depth_counter = 0
time_points_saved =[]

plot_tree_(tree_plot, ax=axs, inc_names=True)
axs.axis('off')
axs.plot(leaf_plot["x_temp"], leaf_plot["y_temp"], 'ro', markersize=10)
plt.savefig(f'{outputdir_tree}/img_{depth_counter}_{0}.png', bbox_inches='tight')
plt.close(fig)


total_distance = 0
current_leaf = leaf_plot
while current_leaf.parent is not None:
    total_distance += current_leaf["edge_length"]
    current_leaf = current_leaf.parent


# Do the lift 
while leaf_plot.parent is not None:

    
    branch_distance = leaf_plot["edge_length"]
    branch_plots = int((branch_distance / total_distance) * n_images)
    if branch_plots <1: # if the branch is too short, we artifical add one image 
        branch_plots = 1
    time_points = np.floor(np.linspace(0,99,branch_plots+1)).astype(int)
    time_points_saved.append(time_points[:-1]) # save for the other tree 
    
    # Convert these to timepoints, in our scale 0 to 99, make integer
    y_positions = np.linspace(leaf_plot.data["y_temp"], leaf_plot.parent.data["y_temp"], 100)[time_points[:-1]]
    x_positions = np.linspace(leaf_plot.data["x_temp"], leaf_plot.data["x_temp"], 100)[time_points[:-1]]  

    for x_temp, y_temp,time_p in zip(x_positions, y_positions,time_points):
        fig, axs = plt.subplots(figsize=(8, 6))
        plot_tree_(tree_plot, ax=axs, inc_names=True)
        axs.axis('off')

        axs.plot(x_temp, y_temp, 'ro', markersize=10)
        plt.savefig(f'{outputdir_tree}/img_{depth_counter}_{time_p}.png', bbox_inches='tight')
        plt.close(fig)

    leaf_plot = leaf_plot.parent
    depth_counter += 1


# Make the root dot 
fig, axs = plt.subplots(figsize=(8, 6))
plot_tree_(tree_plot, ax=axs, inc_names=True)
axs.plot(leaf_plot.data["x_temp"], 0, 'ro', markersize=10)

axs.axis('off')
plt.savefig(f'{outputdir_tree}/img_{depth_counter}_{0}.png', bbox_inches='tight')
plt.close(fig)

# Functions to generate the images. 

In [95]:


def chunker(seq, size):
# Yield successive chunks of size `size` from `seq`.
    for pos in range(0, len(seq)//2, size):  # Divide length by 2 since points are pairs
        yield seq[2*pos:2*(pos + size)]  # Scale indices by 2 to get correct point pairs

# Select leaf node using provided 
# flow arbitrary points of N
def ode_Hamiltonian_advect(c,y):
    t,x,chart = c
    qp, = y
    q = qp[0]
    p = qp[1]

    dxt = jnp.tensordot(M.K(x,q),p,(1,0)).reshape((-1,M.m))
    return dxt

# flow arbitrary points of backwards
def ode_Hamiltonian_advect_rev(c,y):
    t,x,chart = c
    qp, = y
    q = qp[0]
    p = qp[1]

    dxt = -jnp.tensordot(M.K(x,q),p,(1,0)).reshape((-1,M.m))
    return dxt


from jaxgeometry.manifolds.landmarks import landmarks   
from jaxgeometry.Riemannian import metric
from jaxgeometry.dynamics import Hamiltonian
from jaxgeometry.stochastics import Brownian_coords
from jaxgeometry.Riemannian import Log
from jaxgeometry.dynamics import flow_differential


In [None]:
# Reset the correct leaf for the reconstructed tree
for leaf in tree.iter_leaves():
    if leaf.name == leaf_species:
        leaf_node = leaf
        break

n_images=50

padding=125
n_steps = 100
sigma_factor = 1
depth = 0

# Add padding and center the image
new_image, offset = add_padding_and_center(image_path, padding)

# Convert landmarks to JAX array and adjust for padding offset
iterative_start_points = jnp.array(leaf_landmarks, dtype=jnp.float32) # type: ignore
iterative_start_points = iterative_start_points.at[0::2].set(iterative_start_points[::2] + offset[0])
iterative_start_points = iterative_start_points.at[1::2].set(iterative_start_points[1::2] + offset[1])

# Store original landmarks
image_landmarks = iterative_start_points.copy()
# Get dimensions of padded image
height, width = new_image.shape[:2]

# Create grid of points covering the image
x, y = np.meshgrid(np.linspace(0, width-1, width), np.linspace(0, height-1, height))
x = x.flatten()
y = y.flatten()
org_img_coords = np.vstack((x, y)).T.flatten()

# Set processing parameters
chunk_size = 1000 * 1000
total_points = len(org_img_coords) // 2

# Set time steps for evolution
_dts = dts(n_steps=n_steps)


n_landmarks = int(len(image_landmarks)/d)
iter_image = new_image.copy()


depth = 0 
while leaf and leaf.parent:
    # Estiamte contrast 
    sigma_k = leaf['sigma']*sigma_factor
    M = landmarks(n_landmarks,k_sigma=sigma_k*jnp.eye(2))
    metric.initialize(M)
    Brownian_coords.initialize(M)
    Hamiltonian.initialize(M)
    Log.initialize(M,f=M.Exp_Hamiltonian)
    Hamiltonian.initialize(M)
    M.Hamiltonian_advect_rev = lambda xs,qps,dts: integrate(ode_Hamiltonian_advect_rev,None,
                                                        xs[0].reshape((-1,M.m)),xs[1],dts,qps[::-1])
    
    parent_landmarks = leaf.parent.data["value"]
        
    child_landmarks,grid_points_moved,_,_,transform_params= align_A_to_B(image_landmarks,parent_landmarks, org_img_coords)


    q = M.coords(child_landmarks)
    v = (parent_landmarks, [0])
    p = M.Log(q, v)[0]
        
    (_, qps, _) = M.Hamiltonian_dynamics(q, p, _dts)

    # Apply on reversred coordinates
    results = np.empty((n_steps, total_points, 2))

    # Process chunks
    for i, chunk in enumerate(chunker(grid_points_moved, chunk_size)):
        _, xs_chunk = M.Hamiltonian_advect_rev((chunk, M.chart()), qps, _dts)
        # Ensure the slice matches the chunk size 
        results[:, i * chunk_size:(i + 1) * chunk_size, :] = xs_chunk



    # Loop throuh frames 
    for idx in time_points_saved[depth]:
        # Find landmarks now, and where they are in the interpolated image
        # hit points 
        def plot_frame(qps, idx, parent_landmarks, image_landmarks, results, iter_image, output_dir, depth,transform_params):
            plt.figure(figsize=(1.9, 1.9))
            plt.imshow(interpolate_image(parent_landmarks,image_landmarks, results[idx, :].flatten(),iter_image,grid_points_moved,transform_params))
            # Make a scatter, with black color stars
            plt.axis('off')
            # Save figure
            plt.tight_layout(pad=0)
            plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
            plt.savefig(f'{output_dir}/img_{depth}_{idx}.png', dpi=660)
            plt.close()
                       
        plot_frame(qps, idx, parent_landmarks, image_landmarks, results, iter_image, outputdir_butterfly, depth,transform_params)
            
    # Update the iterative image, to next level, along with landmarks 
    iter_image = interpolate_image(parent_landmarks,image_landmarks, results[n_steps-1, :].flatten(),iter_image,grid_points_moved,transform_params)
    _,_,_,image_landmarks,_ = align_A_to_B(image_landmarks,parent_landmarks, qps[n_steps-1][0],transform_params)
    child_landmarks = qps[n_steps-1][0]

    # Next
    leaf = leaf.parent
    depth += 1 

    # When we end at the root, make sure to plot it
    if leaf.parent is None: 
        plt.figure(figsize=(1.9, 1.9))
        plt.imshow(iter_image)
        # Make a scatter, with black color stars
        plt.axis('off')
        # Save figure
        plt.tight_layout(pad=0)
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
        plt.savefig(f'{outputdir_butterfly}/img_{depth}_{0}.png', dpi=660)
        plt.close()
        
    



# Make the final gifs

In [None]:
#import os
#from PIL import Image
import re

def sort_png_files(png_files):
    def extract_numbers(filename):
        #Match any string, and then the numbers
        match = re.match(r'img_(\d+)_(\d+)\.png', filename)
        if match:
            return int(match.group(1)), int(match.group(2))
        return 0, 0

    return sorted(png_files, key=extract_numbers)


def pngs_to_gif(directory, output_gif,rotation=0):
    # List all PNG files in the directory with the pattern figure_X.png
    png_files = sorted([f for f in os.listdir(directory) if f.startswith("img") and f.endswith(".png")], key=lambda x: int(x.split('_')[1].split('.')[0]))
    png_files = sort_png_files(png_files)

    images = []
    for png_file in png_files:
        png_path = os.path.join(directory, png_file)
        img = Image.open(png_path)
        # rotate image 90 degrees
        img = img.rotate(rotation)
        images.append(img)

    # Check if images list is not empty
    if images:
        first_image, *remaining_images = images
        first_image.save(output_gif, save_all=True, append_images=remaining_images, duration=100, loop=1)
    else:
        print("No images found in the provided directory.")

# Example usage
pngs_to_gif(f'{outputdir_tree}', f'./Output_folder/{leaf_species}_tree.gif',rotation=270)
pngs_to_gif(f'{outputdir_butterfly}', f'./Output_folder/{leaf_species}butterfly.gif')