In [None]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import tempfile
import pandas as pd

In [None]:
### Plot the 512 x 512 processed AIA data sets for 7/7, 7/11, 7/20 and 8/1


In [None]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import tempfile
import pandas as pd

# --- Configurable Parameters ---
data_dir = "/mnt/data2/AIA_processed_data/94"
saturation_threshold = 16000
num_workers = 56
chunk_size = 50
PREPROCESS_MODE = "disk"  # Options: "stream", "disk", "histogram"
temp_dir = "/mnt/data2/tmp"
os.makedirs(temp_dir, exist_ok=True)

# --- Stats Classes ---
class StreamingStats:
    def __init__(self):
        self.n = 0
        self.sum = 0.0
        self.sum_sq = 0.0
        self.min_val = np.inf
        self.max_val = -np.inf
        self.all_pixels = []

    def update(self, pixels):
        if pixels.size > 0:
            self.n += pixels.size
            self.sum += np.sum(pixels)
            self.sum_sq += np.sum(pixels ** 2)
            self.min_val = min(self.min_val, np.min(pixels))
            self.max_val = max(self.max_val, np.max(pixels))
            self.all_pixels.extend(pixels.tolist())

    def get_stats(self):
        if self.n == 0:
            return None
        mean = self.sum / self.n
        variance = (self.sum_sq / self.n) - mean ** 2
        std = np.sqrt(max(0, variance))
        pixels = np.array(self.all_pixels)
        return {
            'mean': mean,
            'std': std,
            'min': self.min_val,
            'max': self.max_val,
            'p95': np.percentile(pixels, 95),
            'p99': np.percentile(pixels, 99),
            'p99.5': np.percentile(pixels, 99.5),
            'total_pixels': self.n,
            'all_pixels': pixels
        }

class DiskBasedStats:
    def __init__(self, temp_dir=None):
        self.n = 0
        self.sum = 0.0
        self.sum_sq = 0.0
        self.min_val = np.inf
        self.max_val = -np.inf
        self.temp_files = []
        self.temp_dir = temp_dir or tempfile.gettempdir()
        os.makedirs(self.temp_dir, exist_ok=True)

    def update(self, pixels):
        if pixels.size > 0:
            self.n += pixels.size
            self.sum += np.sum(pixels)
            self.sum_sq += np.sum(pixels ** 2)
            self.min_val = min(self.min_val, np.min(pixels))
            self.max_val = max(self.max_val, np.max(pixels))
            filename = os.path.join(self.temp_dir, f"pixels_{len(self.temp_files)}.npy")
            np.save(filename, pixels)
            self.temp_files.append(filename)

    def get_stats(self):
        if self.n == 0:
            return None
        mean = self.sum / self.n
        variance = (self.sum_sq / self.n) - mean ** 2
        std = np.sqrt(max(0, variance))
        pixels = []
        for f in tqdm(self.temp_files, desc="Loading pixel data"):
            pixels.append(np.load(f))
        all_pixels = np.concatenate(pixels)
        return {
            'mean': mean,
            'std': std,
            'min': self.min_val,
            'max': self.max_val,
            'p95': np.percentile(all_pixels, 95),
            'p99': np.percentile(all_pixels, 99),
            'p99.5': np.percentile(all_pixels, 99.5),
            'total_pixels': self.n,
            'all_pixels': all_pixels
        }

class HistogramOnlyStats:
    def __init__(self, bins=500):
        self.pixel_values = []
        self.bins = bins

    def update(self, pixels):
        if pixels.size > 0:
            self.pixel_values.extend(pixels.tolist())

    def get_stats(self):
        if not self.pixel_values:
            return None
        pixels = np.array(self.pixel_values)
        hist, bin_edges = np.histogram(pixels, bins=self.bins)
        return {
            'histogram': hist,
            'bin_edges': bin_edges,
            'total_pixels': len(pixels)
        }

# --- Dispatcher ---
def get_stats_engine(mode):
    if mode == "stream":
        return StreamingStats()
    elif mode == "disk":
        return DiskBasedStats(temp_dir=temp_dir)
    elif mode == "histogram":
        return HistogramOnlyStats()
    else:
        raise ValueError("Unknown PREPROCESS_MODE")

# --- File Discovery ---
print("Searching for .fits files...")
fits_files = glob.glob(os.path.join(data_dir, "**", "*.fits"), recursive=True)

if not fits_files:
    raise FileNotFoundError(f"No .fits files found in {data_dir}")

print(f"Found {len(fits_files)} files.")

# --- FITS Processor ---
def process_single_file(file_path):
    try:
        with fits.open(file_path) as hdul:
            for hdu in hdul:
                data = getattr(hdu, 'data', None)
                if data is not None and data.ndim >= 2:
                    if data.ndim == 3:
                        data = data[0]
                    valid_pixels = data[(data > 0) & np.isfinite(data)]
                    return valid_pixels.flatten() if valid_pixels.size else np.array([])
        return np.array([])
    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return np.array([])

# --- Processing Loop ---
stats = get_stats_engine(PREPROCESS_MODE)
total_files = len(fits_files)

for i in tqdm(range(0, total_files, chunk_size), desc="Processing FITS"):
    chunk = fits_files[i:i+chunk_size]
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        results = executor.map(process_single_file, chunk)
    for pixels in results:
        if pixels.size:
            stats.update(pixels)

# --- Final Statistics and Plotting ---
results = stats.get_stats()
if results is None:
    print("No valid data.")
elif PREPROCESS_MODE == "histogram":
    hist, bins = results['histogram'], results['bin_edges']
    plt.figure(figsize=(10, 5))
    plt.bar(bins[:-1], hist, width=np.diff(bins), align='edge', edgecolor='black')
    plt.xscale('log')
    plt.yscale('log')
    plt.title("Histogram of Intensity")
    plt.xlabel("Pixel Intensity")
    plt.ylabel("Count")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("aia_histogram_only.png", dpi=300)
    plt.show()
else:
    print(f"\nProcessed {results['total_pixels']:,} pixels.")
    print(f"Mean: {results['mean']:.2f}, Std: {results['std']:.2f}, Min: {results['min']:.1f}, Max: {results['max']:.1f}")
    print(f"P95: {results['p95']:.2f}, P99: {results['p99']:.2f}, P99.5: {results['p99.5']:.2f}")

    all_pixels = results['all_pixels']
    log_bins = np.logspace(np.log10(max(1, all_pixels.min())), np.log10(all_pixels.max()), 500)
    counts, bin_edges = np.histogram(all_pixels, bins=log_bins)
    bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])

    plt.figure(figsize=(10, 5))
    plt.bar(np.log10(bin_centers), np.log10(counts + 1), width=0.01, align='center', edgecolor='black')
    plt.axvline(np.log10(saturation_threshold), color='red', linestyle='--', label='Saturation Threshold')
    plt.xlabel("log10(Pixel Intensity)")
    plt.ylabel("log10(Count)")
    plt.title("Log-Log Histogram of Intensity")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("aia_loglog_hist.png", dpi=300)
    plt.show()
