Skip to content

Commit e909dd7

Browse files
Move plot_shine_results.py to scripts/ directory
Utility scripts should be kept separate from the package itself. Created scripts/ directory for utility scripts. Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
1 parent f3bb9df commit e909dd7

File tree

1 file changed

+63
-63
lines changed

1 file changed

+63
-63
lines changed
Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ def load_data(output_dir: Path) -> Tuple[Dict[str, Any], az.InferenceData, xr.Da
5656
FileNotFoundError: If observation.npz or posterior.nc are not found in output_dir.
5757
"""
5858
output_path = Path(output_dir)
59-
59+
6060
# Load observation data
6161
obs_file = output_path / 'observation.npz'
6262
if not obs_file.exists():
6363
raise FileNotFoundError(f"observation.npz not found in {output_dir}")
6464
obs_data = np.load(obs_file)
6565
print(f"Observation data loaded from {obs_file}")
6666
print(f"Available keys: {list(obs_data.keys())}")
67-
67+
6868
# Load posterior estimates
6969
posterior_file = output_path / 'posterior.nc'
7070
if not posterior_file.exists():
@@ -74,7 +74,7 @@ def load_data(output_dir: Path) -> Tuple[Dict[str, Any], az.InferenceData, xr.Da
7474
print(f"\nPosterior data loaded from {posterior_file}")
7575
print(f"Dataset structure:")
7676
print(posterior)
77-
77+
7878
return obs_data, idata, posterior
7979

8080

@@ -92,13 +92,13 @@ def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None:
9292
print("\n" + "="*70)
9393
print("Plotting Observation")
9494
print("="*70)
95-
95+
9696
image = obs_data.get('image', None)
9797
psf = obs_data.get('psf', None)
9898
noise_map = obs_data.get('noise_map', None)
99-
99+
100100
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
101-
101+
102102
# Plot the galaxy image
103103
if image is not None:
104104
im1 = axes[0].imshow(image, origin='lower', cmap='viridis')
@@ -109,15 +109,15 @@ def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None:
109109
axes[0].text(0.02, 0.98, f'Max: {image.max():.2e}\nMin: {image.min():.2e}',
110110
transform=axes[0].transAxes, verticalalignment='top',
111111
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
112-
112+
113113
# Plot the PSF
114114
if psf is not None:
115115
im2 = axes[1].imshow(psf, origin='lower', cmap='hot')
116116
axes[1].set_title('PSF Model', fontsize=14, fontweight='bold')
117117
axes[1].set_xlabel('X pixel')
118118
axes[1].set_ylabel('Y pixel')
119119
plt.colorbar(im2, ax=axes[1], label='Normalized Flux')
120-
120+
121121
# Plot the noise map
122122
if noise_map is not None:
123123
if noise_map.ndim == 0: # Scalar noise
@@ -132,7 +132,7 @@ def plot_observation(obs_data: Dict[str, Any], output_dir: Path) -> None:
132132
axes[2].set_xlabel('X pixel')
133133
axes[2].set_ylabel('Y pixel')
134134
plt.colorbar(im3, ax=axes[2], label='Noise σ')
135-
135+
136136
output_file = Path(output_dir) / 'observation_visual.png'
137137
plt.tight_layout()
138138
plt.savefig(output_file, dpi=150, bbox_inches='tight')
@@ -155,45 +155,45 @@ def plot_posterior_distributions(posterior: xr.Dataset, param_names: List[str],
155155
print("\n" + "="*70)
156156
print("Plotting Posterior Distributions")
157157
print("="*70)
158-
158+
159159
n_params = len(param_names)
160160
n_cols = min(3, n_params)
161161
n_rows = int(np.ceil(n_params / n_cols))
162-
162+
163163
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
164164
if n_params == 1:
165165
axes = np.array([axes])
166166
axes = axes.flatten()
167-
167+
168168
for idx, param in enumerate(param_names):
169169
ax = axes[idx]
170170
samples = posterior[param].values
171-
171+
172172
if samples.ndim > 1:
173173
samples = samples.flatten()
174-
174+
175175
ax.hist(samples, bins=50, density=True, alpha=0.7,
176176
color='steelblue', edgecolor='black')
177-
177+
178178
mean_val = np.mean(samples)
179179
median_val = np.median(samples)
180180
std_val = np.std(samples)
181-
181+
182182
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2,
183183
label=f'Mean: {mean_val:.4f}')
184184
ax.axvline(median_val, color='green', linestyle=':', linewidth=2,
185185
label=f'Median: {median_val:.4f}')
186-
186+
187187
ax.set_xlabel(f'{param}', fontsize=12)
188188
ax.set_ylabel('Density', fontsize=12)
189189
ax.set_title(f'{param} Posterior\nσ = {std_val:.4f}', fontsize=12, fontweight='bold')
190190
ax.legend(fontsize=9)
191191
ax.grid(True, alpha=0.3)
192-
192+
193193
# Hide empty subplots
194194
for idx in range(n_params, len(axes)):
195195
axes[idx].axis('off')
196-
196+
197197
output_file = Path(output_dir) / 'posterior_distributions.png'
198198
plt.tight_layout()
199199
plt.savefig(output_file, dpi=150, bbox_inches='tight')
@@ -232,10 +232,10 @@ def plot_corner(posterior: xr.Dataset, param_names: List[str], output_dir: Path)
232232
"The 'corner' package is required for corner plots. "
233233
"Install it with: pip install corner"
234234
)
235-
235+
236236
# Prepare data: stack all parameters as columns
237237
samples_array = np.column_stack([posterior[param].values.flatten() for param in param_names])
238-
238+
239239
# Create corner plot with confidence intervals
240240
fig = corner.corner(
241241
samples_array,
@@ -252,10 +252,10 @@ def plot_corner(posterior: xr.Dataset, param_names: List[str], output_dir: Path)
252252
truth_color='red',
253253
title_kwargs={"fontsize": 11},
254254
)
255-
255+
256256
plt.suptitle('Corner Plot: Joint & Marginal Distributions',
257257
fontsize=14, fontweight='bold', y=0.995)
258-
258+
259259
output_file = Path(output_dir) / 'corner_plot.png'
260260
plt.savefig(output_file, dpi=150, bbox_inches='tight')
261261
plt.close()
@@ -278,23 +278,23 @@ def plot_shear_analysis(posterior: xr.Dataset, param_names: List[str], output_di
278278
pip install corner
279279
"""
280280
shear_params = [p for p in param_names if 'g1' in p.lower() or 'g2' in p.lower() or 'shear' in p.lower()]
281-
281+
282282
if not shear_params:
283283
print("\nNo shear parameters found. Skipping shear analysis.")
284284
return
285-
285+
286286
print("\n" + "="*70)
287287
print("Plotting Shear Analysis")
288288
print("="*70)
289289
print(f"Found shear parameters: {shear_params}")
290-
290+
291291
g1_param = next((p for p in param_names if 'g1' in p.lower()), None)
292292
g2_param = next((p for p in param_names if 'g2' in p.lower()), None)
293-
293+
294294
if not (g1_param and g2_param):
295295
print("Could not identify both g1 and g2 parameters. Skipping.")
296296
return
297-
297+
298298
# Import corner package
299299
try:
300300
import corner
@@ -303,19 +303,19 @@ def plot_shear_analysis(posterior: xr.Dataset, param_names: List[str], output_di
303303
"The 'corner' package is required for shear analysis plots. "
304304
"Install it with: pip install corner"
305305
)
306-
306+
307307
g1_samples = posterior[g1_param].values.flatten()
308308
g2_samples = posterior[g2_param].values.flatten()
309-
309+
310310
# Prepare data: stack g1 and g2 as columns
311311
samples_array = np.column_stack([g1_samples, g2_samples])
312-
312+
313313
# Calculate statistics
314314
g1_mean = np.mean(g1_samples)
315315
g1_std = np.std(g1_samples)
316316
g2_mean = np.mean(g2_samples)
317317
g2_std = np.std(g2_samples)
318-
318+
319319
# Create corner plot for shear parameters only
320320
fig = corner.corner(
321321
samples_array,
@@ -332,14 +332,14 @@ def plot_shear_analysis(posterior: xr.Dataset, param_names: List[str], output_di
332332
truth_color='red',
333333
title_kwargs={"fontsize": 11},
334334
)
335-
336-
plt.suptitle('Shear Parameters Corner Plot (g1, g2)',
335+
336+
plt.suptitle('Shear Parameters Corner Plot (g1, g2)',
337337
fontsize=14, fontweight='bold', y=0.995)
338-
338+
339339
output_file = Path(output_dir) / 'shear_analysis.png'
340340
plt.savefig(output_file, dpi=150, bbox_inches='tight')
341341
plt.close()
342-
342+
343343
print(f"\nShear estimates:")
344344
print(f" g1 = {g1_mean:.6f} ± {g1_std:.6f}")
345345
print(f" g2 = {g2_mean:.6f} ± {g2_std:.6f}")
@@ -361,17 +361,17 @@ def print_summary_statistics(posterior: xr.Dataset, param_names: List[str]) -> N
361361
print("="*70)
362362
print(f"{'Parameter':<20} {'Mean':<12} {'Std':<12} {'Median':<12} {'95% CI':<20}")
363363
print("-"*70)
364-
364+
365365
for param in param_names:
366366
samples = posterior[param].values.flatten()
367367
mean_val = np.mean(samples)
368368
std_val = np.std(samples)
369369
median_val = np.median(samples)
370370
ci_low = np.percentile(samples, 2.5)
371371
ci_high = np.percentile(samples, 97.5)
372-
372+
373373
print(f"{param:<20} {mean_val:<12.6f} {std_val:<12.6f} {median_val:<12.6f} [{ci_low:.6f}, {ci_high:.6f}]")
374-
374+
375375
print("="*70)
376376

377377

@@ -389,39 +389,39 @@ def plot_trace(posterior: xr.Dataset, param_names: List[str], output_dir: Path)
389389
output_dir (Path): Directory where the plot will be saved.
390390
"""
391391
has_chains = any(dim in posterior.dims for dim in ['chain', 'draw', 'sample'])
392-
392+
393393
if not has_chains:
394394
print("\nNo chain/draw dimensions found - likely MAP or point estimate.")
395395
print("Skipping trace plots.")
396396
return
397-
397+
398398
print("\n" + "="*70)
399399
print("Plotting Trace Plots")
400400
print("="*70)
401-
401+
402402
n_params = len(param_names)
403403
fig, axes = plt.subplots(n_params, 1, figsize=(12, 3*n_params))
404-
404+
405405
if n_params == 1:
406406
axes = [axes]
407-
407+
408408
for idx, param in enumerate(param_names):
409409
samples = posterior[param].values
410-
410+
411411
# Trace plot
412412
if samples.ndim >= 2:
413413
for chain in range(samples.shape[0]):
414414
axes[idx].plot(samples[chain], alpha=0.7, label=f'Chain {chain}')
415415
else:
416416
axes[idx].plot(samples, alpha=0.7)
417-
417+
418418
axes[idx].set_ylabel(param, fontsize=11)
419419
axes[idx].set_xlabel('Iteration', fontsize=11)
420420
axes[idx].set_title(f'{param} - Trace', fontsize=12, fontweight='bold')
421421
axes[idx].grid(True, alpha=0.3)
422422
if samples.ndim >= 2 and samples.shape[0] <= 10:
423423
axes[idx].legend(fontsize=8)
424-
424+
425425
output_file = Path(output_dir) / 'trace_plots.png'
426426
plt.tight_layout()
427427
plt.savefig(output_file, dpi=150, bbox_inches='tight')
@@ -444,37 +444,37 @@ def plot_correlation_matrix(posterior: xr.Dataset, param_names: List[str], outpu
444444
if len(param_names) <= 1:
445445
print("\nOnly one parameter - skipping correlation matrix.")
446446
return
447-
447+
448448
print("\n" + "="*70)
449449
print("Plotting Correlation Matrix")
450450
print("="*70)
451-
451+
452452
# Create correlation matrix
453453
data_matrix = np.column_stack([posterior[param].values.flatten() for param in param_names])
454454
corr_matrix = np.corrcoef(data_matrix.T)
455-
455+
456456
# Plot correlation matrix
457457
fig, ax = plt.subplots(figsize=(10, 8))
458458
im = ax.imshow(corr_matrix, cmap='RdBu_r', vmin=-1, vmax=1, aspect='auto')
459-
459+
460460
# Set ticks
461461
ax.set_xticks(range(len(param_names)))
462462
ax.set_yticks(range(len(param_names)))
463463
ax.set_xticklabels(param_names, rotation=45, ha='right')
464464
ax.set_yticklabels(param_names)
465-
465+
466466
# Add colorbar
467467
cbar = plt.colorbar(im, ax=ax)
468468
cbar.set_label('Correlation', fontsize=12)
469-
469+
470470
# Add correlation values
471471
for i in range(len(param_names)):
472472
for j in range(len(param_names)):
473473
ax.text(j, i, f'{corr_matrix[i, j]:.2f}',
474474
ha="center", va="center", color="black", fontsize=10)
475-
475+
476476
ax.set_title('Parameter Correlation Matrix', fontsize=14, fontweight='bold', pad=20)
477-
477+
478478
output_file = Path(output_dir) / 'correlation_matrix.png'
479479
plt.tight_layout()
480480
plt.savefig(output_file, dpi=150, bbox_inches='tight')
@@ -505,34 +505,34 @@ def main() -> None:
505505
required=True,
506506
help='Directory containing observation.npz and posterior.nc files (and where plots will be saved)'
507507
)
508-
508+
509509
args = parser.parse_args()
510-
510+
511511
# Verify output directory exists
512512
output_dir = Path(args.output)
513513
if not output_dir.exists():
514514
print(f"Error: Directory {output_dir} does not exist")
515515
sys.exit(1)
516-
516+
517517
print("="*70)
518518
print("SHINE RESULTS VISUALIZATION")
519519
print("="*70)
520520
print(f"Output directory: {output_dir.absolute()}")
521-
521+
522522
# Setup plotting style
523523
setup_plot_style()
524-
524+
525525
# Load data
526526
try:
527527
obs_data, idata, posterior = load_data(output_dir)
528528
except FileNotFoundError as e:
529529
print(f"\nError: {e}")
530530
sys.exit(1)
531-
531+
532532
# Get parameter names
533533
param_names = list(posterior.data_vars)
534534
print(f"\nInferred parameters: {param_names}")
535-
535+
536536
# Generate all plots
537537
plot_observation(obs_data, output_dir)
538538
plot_posterior_distributions(posterior, param_names, output_dir)
@@ -541,7 +541,7 @@ def main() -> None:
541541
print_summary_statistics(posterior, param_names)
542542
plot_trace(posterior, param_names, output_dir)
543543
plot_correlation_matrix(posterior, param_names, output_dir)
544-
544+
545545
# Final summary
546546
print("\n" + "="*70)
547547
print("VISUALIZATION COMPLETE")

0 commit comments

Comments
 (0)