In [9]:
import numpy as np
from astropy.io import fits
import os
from PIL import Image
import matplotlib.pyplot as plt

class AstroColorCombiner:
    def __init__(self, base_path, object_name="NGC1365"):
        self.base_path = base_path
        self.object_name = object_name
        self.data = {}
        
    def load_fits_file(self, filter_name):
        """Load FITS file for a given filter"""
        filename = f"aligned_{self.object_name}_{filter_name}_stacked_sigma3.0.fits"
        filepath = os.path.join(self.base_path, filename)
        
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"File not found: {filepath}")
        
        with fits.open(filepath) as hdul:
            data = hdul[0].data.astype(np.float64)
            print(f"Loaded {filter_name}: shape={data.shape}, min={data.min():.2f}, max={data.max():.2f}")
        
        return data
    
    def load_all_filters(self, filters=['r', 'g', 'i', 'clear']):
        """Load specified filters"""
        for f in filters:
            try:
                self.data[f] = self.load_fits_file(f)
            except FileNotFoundError:
                print(f"Warning: {f} filter not found, skipping...")
        print("\nAll available filters loaded successfully!")
        return self.data
    
    def normalize_channel(self, data, lower_percentile=0.1, upper_percentile=99.9):
        """Normalize channel to 0-1 range with percentile clipping"""
        lower = np.percentile(data, lower_percentile)
        upper = np.percentile(data, upper_percentile)
        
        normalized = (data - lower) / (upper - lower)
        normalized = np.clip(normalized, 0, 1)
        
        return normalized
    
    def apply_color_balance(self, r, g, b, r_weight=1.0, g_weight=1.0, b_weight=1.0):
        """Apply color balance weights to RGB channels"""
        r_balanced = r * r_weight
        g_balanced = g * g_weight
        b_balanced = b * b_weight
        
        # Renormalize to keep within 0-1 range
        max_val = max(r_balanced.max(), g_balanced.max(), b_balanced.max())
        if max_val > 1.0:
            r_balanced /= max_val
            g_balanced /= max_val
            b_balanced /= max_val
        
        return r_balanced, g_balanced, b_balanced
    
    def create_synthetic_luminance(self, r, g, b, method='average'):
        """Create synthetic luminance from RGB channels"""
        if method == 'average':
            # Simple average
            luminance = (r + g + b) / 3.0
        elif method == 'weighted':
            # Human eye sensitivity weighted (similar to grayscale conversion)
            luminance = 0.299 * r + 0.587 * g + 0.114 * b
        elif method == 'maximum':
            # Maximum of the three channels
            luminance = np.maximum(np.maximum(r, g), b)
        else:
            luminance = (r + g + b) / 3.0
        
        return luminance
    
    def create_rgb_image(self, r_filter='r', g_filter='g', b_filter='i',
                        r_weight=1.0, g_weight=1.0, b_weight=1.0, 
                        lower_percentile=0.1, upper_percentile=99.9):
        """Create RGB color image from specified filters"""
        print("\n=== Creating RGB Image ===")
        print(f"Using filters: R={r_filter}, G={g_filter}, B={b_filter}")
        
        # Normalize each channel
        r_norm = self.normalize_channel(self.data[r_filter], lower_percentile, upper_percentile)
        g_norm = self.normalize_channel(self.data[g_filter], lower_percentile, upper_percentile)
        b_norm = self.normalize_channel(self.data[b_filter], lower_percentile, upper_percentile)
        
        # Apply color balance
        r_bal, g_bal, b_bal = self.apply_color_balance(r_norm, g_norm, b_norm, 
                                                        r_weight, g_weight, b_weight)
        
        # Stack into RGB image
        rgb_image = np.dstack((r_bal, g_bal, b_bal))
        
        print(f"RGB image created: shape={rgb_image.shape}")
        return rgb_image
    
    def create_lrgb_image(self, r_filter='r', g_filter='g', b_filter='i',
                         luminance_filter='clear', use_synthetic_lum=False,
                         r_weight=1.0, g_weight=1.0, b_weight=1.0,
                         luminance_method='weighted', lum_weight=0.5,
                         lower_percentile=0.1, upper_percentile=99.9):
        """Create LRGB color image with real or synthetic luminance"""
        print("\n=== Creating LRGB Image ===")
        print(f"Using filters: R={r_filter}, G={g_filter}, B={b_filter}")
        
        # Normalize RGB channels
        r_norm = self.normalize_channel(self.data[r_filter], lower_percentile, upper_percentile)
        g_norm = self.normalize_channel(self.data[g_filter], lower_percentile, upper_percentile)
        b_norm = self.normalize_channel(self.data[b_filter], lower_percentile, upper_percentile)
        
        # Get luminance
        if use_synthetic_lum or luminance_filter not in self.data:
            # Create synthetic luminance from RGB
            luminance = self.create_synthetic_luminance(r_norm, g_norm, b_norm, luminance_method)
            print(f"Using synthetic luminance (method: {luminance_method})")
        else:
            # Use real luminance filter
            luminance = self.normalize_channel(self.data[luminance_filter], lower_percentile, upper_percentile)
            print(f"Using {luminance_filter} filter as luminance")
        
        # Apply color balance to RGB
        r_bal, g_bal, b_bal = self.apply_color_balance(r_norm, g_norm, b_norm,
                                                        r_weight, g_weight, b_weight)
        
        # Blend luminance with color channels
        # LRGB technique: use luminance for detail, RGB for color
        r_lrgb = luminance * lum_weight + r_bal * (1 - lum_weight)
        g_lrgb = luminance * lum_weight + g_bal * (1 - lum_weight)
        b_lrgb = luminance * lum_weight + b_bal * (1 - lum_weight)
        
        # Clip to valid range
        r_lrgb = np.clip(r_lrgb, 0, 1)
        g_lrgb = np.clip(g_lrgb, 0, 1)
        b_lrgb = np.clip(b_lrgb, 0, 1)
        
        # Stack into LRGB image
        lrgb_image = np.dstack((r_lrgb, g_lrgb, b_lrgb))
        
        print(f"LRGB image created: shape={lrgb_image.shape}")
        return lrgb_image
    
    def save_tiff(self, image, output_path, bits=16):
        """Save image as TIFF file"""
        from tifffile import imwrite
        
        if bits == 16:
            # Convert to 16-bit
            image_scaled = (image * 65535).astype(np.uint16)
            # Save with tifffile (handles multi-channel 16-bit properly)
            imwrite(output_path, image_scaled, photometric='rgb')
        else:
            # Convert to 8-bit
            image_scaled = (image * 255).astype(np.uint8)
            # PIL can handle 8-bit RGB fine
            img = Image.fromarray(image_scaled)
            img.save(output_path)
        
        print(f"Saved: {output_path}")
    
    def display_images(self, rgb_image, lrgb_image, output_path):
        """Display RGB and LRGB images side by side"""
        fig, axes = plt.subplots(1, 2, figsize=(16, 8))
        
        axes[0].imshow(rgb_image, origin='lower')
        axes[0].set_title('RGB Combination', fontsize=14, fontweight='bold')
        axes[0].axis('off')
        
        axes[1].imshow(lrgb_image, origin='lower')
        axes[1].set_title('LRGB Combination', fontsize=14, fontweight='bold')
        axes[1].axis('off')
        
        plt.tight_layout()
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.show()


