# Notebook to generate raster plots for predictions, ground truth and mismatches between them aligned with trophallaxis-likelihoods by the model

Plots generate for all videos in:
- test set
- val set




## Setup

- run1 called jumbo_graceful_mongoose on Wandb: https://wandb.ai/sposiboh/feral_public/runs/ribfpl6y
- run2 called huge_athletic_cassowary on Wandb: https://wandb.ai/sposiboh/feral_public/runs/ksmed0o3

In [13]:
import json
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import ListedColormap
import numpy as np
import os

run1 = {
    "name": "run1",
    "chunk_length": 64,
    "chunk_shift": 32
}
run2 = {
    "name": "run2",
    "chunk_length": 64,
    "chunk_shift": 16
}

RUN = run2

# whether to plot the troph-likelihood for all chunks (True) or just the final likelihood (False)
PLOT_CHUNKS = True

FOLDER = RUN["name"]
CHUNK_LENGTH = RUN["chunk_length"]
CHUNK_SHIFT = RUN["chunk_shift"]
# how often each frame can be processed by FERAL at max
FRAME_FREQ = int(CHUNK_LENGTH / CHUNK_SHIFT)

TEST_FRAME_LABELS_FERAL_JSON = f"{FOLDER}/ensembled_test.json"
TEST_PROBS_JSON = f"{FOLDER}/raw_test.json"
VAL_FRAME_LABELS_BORIS_JSON = f"{FOLDER}/feral_behavioral_labels.json"
VAL_PROBS_JSON = f"{FOLDER}/best_cp_validation_set.json"

video_dict = {}

## Parse data from validation and test set

In [14]:
def parse_test_frame_labels(json_path: str, video_dict: dict):
    """Parses the predicted labels by FERAL and puts them in the video dictionary
       for the corresponing video. Fills trophallaxis probability values with NaN

    Args:
        json_path (str): Path to file
        video_dict (dict): Dictionary which contains all the data by video
    """
    with open(json_path, "r") as f:
        preds_json = json.load(f)
    for filename, predictions in preds_json["pred"].items():
        video_dict[filename] = {
            "predictions": predictions,
            "n_frames": len(predictions),
            "troph_prob": np.full(shape=(len(predictions), FRAME_FREQ), fill_value=np.nan),
            "type": "test",
        }

    for filename, gt in preds_json["gt"].items():
        video_dict[filename]["gt"] = gt


def parse_val_frame_labels(json_path: str, video_dict: dict):
    """Parses the ground truth labels exported from BORIS and puts them in the video dictionary
       for the corresponing video. Fills trophallaxis probability values with NaN

    Args:
        json_path (str): Path to file
        video_dict (dict): Dictionary which contains all the data by video
    """
    with open(json_path, "r") as f:
        preds_json = json.load(f)
    val_videos = preds_json["splits"]["val"]
    for video in val_videos:
        gt = preds_json["labels"][video]
        video_dict[video] = {
            "n_frames": len(gt),
            "troph_prob": np.full(shape=(len(gt), FRAME_FREQ), fill_value=np.nan),
            "gt": gt,
            "type": "val"
        }


def parse_probs(json_path: str, video_dict: dict):
    """Parses the trophallaxis probability values and puts them in the video dictionary
       for the corresponing video

    Args:
        json_path (str): Path to file
        video_dict (dict): Dictionary which contains all the data by video
    """

    with open(json_path, "r") as f:
        answers_json = json.load(f)

    ind = 0
    for filename, outputs, targets in answers_json:
        name_split = filename.split("_globalind_")
        frame_str = name_split[1]
        frame_split = frame_str.split("_chunkind_")

        filename = name_split[0]
        frame_index = int(frame_split[0])
        chunk_index = int(frame_split[1])

        # new video starts in input file
        if frame_index == 0:
            ind = 0
        # new chunk starts in input file
        elif chunk_index == 0:
            ind += 1
            if ind >= FRAME_FREQ:
                ind = 0

        video_dict[filename]["troph_prob"][frame_index][ind] = outputs[1]


parse_test_frame_labels(TEST_FRAME_LABELS_FERAL_JSON, video_dict)
parse_probs(TEST_PROBS_JSON, video_dict)
parse_val_frame_labels(VAL_FRAME_LABELS_BORIS_JSON, video_dict)
parse_probs(VAL_PROBS_JSON, video_dict)

