In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import cv2
from collections import defaultdict
import csv
import os

%run utils/experiment_chooser.ipynb
experiment_chooser.choose()

In [None]:
def track(labels_arr, tracking_arr):
    lines = {k:[] for k in np.unique(tracking_arr[0]) if k > 0}
    ctr = 0
    for labels, tracking in zip(labels_arr, tracking_arr):
        tp = {}
        for _idx in np.argwhere(tracking != 0):
            idx = tuple(_idx)
            id = tracking[idx]
            if id in tp: # skip if we've already recorded this id in this timepoint
                continue
            label = labels[idx]
            tp[id] = label

        for id, label in tp.items():
            if id not in lines or len(lines[id]) != ctr:
                continue
            lines[id].append(label)

        ctr += 1
    return lines

def get_wellspec(label, wells):
    for wellspec in wells:
        if wellspec.label == label:
            return wellspec
    return None
        
def format_row(well, line, line_id, tps):
    censored = 0
    # only look at live cells
    if line[0] != 1:
        return None

    # censor zombies
    if 2 in line:
        idx = line.index(2)
        if 1 in line[idx:]:
            censored = 1

    # observed death
    dead_at = -1
    for idx, label in enumerate(line):
        if label == 2:
            dead_at = idx
            break

    # unobserved death (segmentation fails)
    if len(line) < tps and dead_at == -1:
        censored = True

    death_cause = 'death' if dead_at != -1 else 'NA'
    event = bool(censored)
    last_tp = tps if dead_at == -1 else dead_at + 1 # 1-indexed
    last_time = last_tp

    group = "-".join(drug.drug_label for drug in well.drugs)
    if group == "":
        group = "NA"
    
    row = {
        "well": well.label, # well...
        "id": line_id,
        "well-id": f"{well.label}-{line_id}",
        "group": group,
        "cell_type": "NA",
        "drug": "NA",
        "drug_conc": "NA",
        "column": "NA",
        "last_tp": last_tp,
        "last_time": last_time,
        "death_cause": death_cause,
        "censored": censored,
        "event": event
    }
    
    return row

In [None]:
fieldnames = ['well','id','well-id','group','cell_type','drug','drug_conc','column','last_tp','last_time','death_cause','censored','event']
experiment = experiment_chooser.fetch()
experiment_base = experiment_chooser.fetch_base()
output_dir = experiment_base / "results" 
os.makedirs(output_dir, exist_ok=True)

label_paths = list((experiment_base / "processed_imgs" / "object_predictions").glob("*.h5"))
track_paths = list((experiment_base / "processed_imgs" / "tracking").glob("*.h5"))

with open(output_dir / "survival_data.csv", 'w', newline='') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    
    label_track_paired = []
    for label_e in label_paths:
        for track_e in track_paths:
            label_vertex = label_e.name
            track_vertex = track_e.name
            if label_vertex == track_vertex:
                label_track_paired.append((label_e, track_e))
    
    for label_path, track_path in label_track_paired:
        labels = h5py.File(label_path)["exported_data"]
        tracks = h5py.File(track_path)["exported_data"]
        tracking_lines = track(labels, tracks)
        well_label = label_path.name.split("-")[0]
        well = get_wellspec(well_label, experiment.mfile.wells)
        tps = max(map(len, tracking_lines.values()))
        rows = [format_row(well, line, idx, tps) for idx, line in enumerate(tracking_lines.values())]
        rows = [row for row in rows if row is not None]
        writer.writerows(rows)