# Loading in xray data from exp (55Fe x-rays on SCA 20663)

In [1]:
# %%
%matplotlib widget  
# Turn ON notebook and OFF inline and ipympl when running on OSC
#%matplotlib notebook
#%matplotlib inline 
#%matplotlib ipympl


import numpy as np
from pathlib import Path
from astropy.io import fits
from astropy.stats import sigma_clipped_stats
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  
import ipywidgets as widgets
from scipy.ndimage import label

# Load in FITS cube
try:
    here = Path(__file__).parent
except NameError:
    here = Path.cwd()

fits_path = here / '20190919_95k_1p1m0p1_fe55_20663_003_diff.fits'

with fits.open(fits_path) as hdulist:
    data = hdulist[0].data   # shape: (Nframe, 4096, 4096)

# Define constants 
half_size = 256               # half of a 512×512 patch
Nframe, height, width = data.shape

# Gini‐coefficient helper function (using LPM formula)
def compute_gini(arr_flat: np.ndarray) -> float:
    """
    Given a 1D array of pixel values, compute the Lutz‐Primack‐Madau Gini coefficient.
    """
    arr = np.sort(arr_flat.flatten())
    N = arr.size
    if N <= 1:
        return 0.0
    mean_val = arr.mean()
    if mean_val == 0:
        return 0.0
    idx = np.arange(1, N + 1)  # 1..N
    numerator = np.sum((2 * idx - N - 1) * arr)
    return float(numerator / (mean_val * N * (N - 1)))

# Main plotting function 
def plot_patch_and_events(
    frame_idx: int,
    x_center: int,
    y_center: int,
    sigma_thresh: float,
    min_pixels: int
):
    """
    1) Extract a 512x512 patch at (frame_idx, x_center, y_center).
    2) Define threshold = median + sigma_thresh × sigma_est.
    3) Label connected pixels above that threshold, discard blobs < min_pixels.
    4) Compute each blob’s Gini, then plot (3D + 2D overlay) and print Ginis.
    """
    # Validate frame index
    if not (0 <= frame_idx < Nframe):
        print(f"Warning: frame must be in [0, {Nframe-1}]. Got {frame_idx}")
        return

    # Validate that 512x512 patch stays in bounds
    if not (half_size <= x_center < width - half_size and half_size <= y_center < height - half_size):
        print(
            f"Warning: center out of range.\n"
            f"  x_center in [{half_size}, {width - half_size - 1}],\n"
            f"  y_center in [{half_size}, {height - half_size - 1}]."
        )
        return

    # Slice out the 512x512 patch
    y0, y1 = y_center - half_size, y_center + half_size
    x0, x1 = x_center - half_size, x_center + half_size
    patch = data[frame_idx, y0:y1, x0:x1]   # shape = (512, 512)

    # Compute median and sigma_est via sigma_clipped_stats
    _, med, _ = sigma_clipped_stats(patch, sigma=3.0, maxiters=5)
    mad       = np.median(np.abs(patch - med))
    sigma_est = mad * 1.4826
    threshold = med + sigma_thresh * sigma_est

    # Build a binary mask of pixels above threshold
    binary = patch > threshold

    # Label all connected components in that mask
    structure = np.array([[0,1,0],
                          [1,1,1],
                          [0,1,0]], dtype=int)
    labeled_temp, n_temp = label(binary, structure=structure)

    # Discard any blob that has fewer than min_pixels
    mask_filtered = np.zeros_like(binary, dtype=bool)
    for lab_id in range(1, n_temp + 1):
        coords = np.where(labeled_temp == lab_id)
        if coords[0].size >= min_pixels:
            mask_filtered[labeled_temp == lab_id] = True

    # Relabel the filtered mask
    labeled, n_labels = label(mask_filtered, structure=structure)

    # Compute Gini and centroid for each labeled event
    event_ginis = {}
    centroids    = {}
    for lab_id in range(1, n_labels + 1):
        mask_event = (labeled == lab_id)
        if not np.any(mask_event):
            continue
        pix_vals = patch[mask_event]
        event_ginis[lab_id] = compute_gini(pix_vals)

        coords = np.column_stack(np.where(mask_event))
        local_y = coords[:, 0]  # 0..49
        local_x = coords[:, 1]  # 0..49
        centroid_y = local_y.mean() + y0
        centroid_x = local_x.mean() + x0
        centroids[lab_id] = (centroid_x, centroid_y)

    # Build (X,Y) grids for the 3D surface
    y_coords = np.arange(y0, y1)   # length = 50
    x_coords = np.arange(x0, x1)   # length = 50
    X, Y = np.meshgrid(x_coords, y_coords)

    # plot: 3D surface on the left, 2D overlay on the right
    fig = plt.figure(figsize=(12, 5))

    # — Left: 3D surface
    ax3d = fig.add_subplot(1, 2, 1, projection='3d')
    surf = ax3d.plot_surface(
        X, Y, patch,
        cmap='viridis',
        edgecolor='none',
        rcount=50,
        ccount=50
    )
    ax3d.set_xlabel('X Pixel')
    ax3d.set_ylabel('Y Pixel')
    ax3d.set_zlabel('DN')
    ax3d.set_title(f'Frame {frame_idx} @ (x={x_center}, y={y_center})')
    fig.colorbar(surf, ax=ax3d, shrink=0.5, pad=0.1).set_label('DN')

    # — Right: 2D patch with red outlines + event IDs
    ax2d = fig.add_subplot(1, 2, 2)
    im = ax2d.imshow(
        patch,
        cmap='gray',
        origin='lower',
        extent=(x0, x1 - 1, y0, y1 - 1)
    )
    ax2d.set_xlabel('X Pixel')
    ax2d.set_ylabel('Y Pixel')
    ax2d.set_title('2D Patch with Event Boundaries & IDs')

    for lab_id in range(1, n_labels + 1):
        coords = np.column_stack(np.where(labeled == lab_id))
        if coords.size == 0:
            continue
        global_x = coords[:, 1] + x0
        global_y = coords[:, 0] + y0

        ax2d.scatter(
            global_x,
            global_y,
            s=10,
            facecolors='none',
            edgecolors='red',
            linewidths=0.8
        )

        cx, cy = centroids[lab_id]
        ax2d.text(
            cx, cy, str(lab_id),
            color='yellow',
            fontsize=8,
            ha='center',
            va='center'
        )

    fig.colorbar(im, ax=ax2d, shrink=0.5, pad=0.1).set_label('DN')

    plt.tight_layout()
    plt.show()

    # Print out each event’s Gini
    if not event_ginis:
        print("No events detected in this 512x512 patch.")
    else:
        print("Event_ID → Gini (over that event's pixels):")
        for eid, gval in event_ginis.items():
            print(f"  {eid:2d} → {gval:.4f}")

