In [4]:
import jax
import jax.numpy as jnp
from jax import random
from jax import tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import copy 
import pandas as pd
import glob
import sys
import yaml

sys.path.append('../refactored_scripts/')

from model_building import get_theta_shape, build_xyz_basis, build_lindblad_operators, prepare_initial_state
from diagnostics import print_training_info, print_hamiltonian_parameters, print_noise_parameters, print_relative_error, generate_diagnostic_trajectories
from mlp import init_mlp_params, mlp_forward, make_step_fn, train_phase
from figures import plot_noise_parameters, plot_hamiltonian_parameters, plot_mixed_state_fidelity, plot_purity, plot_pure_state_fidelity, plot_observables, plot_training_loss

Array = jnp.ndarray

In [21]:
def load_config(config_path):
    '''Load configuration from YAML file'''
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

def load_experimental_data(config):
    """Load experimental/simulated data"""
    N = config["L"]
    T_max = config["t_max"]
    search_pattern = f"experimental_data_quantum_sampling_L{N}_*_counts.csv"
    files = glob.glob(search_pattern)

    if not files:
        raise FileNotFoundError(f"No data found for L={N}")

    config_file = files[0]
    file_core = config_file.replace(".csv", "").replace("experimental_data_quantum_sampling_", "")
    
    print(f"\n{'='*60}")
    print(f"LOADING DATA: {file_core}")
    print(f"{'='*60}")
    
    df_counts = pd.read_csv(f"experimental_data_quantum_sampling_{file_core}.csv", index_col='bitstring')
        
    bitstrings = df_counts.index.values.astype(np.float32)
    counts_shots = df_counts.values.astype(np.int32)
    
    return bitstrings, counts_shots

In [22]:
config_file = "/Users/omichel/Desktop/qilimanjaro/projects/retech/retech_2025/config_files/lindbladian_learning_configuration.yaml"
#load configuration
CONFIG = load_config(config_file)
#Print useful information

In [24]:
# Load data
bitstrings, counts_shots = load_experimental_data(CONFIG)


LOADING DATA: L4_Chi_4_R10000_counts


In [25]:
print(bitstrings)

[0.000e+00 1.000e+00 1.000e+01 1.100e+01 1.000e+02 1.010e+02 1.100e+02
 1.110e+02 1.000e+03 1.001e+03 1.010e+03 1.011e+03 1.100e+03 1.101e+03
 1.110e+03 1.111e+03]