def main():
    # Configuration
    input_path = "/home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/second_alignment"
    output_path = "/home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/colour_combined"
    object_name = "NGC1365"
    
    # Create output directory if it doesn't exist
    os.makedirs(output_path, exist_ok=True)
    print(f"Output directory: {output_path}\n")
    
    # Initialize combiner
    combiner = AstroColorCombiner(input_path, object_name)
    
    # Load all filter data
    combiner.load_all_filters()
    
    # ====== RGB Combination (using R, G, I filters) ======
    # Adjust these weights for color balance (try different values!)
    rgb_image = combiner.create_rgb_image(
        r_filter='r',          # Red channel
        g_filter='g',          # Green channel  
        b_filter='i',          # Blue channel (using i-band filter)
        r_weight=1.0,          # Red channel weight
        g_weight=1.0,          # Green channel weight
        b_weight=1.0,          # Blue channel weight
        lower_percentile=0.1,
        upper_percentile=99.9
    )
    
    # Save RGB as 16-bit TIFF
    rgb_output = os.path.join(output_path, f'{object_name}_RGB_16bit.tif')
    combiner.save_tiff(rgb_image, rgb_output, bits=16)
    
    # ====== LRGB Combination (using Clear filter as Luminance) ======
    # Adjust these parameters for LRGB
    lrgb_image = combiner.create_lrgb_image(
        r_filter='r',              # Red channel
        g_filter='g',              # Green channel
        b_filter='i',              # Blue channel (using i-band filter)
        luminance_filter='clear',  # Using clear filter as luminance
        use_synthetic_lum=False,   # Set to True to create synthetic luminance instead
        r_weight=1.0,              # Red channel weight
        g_weight=1.0,              # Green channel weight
        b_weight=1.0,              # Blue channel weight
        lum_weight=0.5,            # How much luminance vs color (0=pure color, 1=pure luminance)
        lower_percentile=0.1,
        upper_percentile=99.9
    )
    
    # Save LRGB as 16-bit TIFF
    lrgb_output = os.path.join(output_path, f'{object_name}_LRGB_16bit.tif')
    combiner.save_tiff(lrgb_image, lrgb_output, bits=16)
    
    # Display comparison (save to output directory)
    comparison_path = os.path.join(output_path, f'{object_name}_comparison.png')
    combiner.display_images(rgb_image, lrgb_image, comparison_path)
    
    print("\n" + "="*60)
    print("Processing Complete!")
    print("="*60)
    print(f"\nOutput files created in: {output_path}")
    print(f"  - {object_name}_RGB_16bit.tif")
    print(f"  - {object_name}_LRGB_16bit.tif")
    print(f"  - {object_name}_comparison.png")
    print("\nTip: Adjust the weight parameters to fine-tune color balance!")


if __name__ == "__main__":
    main()

Output directory: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/colour_combined

Loaded r: shape=(2048, 2048), min=0.00, max=1.00
Loaded g: shape=(2048, 2048), min=0.00, max=1.02
Loaded i: shape=(2048, 2048), min=-0.00, max=0.98
Loaded clear: shape=(2048, 2048), min=-0.00, max=1.01

All available filters loaded successfully!

=== Creating RGB Image ===
Using filters: R=r, G=g, B=i
RGB image created: shape=(2048, 2048, 3)
Saved: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/colour_combined/NGC1365_RGB_16bit.tif

=== Creating LRGB Image ===
Using filters: R=r, G=g, B=i
Using clear filter as luminance
LRGB image created: shape=(2048, 2048, 3)
Saved: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/colour_combined/NGC1365_LRGB_16bit.tif