## Create plots for all videos

confs_troph contains confidence values, which FERAL predicted and gets filled with FRAME_FREQ subarrays:
- confs_troph[0] contains frame-wise conf values in chunks starting from 
    - 0,
    - CHUNK_LENGTH, 
    - 2 * CHUNK_LENGTH, 
    - ...
- confs_troph[1] contains frame-wise conf values in chunks starting from 
    - CHUNK_SHIFT, 
    - CHUNK_LENGTH + CHUNK_SHIFT,
    - 2 * CHUNK_LENGTH + CHUNK_SHIFT, 
    - ...
- confs_troph[2] contains frame-wise conf values in chunks starting from
    - 2 * CHUNK_SHIFT, 
    - CHUNK_LENGTH + 2 * CHUNK_SHIFT, 
    - 2 * CHUNK_LENGTH + 2 * CHUNK_SHIFT, 
    - ...
- ...

In [15]:
def create_plots_for_video(
    confs_troph: np.ndarray, 
    confs_troph_final: np.ndarray, 
    filename: str, 
    video_dict: dict
):
    """Saves the full plot with raster plots for predictions, ground truth and mismatches between them
       aligned with trophallaxis-likelihood values predicted by the model

    Args:
        confs_troph (np.ndarray): Frame-wise troph-likelihood values predicted by FERAL for the chunks
        confs_troph_final (np.ndarray): Final frame-wise troph-likelihood values predicted by FERAL by averaging chunk values for each frame
        filename (str): Video filename
        video_dict (dict): Dictionary which contains all the data by video
    """
    fig, axs = plt.subplots(4, 1, figsize=(16, 8), sharex=True, gridspec_kw={'height_ratios': [1, 1, 1, 5]})
    fig.text(0.5, 0.04, "Frame", ha='center', fontsize=12)
    fig.text(0.5, 0.96, f"Video: {filename}", ha='center', fontsize=14)
    arr_pred = np.array(video_dict[filename]["predictions"]).reshape(1, -1)
    arr_gt = np.array(video_dict[filename]["gt"]).reshape(1, -1)
    arr_diff = (arr_pred != arr_gt).astype(int)

    # Prepare colormaps and labels
    base_cmap = cm.get_cmap('nipy_spectral')

    class_names = ["other", "troph"]

    # Skip dark colors near 0.0 â€” start sampling from 0.05 or 0.1
    color_range = np.linspace(0.1, 1.0, len(class_names))
    colors = [base_cmap(val) for val in color_range]

    if 'other' in class_names:
        ind = list(class_names).index('other')
        colors[ind], colors[-1] = colors[-1], colors[ind]

    cmap = ListedColormap(colors)
    diff_cmap = ListedColormap(['white', 'red'])
    labels = ['Prediction', 'Ground truth', 'Mismatch']
    
    for i, arr in enumerate([arr_pred, arr_gt, arr_diff]):
        cmap_used = cmap if i < 2 else diff_cmap
        axs[i].imshow(arr, aspect='auto', cmap=cmap_used, interpolation='nearest', vmin=0, vmax=1)
        axs[i].set_yticks([0])
        axs[i].set_yticklabels([labels[i]], fontsize=12, rotation=0, va='center')
        axs[i].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
        axs[i].tick_params(axis='y', which='both', left=False)

    x = np.arange(n_frames)
    axs[3].plot(x, confs_troph_final, label="Combined troph. likelihood", linewidth = 1.5, markersize=2, zorder=4, color='magenta')
    axs[3].plot(x, np.full(shape=(n_frames, 1), fill_value=0.5), linewidth=1, color='orange')
    axs[3].set_ylim(0, 1)
    axs[3].set_yticks(np.arange(0, 1.1, 0.1))
    axs[3].set_xticks(np.arange(0, n_frames+1, 32))
    axs[3].set_ylabel("Trophallaxis-likelihood", fontsize=12)
    axs[3].grid()

    # include chunk likelihoods in plot
    if PLOT_CHUNKS and FRAME_FREQ <= 4:
        colors_graphs = ["green", "mediumblue", "cyan", "darkred"]
        plot_labels = []
        for start in range(0, CHUNK_LENGTH, CHUNK_SHIFT):
            plot_labels.append(f"Chunks starting at {start}, {start+CHUNK_LENGTH}, ...")
        
        for i, conf_arr in enumerate(confs_troph):
            for start in range(0, len(conf_arr), CHUNK_LENGTH):
                end = start + CHUNK_LENGTH
                x = np.arange(start + i * CHUNK_SHIFT, end + i * CHUNK_SHIFT)

                # only show one label per conf_arr
                if start < CHUNK_LENGTH:
                    label = plot_labels[i]
                else:
                    label = '_nolegend_'

                # first and last chunk plot line should show in front of final troph. likelihood plot line
                if start + i * CHUNK_SHIFT == 0 or end + i * CHUNK_SHIFT >= len(conf_arr):
                    zorder = 5
                else:
                    zorder = 1
                axs[3].plot(x, conf_arr[start:end], label=label, linewidth=0.8, markersize=1, zorder=zorder, color=colors_graphs[i])

        if FRAME_FREQ == 2:
            for i, label in enumerate(axs[3].get_xticklabels()):
                label.set_color(colors_graphs[i % 2])

    axs[3].legend()

    fig.canvas.draw()

    chunk_str = "with_chunks" if PLOT_CHUNKS else "without_chunks"
    fig_path = f"{FOLDER}/plots/{video_dict[filename]['type']}_set/{chunk_str}/likelihood_analysis_{filename}.png"
    os.makedirs(os.path.dirname(fig_path), exist_ok=True)
    plt.savefig(fig_path)
    #plt.show()
    plt.close(fig)


