# Global Stuffs

In [6]:
import os
import glob
import numpy as np
from param import output
import pyFAI
import matplotlib.pyplot as plt
import tifffile
import fabio
import h5py
from tqdm import tqdm
from joblib import Parallel, delayed

# import Path




# --- Utility Functions ---

def set_plot_style(axs, fonts, xlabel, ylabel):
    axs.set_xlabel(xlabel, fontsize=fonts)
    axs.set_ylabel(ylabel, fontsize=fonts)
    axs.tick_params(axis='both', which='major', direction='out', length=4, width=1)
    axs.tick_params(which='minor', width=1, size=2)
    axs.minorticks_on()
    axs.set_facecolor('white')
    for key in axs.spines:
        axs.spines[key].set_linewidth(1)
    axs.tick_params(axis='x', labelsize=fonts)
    axs.tick_params(axis='y', labelsize=fonts)
    return axs

def plot_2d_images(image, poni_file, output_folder, file_label, azm_range, vmin=None, vmax=None):
    ai = pyFAI.load(poni_file)
    pixel_size = ai.pixel1
    detector_distance = ai.dist
    beamx = ai.getFit2D()['centerX']
    beamy = ai.getFit2D()['centerY']
    wavelength = ai.wavelength
    x_pixels = image.shape[1]
    y_pixels = image.shape[0]
    x_coords = np.arange(x_pixels) - beamx
    y_coords = np.arange(y_pixels) - beamy
    xx, yy = np.meshgrid(x_coords, y_coords)
    qx = 1e-9 * 2 * np.pi / wavelength * np.sin(pixel_size * xx / detector_distance)
    qy = 1e-9 * 2 * np.pi / wavelength * np.sin(pixel_size * yy / detector_distance)
    azimuthal_angles = np.degrees(np.arctan2(qy, qx))
    mask = (azimuthal_angles >= azm_range[0]) & (azimuthal_angles <= azm_range[1])
    fig, ax = plt.subplots(figsize=(6,6))
    image = np.nan_to_num(image, nan=0.0)
    image = 10000000 * (image / np.sum(image))
    im_avg = ax.pcolormesh(qx, qy, image, cmap='jet', vmin=vmin, vmax=vmax, shading='auto')
    ax.set_aspect('equal')
    cbar_avg = fig.colorbar(im_avg, ax=ax, shrink=.8, label='Intensity [a.u.]')
    cbar_avg.ax.tick_params(labelsize=20)
    cbar_avg.ax.yaxis.label.set_size(20)
    ax.set_xlabel(r"$q_x$ [nm$^{-1}$]", fontsize=20)
    ax.set_ylabel(r"$q_y$ [nm$^{-1}$]", fontsize=20)
    ax.set_title(f"{file_label}_2D", fontsize=14, y=1.05)
    ax.tick_params(axis='both', which='major', labelsize=20, width=1.5)
    output_folder = os.path.join(output_folder, "Figures")
    os.makedirs(output_folder, exist_ok=True)
    plt.savefig(f"{output_folder}/{file_label}.png", dpi=300, bbox_inches='tight')
    plt.close()

# For sorting files consistently
def sort_files(filenames):
    import re
    def extract_numbers(filename):
        match = re.search(r'_s(\d+)_0*(\d+)', filename)
        if match:
            snumber = int(match.group(1))
            number = int(match.group(2))
            return snumber, number
        return float('inf'), float('inf')
    return sorted(filenames, key=extract_numbers)

# --- TIFF Data Reading ---
def get_tiff_img_data(fname):
    with tifffile.TiffFile(fname) as tif:
        for idx, page in enumerate(tif.pages):
            if idx == 0:
                img_data = page.asarray().astype(float)
            else:
                img_data = np.concatenate((img_data, page.asarray().astype(float)))
    img_data[img_data > 1e8] = np.nan
    return img_data

# --- HDF5 Data Reading (Eiger/Linkam) ---
def get_single_image_h5(fname, n, mask=None, threshold=1e8):
    try:
        imgs = fabio.open(fname)
    except Exception:
        return None
    img = imgs.get_frame(n).data.astype(float)
    img[img > threshold] = np.nan
    if mask is not None:
        img[mask == 1] = np.nan
    return img

def get_nImages(fname):
    try:
        return fabio.open(fname).nframes
    except Exception:
        return 1

