@@ -56,15 +56,15 @@ def load_data(output_dir: Path) -> Tuple[Dict[str, Any], az.InferenceData, xr.Da
5656 FileNotFoundError: If observation.npz or posterior.nc are not found in output_dir.
5757 """
5858 output_path = Path (output_dir )
59-
59+
6060 # Load observation data
6161 obs_file = output_path / 'observation.npz'
6262 if not obs_file .exists ():
6363 raise FileNotFoundError (f"observation.npz not found in { output_dir } " )
6464 obs_data = np .load (obs_file )
6565 print (f"Observation data loaded from { obs_file } " )
6666 print (f"Available keys: { list (obs_data .keys ())} " )
67-
67+
6868 # Load posterior estimates
6969 posterior_file = output_path / 'posterior.nc'
7070 if not posterior_file .exists ():
@@ -74,7 +74,7 @@ def load_data(output_dir: Path) -> Tuple[Dict[str, Any], az.InferenceData, xr.Da
7474 print (f"\n Posterior data loaded from { posterior_file } " )
7575 print (f"Dataset structure:" )
7676 print (posterior )
77-
77+
7878 return obs_data , idata , posterior
7979
8080
@@ -92,13 +92,13 @@ def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None:
9292 print ("\n " + "=" * 70 )
9393 print ("Plotting Observation" )
9494 print ("=" * 70 )
95-
95+
9696 image = obs_data .get ('image' , None )
9797 psf = obs_data .get ('psf' , None )
9898 noise_map = obs_data .get ('noise_map' , None )
99-
99+
100100 fig , axes = plt .subplots (1 , 3 , figsize = (15 , 5 ))
101-
101+
102102 # Plot the galaxy image
103103 if image is not None :
104104 im1 = axes [0 ].imshow (image , origin = 'lower' , cmap = 'viridis' )
@@ -109,15 +109,15 @@ def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None:
109109 axes [0 ].text (0.02 , 0.98 , f'Max: { image .max ():.2e} \n Min: { image .min ():.2e} ' ,
110110 transform = axes [0 ].transAxes , verticalalignment = 'top' ,
111111 bbox = dict (boxstyle = 'round' , facecolor = 'white' , alpha = 0.8 ))
112-
112+
113113 # Plot the PSF
114114 if psf is not None :
115115 im2 = axes [1 ].imshow (psf , origin = 'lower' , cmap = 'hot' )
116116 axes [1 ].set_title ('PSF Model' , fontsize = 14 , fontweight = 'bold' )
117117 axes [1 ].set_xlabel ('X pixel' )
118118 axes [1 ].set_ylabel ('Y pixel' )
119119 plt .colorbar (im2 , ax = axes [1 ], label = 'Normalized Flux' )
120-
120+
121121 # Plot the noise map
122122 if noise_map is not None :
123123 if noise_map .ndim == 0 : # Scalar noise
@@ -132,7 +132,7 @@ def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None:
132132 axes [2 ].set_xlabel ('X pixel' )
133133 axes [2 ].set_ylabel ('Y pixel' )
134134 plt .colorbar (im3 , ax = axes [2 ], label = 'Noise σ' )
135-
135+
136136 output_file = Path (output_dir ) / 'observation_visual.png'
137137 plt .tight_layout ()
138138 plt .savefig (output_file , dpi = 150 , bbox_inches = 'tight' )
@@ -155,45 +155,45 @@ def plot_posterior_distributions(posterior: xr.Dataset, param_names: List[str],
155155 print ("\n " + "=" * 70 )
156156 print ("Plotting Posterior Distributions" )
157157 print ("=" * 70 )
158-
158+
159159 n_params = len (param_names )
160160 n_cols = min (3 , n_params )
161161 n_rows = int (np .ceil (n_params / n_cols ))
162-
162+
163163 fig , axes = plt .subplots (n_rows , n_cols , figsize = (5 * n_cols , 4 * n_rows ))
164164 if n_params == 1 :
165165 axes = np .array ([axes ])
166166 axes = axes .flatten ()
167-
167+
168168 for idx , param in enumerate (param_names ):
169169 ax = axes [idx ]
170170 samples = posterior [param ].values
171-
171+
172172 if samples .ndim > 1 :
173173 samples = samples .flatten ()
174-
174+
175175 ax .hist (samples , bins = 50 , density = True , alpha = 0.7 ,
176176 color = 'steelblue' , edgecolor = 'black' )
177-
177+
178178 mean_val = np .mean (samples )
179179 median_val = np .median (samples )
180180 std_val = np .std (samples )
181-
181+
182182 ax .axvline (mean_val , color = 'red' , linestyle = '--' , linewidth = 2 ,
183183 label = f'Mean: { mean_val :.4f} ' )
184184 ax .axvline (median_val , color = 'green' , linestyle = ':' , linewidth = 2 ,
185185 label = f'Median: { median_val :.4f} ' )
186-
186+
187187 ax .set_xlabel (f'{ param } ' , fontsize = 12 )
188188 ax .set_ylabel ('Density' , fontsize = 12 )
189189 ax .set_title (f'{ param } Posterior\n σ = { std_val :.4f} ' , fontsize = 12 , fontweight = 'bold' )
190190 ax .legend (fontsize = 9 )
191191 ax .grid (True , alpha = 0.3 )
192-
192+
193193 # Hide empty subplots
194194 for idx in range (n_params , len (axes )):
195195 axes [idx ].axis ('off' )
196-
196+
197197 output_file = Path (output_dir ) / 'posterior_distributions.png'
198198 plt .tight_layout ()
199199 plt .savefig (output_file , dpi = 150 , bbox_inches = 'tight' )
@@ -232,10 +232,10 @@ def plot_corner(posterior: xr.Dataset, param_names: List[str], output_dir: Path)
232232 "The 'corner' package is required for corner plots. "
233233 "Install it with: pip install corner"
234234 )
235-
235+
236236 # Prepare data: stack all parameters as columns
237237 samples_array = np .column_stack ([posterior [param ].values .flatten () for param in param_names ])
238-
238+
239239 # Create corner plot with confidence intervals
240240 fig = corner .corner (
241241 samples_array ,
@@ -252,10 +252,10 @@ def plot_corner(posterior: xr.Dataset, param_names: List[str], output_dir: Path)
252252 truth_color = 'red' ,
253253 title_kwargs = {"fontsize" : 11 },
254254 )
255-
255+
256256 plt .suptitle ('Corner Plot: Joint & Marginal Distributions' ,
257257 fontsize = 14 , fontweight = 'bold' , y = 0.995 )
258-
258+
259259 output_file = Path (output_dir ) / 'corner_plot.png'
260260 plt .savefig (output_file , dpi = 150 , bbox_inches = 'tight' )
261261 plt .close ()
@@ -278,23 +278,23 @@ def plot_shear_analysis(posterior: xr.Dataset, param_names: List[str], output_di
278278 pip install corner
279279 """
280280 shear_params = [p for p in param_names if 'g1' in p .lower () or 'g2' in p .lower () or 'shear' in p .lower ()]
281-
281+
282282 if not shear_params :
283283 print ("\n No shear parameters found. Skipping shear analysis." )
284284 return
285-
285+
286286 print ("\n " + "=" * 70 )
287287 print ("Plotting Shear Analysis" )
288288 print ("=" * 70 )
289289 print (f"Found shear parameters: { shear_params } " )
290-
290+
291291 g1_param = next ((p for p in param_names if 'g1' in p .lower ()), None )
292292 g2_param = next ((p for p in param_names if 'g2' in p .lower ()), None )
293-
293+
294294 if not (g1_param and g2_param ):
295295 print ("Could not identify both g1 and g2 parameters. Skipping." )
296296 return
297-
297+
298298 # Import corner package
299299 try :
300300 import corner
@@ -303,19 +303,19 @@ def plot_shear_analysis(posterior: xr.Dataset, param_names: List[str], output_di
303303 "The 'corner' package is required for shear analysis plots. "
304304 "Install it with: pip install corner"
305305 )
306-
306+
307307 g1_samples = posterior [g1_param ].values .flatten ()
308308 g2_samples = posterior [g2_param ].values .flatten ()
309-
309+
310310 # Prepare data: stack g1 and g2 as columns
311311 samples_array = np .column_stack ([g1_samples , g2_samples ])
312-
312+
313313 # Calculate statistics
314314 g1_mean = np .mean (g1_samples )
315315 g1_std = np .std (g1_samples )
316316 g2_mean = np .mean (g2_samples )
317317 g2_std = np .std (g2_samples )
318-
318+
319319 # Create corner plot for shear parameters only
320320 fig = corner .corner (
321321 samples_array ,
@@ -332,14 +332,14 @@ def plot_shear_analysis(posterior: xr.Dataset, param_names: List[str], output_di
332332 truth_color = 'red' ,
333333 title_kwargs = {"fontsize" : 11 },
334334 )
335-
336- plt .suptitle ('Shear Parameters Corner Plot (g1, g2)' ,
335+
336+ plt .suptitle ('Shear Parameters Corner Plot (g1, g2)' ,
337337 fontsize = 14 , fontweight = 'bold' , y = 0.995 )
338-
338+
339339 output_file = Path (output_dir ) / 'shear_analysis.png'
340340 plt .savefig (output_file , dpi = 150 , bbox_inches = 'tight' )
341341 plt .close ()
342-
342+
343343 print (f"\n Shear estimates:" )
344344 print (f" g1 = { g1_mean :.6f} ± { g1_std :.6f} " )
345345 print (f" g2 = { g2_mean :.6f} ± { g2_std :.6f} " )
@@ -361,17 +361,17 @@ def print_summary_statistics(posterior: xr.Dataset, param_names: List[str]) -> N
361361 print ("=" * 70 )
362362 print (f"{ 'Parameter' :<20} { 'Mean' :<12} { 'Std' :<12} { 'Median' :<12} { '95% CI' :<20} " )
363363 print ("-" * 70 )
364-
364+
365365 for param in param_names :
366366 samples = posterior [param ].values .flatten ()
367367 mean_val = np .mean (samples )
368368 std_val = np .std (samples )
369369 median_val = np .median (samples )
370370 ci_low = np .percentile (samples , 2.5 )
371371 ci_high = np .percentile (samples , 97.5 )
372-
372+
373373 print (f"{ param :<20} { mean_val :<12.6f} { std_val :<12.6f} { median_val :<12.6f} [{ ci_low :.6f} , { ci_high :.6f} ]" )
374-
374+
375375 print ("=" * 70 )
376376
377377
@@ -389,39 +389,39 @@ def plot_trace(posterior: xr.Dataset, param_names: List[str], output_dir: Path)
389389 output_dir (Path): Directory where the plot will be saved.
390390 """
391391 has_chains = any (dim in posterior .dims for dim in ['chain' , 'draw' , 'sample' ])
392-
392+
393393 if not has_chains :
394394 print ("\n No chain/draw dimensions found - likely MAP or point estimate." )
395395 print ("Skipping trace plots." )
396396 return
397-
397+
398398 print ("\n " + "=" * 70 )
399399 print ("Plotting Trace Plots" )
400400 print ("=" * 70 )
401-
401+
402402 n_params = len (param_names )
403403 fig , axes = plt .subplots (n_params , 1 , figsize = (12 , 3 * n_params ))
404-
404+
405405 if n_params == 1 :
406406 axes = [axes ]
407-
407+
408408 for idx , param in enumerate (param_names ):
409409 samples = posterior [param ].values
410-
410+
411411 # Trace plot
412412 if samples .ndim >= 2 :
413413 for chain in range (samples .shape [0 ]):
414414 axes [idx ].plot (samples [chain ], alpha = 0.7 , label = f'Chain { chain } ' )
415415 else :
416416 axes [idx ].plot (samples , alpha = 0.7 )
417-
417+
418418 axes [idx ].set_ylabel (param , fontsize = 11 )
419419 axes [idx ].set_xlabel ('Iteration' , fontsize = 11 )
420420 axes [idx ].set_title (f'{ param } - Trace' , fontsize = 12 , fontweight = 'bold' )
421421 axes [idx ].grid (True , alpha = 0.3 )
422422 if samples .ndim >= 2 and samples .shape [0 ] <= 10 :
423423 axes [idx ].legend (fontsize = 8 )
424-
424+
425425 output_file = Path (output_dir ) / 'trace_plots.png'
426426 plt .tight_layout ()
427427 plt .savefig (output_file , dpi = 150 , bbox_inches = 'tight' )
@@ -444,37 +444,37 @@ def plot_correlation_matrix(posterior: xr.Dataset, param_names: List[str], outpu
444444 if len (param_names ) <= 1 :
445445 print ("\n Only one parameter - skipping correlation matrix." )
446446 return
447-
447+
448448 print ("\n " + "=" * 70 )
449449 print ("Plotting Correlation Matrix" )
450450 print ("=" * 70 )
451-
451+
452452 # Create correlation matrix
453453 data_matrix = np .column_stack ([posterior [param ].values .flatten () for param in param_names ])
454454 corr_matrix = np .corrcoef (data_matrix .T )
455-
455+
456456 # Plot correlation matrix
457457 fig , ax = plt .subplots (figsize = (10 , 8 ))
458458 im = ax .imshow (corr_matrix , cmap = 'RdBu_r' , vmin = - 1 , vmax = 1 , aspect = 'auto' )
459-
459+
460460 # Set ticks
461461 ax .set_xticks (range (len (param_names )))
462462 ax .set_yticks (range (len (param_names )))
463463 ax .set_xticklabels (param_names , rotation = 45 , ha = 'right' )
464464 ax .set_yticklabels (param_names )
465-
465+
466466 # Add colorbar
467467 cbar = plt .colorbar (im , ax = ax )
468468 cbar .set_label ('Correlation' , fontsize = 12 )
469-
469+
470470 # Add correlation values
471471 for i in range (len (param_names )):
472472 for j in range (len (param_names )):
473473 ax .text (j , i , f'{ corr_matrix [i , j ]:.2f} ' ,
474474 ha = "center" , va = "center" , color = "black" , fontsize = 10 )
475-
475+
476476 ax .set_title ('Parameter Correlation Matrix' , fontsize = 14 , fontweight = 'bold' , pad = 20 )
477-
477+
478478 output_file = Path (output_dir ) / 'correlation_matrix.png'
479479 plt .tight_layout ()
480480 plt .savefig (output_file , dpi = 150 , bbox_inches = 'tight' )
@@ -505,34 +505,34 @@ def main() -> None:
505505 required = True ,
506506 help = 'Directory containing observation.npz and posterior.nc files (and where plots will be saved)'
507507 )
508-
508+
509509 args = parser .parse_args ()
510-
510+
511511 # Verify output directory exists
512512 output_dir = Path (args .output )
513513 if not output_dir .exists ():
514514 print (f"Error: Directory { output_dir } does not exist" )
515515 sys .exit (1 )
516-
516+
517517 print ("=" * 70 )
518518 print ("SHINE RESULTS VISUALIZATION" )
519519 print ("=" * 70 )
520520 print (f"Output directory: { output_dir .absolute ()} " )
521-
521+
522522 # Setup plotting style
523523 setup_plot_style ()
524-
524+
525525 # Load data
526526 try :
527527 obs_data , idata , posterior = load_data (output_dir )
528528 except FileNotFoundError as e :
529529 print (f"\n Error: { e } " )
530530 sys .exit (1 )
531-
531+
532532 # Get parameter names
533533 param_names = list (posterior .data_vars )
534534 print (f"\n Inferred parameters: { param_names } " )
535-
535+
536536 # Generate all plots
537537 plot_observation (obs_data , output_dir )
538538 plot_posterior_distributions (posterior , param_names , output_dir )
@@ -541,7 +541,7 @@ def main() -> None:
541541 print_summary_statistics (posterior , param_names )
542542 plot_trace (posterior , param_names , output_dir )
543543 plot_correlation_matrix (posterior , param_names , output_dir )
544-
544+
545545 # Final summary
546546 print ("\n " + "=" * 70 )
547547 print ("VISUALIZATION COMPLETE" )
0 commit comments