In [None]:
job_type = '{{ job_type }}'
h5_filepaths = {{ genome_h5_filepaths }}

In [None]:
import os

if job_type == 'sort':
    dataset_name = 'id_trajectory'
elif job_type == 'shim':
    dataset_name = 'id_trajectory_shimmed'

In [None]:
import math

import h5py
import numpy as np
import matplotlib.pyplot as plt

plt.style.use('seaborn-whitegrid')

# Trajectory plots

In [None]:
# get x, z indices for plotting (assume all h5 files have the same
# dimension trajectory data)
with h5py.File(h5_filepaths[0], 'r') as h5_file:
    data = h5_file.get(dataset_name)[()]
    x_axis_centre_index = data.shape[0]//2 # round down if the length is odd
    z_axis_centre_index = data.shape[1]//2 # round down if the length is odd

x_axis_indices = [
    0,
    x_axis_centre_index,
    data.shape[0] - 1
]
    
z_axis_indices = [
    0,
    z_axis_centre_index,
    data.shape[1] - 1
]

HARDCODED_TITLES = ['X trajectory\n', 'Z trajectory\n']
HARDCODED_VERTICAL_AXES_LABELS = ['x', 'z']
HARDCODED_TRAJECTORY_INDICES = [0, 1]

xz_positions = [(x_axis_centre_index, z_axis_centre_index, 'Centre x, centre z')]

for x_counter, x_index in enumerate(x_axis_indices):
    
    if x_counter == 0:
        x_title_string = 'Lower x, '
    elif x_counter == 1:
        x_title_string = 'Centre x, '
    elif x_counter == 2:
        x_title_string = 'Upper x, '
        
    for z_counter, z_index in enumerate(z_axis_indices):
        
        if z_counter == 0:
            z_title_string = 'lower z'
        elif z_counter == 1:
            z_title_string = 'centre z'
        elif z_counter == 2:
            z_title_string = 'upper z'
            
        if not (x_counter == 1 and z_counter == 1):
            xz_positions.append((x_index, z_index, x_title_string + z_title_string))
            
for x_index, z_index, fig_title in xz_positions:
    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(25, 25))
    fig.add_subplot(111, frameon=False)
    fig.suptitle(fig_title, fontsize=50)
    plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
    plt.grid(False)

    for subplot_ax, title, y_label in zip(ax, HARDCODED_TITLES, HARDCODED_VERTICAL_AXES_LABELS):
        plt.setp(subplot_ax.get_xticklabels(), fontsize=40)
        plt.setp(subplot_ax.get_yticklabels(), fontsize=40)
        subplot_ax.set_xlabel('s', fontsize=40)
        subplot_ax.set_ylabel(y_label, fontsize=40)
        subplot_ax.set_title(title, fontsize=40)
        subplot_ax.xaxis.grid(linewidth=2.0)
        subplot_ax.yaxis.grid(linewidth=2.0)

    for filepath in h5_filepaths:
        filename = os.path.split(filepath)[1]
        with h5py.File(filepath, 'r') as h5_file:
            data = h5_file.get(dataset_name)[()]
            x_axis_centre_index = data.shape[0]//2 # round down if the length is odd
            z_axis_centre_index = data.shape[1]//2 # round down if the length is odd
            data_slice = data[x_index][z_index]

            for subplot_ax, data_index in zip(ax, HARDCODED_TRAJECTORY_INDICES):
                trajectory_data = data_slice.transpose()[data_index]
                s = np.linspace(0, len(data_slice), len(data_slice))
                subplot_ax.plot(
                    s,
                    trajectory_data,
                    label=filename,
                    linewidth=4.0,
                    alpha=0.7
                )
                subplot_ax.legend(
                    loc='lower left',
                    title='Genomes',
                    title_fontsize=40,
                    fontsize=35,
                    prop={'size': 40}
                )

    if job_type == 'shim':
        # plot the original genome too: it's contained in every shimmed genome so just
        # arbitrarily pick the 0th element in the list of h5 files
        with h5py.File(h5_filepaths[0], 'r') as h5_file:
            data = h5_file.get('id_trajectory_original')[()]
            x_axis_centre_index = data.shape[0]//2 # round down if the length is odd
            z_axis_centre_index = data.shape[1]//2 # round down if the length is odd
            data_slice = data[x_index][z_index]

            for subplot_ax, data_index in zip(ax, HARDCODED_TRAJECTORY_INDICES):
                trajectory_data = data_slice.transpose()[data_index]
                s = np.linspace(0, len(data_slice), len(data_slice))
                subplot_ax.plot(
                    s,
                    trajectory_data,
                    label='original genome',
                    linewidth=4.0,
                    alpha=0.7
                )
                subplot_ax.legend(
                    loc='lower left',
                    title='Genomes',
                    title_fontsize=40,
                    fontsize=35,
                    prop={'size': 40}
                )

    plt.tight_layout(rect=[0, 0, 0.9, 0.95])

plt.show()