In [1]:
import numpy as np
import os
from astropy.io import fits
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
from scipy.stats import kurtosis
from numpy import hamming  # Import hamming window function

# Define the directory containing the FITS files
directory = r"C:\Users\lsann\Desktop\TESTPREP\Test10-09"

# Directory to save the outputs
output_dir = r"C:\Users\lsann\Desktop\TESTPREP\Output-test10-09"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Function to calculate dynamic threshold
def dynamic_threshold(data, sensitivity=1.0):
    mean, std = np.mean(data), np.std(data)
    threshold = mean + sensitivity * std
    return threshold

# Function to compute rolling mean
def rolling_mean(data, window=75):  # Change the window size to reach the optimal plot
    return np.convolve(data, np.ones(window) / window, mode='valid')

# Function to apply hamming window
def apply_hamming_window(data):
    window = hamming(data.shape[0])[:, None, None] * hamming(data.shape[1])[None, :, None] * hamming(data.shape[2])[None, None, :]
    return data * window

# Function to compute and plot Fourier Transform
def compute_and_plot_fourier_3d_full(data, title_prefix, output_dir, projection_axis='xy'):
    # Handle NaN values by replacing them with zero
    data = np.nan_to_num(data)

    # Apply Hamming window to the data before Fourier Transform
    data_windowed = apply_hamming_window(data)
    
    # Compute the Fourier transform
    fourier_transform = np.fft.fftn(data_windowed)
    magnitude_transform = np.abs(fourier_transform)
    phase_transform = np.angle(fourier_transform)

    # Shift the zero-frequency component to the center of the spectrum
    magnitude_transform = np.fft.fftshift(magnitude_transform)
    phase_transform = np.fft.fftshift(phase_transform)

    # Check for empty or near-empty Fourier data to avoid blank plots
    if np.all(magnitude_transform == 0):
        print(f"Warning: Fourier Transform resulted in an empty plot for {title_prefix}.")
        return

    # Projection based on the specified axis
    if projection_axis == 'xy':
        magnitude_display = magnitude_transform[:, :, magnitude_transform.shape[2] // 2]
        phase_display = phase_transform[:, :, magnitude_transform.shape[2] // 2]
    elif projection_axis == 'xz':
        magnitude_display = magnitude_transform[:, magnitude_transform.shape[1] // 2, :]
        phase_display = phase_transform[:, magnitude_transform.shape[1] // 2, :]
    elif projection_axis == 'yz':
        magnitude_display = magnitude_transform[magnitude_transform.shape[0] // 2, :, :]
        phase_display = phase_transform[magnitude_transform.shape[0] // 2, :, :]
    else:
        raise ValueError(f"Invalid Projection Axis: {projection_axis}. Use 'xy', 'xz', or 'yz'.")
    
    # Log scale for better visualization
    magnitude_display = np.log10(magnitude_display)
    
    # Plotting
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.imshow(magnitude_display, cmap='gray', origin='lower')
    plt.title(f'{title_prefix} - Magnitude')
    plt.colorbar()

    plt.subplot(1, 2, 2)
    plt.imshow(phase_display, cmap='hsv', origin='lower')
    plt.title(f'{title_prefix} - Phase')
    plt.colorbar()

    plt.tight_layout()

    # Save the plot
    output_path = os.path.join(output_dir, f'{title_prefix}_fourier_{projection_axis}.png')
    plt.savefig(output_path)
    plt.close()

    print(f"Fourier transform plots saved for {title_prefix} in {output_dir}")

distances = []

# Process each FITS file in the directory
for filename in os.listdir(directory):
    if filename.endswith(".fits"):
        file_path = os.path.join(directory, filename)
        
        # Load the data using memmap to handle large files
        HDUL = fits.open(file_path, memmap=True)
        data = HDUL[0].data

        # Create an output directory for each target
        target_dir = os.path.join(output_dir, os.path.splitext(filename)[0])
        if not os.path.exists(target_dir):
            os.makedirs(target_dir)

        results_file = os.path.join(target_dir, "results.txt")

        try:
            # Apply Gaussian smoothing to enhance faint sources
            smoothed_data = gaussian_filter(data, sigma=1)  # <-- Commented out
            
            # Integrated flux map
            integrated_flux = np.nansum(smoothed_data, axis=0)  # Changed 'smoothed_data' to 'data'
            total_flux = np.nansum(integrated_flux)

            # Integrated cube 
            integrated_image = np.nansum(smoothed_data, axis=0)

            # Center of the image
            center_of_image = np.array([integrated_flux.shape[0] // 2, integrated_flux.shape[1] // 2])

            # Exclude a larger margin around the edges to prevent detecting noise or artifacts
            margin = 20  # Larger margin to exclude edges
            integrated_flux[:margin, :] = 0
            integrated_flux[-margin:, :] = 0
            integrated_flux[:, :margin] = 0
            integrated_flux[:, -margin:] = 0

            # 90% Threshold for source detection
            threshold = 0.99 * np.max(integrated_flux)

            # Find the source coordinates (x_center, y_center) in the integrated flux map
            y_center, x_center = np.unravel_index(np.argmax(integrated_flux > threshold), integrated_flux.shape)
            center_of_source = np.array([y_center, x_center])

            #Evaluating distance center of source - center of img
            distance = np.sqrt((x_center - center_of_image[1])**2 + (y_center - center_of_image[0])**2)
            distances.append(distance) 

            # Identify the slice that contains the maximum flux at the source coordinates
            source_spectrum = data[:, y_center, x_center]
            max_slice_index = np.argmax(source_spectrum)

            # Extract the 2D data from this slice
            slice_data = data[max_slice_index]

            # Detect small sources: 1 to 4 pixels
            small_source = np.sum(slice_data > threshold) <= 4
            if small_source:
                # Calculate source radius based on the number of pixels
                num_pixels = np.sum(slice_data > threshold)
                if num_pixels == 1:
                    source_radius = 0.5  # Radius for 1 pixel flux
                else:
                    source_radius = 1.0  # Radius for 2-4 pixel flux
            else:
                # Calculate the source radius based on 90% of the maximum flux in the selected slice
                y, x = np.indices(slice_data.shape)
                r = np.sqrt((x - x_center)**2 + (y - y_center)**2)
                r = r.astype(int)

                tbin = np.bincount(r.ravel(), slice_data.ravel())
                nr = np.bincount(r.ravel())
                radialprofile = tbin / nr

                radialprofile_smooth = gaussian_filter(radialprofile, sigma=2)
                flux_threshold = 0.9 * np.max(radialprofile_smooth)
                source_radius = np.argmax(radialprofile_smooth <= flux_threshold)
                if source_radius == 0:
                    source_radius = np.argmax(radialprofile_smooth > 0)

            sigma_fit = source_radius / 2.0

            # Adjust the mask creation based on slice_data shape
            y, x = np.indices(slice_data.shape)
            source_mask = (x - x_center)**2 + (y - y_center)**2 <= (3 * sigma_fit)**2
            inner_annulus = (x - x_center)**2 + (y - y_center)**2 > (6 * sigma_fit)**2
            outer_annulus = (x - x_center)**2 + (y - y_center)**2 <= (9 * sigma_fit)**2
            background_mask = inner_annulus & outer_annulus

            # Calculate mean source and standard deviation of background
            mean_source = np.mean(integrated_flux[source_mask])
            std_background = np.std(integrated_flux[background_mask])

            # Calculate the SNR for the selected slice
            overall_snr = mean_source / std_background if std_background != 0 else 0

            # Calculate kurtosis of the selected slice
            kurt = kurtosis(slice_data.ravel(), fisher=True) if slice_data.size > 0 else 0

            # # Compute mean flux, std, and SNR for all slices
            mean_fluxes = []
            std_backgrounds = []
            snr_values = []

            for i in range(data.shape[0]):
                slice_data = data[i]
                mean_flux = np.mean(slice_data[source_mask])
                mean_fluxes.append(mean_flux)
                std_background = np.std(slice_data[background_mask])
                std_backgrounds.append(std_background)
                snr = mean_flux / std_background if std_background != 0 else 0
                snr_values.append(snr)

            # Check if SNR is NaN or 0.0, and apply zoom-in and variable threshold if needed
            if np.isnan(overall_snr) or overall_snr == 0.0:
                # Apply dynamic threshold
                threshold = dynamic_threshold(integrated_flux, sensitivity=1.5)
                y_center, x_center = np.unravel_index(np.argmax(integrated_flux > threshold), integrated_flux.shape)
                source_spectrum = data[:, y_center, x_center]
                max_slice_index = np.argmax(source_spectrum)
                slice_data = data[max_slice_index]

                # Recalculate source and background masks
                y, x = np.indices(slice_data.shape)
                r = np.sqrt((x - x_center)**2 + (y - y_center)**2)
                source_mask = (x - x_center)**2 + (y - y_center)**2 <= (3 * sigma_fit)**2
                inner_annulus = (x - x_center)**2 + (y - y_center)**2 > (6 * sigma_fit)**2
                outer_annulus = (x - x_center)**2 + (y - y_center)**2 <= (9 * sigma_fit)**2
                background_mask = inner_annulus & outer_annulus

                mean_source = np.mean(integrated_flux[source_mask])
                std_background = np.std(integrated_flux[background_mask])
                overall_snr = mean_source / std_background if std_background != 0 else 0

            # Apply rolling mean to flux, std, and snr
            mean_flux_rolling = rolling_mean(mean_fluxes)
            std_background_rolling = rolling_mean(std_backgrounds)
            snr_rolling = rolling_mean(snr_values)

            # Save results to the text file
            with open(results_file, "w") as f:
                f.write(f"Name of target: {filename}\n")
                f.write(f"center_of_image: {[integrated_flux.shape[0] // 2, integrated_flux.shape[1] // 2]}\n")
                f.write(f"center_of_source: {[y_center, x_center]}\n")
                f.write(f"source_radius: {source_radius}\n")
                f.write(f"sigma_fit: {sigma_fit}\n")
                f.write(f"overall_snr: {overall_snr}\n")
                f.write(f"kurtosis: {kurt}\n")
                f.write(f"std_background: {std_background}\n")  # Added line
                f.write(f"Source flux: {np.mean(mean_source)}\n")
                f.write(f"Total flux of the cube: {total_flux}\n")
                
            # Fourier Transform
            compute_and_plot_fourier_3d_full(data, os.path.splitext(filename)[0], target_dir, projection_axis='yz')
            
            # Plot rolling mean SNR
            plt.figure(figsize=(10, 6))
            plt.plot(range(len(snr_rolling)), snr_rolling, label='SNR (Rolling Mean)')
            plt.xlabel('Slice Index')
            plt.ylabel('SNR')
            plt.title(f'{os.path.splitext(filename)[0]} - SNR Across All Slices (Rolling Mean)')
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(target_dir, f'{os.path.splitext(filename)[0]}_snr_across_slices_rolling_mean.png'))
            plt.close()

            # Plot rolling mean flux
            plt.figure(figsize=(10, 6))
            plt.plot(range(len(mean_flux_rolling)), mean_flux_rolling, label='Mean Flux (Rolling Mean)', color='green')
            plt.xlabel('Slice Index')
            plt.ylabel('Mean Flux')
            plt.title(f'{os.path.splitext(filename)[0]} - Mean Flux Across All Slices (Rolling Mean)')
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(target_dir, f'{os.path.splitext(filename)[0]}_mean_flux_across_slices_rolling_mean.png'))
            plt.close()

            # Plot rolling mean std deviation
            plt.figure(figsize=(10, 6))
            plt.plot(range(len(std_background_rolling)), std_background_rolling, label='Standard Deviation (Rolling Mean)', color='red')
            plt.xlabel('Slice Index')
            plt.ylabel('Standard Deviation')
            plt.title(f'{os.path.splitext(filename)[0]} - Standard Deviation Across All Slices (Rolling Mean)')
            plt.legend()
            plt.grid(True)
            plt.savefig(os.path.join(target_dir, f'{os.path.splitext(filename)[0]}_std_dev_across_slices_rolling_mean.png'))
            plt.close()

            # Pixel flux distribution of the selected slice
            plt.figure(figsize=(10, 6))
            plt.hist(slice_data.ravel(), bins=100, color='blue', alpha=0.7)
            plt.xlabel('Pixel Flux')
            plt.ylabel('Number of Pixels')
            plt.title(f'{os.path.splitext(filename)[0]} - Pixel Flux Distribution of the Selected Slice')
            plt.grid(True)
            plt.savefig(os.path.join(target_dir, f'{os.path.splitext(filename)[0]}_pixel_flux_distribution.png'))
            plt.close()

            # Visualize the source and background regions
            plt.figure(figsize=(6, 6))
            plt.imshow(source_mask + 2 * background_mask, origin='lower', cmap='viridis')
            plt.title(f'{os.path.splitext(filename)[0]} - Source (Green) and Background (Yellow) Regions')
            plt.colorbar(label='Region Mask')
            plt.savefig(os.path.join(target_dir, f'{os.path.splitext(filename)[0]}_source_and_background_regions.png'))
            plt.close()
            
            # Saving integrated cube plot
            plt.figure(figsize=(10, 8))
            plt.imshow(integrated_image, origin='lower', cmap='viridis')
            plt.colorbar(label='Integrated Flux')
            plt.xlabel('RA (pixel)')
            plt.ylabel('DEC (pixel)')
            plt.title(f'{filename} - Integrated Cube')
            plt.savefig(os.path.join(target_dir, f'{os.path.splitext(filename)[0]}_integrated_cube.png'))
            plt.close()
        except MemoryError:
            print(f"MemoryError: Unable to process {filename} due to memory constraints.")
        
        finally:
            # Close the FITS file
            HDUL.close()
            
mean_distance = np.mean(distances) if distances else 0
std_distance = np.std(distances)
print(f'mean distance is {mean_distance} and std is {std_distance}')

print(f"Processing complete. Results saved in {output_dir}")


Fourier transform plots saved for cropped_J0842 in C:\Users\lsann\Desktop\TESTPREP\Output-test10-09\cropped_J0842
Fourier transform plots saved for cropped_J0842_dirty_cube in C:\Users\lsann\Desktop\TESTPREP\Output-test10-09\cropped_J0842_dirty_cube
mean distance is 40.90795183044483 and std is 15.100976029316952
Processing complete. Results saved in C:\Users\lsann\Desktop\TESTPREP\Output-test10-09
