> Uses the predictions from the detection model to get the segm predictions. tested against following packages.

```python
qct_data: 0.2.10
qct_utils: 2.0.3
```

In [None]:
import pyrootutils

root = pyrootutils.setup_root(
    search_from="",
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)

import json
import os
from typing import Any, Optional

import SimpleITK as sitk
import torch
from qct_data.ct_loader import CTAnnotLoader
from qct_utils.ct_schema import ITKDIM, BigCtscan, Ctscan
from qct_utils.ctscan_dataloaders.utils import read_annotation_csvs
from tqdm import tqdm

from src.apps.segmentation_infer import NoduleSegmInfer
from src.metrics.segmentation.tp_metrics import TPMetrics

# Test against BigCT output

In [None]:
ckpt_path = "segm_qnet_8820.ckpt"
gt_annot_csv = [
    "/home/users/souvik.mandal/projects/qct/qct_data_updates/data/studies/FDA/WCG/wcg_s1_s2_gt.csv"
]
pred_bigct_json_root = (
    "/home/users/souvik.mandal/projects/qct/qct_meta_training_framework/data/bigct_tmp"
)
scans_root = (
    "/home/users/souvik.mandal/projects/qct/qct_data_updates/data/studies/FDA/WCG/data_tmp"
)
det_thr = 0.7  # there are too many nodules without threshold, Keep it `None` for no theresholding
new_bigct_save_root = (
    "/home/users/souvik.mandal/projects/qct/qct_meta_training_framework/data/wcg/model8820_post"
)

In [None]:
gt_df = read_annotation_csvs(gt_annot_csv)
len(
    gt_df.scan_name.unique()
)  # we will only use these many sids since remaining scans are FP predictions

In [None]:
considered_series_ids = gt_df.scan_name.unique().tolist()

## Update the segmentation masks

1. If we have the new model integrated with prod we dont need to run the below code.
2. Following section will get the bbox from the bigctjson and infer with the new model and save the new results to another folder.
3. If u have a large number of datapoints copy the below code and run in a tmux script.

In [None]:
device = torch.device("cuda:0")
segmentation_infer = NoduleSegmInfer(ckpt_path=ckpt_path, device=device)

In [None]:
def load_json(path: str):
    with open(path, "r") as file:
        data = json.load(file)
    return data


def dump_json(data: Any, save_path: str):
    with open(save_path, "w") as not_file:
        json.dump(data, not_file)

In [None]:
def load_bigct(
    big_ct_root: str,
    sid: str,
    scans_root: str,
    det_thr: Optional[float] = None,
    load_scan: bool = True,
):
    """
    Load bigct from json and if the scan is missing add the scan
    """
    big_ct_path = os.path.join(big_ct_root, f"{sid}_bigct.json")
    big_ct = load_json(big_ct_path)
    pred_ct = BigCtscan(**big_ct).Pred
    if pred_ct.Scan is None and load_scan:
        scan = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(scans_root, f"{sid}.nii.gz")))
        pred_ct.Scan = scan
    if det_thr is not None:
        pred_ct.Annot = [annot for annot in pred_ct.Annot if annot.annot.conf > det_thr]
    return pred_ct

In [None]:
def export_bigct(ctscan: Ctscan, save_root: str):
    """
    Save a ctscan with new segmentation output to bigct output format
    """
    sid = ctscan.SeriesInstanceUID
    save_path = os.path.join(save_root, f"{sid}_bigct.json")
    big_ct_dict = BigCtscan(Pred=ctscan).dict()
    dump_json(big_ct_dict, save_path)

In [None]:
# uncomment if we need to update the bigct results from prod

# for sid in tqdm(considered_series_ids[:5]):
#     pred_ct = load_bigct(big_ct_root=pred_bigct_json_root, sid=sid, scans_root = scans_root, det_thr=det_thr)
#     new_pred_ct = segmentation_infer.predict_ctscan(pred_ct, crop_margin=ITKDIM(z=15, x=80, y=80), roi_margin=ITKDIM(z=5, y=20, x=20), conf_thr=0.5, volume_thr=0)
#     export_bigct(new_pred_ct, new_bigct_save_root)

## Model output performance
1. Will use the `new_bigct_save_root` to get the pred bigct jsons and `gt_annot_csv` to get the ground truth annotations.
2. Only computes the performance on the TP nodules.

In [None]:
gt_annot_loader = CTAnnotLoader(scans_root=scans_root, csv_loc=gt_annot_csv)

In [None]:
metric = TPMetrics(match_annots=True)

for sid in tqdm(considered_series_ids[:10]):
    gt_ctscan = gt_annot_loader[sid]
    pred_ctscan = load_bigct(
        big_ct_root=new_bigct_save_root, sid=sid, scans_root="", det_thr=None, load_scan=False
    )
    metric.update(pred_ctscans=[pred_ctscan], gt_ctscans=[gt_ctscan])

In [None]:
metric.compute()

### Inference on single series and visualize

In [None]:
# Inference on single ctscan
from copy import deepcopy

import matplotlib.pyplot as plt
from qct_utils.ct_vis.scan_vis import ctscan_to_df_pretty, vis_ctscan_annots

In [None]:
series_id = "1.3.6.1.4.1.55648.0105750886814768212503752621247817.4"
pred_ctscan = load_bigct(pred_bigct_json_root, series_id, scans_root, det_thr=0.7, load_scan=True)
gt_ctscan = gt_annot_loader[series_id]

In [None]:
new_pred_ct = segmentation_infer.predict_ctscan(
    deepcopy(gt_ctscan),
    crop_margin=ITKDIM(z=15, x=80, y=80),
    roi_margin=ITKDIM(z=5, y=20, x=20),
    conf_thr=0.5,
    volume_thr=0,
)

In [None]:
tdf = ctscan_to_df_pretty(gt_ctscan)
tdf

In [None]:
new_pred_ct.Annot[0].annot.mask.sum()

In [None]:
annotated_new_pred_scan = vis_ctscan_annots(new_pred_ct)
annotated_old_pred_scan = vis_ctscan_annots(pred_ctscan)
annotated_gt_scan = vis_ctscan_annots(gt_ctscan)

In [None]:
index = 1
y_start = int(tdf.iloc[index]["y_center"]) - 50
x_start = int(tdf.iloc[index]["x_center"]) - 50
for index in range(
    int(tdf.iloc[index]["z_center"] - 3 - tdf.iloc[index]["d"] / 2),
    int(tdf.iloc[index]["z_center"] + 3 + tdf.iloc[index]["d"] / 2),
):
    plt.figure(index)
    plt.subplot(1, 3, 1)
    plt.imshow(
        annotated_new_pred_scan["scan"][index, y_start : y_start + 100, x_start : x_start + 100]
    )
    plt.subplot(1, 3, 2)
    plt.imshow(
        annotated_old_pred_scan["scan"][index, y_start : y_start + 100, x_start : x_start + 100]
    )
    plt.subplot(1, 3, 3)
    plt.imshow(annotated_gt_scan["scan"][index, y_start : y_start + 100, x_start : x_start + 100])