In [None]:
# large batch process for extracting slices from each fits file
# (needs quite large CPU memory)
# (might be improved by dividing tasks into smaller batches according to CPU memo capacity?)

import os
import time
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from scipy.ndimage import rotate
from tqdm.notebook import tqdm  

def read_fits(file_path):
    rsm = fits.open(file_path)
    data = rsm[1].data
    header = rsm[1].header
    rsm.close()
    return np.flipud(data), header

def pixel_to_heliocentric(x, y, header):
    xc = header['CRPIX1']
    yc = header['CRPIX2']
    scale = header['CDELT1']  # arcsec per pixel
    x_heliocentric = (x - xc) * scale
    y_heliocentric = (y - yc) * scale
    return x_heliocentric, y_heliocentric

def compute_rotation_angle(start, end):
    dy = start[1] - end[1]
    dx = end[0] - start[0]
    angle = np.arctan2(dy, dx) * 180 / np.pi
    # print(angle)
    # print(dx)
    if dx < 0:
        if angle < 0:
            return (np.abs(angle) - 90)
        else:
            return (- angle - 90)
    else:
        if angle < 0:
            return (- angle - 90)
        else:
            return (- angle - 90)

def rotate_point(x, y, angle, center):
    angle_rad = np.radians(angle)
    x_c, y_c = center
    x_new = (x - x_c) * np.cos(angle_rad) - (y - y_c) * np.sin(angle_rad) + x_c
    y_new = (x - x_c) * np.sin(angle_rad) + (y - y_c) * np.cos(angle_rad) + y_c
    return x_new, y_new

def compute_distance_and_comparison(start, end, data):
    # compute and compare the distance between img centre and start point, end point
    center = (data.shape[1] / 2, data.shape[0] / 2)

    def calculate_distance(point, center):
        x, y = point
        x_c, y_c = center
        return np.sqrt((x - x_c)**2 + (y - y_c)**2)
    
    start_distance = calculate_distance(start, center)
    end_distance = calculate_distance(end, center)

    if start_distance > end_distance:
        return "Warning: Start point is farther from the image center than the end point. Please exchange the coords."
    elif start_distance < end_distance:
        return "Correct slit coord input. End point is farther from the image center than the start point."
    else:
        return "Please note that start point and end point are equidistant from the image center."

def rotate_image(data, angle):
    return rotate(data, angle, reshape=False)

def plot_images_and_slice(original_data, rotated_data, start, end, zoom_percent):
    # show the original and rotated img
    # draw the slit line and show zoom in part of img
    # zoom_percent controls the size of zoom in region displayed
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    vmin = 0
    vmax = original_data.mean() * 10

    axes[0].imshow(original_data, origin='lower', cmap='afmhot', vmin=vmin, vmax=vmax)
    axes[0].set_title('Original Image')
    axes[0].plot([start[0], end[0]], [start[1], end[1]], color='#1E90FF', lw=2, linestyle='--')

    angle = compute_rotation_angle(start, end)
    center = (original_data.shape[1] / 2, original_data.shape[0] / 2)
    start_rotated = rotate_point(start[0], start[1], -angle, center)
    end_rotated = rotate_point(end[0], end[1], -angle, center)
    
    axes[1].imshow(rotated_data, origin='lower', cmap='afmhot', vmin=vmin, vmax=vmax)
    axes[1].set_title(f'Rotated Image (Clockwise rotated angle: {angle:.2f}°)')
    axes[1].plot([start_rotated[0], end_rotated[0]], [start_rotated[1], end_rotated[1]], color='#1E90FF', lw=2, linestyle='--')

    plt.tight_layout()
    ##plt.show()

    # zoom in image region
    line_length = np.linalg.norm([start_rotated[0] - end_rotated[0], start_rotated[1] - end_rotated[1]])
    zoom_margin = line_length * zoom_percent
    center_x = (start[0] + end[0]) / 2
    center_y = (start[1] + end[1]) / 2

    half_size = int(zoom_margin)
    x_min = max(0, int(center_x - half_size))
    x_max = min(original_data.shape[1], int(center_x + half_size))
    y_min = max(0, int(center_y - half_size))
    y_max = min(original_data.shape[0], int(center_y + half_size))

    zoom_in_data = original_data[y_min:y_max, x_min:x_max]
    
    fig, ax = plt.subplots(figsize=(7, 7))
    ax.imshow(zoom_in_data, origin='lower', cmap='afmhot', vmin=vmin, vmax=vmax)
    ax.set_title('Zoom In on Slice Line')
    start_zoomed = (start[0] - x_min, start[1] - y_min)
    end_zoomed = (end[0] - x_min, end[1] - y_min)
    ax.plot([start_zoomed[0], end_zoomed[0]], [start_zoomed[1], end_zoomed[1]], color='#1E90FF', lw=2, linestyle='--')
    
    ##plt.show()
    