Processing Complete!

Output files created in: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama

In [5]:
import numpy as np
from astropy.io import fits
import os
from tifffile import imwrite
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons
from PIL import Image
import threading

class InteractiveColorBalancer:
    def __init__(self, base_path, object_name="NGC1365"):
        self.base_path = base_path
        self.object_name = object_name
        self.data = {}
        self.current_stretch = 'asinh'
        self.current_mode = 'RGB'
        
        # Default parameters
        self.r_weight = 1.0
        self.g_weight = 1.0
        self.b_weight = 1.0
        self.lum_weight = 0.5
        self.lower_percentile = 0.1
        self.upper_percentile = 99.9
        
        # Control flag
        self.keep_running = True
        
    def load_fits_file(self, filter_name, stretch_method):
        """Load FITS file for a given filter and stretch method"""
        # Handle different naming conventions for each stretch method
        if stretch_method == 'logarithmic':
            stretch_prefix = 'log'
        elif stretch_method == 'hyperbolic':
            stretch_prefix = 'hyperbolic'
        elif stretch_method == 'asinh':
            stretch_prefix = 'asinh'
        else:
            stretch_prefix = stretch_method
        
        filename = f"stretched_{stretch_prefix}_{self.object_name}_{filter_name}.fits"
        filepath = os.path.join(self.base_path, stretch_method, filename)
        
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"File not found: {filepath}")
        
        with fits.open(filepath) as hdul:
            data = hdul[0].data.astype(np.float64)
        
        return data
    
    def load_stretch_method(self, stretch_method):
        """Load all filters for a specific stretch method"""
        print(f"Loading {stretch_method} stretch...")
        self.current_stretch = stretch_method
        filters = ['r', 'g', 'i', 'clear']
        
        for f in filters:
            try:
                self.data[f] = self.load_fits_file(f, stretch_method)
            except FileNotFoundError:
                print(f"Warning: {f} filter not found")
        
        print(f"Loaded {stretch_method} stretch successfully!")
    
    def normalize_channel(self, data, lower_percentile, upper_percentile):
        """Normalize channel to 0-1 range with percentile clipping"""
        lower = np.percentile(data, lower_percentile)
        upper = np.percentile(data, upper_percentile)
        
        normalized = (data - lower) / (upper - lower)
        normalized = np.clip(normalized, 0, 1)
        
        return normalized
    
    def apply_color_balance(self, r, g, b, r_weight, g_weight, b_weight):
        """Apply color balance weights to RGB channels"""
        r_balanced = r * r_weight
        g_balanced = g * g_weight
        b_balanced = b * b_weight
        
        # Renormalize to keep within 0-1 range
        max_val = max(r_balanced.max(), g_balanced.max(), b_balanced.max())
        if max_val > 1.0:
            r_balanced /= max_val
            g_balanced /= max_val
            b_balanced /= max_val
        
        return r_balanced, g_balanced, b_balanced
    
    def create_rgb_image(self):
        """Create RGB color image"""
        r_norm = self.normalize_channel(self.data['r'], self.lower_percentile, self.upper_percentile)
        g_norm = self.normalize_channel(self.data['g'], self.lower_percentile, self.upper_percentile)
        b_norm = self.normalize_channel(self.data['i'], self.lower_percentile, self.upper_percentile)
        
        r_bal, g_bal, b_bal = self.apply_color_balance(r_norm, g_norm, b_norm,
                                                        self.r_weight, self.g_weight, self.b_weight)
        
        rgb_image = np.dstack((r_bal, g_bal, b_bal))
        return rgb_image
    
    def create_lrgb_image(self):
        """Create LRGB color image"""
        r_norm = self.normalize_channel(self.data['r'], self.lower_percentile, self.upper_percentile)
        g_norm = self.normalize_channel(self.data['g'], self.lower_percentile, self.upper_percentile)
        b_norm = self.normalize_channel(self.data['i'], self.lower_percentile, self.upper_percentile)
        luminance = self.normalize_channel(self.data['clear'], self.lower_percentile, self.upper_percentile)
        
        r_bal, g_bal, b_bal = self.apply_color_balance(r_norm, g_norm, b_norm,
                                                        self.r_weight, self.g_weight, self.b_weight)
        
        # Blend luminance with color
        r_lrgb = luminance * self.lum_weight + r_bal * (1 - self.lum_weight)
        g_lrgb = luminance * self.lum_weight + g_bal * (1 - self.lum_weight)
        b_lrgb = luminance * self.lum_weight + b_bal * (1 - self.lum_weight)
        
        r_lrgb = np.clip(r_lrgb, 0, 1)
        g_lrgb = np.clip(g_lrgb, 0, 1)
        b_lrgb = np.clip(b_lrgb, 0, 1)
        
        lrgb_image = np.dstack((r_lrgb, g_lrgb, b_lrgb))
        return lrgb_image
    
    def get_current_image(self):
        """Get current image based on mode"""
        if self.current_mode == 'RGB':
            return self.create_rgb_image()
        else:
            return self.create_lrgb_image()
    
    def save_current_image(self, output_path):
        """Save current image with current parameters"""
        image = self.get_current_image()
        
        # Create output directory
        os.makedirs(output_path, exist_ok=True)
        
        # Generate filename
        filename = f"{self.object_name}_{self.current_mode}_{self.current_stretch}_"
        filename += f"R{self.r_weight:.2f}_G{self.g_weight:.2f}_B{self.b_weight:.2f}"
        if self.current_mode == 'LRGB':
            filename += f"_L{self.lum_weight:.2f}"
        
        # Save as 16-bit TIFF
        tiff_path = os.path.join(output_path, filename + ".tif")
        image_16bit = (image * 65535).astype(np.uint16)
        imwrite(tiff_path, image_16bit, photometric='rgb')
        
        # Save as PNG preview
        png_path = os.path.join(output_path, filename + ".png")
        image_8bit = (image * 255).astype(np.uint8)
        img = Image.fromarray(image_8bit)
        img.save(png_path)
        
        print(f"\n✓ Saved images:")
        print(f"  TIFF: {tiff_path}")
        print(f"  PNG:  {png_path}")
        
        return tiff_path, png_path


