In [1]:
from models.Sentinel_Models.Sentinel_Retina_Net import Sentinel
import torch
import torchvision
import os


# Save Model

In [2]:

model_path = "/data/Dataset_Compilation_and_Statistics/Sentinel_Datasets/Best_models/RME04_E73.pt"
ts_name = os.path.basename(model_path).replace(".pt", ".torchscript")
save_path = os.path.dirname(model_path)

model = Sentinel()
model.eval()
model.load_original_model(model_path)
TS_Model = torch.jit.script(model)
path = os.path.join(save_path,ts_name )
TS_Model.save(path)
print("Saved Model")

model = torch.jit.load(str(path))  # cast to string if pathlib.Path
model.eval()  # Ensure it's in inference mode


Loading Model: /data/Dataset_Compilation_and_Statistics/Sentinel_Datasets/Best_models/RME04_E73.pt
Saved Model


RecursiveScriptModule(
  original_name=Sentinel
  (retina_net): RecursiveScriptModule(
    original_name=RetinaNet
    (backbone): RecursiveScriptModule(
      original_name=BackboneWithFPN
      (body): RecursiveScriptModule(
        original_name=IntermediateLayerGetter
        (conv1): RecursiveScriptModule(original_name=Conv2d)
        (bn1): RecursiveScriptModule(original_name=FrozenBatchNorm2d)
        (relu): RecursiveScriptModule(original_name=ReLU)
        (maxpool): RecursiveScriptModule(original_name=MaxPool2d)
        (layer1): RecursiveScriptModule(
          original_name=Sequential
          (0): RecursiveScriptModule(
            original_name=Bottleneck
            (conv1): RecursiveScriptModule(original_name=Conv2d)
            (bn1): RecursiveScriptModule(original_name=FrozenBatchNorm2d)
            (conv2): RecursiveScriptModule(original_name=Conv2d)
            (bn2): RecursiveScriptModule(original_name=FrozenBatchNorm2d)
            (conv3): RecursiveScriptModule(

# Test Model

In [7]:

from typing import Any, cast
import pydantic
import torch
import base64
import io
import numpy as np
from fastapi.responses import StreamingResponse
from astropy.io import fits
from importlib import resources
import json

import torchvision
from numpy import typing as npt
import numpy as np
import pathlib

import numpy as np
from astropy.visualization import ZScaleInterval
import numpy as np
from numpy import typing as npt
import cv2

def inference( data: list) -> list:
    path = os.path.join("/data/Sentinel_Datasets/Finalized_datasets/LMNT01Sat_Training_Channel_Mixture_C/models/LMNT01_MixtureC/","LMNT01-249-TS.torchscript" )
    model = torch.jit.load(str(path))  # cast to string if pathlib.Path
    model.eval()  # Ensure it's in inference mode

    batch_detections: list[dict[str, list[dict[str, Any]]]] = [
        {"detections": []} for _ in data
    ]
    sidereal_detections = 0
    images = []
    rate_indices = []
    x_resolutions = []
    y_resolutions = []

    for i, file in enumerate(data):
        decoded = base64.b64decode(file["file"])
        tempfits = fits.open(io.BytesIO(decoded))
        fitfile = tempfits[0]
        header = fitfile.header
        img_data = fitfile.data
        y_resolutions.append(img_data.shape[0])
        x_resolutions.append(img_data.shape[1])
        if header["TRKMODE"] == "sidereal":
            sidereal_detections += 1
            continue

        arr_float = channel_mixture_C(img_data)  # Expects [C, H, W] float32
        images.append(arr_float)
        rate_indices.append(i)

    if not images:
        return pydantic.TypeAdapter(list[entities.ObjectDetections]).validate_python(batch_detections)

    images_np = np.stack(images, axis=0)  # Shape: [B, C, H, W]
    batch = torch.from_numpy(images_np)

    with torch.no_grad():
        outputs = model(batch)  # Should be shape: [B, 5, N]

    for k, (orig_i, preds) in enumerate(zip(rate_indices, outputs)):
        detections = []
        preds = preds.permute(1, 0)  # Now shape: [N, 5]
        H, W = images_np[k].shape[1:]  # assume [C, H, W]

        for det in preds:
            x_c, y_c, w, h, conf = det.tolist()
            if conf < 0.5:  # Confidence threshold (adjustable)
                continue

            # Convert normalized center/size to pixel coordinates
            x_c *= W
            y_c *= H
            w *= W
            h *= H

            xmin = x_c - w / 2
            xmax = x_c + w / 2
            ymin = y_c - h / 2
            ymax = y_c + h / 2

            # Clamp to image bounds
            xmin = max(0, xmin)
            xmax = min(W - 1, xmax)
            ymin = max(0, ymin)
            ymax = min(H - 1, ymax)

            signal = images_np[k, 0, int(y_c), int(x_c)]
            noise = np.std(images_np[k, 0])

            detection = {
                "class_id": 0,  # Only one class
                "pixel_centroid": [float(x_c)/x_resolutions[k], float(y_c)/y_resolutions[k]],
                "prob": float(conf),
                "snr": float(signal / noise) if noise > 0 else 0,
                "x_max": float(xmax)/x_resolutions[k],
                "x_min": float(xmin)/x_resolutions[k],
                "y_max": float(ymax)/y_resolutions[k],
                "y_min": float(ymin)/y_resolutions[k],
            }
            detections.append(detection)

        batch_detections[orig_i]["detections"] = detections

    return batch_detections



def _iqr_clip(x, threshold=5.0):
    """
    IQR-Clip normalization: Robust contrast normalization with hard clipping.
    
    Args:
        x (np.ndarray): Grayscale image, shape (H, W)
    
    Returns:
        np.ndarray: Normalized and clipped image, same shape, dtype float32
    """
    x = x.astype(np.float32)
    q1 = np.percentile(x, 25)
    q2 = np.percentile(x, 50)
    q3 = np.percentile(x, 75)
    iqr = q3 - q1

    # Normalize relative to the median (q2)
    x_norm = (x - q2) / (iqr + 1e-8)

    # Clip values beyond ±5 IQR
    x_clipped = np.clip(x_norm, -threshold, threshold)

    return x_clipped

def _iqr_log(x, threshold=5.0):
    """
    IQR-Log normalization: IQR-based normalization followed by log compression of outliers.
    
    Args:
        x (np.ndarray): Grayscale image, shape (H, W)
    
    Returns:
        np.ndarray: Soft-clipped image using log transform for values > ±5 IQR
    """
    x = x.astype(np.float32)
    q1 = np.percentile(x, 25)
    q2 = np.percentile(x, 50)
    q3 = np.percentile(x, 75)
    iqr = q3 - q1

    # Normalize relative to the median (q2)
    x_soft = (x - q2) / (iqr + 1e-8)

    # Apply log transformation to soft-clip tails
    threshold = 5.0

    # Positive tail
    over = x_soft > threshold
    x_soft[over] = threshold + np.log1p(x_soft[over] - threshold)

    # Negative tail
    under = x_soft < -threshold
    x_soft[under] = -threshold - np.log1p(-x_soft[under] - threshold)

    return x_soft

def _adaptive_iqr(fits_image:np.ndarray, bkg_subtract:bool=True, verbose:bool=False) -> np.ndarray:
    '''
    Performs Log1P contrast enhancement. Searches for the highest contrast image and enhances stars.
    Optionally can perform background subtraction as well

    Notes: Current configuration of the PSF model works for a scale of 4-5 arcmin
           sized image. Will make this more adjustable with calculations if needed.

    Input: The stacked frames to be processed for astrometric localization. Works
           best when background has already been corrected.

    Output: A numpy array of shape (2, N) where N is the number of stars extracted. 
    '''  

    if verbose:
        print("| Percentile | Contrast |")
        print("|------------|----------|")
    best_contrast_score = 0
    best_percentile = 0
    best_image = None
    percentiles=[]
    contrasts=[]

    for i in range(20):
        #Scans image to find optimal subtraction of median
        percentile = 90+0.5*i
        temp_image = fits_image-np.quantile(fits_image, (percentile)/100)
        temp_image[temp_image < 0] = 0
        scaled_data = np.log1p(temp_image)
        #Metric to optimize, currently it is prominence
        contrast = (np.max(scaled_data)+np.mean(scaled_data))/2-np.median(scaled_data)
        percentiles.append(percentile)
        contrasts.append(contrast)

        if contrast > best_contrast_score*1.05:
            best_contrast_multiplier = i
            best_image = scaled_data.copy()
            best_contrast_score = contrast
            best_percentile = percentile
        if verbose: print("|    {:.2f}   |   {:.2f}   |".format(percentile,contrast))
    if verbose: print("Best percentile): {}".format(best_percentile))
    if best_image is None:
        return fits_image
    return best_image

def _zscale(image:np.ndarray, contrast:float=.5) -> np.ndarray:
    scalar = ZScaleInterval(contrast=contrast)
    return scalar(image)

def _minmax_scale(arr:np.ndarray) -> np.ndarray:
    """Scales a 2D NumPy array to the range [0, 1] using min-max normalization."""
    arr_min = arr.min()
    arr_max = arr.max()
    if arr_max == arr_min:
        return np.zeros_like(arr, dtype=float)  # Avoid division by zero
    return (arr - arr_min) / (arr_max - arr_min)

def _median_row_subtraction(img):
    """
    Subtracts the median from each row and adds back the global median.

    Args:
        img (np.ndarray): Input image of shape (H, W).

    Returns:
        np.ndarray: Processed image of shape (H, W), dtype float32.
    """
    img = img.astype(np.float32)
    global_median = np.median(img)
    row_medians = np.median(img, axis=1, keepdims=True)
    result = img - row_medians + global_median
    return result

def _median_column_subtraction(img):
    """
    Subtracts the median from each column and adds back the global median.

    Args:
        img (np.ndarray): Input image of shape (H, W).

    Returns:
        np.ndarray: Processed image of shape (H, W), dtype float32.
    """
    img = img.astype(np.float32)
    global_median = np.median(img)
    col_medians = np.median(img, axis=0, keepdims=True)
    result = img - col_medians + global_median
    return result

def adaptiveIQR(data:np.ndarray) -> np.ndarray:
    contrast_enhance = _adaptive_iqr(data)
    contrast_enhance = (_minmax_scale(contrast_enhance)*255).astype(np.uint8)

    return np.stack([contrast_enhance, contrast_enhance, contrast_enhance], axis=0)

def zscale(data:np.ndarray) -> np.ndarray:
    zscaled = _zscale(data)
    zscaled = (zscaled * 255).astype(np.uint8)

    return np.stack([zscaled, zscaled, zscaled], axis=0)

def iqr_clipped(data, threshold=5) -> np.ndarray:
    data = _iqr_clip(data, threshold)
    data = (_minmax_scale(data)*255).astype(np.uint8)
    return np.stack([data]*3, axis=0)

def iqr_log(data, threshold=5) -> np.ndarray:
    data = _iqr_log(data, threshold)
    data = (_minmax_scale(data)*255).astype(np.uint8)
    return np.stack([data]*3, axis=0)

def channel_mixture_A(data:np.ndarray) -> np.ndarray:
    zscaled = _zscale(data)
    zscaled = (zscaled * 255).astype(np.uint8)
    contrast_enhance = _iqr_clip(data)
    contrast_enhance = (_minmax_scale(contrast_enhance)*255).astype(np.uint8)

    data = (data / 255).astype(np.uint8)
    return np.stack([data, contrast_enhance, zscaled], axis=0)

def channel_mixture_B(data:np.ndarray) -> np.ndarray:
    zscaled = _zscale(data)
    zscaled = (zscaled * 255).astype(np.uint8)
    contrast_enhance = _adaptive_iqr(data)
    contrast_enhance = (_minmax_scale(contrast_enhance)*255).astype(np.uint8)

    data = (data / 255).astype(np.uint8)
    return np.stack([data, contrast_enhance, zscaled], axis=0)

def channel_mixture_C(data:np.ndarray) -> np.ndarray:
    zscaled = _zscale(data)
    zscaled = (zscaled).astype(np.float32)
    contrast_enhance = _iqr_log(data)
    contrast_enhance = (_minmax_scale(contrast_enhance)).astype(np.float32)

    data = (data).astype(np.float32)/65535
    return np.stack([data, contrast_enhance, zscaled], axis=0)

def raw_file(data: np.ndarray) -> np.ndarray:
    return  np.stack([data/65535]*3, axis=0)

def preprocess_image( image: npt.NDArray) -> npt.NDArray:
    # Apply zscale to the image data for contrast enhancement
    zscale = ZScaleInterval()
    vmin, vmax = zscale.get_limits(image)

    # Apply Z-scale normalization (clipping values between vmin and vmax)
    #image = np.clip(image, vmin, vmax)
    #image = (image - vmin) / (vmax - vmin) * 255  # Scale to 0-255 range
    # Convert the image data to an unsigned 8-bit integer (for saving as PNG)
    
    image = image.astype(np.float32)/65535.0

    height, width = image.shape
    new_height = (
        (height // 32) * 32 if height % 32 == 0 else ((height // 32) + 1) * 32
    )
    new_width = (width // 32) * 32 if width % 32 == 0 else ((width // 32) + 1) * 32
    #resized_image = cv2.resize(image, (new_width, new_height))
    resized_image = cv2.resize(image, (512, 512))
    image = np.stack([resized_image] * 3, axis=0)
    
    return image

# Path to your JSON file
file_path = "detect-with-metadata-fc46157d-13b9-461b-ac71-7a528f419e43.json"

# Load JSON data
with open(file_path, "r") as f:
    data = json.load(f)

inference(data)





RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/models/Sentinel_Models/Sentinel_Retina_Net.py", line 18, in forward
    raw_outputs0 = unchecked_cast(List[Dict[str, Tensor]], raw_outputs)
    output_formatter = self.output_formatter
    tranformed_outputs = (output_formatter).forward(raw_outputs0, )
                          ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    nms = self.nms
    return (nms).forward(tranformed_outputs, )
  File "code/__torch__/models/Subcomponents/Post_processing_adapters.py", line 38, in forward
        _24 = torch.view(batch_num["scores"], [1, 0])
        _25 = torch.select(storage_vector, 0, 4)
        _26 = torch.copy_(torch.slice(_25), _24)
              ~~~~~~~~~~~ <--- HERE
      else:
        pass

Traceback of TorchScript, original code (most recent call last):
  File "/home/davidchaparro/Repos/ModelTrainingExperiments/model_training/models/Sentinel_Models/Sentinel_Retina_Net.py", line 31, in forward
        if not isinstance(raw_outputs, list):
            raw_outputs = [raw_outputs]
        tranformed_outputs: torch.Tensor = self.output_formatter.forward(raw_outputs)
                                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        outputs: torch.Tensor = self.nms.forward(tranformed_outputs)
    
  File "/home/davidchaparro/Repos/ModelTrainingExperiments/model_training/models/Subcomponents/Post_processing_adapters.py", line 19, in forward
            storage_vector[3,:] = (batch_num["boxes"][:,3]-batch_num["boxes"][:,1])
            if batch_num["scores"].numel() == 0:
                storage_vector[4,:] = batch_num["scores"].view(1,0)
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            if batch_num["scores"].ndim == 1:
                storage_vector[4,:] = batch_num["scores"].unsqueeze(1)
RuntimeError: output with shape [0] doesn't match the broadcast shape [1, 0]
