In [7]:
import matplotlib.pyplot as plt
import argparse
import vtk
import numpy as np
from pathlib import Path
import shutil

from svinterface.core.polydata import Centerlines
from svinterface.core.zerod.solver import SolverResults
from svinterface.manager.baseManager import Manager
from svinterface.utils.io import write_json

from svinterface.plotting.params import set_params


def plot_valid(c_3d: Centerlines, c_1d_list: list):
    
    # use valid array
    caps = c_3d.get_pointdata_array("Caps_0D")
    juncs = c_3d.get_pointdata_array("Junctions_0D") 
    vess = c_3d.get_pointdata_array("Vessels_0D") 
    #! pull out which outlet it actually is
    valid = np.array(sorted(list(set([0] + list(np.where(caps != -1)[0]) + list(np.where(juncs != -1)[0]) + list(np.where(vess != -1)[0])))))
    
    results_3d = {}
    # iterate through each valid point
    for oidx, point_id in enumerate(valid):
        time = []
        pressure = []
        flow = []
        #! use the fact that they should be in order already
        for arr_name in c_3d.get_pointdata_arraynames():
            if arr_name.startswith("pressure_"):
                time.append(float(arr_name.split('_')[1]))
                pressure.append(c_3d.polydata.GetPointData().GetArray(arr_name).GetValue(point_id))
        
        results_3d[oidx] = {'time': time,
                            'pressure': pressure,
                            'flow': flow,
                            'point_id': point_id}
    
    fig1, ax1 = plt.subplots(1, 3, figsize=(8, 3))
    fig2, ax2 = plt.subplots(1, 3, figsize=(8, 3))
    
           
    fig1.suptitle("Summary Comparison of 0D and 3D Values at Relevant Points")
        
    threed_means = []
    for i in range(len(valid)):
        threed_means.append(np.trapz(results_3d[i]['pressure'], results_3d[i]['time']) / (results_3d[i]['time'][-1] - results_3d[i]['time'][0]))
    ax1[1].scatter(range(len(valid)), threed_means, label = '3D')
    ax1[1].legend()
    ax1[1].set_title("Mean")
    ax1[1].set_ylabel("Pressure (mmHg)")
    ax1[1].set_xlabel("Points")
    
    threed_maxs = []
    for i in range(len(valid)):
        threed_maxs.append(np.array(results_3d[i]['pressure']).max())
    ax1[0].scatter(range(len(valid)), threed_maxs, label = '3D')
    ax1[0].legend()
    ax1[0].set_title("Systolic")
    ax1[0].set_ylabel("Pressure (mmHg)")
    ax1[0].set_xlabel("Points")
    
    threed_mins = []
    for i in range(len(valid)):
        threed_mins.append(np.array(results_3d[i]['pressure']).min())
    ax1[2].scatter(range(len(valid)), threed_mins, label = '3d')
    ax1[2].legend()
    ax1[2].set_title("Diastolic")
    ax1[2].set_ylabel("Pressure (mmHg)")
    ax1[2].set_xlabel("Points")

    
    fig2.suptitle("3D vs 0D Pressures at Relevant Points.")
    ax2[0].plot([min(threed_maxs), max(threed_maxs)], [min(threed_maxs), max(threed_maxs)])
    ax2[0].set_title("Systolic")
    ax2[1].plot([min(threed_means), max(threed_means)], [min(threed_means), max(threed_means)])
    ax2[1].set_title("Mean")
    ax2[2].plot([min(threed_mins), max(threed_mins)], [min(threed_mins), max(threed_mins)])
    ax2[2].set_title("Diastolic")
    for i in range(3):
        ax2[i].set_ylabel("0D Pressure (mmHg)")
        ax2[i].set_xlabel("3D Pressure (mmHg)")
    
    
    
    for c_1d, name in c_1d_list:
        results_1d = {}
        # iterate through each outlet
        for oidx, point_id in enumerate(valid):
            time = []
            pressure = []
            flow = []
            #! use the fact that they should be in order already
            for arr_name in c_1d.get_pointdata_arraynames():
                if arr_name.startswith("pressure_"):
                    time.append(float(arr_name.split('_')[1]))
                    pressure.append(c_1d.polydata.GetPointData().GetArray(arr_name).GetValue(point_id))
            
            results_1d[oidx] = {'time': time,
                                'pressure': pressure,
                                'flow': flow,
                                'point_id': point_id}
        
        ## Summary Statistics
        
        ## means
 
        zerod_means = []
        for i in range(len(valid)):
            zerod_means.append(np.trapz(results_1d[i]['pressure'], results_1d[i]['time']) / (results_1d[i]['time'][-1] - results_1d[i]['time'][0]))
        ax1[1].scatter(range(len(valid)), zerod_means, label = name)
        
        ## systolic
        zerod_maxs = []
        for i in range(len(valid)):
            zerod_maxs.append(np.array(results_1d[i]['pressure']).max())
        ax1[0].scatter(range(len(valid)), zerod_maxs, label = name)
        
        # diastolic
        zerod_mins = []
        for i in range(len(valid)):
            zerod_mins.append(np.array(results_1d[i]['pressure']).min())
        ax1[2].scatter(range(len(valid)), zerod_mins, label = name)
        
        
        
        ## plot 3D on x axis, and 0D on y axis
        ax2[0].scatter(threed_maxs, zerod_maxs)
        ax2[1].scatter(threed_means, zerod_means)
        ax2[2].scatter(threed_mins, zerod_mins)

    
    return fig1, fig2

In [None]:
c_3d = Centerlines.load_polydata('../../data/diseased/AS1_SU0308_stent/results/AS1_SU0308_nonlinear/3D_DIR/prestent/AS1_SU0308_3D_centerlines.formatted.vtp')
    
c_0d = (Centerlines.load_centerlines('../../data/diseased/AS1_SU0308_stent/results/AS1_SU0308_nonlinear/LPN_DIR/AS1_SU0308.sim.1/')

comp_folder = zero_d_sim / "3D_vs_0D"
if comp_folder.exists():
    shutil.rmtree(str(comp_folder))
comp_folder.mkdir(exist_ok=True, parents=True)

# plotting params
set_params()

plot_valid(c_3d, c_0d, comp_folder)