def launch_interactive_tool(base_path, output_path, object_name="NGC1365", initial_stretch='asinh'):
    """Launch interactive color balance tool with sliders"""
    
    # Initialize balancer
    balancer = InteractiveColorBalancer(base_path, object_name)
    balancer.load_stretch_method(initial_stretch)
    
    # Use TkAgg backend for better interactivity
    plt.switch_backend('TkAgg')
    
    # Create figure and axes
    fig = plt.figure(figsize=(14, 10))
    fig.canvas.manager.set_window_title('Interactive Color Balance Tool - Keep this window open!')
    
    # Main image display
    ax_image = plt.axes([0.1, 0.35, 0.8, 0.6])
    
    # Initial image
    initial_image = balancer.get_current_image()
    im = ax_image.imshow(initial_image, origin='lower')
    ax_image.set_title(f'{object_name} - {balancer.current_mode} ({balancer.current_stretch})', 
                       fontsize=14, fontweight='bold')
    ax_image.axis('off')
    
    # Slider positions
    slider_color = 'lightgoldenrodyellow'
    
    # Red weight slider
    ax_r = plt.axes([0.15, 0.25, 0.7, 0.02], facecolor=slider_color)
    slider_r = Slider(ax_r, 'Red', 0.1, 2.0, valinit=1.0, valstep=0.05, color='red')
    
    # Green weight slider
    ax_g = plt.axes([0.15, 0.20, 0.7, 0.02], facecolor=slider_color)
    slider_g = Slider(ax_g, 'Green', 0.1, 2.0, valinit=1.0, valstep=0.05, color='green')
    
    # Blue weight slider
    ax_b = plt.axes([0.15, 0.15, 0.7, 0.02], facecolor=slider_color)
    slider_b = Slider(ax_b, 'Blue', 0.1, 2.0, valinit=1.0, valstep=0.05, color='blue')
    
    # Luminance weight slider (for LRGB mode)
    ax_lum = plt.axes([0.15, 0.10, 0.7, 0.02], facecolor=slider_color)
    slider_lum = Slider(ax_lum, 'Luminance', 0.0, 1.0, valinit=0.5, valstep=0.05, color='gray')
    
    # Mode radio buttons (RGB / LRGB)
    ax_mode = plt.axes([0.02, 0.7, 0.08, 0.15], facecolor=slider_color)
    radio_mode = RadioButtons(ax_mode, ('RGB', 'LRGB'))
    
    # Stretch method radio buttons
    ax_stretch = plt.axes([0.02, 0.45, 0.08, 0.2], facecolor=slider_color)
    radio_stretch = RadioButtons(ax_stretch, ('asinh', 'hyperbolic', 'logarithmic'))
    radio_stretch.set_active(['asinh', 'hyperbolic', 'logarithmic'].index(initial_stretch))
    
    # Reset button
    ax_reset = plt.axes([0.02, 0.35, 0.08, 0.04])
    btn_reset = Button(ax_reset, 'Reset', color=slider_color, hovercolor='0.975')
    
    # Save button
    ax_save = plt.axes([0.02, 0.30, 0.08, 0.04])
    btn_save = Button(ax_save, 'Save', color='lightgreen', hovercolor='0.975')
    
    def update(val=None):
        """Update image when sliders change"""
        balancer.r_weight = slider_r.val
        balancer.g_weight = slider_g.val
        balancer.b_weight = slider_b.val
        balancer.lum_weight = slider_lum.val
        
        # Get updated image
        updated_image = balancer.get_current_image()
        im.set_data(updated_image)
        
        # Update title
        ax_image.set_title(f'{object_name} - {balancer.current_mode} ({balancer.current_stretch}) | '
                          f'R:{balancer.r_weight:.2f} G:{balancer.g_weight:.2f} B:{balancer.b_weight:.2f}' +
                          (f' L:{balancer.lum_weight:.2f}' if balancer.current_mode == 'LRGB' else ''),
                          fontsize=12, fontweight='bold')
        
        fig.canvas.draw_idle()
    
    def change_mode(label):
        """Change between RGB and LRGB mode"""
        balancer.current_mode = label
        
        # Show/hide luminance slider based on mode
        slider_lum.ax.set_visible(label == 'LRGB')
        
        update()
    
    def change_stretch(label):
        """Change stretch method"""
        balancer.load_stretch_method(label)
        update()
    
    def reset(event):
        """Reset all sliders to default values"""
        slider_r.reset()
        slider_g.reset()
        slider_b.reset()
        slider_lum.reset()
        update()
    
    def save(event):
        """Save current image"""
        balancer.save_current_image(output_path)
    
    def on_close(event):
        """Handle window close event"""
        balancer.keep_running = False
        print("\n✓ Interactive session ended.")
    
    # Connect callbacks
    slider_r.on_changed(update)
    slider_g.on_changed(update)
    slider_b.on_changed(update)
    slider_lum.on_changed(update)
    radio_mode.on_clicked(change_mode)
    radio_stretch.on_clicked(change_stretch)
    btn_reset.on_clicked(reset)
    btn_save.on_clicked(save)
    fig.canvas.mpl_connect('close_event', on_close)
    
    # Initial visibility of luminance slider
    slider_lum.ax.set_visible(balancer.current_mode == 'LRGB')
    
    plt.suptitle('Interactive Color Balance Tool - Close window when done', 
                 fontsize=16, fontweight='bold', y=0.98)
    
    # Add instructions
    instructions = ("Instructions:\n"
                   "• Adjust sliders to change color balance\n"
                   "• Switch between RGB/LRGB modes\n"
                   "• Try different stretch methods\n"
                   "• Click 'Save' to export current view\n"
                   "• Click 'Reset' to restore defaults\n"
                   "• Close window when finished")
    fig.text(0.02, 0.12, instructions, fontsize=9, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # Show the plot - this will block until window is closed
    print("\n" + "="*70)
    print("Interactive window is now open!")
    print("Adjust the sliders and settings in the window.")
    print("Close the window when you're done.")
    print("="*70 + "\n")
    
    plt.show(block=True)
    
    return balancer


# Main execution
if __name__ == "__main__":
    base_path = "/home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/second_alignment"
    output_path = "/home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/colour_combined/interactive"
    object_name = "NGC1365"
    
    print("="*70)
    print("Interactive Color Balance Tool for Astronomy Images")
    print("="*70)
    print("\nInitializing...")
    
    try:
        balancer = launch_interactive_tool(base_path, output_path, object_name, initial_stretch='asinh')
        
        print("\n" + "="*70)
        print("Session complete!")
        print("="*70)
        print(f"\nSaved images are in: {output_path}")
        
    except Exception as e:
        print(f"\n❌ Error: {str(e)}")
        print("\nTroubleshooting:")
        print("1. Make sure you have a display available (not running headless)")
        print("2. Try: pip install PyQt5")
        print("3. Check that your file paths are correct")

Interactive Color Balance Tool for Astronomy Images

Initializing...
Loading asinh stretch...
Loaded asinh stretch successfully!

Interactive window is now open!
Adjust the sliders and settings in the window.
Close the window when you're done.


✓ Interactive session ended.

Session complete!

Saved images are in: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/colour_combined/interactive


In [5]:
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Configuration
BASE_PATH = "/home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/STACKED/aligned"
OUTPUT_PATH = "/home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/STACKED/color_combined"

OBJECT_NAME = "NGC1365"

# All stretch methods already processed in your pipeline
STRETCH_METHODS = ['normalized', 'log', 'hyperbolic', 'asinh', 'gamma', 'adaptive_asinh']

# Filter mapping for RGB
FILTER_MAPPING = {
    'R': 'i',      # Red channel
    'G': 'r',      # Green channel  
    'B': 'g',      # Blue channel
    'L': 'clear'   # Luminance (for LRGB)
}

# Simple color combination parameters
COLOR_PARAMS = {
    'saturation_boost': 1.2,       # Saturation enhancement (1.0 = no change)
    'lrgb_luminance_weight': 0.60, # Weight for luminance in LRGB
    'lrgb_color_weight': 0.40,     # Weight for color in LRGB
    'apply_clip': True,            # Clip to 0-1 range for display
}

# Display parameters
DISPLAY_PARAMS = {
    'comparison_figure_size': (24, 12),
    'individual_figure_size': (12, 12),
    'zoom_size': 600,
}


def load_aligned_image(filter_name, method):
    """Load aligned FITS image (already normalized and stretched)"""
    filename = f"aligned_{method}_{OBJECT_NAME}_{filter_name}.fits"
    filepath = Path(BASE_PATH) / method / filename
    
    if not filepath.exists():
        raise FileNotFoundError(f"File not found: {filepath}")
    
    with fits.open(filepath) as hdul:
        data = hdul[0].data.astype(np.float64)
        header = hdul[0].header
    
    # Handle NaN and inf values
    data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
    
    # Statistics
    print(f"  {filter_name}: shape={data.shape}, "
          f"range=[{np.min(data):.6f}, {np.max(data):.6f}], "
          f"mean={np.mean(data):.6f}, median={np.median(data):.6f}")
    
    return data, header


def rgb_to_hsv(rgb):
    """Convert RGB to HSV color space"""
    r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]
    
    maxc = np.maximum(np.maximum(r, g), b)
    minc = np.minimum(np.minimum(r, g), b)
    v = maxc
    
    deltac = maxc - minc
    s = np.where(maxc != 0, deltac / maxc, 0)
    
    rc = np.where(deltac != 0, (maxc - r) / deltac, 0)
    gc = np.where(deltac != 0, (maxc - g) / deltac, 0)
    bc = np.where(deltac != 0, (maxc - b) / deltac, 0)
    
    h = np.zeros_like(v)
    h = np.where(r == maxc, bc - gc, h)
    h = np.where(g == maxc, 2.0 + rc - bc, h)
    h = np.where(b == maxc, 4.0 + gc - rc, h)
    h = (h / 6.0) % 1.0
    
    return np.stack([h, s, v], axis=-1)


