# Tutorial Plots - Preprint

In [13]:
# Make sure that the package can be imported
import sys
from os.path import expanduser, join, split
home = expanduser("~")
sys.path.append(join(home, 'Repositories/QuREBB'))
from importlib import reload  # Python 3.4+

import numpy as np
import qutip as qt
import scipy.constants as cst
import matplotlib.pyplot as plt
import xarray as xr
import glob

import h5py as h5
save_path=join(home, 'Repositories/QuREBB/tutorial_simulations/notebooks/simulations_preprint/')


import time
from copy import deepcopy

qt.settings.auto_tidyup = False

clr_dict = {
    'dblue': '#2596be',#'#151F30',
    'blue': '#0810b9', #'#103778',
    'lblue': '#8488dc', #'#0593A2',
    'orange': '#FF7A48',
    'dorange': '#E3371E',
    'red': '#972B17',
    'dred': '#651C0F',
    'dprl': '#952b98',
    'lprl': '#b67dd7',
    'vlprl': '#cea9e5'
    
    
}

In [None]:
default_colors_light = ['#0810b9', '#FF7A48','#972B17','#b67dd7']
default_colors_dark = ['#2596be', '#E3371E', '#651C0F', '#952b98']

default_shades_blue = ['#8488dc', '#0810b9', '#2596be']
default_shades_purple = ['#cea9e5', '#b67dd7', '#952b98']

## Helper Functions

In [None]:
import matplotlib.ticker as mticker
class AdditionalTickLocator(mticker.Locator):
    '''This locator chains whatever locator given to it, and then add addition custom ticks to the result'''
    def __init__(self, chain: mticker.Locator, ticks) -> None:
        super().__init__()
        assert chain is not None
        self._chain = chain
        self._additional_ticks = np.asarray(list(ticks))

    def _add_locs(self, locs):
        locs = np.unique(np.concatenate([
            np.asarray(locs),
            self._additional_ticks
        ]))
        return locs

    def tick_values(self, vmin, vmax):
        locs = self._chain.tick_values(vmin, vmax)
        return self._add_locs(locs)

    def __call__(self):
        # this will call into chain's own tick_values,
        # so we also add ours here
        locs = self._chain.__call__()
        return self._add_locs(locs)

    def nonsingular(self, v0, v1):
        return self._chain.nonsingular(v0, v1)
    def set_params(self, **kwargs):
        return self._chain.set_params(**kwargs)
    def view_limits(self, vmin, vmax):
        return self._chain.view_limits(vmin, vmax)


class AdditionalTickFormatter(mticker.Formatter):
    '''This formatter chains whatever formatter given to it, and
    then does special formatting for those passed in custom ticks'''
    def __init__(self, chain: mticker.Formatter, ticks) -> None:
        super().__init__()
        assert chain is not None
        self._chain = chain
        self._additional_ticks = ticks

    def __call__(self, x, pos=None):
        if x in self._additional_ticks:
            return self._additional_ticks[x]
        res = self._chain.__call__(x, pos)
        return res

    def format_data_short(self, value):
        if value in self._additional_ticks:
            return self.__call__(value)
        return self._chain.format_data_short(value)

    def get_offset(self):
        return self._chain.get_offset()
    
    def _set_locator(self, locator):
        self._chain._set_locator(locator)

    def set_locs(self, locs):
        self._chain.set_locs(locs)
        
def axis_add_custom_ticks(axis, ticks):
    locator = axis.get_major_locator()
    formatter = axis.get_major_formatter()
    axis.set_major_locator(AdditionalTickLocator(locator, ticks.keys()))
    axis.set_major_formatter(AdditionalTickFormatter(formatter, ticks))

def find_latest_dataset(label, save_folder):
    files = glob.glob( save_path + f"*{label}*")
    dataset_names = [split(x)[-1] for x in files
    ds_timestamps = [dt.strptime(x[:15], '%Y%m%d-%H%M%S') for x in dataset_names]
    
    time_stamp = dt.now()
    (date_time, micro) = time_stamp.strftime("%Y%m%d-%H%M%S-.%f").split(".")
    date_time = f"{date_time}"[:-1]

# Basic Params

In [44]:
label='ProtocolA'
f = glob.glob( save_path + f"*{label}*")

In [45]:
t0 = split(f[0])[-1][:15]
t1 = split(f[1])[-1][:15]

In [46]:
from datetime import datetime as dt

In [56]:
np.sqrt(0.9)

0.9486832980505138

In [54]:
tt0 = dt.strptime(t0, '%Y%m%d-%H%M%S')
tt1 = dt.strptime(t1, '%Y%m%d-%H%M%S')

np.sort([tt1,tt0])

array([datetime.datetime(2023, 10, 8, 0, 23, 1),
       datetime.datetime(2023, 10, 8, 11, 25, 12)], dtype=object)