# Build interactive sliders
frame_slider = widgets.IntSlider(
    value=61,
    min=0,
    max=Nframe - 1,
    step=1,
    description='Frame:',
    continuous_update=False
)

x_slider = widgets.IntSlider(
    value=1030,
    min=half_size,
    max=width - half_size - 1,
    step=1,
    description='X Center:',
    continuous_update=False
)

y_slider = widgets.IntSlider(
    value=420,
    min=half_size,
    max=height - half_size - 1,
    step=1,
    description='Y Center:',
    continuous_update=False
)

sigma_slider = widgets.FloatSlider(
    value=3.0,
    min=0.0,
    max=10.0,
    step=0.5,
    description='Sigma Threshold:',
    continuous_update=False
)

minpix_slider = widgets.IntSlider(
    value=3,
    min=1,
    max=20,
    step=1,
    description='Min Pixels:',
    continuous_update=False
)

# Link them via interact()
widgets.interact(
    plot_patch_and_events,
    frame_idx=frame_slider,
    x_center=x_slider,
    y_center=y_slider,
    sigma_thresh=sigma_slider,
    min_pixels=minpix_slider
)


interactive(children=(IntSlider(value=61, continuous_update=False, description='Frame:', max=99), IntSlider(va…

<function __main__.plot_patch_and_events(frame_idx: int, x_center: int, y_center: int, sigma_thresh: float, min_pixels: int)>

In [None]:
#enter number of available cores
num_of_cores = os.cpu_count() + 4

#X-ray energy (in eV)
xray_en = 5898.75

#data dimensions
Nframe, h, w = data.shape

def compute_mask_med_frame(data, sigma_mult):
    print("⏳ Finding hot pixels…")
    median_img = np.median(data, axis=0)
    mad        = np.median(np.abs(median_img - np.median(median_img)))
    sigma_est  = 1.4826 * mad
    thresh_med = np.median(median_img) + sigma_mult * sigma_est
    mask_med   = median_img > thresh_med
    print(f"✅ Done looking for hot pixels (σ={sigma_est:.3f}, thresh={thresh_med:.1f})")
    return mask_med

def compute_mask_first_frame(data, sigma_mult):
    print("⏳ Finding very hot pixels…")
    first_img  = data[0]
    med_first  = np.median(first_img)
    mad_first  = np.median(np.abs(first_img - med_first))
    sigma_est  = 1.4826 * mad_first
    thresh0    = med_first + sigma_mult * sigma_est
    mask0      = first_img > thresh0
    print(f"✅ Done looking for very hot pixels (σ={sigma_est:.3f}, thresh={thresh0:.1f})")
    return mask0

def compute_mask_no_response(data, sat_cut):
    print("⏳ Finding non-responsive pixels…")
    # If you wanted a row-wise tqdm you could replace the next line with a loop + tqdm
    frame_diff = np.abs(np.diff(data, axis=0))       # (Nframe-1, 4096,4096)
    med_diff   = np.median(frame_diff, axis=0)
    mask_non_res   = med_diff < sat_cut
    print(f"✅ Done looking for non-responsive pixesls (median(med_diff)={np.median(med_diff):.3e})")
    return mask_non_res

# parameters
sigma_mult = 9.8
sat_cut     = 5.999

print(f"Number of cores available for parallelization = {num_of_cores - 4}")


# 1) prepare tasks as (fn, arg) pairs
tasks = [
    (compute_mask_med_frame,   sigma_mult),
    (compute_mask_first_frame, sigma_mult),
    (compute_mask_no_response, sat_cut),
]

# 2) helper to call each fn with (data, param)
def _run_mask(fn, param):
    return fn(data, param)

# 3) run all three in parallel with a single tqdm bar
mask_hot, mask_veryhot, mask_non_res = thread_map(
    _run_mask,                          # worker that calls fn(data, param)
    [fn for fn, _ in tasks],           # list of your 3 functions
    [param for _, param in tasks],     # their corresponding single argument
    max_workers=num_of_cores,
    desc="Computing all masks",
    unit="mask",
    tqdm_class = tqdm
)

print("🔗 Combining masks into one boolean array…")
base_mask = mask_hot | mask_veryhot | mask_non_res

# create a mask for pixels adjacent to a pixel with flagged response: any neighbor of the base_mask
print("⏳ Finding all adjacent pixels…")
mask_adj  = binary_dilation(base_mask, structure=np.ones((3,3)), border_value=0) & ~base_mask
print("✅ Done with adjacent pixel mask")

print("🔗 Combining all masks into final array…")
maskArray = base_mask | mask_adj
print("🎉 maskArray ready, shape =", maskArray.shape)

print("Comparing to percentages from Hirata, 2024, Table 2:")
# fractions in percent
frac_non_res   = mask_non_res.mean()   * 100  # mask.mean() = mask.sum() / mask.size
frac_hot   = mask_hot.mean()   * 100  
frac_veryhot = mask_veryhot.mean()       * 100  
frac_adj = mask_adj.mean() * 100
frac_all   = maskArray.mean()   * 100  # union 

print(f"Non-resp pixels: {frac_non_res:.2f}% (vs. 0.53%)")
print(f"Hot pixels: {frac_hot:.2f}% (vs. 0.20%)")
print(f"Very hot pixels: {frac_veryhot:.2f}% (vs. 0.11%)")
print(f"Adjacent pixels: {frac_adj:.2f}% (vs. 2.47%)")
print(f"Union:       {frac_all:.2f}%  (vs. 3.01%)")
chime(777,550)

In [None]:
very_hot_pix = np.count_nonzero(mask_veryhot)
hot_pix = np.count_nonzero(mask_hot)
non_res_pix = np.count_nonzero(mask_non_res)
adj_pix = np.count_nonzero(mask_adj)
all_bad_pix = np.count_nonzero(maskArray)

print("Number of very hot pixels = ", very_hot_pix)
print("Number of hot pixels = ", hot_pix)
print("Number of non-responsive pixels = ", non_res_pix)
print("Number of adjacent pixels = ", adj_pix)
print("Total number of unusable pixels = ", all_bad_pix)

labels = np.zeros((h, w), dtype=int)
labels[mask_hot]     = 1
labels[mask_veryhot] = 2
labels[mask_non_res] = 3
labels[mask_adj]     = 4

cmap = mpl.colors.ListedColormap([
    "#000000",  # 0 = good pixels (black) 
    "#e41a1c",  # 1 = hot         (red)
    "#377eb8",  # 2 = very hot    (blue)
    "#4daf4a",  # 3 = non-resp    (green)
    "#984ea3",  # 4 = adjacent    (purple)
])
# make boundaries at ints 0–5
norm = mpl.colors.BoundaryNorm(np.arange(6), cmap.N)

plt.figure(figsize=(8, 8))
im = plt.imshow(labels, origin="lower", cmap=cmap, norm=norm)
cbar = plt.colorbar(im, ticks=[0.5,1.5,2.5,3.5,4.5])
cbar.ax.set_yticklabels([
    "good", "hot", "very hot", "non-resp", "adjacent"
])
plt.title("Unusable pixels in SCA 20663")
plt.xlabel("x (pixels)")
plt.ylabel("y (pixels)")
plt.tight_layout()
plt.show()