def hsv_to_rgb(hsv):
    """Convert HSV to RGB color space"""
    h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2]
    
    i = (h * 6.0).astype(int)
    f = (h * 6.0) - i
    p = v * (1.0 - s)
    q = v * (1.0 - s * f)
    t = v * (1.0 - s * (1.0 - f))
    i = i % 6
    
    rgb = np.zeros(hsv.shape)
    
    mask = (i == 0)
    rgb[mask] = np.stack([v[mask], t[mask], p[mask]], axis=-1)
    mask = (i == 1)
    rgb[mask] = np.stack([q[mask], v[mask], p[mask]], axis=-1)
    mask = (i == 2)
    rgb[mask] = np.stack([p[mask], v[mask], t[mask]], axis=-1)
    mask = (i == 3)
    rgb[mask] = np.stack([p[mask], q[mask], v[mask]], axis=-1)
    mask = (i == 4)
    rgb[mask] = np.stack([t[mask], p[mask], v[mask]], axis=-1)
    mask = (i == 5)
    rgb[mask] = np.stack([v[mask], p[mask], q[mask]], axis=-1)
    
    return rgb


def boost_saturation(rgb, factor=1.2):
    """Boost color saturation"""
    if factor == 1.0:
        return rgb
    
    hsv = rgb_to_hsv(rgb)
    hsv[..., 1] = np.clip(hsv[..., 1] * factor, 0, 1)
    return hsv_to_rgb(hsv)