def plot_heliocentric_images(data, header, start, end, zoom_percent):
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    vmin = 0
    vmax = data.mean() * 7

    # calculate Heliocentric coords
    x_heliocentric = np.arange(data.shape[1])
    y_heliocentric = np.arange(data.shape[0])
    x_heliocentric, y_heliocentric = np.meshgrid(x_heliocentric, y_heliocentric)

    x_heliocentric, y_heliocentric = pixel_to_heliocentric(x_heliocentric, y_heliocentric, header)
    start_heliocentric = pixel_to_heliocentric(start[0], start[1], header)
    end_heliocentric = pixel_to_heliocentric(end[0], end[1], header)
    
    # original
    axes[0].imshow(data, origin='lower', cmap='afmhot', vmin=vmin, vmax=vmax,
                    extent=[x_heliocentric.min(), x_heliocentric.max(), y_heliocentric.min(), y_heliocentric.max()])
    axes[0].set_title('Original Image in Heliocentric Coordinates (arcsec)')
    axes[0].plot([start_heliocentric[0], end_heliocentric[0]], [start_heliocentric[1], end_heliocentric[1]], color='#1E90FF', lw=2, linestyle='--')

    # Zoom In
    line_length = np.linalg.norm(np.array(start_heliocentric) - np.array(end_heliocentric))
    zoom_margin = line_length * zoom_percent
    center_x = (start_heliocentric[0] + end_heliocentric[0]) / 2
    center_y = (start_heliocentric[1] + end_heliocentric[1]) / 2

    x_min = center_x - zoom_margin
    x_max = center_x + zoom_margin
    y_min = center_y - zoom_margin
    y_max = center_y + zoom_margin

    axes[1].imshow(data, origin='lower', cmap='afmhot', vmin=vmin, vmax=vmax,
                    extent=[x_heliocentric.min(), x_heliocentric.max(), y_heliocentric.min(), y_heliocentric.max()])
    axes[1].set_xlim(x_min, x_max)
    axes[1].set_ylim(y_min, y_max)
    axes[1].set_title('Zoom In on Slice Line (Heliocentric Coordinates)')
    axes[1].plot([start_heliocentric[0], end_heliocentric[0]], [start_heliocentric[1], end_heliocentric[1]], color='#1E90FF', lw=2, linestyle='--')

    plt.tight_layout()
    ##plt.show()

    return vmin, vmax


