# Analyse Image
Code used to analyse the data produced by the simulation and the algorithm masks.

In [1]:
import numpy as np
import polars as pl
import logging
import matplotlib.pyplot as plt

from scrs.constants import OUT_DIR
from scrs import Image

In [2]:
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

logging.basicConfig()
logging.getLogger().setLevel(logging.WARNING)

logger = logging.getLogger()

## Summaries
Read through all the files and extract:
- the cosmic ray counts
- the time taken to run

In [3]:
files = list((OUT_DIR / "auto" / "raw").glob("*_raw.fits"))
uuids = [f.stem.split("_")[0] for f in files]

del files

In [4]:
# CR Counts DataFrame
data = {k: [] for k in ["true", "lc", "ac", "pc", "in_star"]}

for i in range(2**5):
    combination = (bool(i & (1 << bit)) for bit in range(5))
    for key, value in zip(data.keys(), combination):
        data[key].append(value)

cr_counts = pl.DataFrame(data)
cr_counts

true,lc,ac,pc,in_star
bool,bool,bool,bool,bool
false,false,false,false,false
true,false,false,false,false
false,true,false,false,false
true,true,false,false,false
false,false,true,false,false
…,…,…,…,…
true,true,false,true,true
false,false,true,true,true
true,false,true,true,true
false,true,true,true,true


In [5]:
time_rows = []
counts = [0] * 32

for uuid in uuids:
    """
    Classify all pixels in the image as true, lc, ac, pc, in_star
    """
    img = Image().load_fits(OUT_DIR / "auto" / "raw" / f"{uuid}_raw.fits")
    cosmics = img.get_diff(0, 1).data
    true_mask = (cosmics > 0).astype(bool)

    star_mask = (img.get_snapshot(idx=2).data > 0).astype(bool)

    masks = Image().load_fits(OUT_DIR / "auto" / "masks" / f"{uuid}_mask.fits")
    hdu = masks.get_header()

    # Times
    # Append algorithm, time, uuid to the times dataframe
    for algorithm in ["LC", "AC", "PC"]:
        time_rows.append((algorithm, hdu[f"{algorithm}_TIME"], uuid))

    # CR Counts
    pc_mask = masks.data.astype(bool)  # Last one run
    ac_mask = masks.get_snapshot(idx=1).data.astype(bool)
    lc_mask = masks.get_snapshot(idx=2).data.astype(bool)

    def count_pixels(
        true: bool | None = None,
        lc: bool | None = None,
        ac: bool | None = None,
        pc: bool | None = None,
        in_star: bool | None = None,
    ):
        masks = []
        if true is not None:
            masks.append(true_mask if true else ~true_mask)
        if lc is not None:
            masks.append(lc_mask if lc else ~lc_mask)
        if ac is not None:
            masks.append(ac_mask if ac else ~ac_mask)
        if pc is not None:
            masks.append(pc_mask if pc else ~pc_mask)
        if in_star is not None:
            masks.append(star_mask if in_star else ~star_mask)

        mask = np.logical_and.reduce(masks)

        return mask.sum()
    
    # Count all pixels for each combination, and add the count to the corresponding row in the dataframe
    for i, mask_filters in enumerate(cr_counts.iter_rows()):
        counts[i] += count_pixels(*mask_filters)

cr_counts = cr_counts.with_columns(pl.Series("count", counts))

In [72]:
times = pl.DataFrame(time_rows, schema={
    "algorithm": pl.String,
    "time": pl.Float64,
    "uuid": pl.String
}, orient="row")

In [73]:
# Remove times greater than 2x the mean time, per algorithm. This is a very simple outlier removal method
# to handle runs impacted by the computer going on standby. Works for the data used, but not a
# robust general solution.
for algorithm in ["LC", "AC", "PC"]:
    mean_time = times.filter(pl.col("algorithm") == algorithm).get_column("time").mean()
    times = times.filter((pl.col("algorithm") != algorithm) | (pl.col("time") < 2 * mean_time))

In [76]:
"""
Get the time statistics for each algorithm
"""
orders = ["LC", "AC", "PC"]

times_summary = times.group_by("algorithm").agg(
    pl.col("time").mean().alias("mean (s)"),
    pl.col("time").median().alias("median (s)"),
    pl.col("time").max().alias("max (s)"),
    pl.col("time").min().alias("min (s)"),
    pl.col("time").std().alias("std (s)"),
)
times_summary

algorithm,mean (s),median (s),max (s),min (s),std (s)
str,f64,f64,f64,f64,f64
"""LC""",119.067409,118.821917,124.750534,114.779644,1.428138
"""AC""",11.89094,11.814318,13.596487,11.555918,0.305816
"""PC""",123.29202,122.899282,142.911443,121.228592,2.218811


In [None]:
"""
Calculate the accuracy, precision, recall, specificity, F1 score, and IoU for each.
Includes the total, in_star, and out_star metrics since being near a star will likely impact the
CR detection performance.
"""

def get_metrics(in_star: bool | None = None):
    metrics = []
    for alg in ["lc", "ac", "pc"]:
        # Apply the star filtering if applicable. If missing, `& True` will have no effect.
        star_filter = (pl.col("in_star") == in_star) if (in_star is not None) else True

        tp = cr_counts.filter(pl.col(alg) & pl.col("true") & star_filter).get_column("count").sum()
        fp = cr_counts.filter(pl.col(alg) & ~pl.col("true") & star_filter).get_column("count").sum()
        fn = cr_counts.filter(~pl.col(alg) & pl.col("true") & star_filter).get_column("count").sum()
        tn = cr_counts.filter(~pl.col(alg) & ~pl.col("true") & star_filter).get_column("count").sum()

        accuracy = (tp + tn) / (tp + tn + fp + fn)
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        specificity = tn / (tn + fp)
        f1 = 2 * (precision * recall) / (precision + recall)
        iou = tp / (tp + fp + fn)

        metrics.append({
            "algorithm": alg.upper(),
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "specificity": specificity,
            "f1": f1,
            "iou": iou
        })

    return pl.DataFrame(metrics)


total_metrics = get_metrics()
in_star_metrics = get_metrics(in_star=True)
out_star_metrics = get_metrics(in_star=False)

In [78]:
cr_counts.write_csv(OUT_DIR / "auto" / "cr_counts.csv")
times.write_csv(OUT_DIR / "auto" / "times.csv")
times_summary.write_csv(OUT_DIR / "auto" / "times_summary.csv")
total_metrics.write_csv(OUT_DIR / "auto" / "total_metrics.csv")
in_star_metrics.write_csv(OUT_DIR / "auto" / "in_star_metrics.csv")
out_star_metrics.write_csv(OUT_DIR / "auto" / "out_star_metrics.csv")

## Reading for further analysis
Use the CSV files from the previous section.

In [79]:
cr_counts = pl.read_csv(OUT_DIR / "auto" / "cr_counts.csv")
times_summary = pl.read_csv(OUT_DIR / "auto" / "times_summary.csv")
total_metrics = pl.read_csv(OUT_DIR / "auto" / "total_metrics.csv")
in_star_metrics = pl.read_csv(OUT_DIR / "auto" / "in_star_metrics.csv")
out_star_metrics = pl.read_csv(OUT_DIR / "auto" / "out_star_metrics.csv")