def create_rgb_image(r_data, g_data, b_data, params):
    """Create RGB color image - simple stacking of already-processed channels"""
    print("\n  Creating RGB image (simple channel stacking)...")
    
    print(f"    R channel: mean={np.mean(r_data):.6f}, max={np.max(r_data):.6f}")
    print(f"    G channel: mean={np.mean(g_data):.6f}, max={np.max(g_data):.6f}")
    print(f"    B channel: mean={np.mean(b_data):.6f}, max={np.max(b_data):.6f}")
    
    # Stack into RGB (data is already processed)
    rgb = np.stack([r_data, g_data, b_data], axis=-1)
    
    # Optional: Boost saturation
    if params['saturation_boost'] != 1.0:
        rgb = boost_saturation(rgb, params['saturation_boost'])
        print(f"    Applied saturation boost: {params['saturation_boost']}")
    
    # Optional: Clip to valid range for display
    if params['apply_clip']:
        rgb = np.clip(rgb, 0, 1)
    
    print(f"    Final RGB: mean={np.mean(rgb):.6f}, max={np.max(rgb):.6f}, min={np.min(rgb):.6f}")
    
    return rgb


def create_lrgb_image(l_data, r_data, g_data, b_data, params):
    """Create LRGB color image - combine luminance with color"""
    print("\n  Creating LRGB image...")
    
    print(f"    L channel: mean={np.mean(l_data):.6f}, max={np.max(l_data):.6f}")
    
    # Create RGB for color
    rgb_color = np.stack([r_data, g_data, b_data], axis=-1)
    
    # Optional: Boost color saturation
    if params['saturation_boost'] != 1.0:
        rgb_color = boost_saturation(rgb_color, params['saturation_boost'])
    
    # Combine luminance and color
    # Method: preserve color ratios while applying luminance for detail
    rgb_mean = np.mean(rgb_color, axis=-1, keepdims=True)
    rgb_mean = np.maximum(rgb_mean, 1e-10)  # Avoid division by zero
    
    # Normalize color (preserves hue and saturation)
    color_normalized = rgb_color / rgb_mean
    
    # Apply luminance for detail
    lrgb = color_normalized * l_data[..., np.newaxis]
    
    # Blend luminance-weighted and original color
    lum_weight = params['lrgb_luminance_weight']
    color_weight = params['lrgb_color_weight']
    lrgb = lum_weight * lrgb + color_weight * rgb_color
    
    # Optional: Clip to valid range
    if params['apply_clip']:
        lrgb = np.clip(lrgb, 0, 1)
    
    print(f"    Final LRGB: mean={np.mean(lrgb):.6f}, max={np.max(lrgb):.6f}, min={np.min(lrgb):.6f}")
    print(f"    Weights used: Luminance={lum_weight}, Color={color_weight}")
    
    return lrgb


def save_color_image(image, filename, header=None):
    """Save color image as FITS (separate channels)"""
    hdu_list = fits.HDUList()
    
    primary_hdu = fits.PrimaryHDU(image[..., 0].astype(np.float32), header=header)
    primary_hdu.header['CHANNEL'] = 'RED'
    hdu_list.append(primary_hdu)
    
    green_hdu = fits.ImageHDU(image[..., 1].astype(np.float32))
    green_hdu.header['CHANNEL'] = 'GREEN'
    hdu_list.append(green_hdu)
    
    blue_hdu = fits.ImageHDU(image[..., 2].astype(np.float32))
    blue_hdu.header['CHANNEL'] = 'BLUE'
    hdu_list.append(blue_hdu)
    
    hdu_list.writeto(filename, overwrite=True)


def save_png(image, filename, dpi=300):
    """Save as PNG"""
    plt.figure(figsize=(12, 12))
    plt.imshow(image, origin='lower')
    plt.axis('off')
    plt.tight_layout(pad=0)
    plt.savefig(filename, dpi=dpi, bbox_inches='tight', pad_inches=0)
    plt.close()