for filename in video_dict.keys():
    n_frames = video_dict[filename]["n_frames"]

    confs_troph = np.array([None] * FRAME_FREQ)
    for i in range(FRAME_FREQ): 
        confs_troph[i] = np.array([video_dict[filename]["troph_prob"][j][i] for j in range(n_frames)])

    # calculate the final conf value
    confs_troph_final = np.nanmean(np.stack([confs_troph[i] for i in range(FRAME_FREQ)]), axis=0)

    # replace trailing NaNs at the end with last non-NaN value
    if np.isnan(confs_troph_final[-1]):
        valid_idx = np.where(~np.isnan(confs_troph_final))[0]
        if valid_idx.size > 0:
            last_valid_value = confs_troph_final[valid_idx[-1]]
            confs_troph_final[valid_idx[-1] + 1:] = last_valid_value

    # clear nan values
    for i in range(FRAME_FREQ):
       confs_troph[i] = confs_troph[i][~np.isnan(confs_troph[i])]

    # add predicted labels by FERAL in case they don't exist in the dict (for the validation videos)
    if "predictions" not in video_dict[filename]:
        preds = confs_troph_final.round()
        video_dict[filename]["predictions"] = preds

    create_plots_for_video(
        confs_troph, 
        confs_troph_final, 
        filename, 
        video_dict
    )
    

  base_cmap = cm.get_cmap('nipy_spectral')
  confs_troph_final = np.nanmean(np.stack([confs_troph[i] for i in range(FRAME_FREQ)]), axis=0)
  base_cmap = cm.get_cmap('nipy_spectral')
  confs_troph_final = np.nanmean(np.stack([confs_troph[i] for i in range(FRAME_FREQ)]), axis=0)
  base_cmap = cm.get_cmap('nipy_spectral')
  confs_troph_final = np.nanmean(np.stack([confs_troph[i] for i in range(FRAME_FREQ)]), axis=0)
  base_cmap = cm.get_cmap('nipy_spectral')
  confs_troph_final = np.nanmean(np.stack([confs_troph[i] for i in range(FRAME_FREQ)]), axis=0)
  base_cmap = cm.get_cmap('nipy_spectral')
  confs_troph_final = np.nanmean(np.stack([confs_troph[i] for i in range(FRAME_FREQ)]), axis=0)
  base_cmap = cm.get_cmap('nipy_spectral')
  confs_troph_final = np.nanmean(np.stack([confs_troph[i] for i in range(FRAME_FREQ)]), axis=0)
  base_cmap = cm.get_cmap('nipy_spectral')
  confs_troph_final = np.nanmean(np.stack([confs_troph[i] for i in range(FRAME_FREQ)]), axis=0)
  base_cmap = cm.get_cmap('