# --- Core Unified Integration Function ---
def integrate_files_unified(
        base_path, poni_file, mask_file, folder, keyword,
        output_base_path, file_type='tiff', nsave2d=3, azm_range=(-60, 60),
        q1=None, q2=None, vmin=None, vmax=None, plot=False, method='splitpixel',
        npt=4000, threshold=1e8, parallel=False, reprocess=False):
    ai = pyFAI.load(poni_file)
    mask = fabio.open(mask_file).data
    output_folder = os.path.join(output_base_path, 'OneD_integrated_WAXS_01', folder)
    os.makedirs(os.path.join(output_folder, 'Figures'), exist_ok=True)

    if file_type == 'tiff':
        data_folder = os.path.join(base_path, folder, keyword, 'Threshold 1')
        tif_files = [f for f in os.listdir(data_folder) if f.endswith('.tif') and not f.startswith('._')]
        tif_files = sort_files(tif_files)
        azimuthal_data, radial_data = [], []
        for i, tif_file in enumerate(tqdm(tif_files, desc='Processing TIFF'), start=1):
            file_path = os.path.join(data_folder, tif_file)
            img = get_tiff_img_data(file_path)
            img[mask == 1] = 0
            # Metadata logic can be adapted here if needed
            normfactor = 1.0
            if i % nsave2d == 0:
                plot_2d_images(img, poni_file, output_folder, tif_file, azm_range, vmin, vmax)
            result = ai.integrate1d(img, npt, error_model='poisson',
                                    correctSolidAngle=True, azimuth_range=azm_range,
                                    normalization_factor=normfactor,
                                    polarization_factor=0.95, method=method, mask=mask)
            q, I_azimuthal, error_azm = result.radial, result.intensity, result.sigma
            azimuthal_data.append((q, I_azimuthal, error_azm, tif_file))
            chi, I_radial = ai.integrate_radial(
                img,
                npt=720,
                radial_range=(q1, q2),
                azimuth_range=(-180, 0),
                mask=mask,
                normalization_factor=normfactor,
                correctSolidAngle=True,
                polarization_factor=0.95,
                method=method,
                unit='chi_deg',
                radial_unit='q_nm^-1')
            radial_data.append((chi, I_radial, tif_file))
        print('TIFF reduction complete.')
        # Add your HDF5 output saving logic here if desired (using azimuthal_data and radial_data)
        return azimuthal_data, radial_data

    elif file_type == 'h5':
        h5_pattern = os.path.join(base_path, folder, f'*{keyword}*_master.h5')
        h5_files = glob.glob(h5_pattern)

        azimuthal_data, radial_data = [], []   # initialize here

        for h5_file in h5_files:
            nframes = get_nImages(h5_file)

            def integrate_h5_frame(n):
                img = get_single_image_h5(h5_file, n, mask, threshold=threshold)
                if img is None:
                    return None
                result = ai.integrate1d(img, npt, correctSolidAngle=True, azimuth_range=azm_range,
                                        polarization_factor=0.95, method=method, mask=mask, unit='q_nm^-1')
                q, I = result.radial, result.intensity
                chi, I_radial = ai.integrate_radial(
                    img, npt=720, radial_range=(q1, q2), azimuth_range=(-180, 0),
                    mask=mask, correctSolidAngle=True, polarization_factor=0.95,
                    method=method, unit='chi_deg', radial_unit='q_nm^-1')
                return (q, I, n), (chi, I_radial, n)

            print(f'Processing {os.path.basename(h5_file)} ...')
            if parallel:
                results = Parallel(n_jobs=-1)(delayed(integrate_h5_frame)(n) for n in range(nframes))
            else:
                results = [integrate_h5_frame(n) for n in tqdm(range(nframes), desc='Processing HDF5')]

            for res in results:
                if res is None: 
                    continue
                azm, rad = res
                azimuthal_data.append(azm)
                radial_data.append(rad)

        print('HDF5 reduction complete.')
        return azimuthal_data, radial_data


    else:
        raise ValueError("file_type must be 'tiff' or 'h5'.")

# --- Usage Example ---
from pathlib import Path
base_path = Path(r'/Volumes/SSD1/RawData1/Redesigned_Plastics/May2025/2025_05_Anjani')
save_path  = base_path / 'processed_data'
save_path.mkdir(exist_ok=True)

calib_path = base_path / 'calibration'
xye_path  = save_path / 'xye_data'
xye_path.mkdir(exist_ok=True) 

calib_file = calib_path / 'LaB6_linkam_15kev.poni'
dp_mask_file1 = calib_path / 'mask_01.edf'

poni_file = str(calib_file)
mask_file = str(dp_mask_file1)
output_base_path = str(save_path)

# For TIFF input:
# integrate_files_unified(
#     base_path, poni_file, mask_file, "P3HB", "some_sample_keyword",
#     output_base_path, file_type='tiff', nsave2d=3, azm_range=(-60, 60),
#     q1=9.25, q2=13, vmin=0, vmax=15, plot=True
# 

# For HDF5 input:
integrate_files_unified(
     base_path, poni_file, mask_file, "linkam", "Run8_LDPE_30C_50ums_scan001_master",
    output_base_path, file_type='h5', nsave2d=1, azm_range=(-45, 45),
         q1=14, q2=16, vmin=0, vmax=2, plot=True, method='csr', npt=2000, threshold=1e8, parallel=True
 )



HDF5 reduction complete.


([], [])