def extract_slice(data, start, end, delta_pix, center, angle, vmin, vmax):
   
    data = np.clip(data, vmin, vmax) 
    data = data.astype(np.uint8)

    start_rotated = rotate_point(start[0], start[1], -angle, center)
    end_rotated = rotate_point(end[0], end[1], -angle, center)
    
    dy = end_rotated[1] - start_rotated[1]
    dx = end_rotated[0] - start_rotated[0]
    
    perp_slope = -dx / dy if dy != 0 else 0  
    
    slice_pixels = []
    for t in np.linspace(0, 1, num=int(np.linalg.norm([dx, dy]))):  
        x = int(start_rotated[0] + t * dx)
        y = int(start_rotated[1] + t * dy)
        
        for i in range(-delta_pix//2, delta_pix//2):
            x_offset = int(i * np.cos(np.arctan(perp_slope)))
            y_offset = int(i * np.sin(np.arctan(perp_slope)))
            if 0 <= y + y_offset < data.shape[0] and 0 <= x + x_offset < data.shape[1]:
                slice_pixels.append(data[y + y_offset, x + x_offset])
    
    return np.array(slice_pixels).reshape(-1, delta_pix)

def save_and_show_slice(slice_data, output_dir, file_name, fits_path):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    fits_name = os.path.basename(fits_path).replace('.fits', '')
    file_name = f"{fits_name}_{file_name}"
    
    fig, ax = plt.subplots(figsize=(7, 6))
    im = ax.imshow(slice_data, cmap='afmhot', origin='lower', aspect='auto')
    ax.set_title('Extracted Slice')
    fig.colorbar(im, ax=ax)

    plt.tight_layout()
    ##plt.show()
    
    output_path = os.path.join(output_dir, file_name)
    plt.imsave(output_path, np.flipud(slice_data), cmap='afmhot')

def process_and_extract_slice(data, start_point, end_point, delta_pix, output_dir, file_path, vmin, vmax):
    angle = compute_rotation_angle(start_point, end_point)
    center = (data.shape[1] / 2, data.shape[0] / 2)

    slice_data = extract_slice(data, start_point, end_point, delta_pix, center, angle, vmin, vmax)
    
    file_name = f"slice_{start_point[0]}_{start_point[1]}_{end_point[0]}_{end_point[1]}.png"
    save_and_show_slice(slice_data, output_dir, file_name, file_path)

def process_directory(directory_path, start_point, end_point, delta_pix, output_dir, zoom_percent):
    fits_files = [f for f in os.listdir(directory_path) if f.endswith(".fits")]
    
    for file_path in tqdm(fits_files, desc="Processing FITS files", leave='False', unit="file"):
        full_path = os.path.join(directory_path, file_path)
        print(f"Processing file: {full_path}")
        
        data, header = read_fits(full_path)
        result = compute_distance_and_comparison(start_point, end_point, data)
        print(result)
        
        rotated_data = rotate_image(data, compute_rotation_angle(start_point, end_point))
        plot_images_and_slice(data, rotated_data, start_point, end_point, 0.7)
        vmin, vmax = plot_heliocentric_images(data, header, start_point, end_point, zoom_percent)

        process_and_extract_slice(rotated_data, start_point, end_point, delta_pix, output_dir, full_path, vmin, vmax)

        plt.close('all')  



#start_point = (700, 1667)
#end_point =  (0, 2120)  
#start_point = (700, 1616)
#end_point =  (0, 2252) 
start_point = (700, 1620)
end_point =  (0, 2100)
delta_pix = 1
zoom_percent = 0.7
directory_path = "/data2/pqf_SDO_SolO/data_SDO/2023_10_02_120007"
output_each_slice_dir = f"/data2/pqf_SDO_SolO/SDO/slice_from_each_img/test_{start_point[0]}_{start_point[1]}_{end_point[0]}_{end_point[1]}"
if not os.path.exists(output_each_slice_dir):
    os.makedirs(output_each_slice_dir)

process_directory(directory_path, start_point, end_point, delta_pix, output_each_slice_dir, zoom_percent)


# saved file_name = f"slice_{start_point[0]}_{start_point[1]}_{end_point[0]}_{end_point[1]}.png"
# if no plots are showed, check for every "##plt.show()" sentence and delete "##" in them.
# if needs to see all plots, add ## to the sentence "plt.close('all')".

# vmin and vmax settings should be adjusted


Processing FITS files:   0%|          | 0/302 [00:00<?, ?file/s]

Processing file: /data2/pqf_SDO_SolO/data_SDO/2023_10_02_120007/aia.lev1_euv_12s.2023-10-02T125255Z.304.image_lev1.fits
Correct slit coord input. End point is farther from the image center than the start point.
Processing file: /data2/pqf_SDO_SolO/data_SDO/2023_10_02_120007/aia.lev1_euv_12s.2023-10-02T124907Z.304.image_lev1.fits
Correct slit coord input. End point is farther from the image center than the start point.
Processing file: /data2/pqf_SDO_SolO/data_SDO/2023_10_02_120007/aia.lev1_euv_12s.2023-10-02T124531Z.304.image_lev1.fits
Correct slit coord input. End point is farther from the image center than the start point.
Processing file: /data2/pqf_SDO_SolO/data_SDO/2023_10_02_120007/aia.lev1_euv_12s.2023-10-02T125043Z.304.image_lev1.fits
Correct slit coord input. End point is farther from the image center than the start point.
Processing file: /data2/pqf_SDO_SolO/data_SDO/2023_10_02_120007/aia.lev1_euv_12s.2023-10-02T122607Z.304.image_lev1.fits
Correct slit coord input. End point 