def create_comparison_plot(rgb_image, lrgb_image, method, params, output_path):
    """Create detailed comparison plot"""
    fig = plt.figure(figsize=DISPLAY_PARAMS['comparison_figure_size'])
    gs = fig.add_gridspec(2, 3, hspace=0.15, wspace=0.15)
    
    # Full RGB
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(rgb_image, origin='lower')
    ax1.set_title('RGB (i-r-g)', fontsize=14, fontweight='bold')
    ax1.axis('off')
    
    # Full LRGB
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.imshow(lrgb_image, origin='lower')
    ax2.set_title('LRGB (clear+i-r-g)', fontsize=14, fontweight='bold')
    ax2.axis('off')
    
    # Difference
    ax3 = fig.add_subplot(gs[0, 2])
    diff = np.abs(lrgb_image - rgb_image)
    diff_display = np.mean(diff, axis=-1)
    im3 = ax3.imshow(diff_display, origin='lower', cmap='hot')
    ax3.set_title('Difference (LRGB - RGB)', fontsize=14, fontweight='bold')
    ax3.axis('off')
    plt.colorbar(im3, ax=ax3, fraction=0.046, label='Absolute difference')
    
    # Zoomed comparisons (center of galaxy)
    cy, cx = rgb_image.shape[0] // 2, rgb_image.shape[1] // 2
    zoom = DISPLAY_PARAMS['zoom_size']
    
    ax4 = fig.add_subplot(gs[1, 0])
    ax4.imshow(rgb_image[cy-zoom:cy+zoom, cx-zoom:cx+zoom], origin='lower')
    ax4.set_title('RGB (zoomed center)', fontsize=12, fontweight='bold')
    ax4.axis('off')
    
    ax5 = fig.add_subplot(gs[1, 1])
    ax5.imshow(lrgb_image[cy-zoom:cy+zoom, cx-zoom:cx+zoom], origin='lower')
    ax5.set_title('LRGB (zoomed center)', fontsize=12, fontweight='bold')
    ax5.axis('off')
    
    # Statistics
    ax6 = fig.add_subplot(gs[1, 2])
    ax6.axis('off')
    
    stats_text = f"""STRETCH: {method.upper()}
{'='*35}

RGB Image:
  Mean: {np.mean(rgb_image):.6f}
  Std:  {np.std(rgb_image):.6f}
  Max:  {np.max(rgb_image):.6f}
  Min:  {np.min(rgb_image):.6f}
  
LRGB Image:
  Mean: {np.mean(lrgb_image):.6f}
  Std:  {np.std(lrgb_image):.6f}
  Max:  {np.max(lrgb_image):.6f}
  Min:  {np.min(lrgb_image):.6f}
  
Difference:
  Mean: {np.mean(diff):.6f}
  Max:  {np.max(diff):.6f}

Parameters:
  Saturation: {params['saturation_boost']}
  Lum weight: {params['lrgb_luminance_weight']}
  Color weight: {params['lrgb_color_weight']}

LRGB adds detail from the
clear (luminance) filter while
preserving color information.
    """
    
    ax6.text(0.05, 0.95, stats_text, transform=ax6.transAxes,
             fontsize=9.5, verticalalignment='top', fontfamily='monospace',
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3))
    
    plt.suptitle(f'{OBJECT_NAME} - RGB vs LRGB Comparison ({method})',
                fontsize=16, fontweight='bold')
    
    comparison_file = Path(output_path) / f"comparison_{method}_{OBJECT_NAME}.png"
    plt.savefig(comparison_file, dpi=150, bbox_inches='tight')
    plt.close()


def create_individual_plots(rgb_image, lrgb_image, method, output_path):
    """Create individual high-resolution plots"""
    
    # RGB plot
    fig, ax = plt.subplots(1, 1, figsize=DISPLAY_PARAMS['individual_figure_size'])
    ax.imshow(rgb_image, origin='lower')
    ax.set_title(f'{OBJECT_NAME} - RGB ({method})', fontsize=16, fontweight='bold', pad=20)
    ax.axis('off')
    plt.tight_layout()
    
    rgb_file = Path(output_path) / f"RGB_{method}_{OBJECT_NAME}.png"
    plt.savefig(rgb_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    # LRGB plot
    fig, ax = plt.subplots(1, 1, figsize=DISPLAY_PARAMS['individual_figure_size'])
    ax.imshow(lrgb_image, origin='lower')
    ax.set_title(f'{OBJECT_NAME} - LRGB ({method})', fontsize=16, fontweight='bold', pad=20)
    ax.axis('off')
    plt.tight_layout()
    
    lrgb_file = Path(output_path) / f"LRGB_{method}_{OBJECT_NAME}.png"
    plt.savefig(lrgb_file, dpi=300, bbox_inches='tight')
    plt.close()


def process_stretch_method(method, r_data, g_data, b_data, l_data, r_header, output_path):
    """Process a single stretch method - just color combination"""
    print(f"\n{'='*80}")
    print(f"PROCESSING: {method.upper()}")
    print(f"{'='*80}")
    
    # Create output subdirectory
    method_path = Path(output_path) / method
    method_path.mkdir(parents=True, exist_ok=True)
    
    # Create RGB image (just stack channels)
    rgb_image = create_rgb_image(r_data, g_data, b_data, COLOR_PARAMS)
    
    # Save RGB
    rgb_fits = method_path / f"RGB_{method}_{OBJECT_NAME}.fits"
    rgb_png = method_path / f"RGB_{method}_{OBJECT_NAME}.png"
    save_color_image(rgb_image, rgb_fits, r_header)
    save_png(rgb_image, rgb_png)
    print(f"  ✓ Saved: {rgb_fits.name}")
    print(f"  ✓ Saved: {rgb_png.name}")
    
    # Create LRGB image
    lrgb_image = create_lrgb_image(l_data, r_data, g_data, b_data, COLOR_PARAMS)
    
    # Save LRGB
    lrgb_fits = method_path / f"LRGB_{method}_{OBJECT_NAME}.fits"
    lrgb_png = method_path / f"LRGB_{method}_{OBJECT_NAME}.png"
    save_color_image(lrgb_image, lrgb_fits, r_header)
    save_png(lrgb_image, lrgb_png)
    print(f"  ✓ Saved: {lrgb_fits.name}")
    print(f"  ✓ Saved: {lrgb_png.name}")
    
    # Create comparison plot
    create_comparison_plot(rgb_image, lrgb_image, method, COLOR_PARAMS, method_path)
    create_individual_plots(rgb_image, lrgb_image, method, method_path)
    print(f"  ✓ Saved: comparison_{method}_{OBJECT_NAME}.png")
    
    print(f"\n✓ Completed {method}")


def main():
    """Main color combination function"""
    print("="*80)
    print(f"COLOR COMBINATION - {OBJECT_NAME}")
    print("="*80)
    print("\nPipeline: Load pre-processed data → Stack channels → Save")
    print("(No normalization, no stretching - data already processed)")
    print(f"\nInput path: {BASE_PATH}")
    print(f"Output path: {OUTPUT_PATH}")
    print(f"\nProcessing {len(STRETCH_METHODS)} stretch methods:")
    for m in STRETCH_METHODS:
        print(f"  • {m}")
    
    print(f"\nFilter mapping:")
    print(f"  R (Red):   {FILTER_MAPPING['R']}")
    print(f"  G (Green): {FILTER_MAPPING['G']}")
    print(f"  B (Blue):  {FILTER_MAPPING['B']}")
    print(f"  L (Lum):   {FILTER_MAPPING['L']}")
    
    print(f"\nColor combination parameters:")
    print(f"  Saturation boost: {COLOR_PARAMS['saturation_boost']}")
    print(f"  LRGB luminance weight: {COLOR_PARAMS['lrgb_luminance_weight']}")
    print(f"  LRGB color weight: {COLOR_PARAMS['lrgb_color_weight']}")
    
    # Create main output directory
    Path(OUTPUT_PATH).mkdir(parents=True, exist_ok=True)
    
    # Load and process each stretch method
    for method in STRETCH_METHODS:
        print(f"\n{'='*80}")
        print(f"LOADING: {method.upper()}")
        print(f"{'='*80}")
        
        try:
            r_data, r_header = load_aligned_image(FILTER_MAPPING['R'], method)
            g_data, _ = load_aligned_image(FILTER_MAPPING['G'], method)
            b_data, _ = load_aligned_image(FILTER_MAPPING['B'], method)
            l_data, _ = load_aligned_image(FILTER_MAPPING['L'], method)
            
            # Process this method
            process_stretch_method(
                method, r_data, g_data, b_data, l_data,
                r_header, OUTPUT_PATH
            )
            
        except FileNotFoundError as e:
            print(f"\n✗ Error: {e}")
            print(f"  Skipping {method}")
            continue
        except Exception as e:
            print(f"\n✗ Unexpected error processing {method}: {e}")
            import traceback
            traceback.print_exc()
            print(f"  Skipping {method}")
            continue
    
    # Summary
    print("\n" + "="*80)
    print("ALL METHODS COMPLETE!")
    print("="*80)
    print(f"\nOutput directory: {OUTPUT_PATH}/")
    print("\nCreated subdirectories:")
    for method in STRETCH_METHODS:
        method_path = Path(OUTPUT_PATH) / method
        if method_path.exists():
            print(f"  • {method}/")
            print(f"    - RGB_{method}_{OBJECT_NAME}.fits/.png")
            print(f"    - LRGB_{method}_{OBJECT_NAME}.fits/.png")
            print(f"    - comparison_{method}_{OBJECT_NAME}.png")
    
    print("\n✓ All color combinations complete!")
    print("\nNext steps:")
    print("  • Review all comparison images")
    print("  • Choose the stretch method that best shows galaxy structure")
    print("  • If colors need adjustment, modify COLOR_PARAMS:")
    print("    - saturation_boost: increase for more vivid colors")
    print("    - lrgb_luminance_weight: increase for more detail")
    print("    - lrgb_color_weight: increase for more color")


if __name__ == "__main__":
    main()

COLOR COMBINATION - NGC1365

Pipeline: Load pre-processed data → Stack channels → Save
(No normalization, no stretching - data already processed)

Input path: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/STACKED/aligned
Output path: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/STACKED/color_combined

Processing 6 stretch methods:
  • normalized
  • log
  • hyperbolic
  • asinh
  • gamma
  • adaptive_asinh

Filter mapping:
  R (Red):   i
  G (Green): r
  B (Blue):  g
  L (Lum):   clear

Color combination parameters:
  Saturation boost: 1.2
  LRGB luminance weight: 0.6
  LRGB color weight: 0.4

LOADING: NORMALIZED
  i: shape=(2048, 2048), range=[-0.000134, 1.015710], mean=0.018680, median=0.018959
  r: shape=(2048, 2048), range=[0.000000, 1.000000], mean=0.012641, median=0.012595
  g: shape=(2048, 2048), range=[0.000000, 1.000580], mean=0.019138, median=0.019284
  clear: shape=(2048, 2