# This notebook computes statistics for the videos in the Konstanz dataset based on the results with FERAL

## Add labels to results based on prediction values

Also invert structure at first level (results[preds][video] --> results[video][preds])

In [69]:
import json
from copy import deepcopy

def add_labels_to_results(
    input_path,
    output_path,
):
    with open(input_path, "r") as f:
        data = json.load(f)

    output_data = {}
    preds = data["preds"]

    for video, video_preds in preds.items():
        labels = [round(troph_pred) for [other_pred, troph_pred] in video_preds]
        output_data[video] = {
            "pred_labels": labels,
            "preds": video_preds
        }

    with open(output_path, "w") as f:
        json.dump(output_data, f, ensure_ascii = False, indent=2)

    return output_data


PREDS_JSON = "inference_preds_all_clips.json"
OUTPUT_JSON = "inference_results_all_clips.json"

results_json = add_labels_to_results(
    input_path=PREDS_JSON,
    output_path=OUTPUT_JSON,
)


## Add run lengths

In [70]:
all_run_lengths = []

def add_run_lengths_to_results(data, output_path):
    run_lengths = []
    for video, video_data in data.items(): 
        run_lengths_vid = []
        labels = video_data["pred_labels"]
        n = len(labels)
        i = 0

        while i < n:
            if labels[i] == 1:
                start = i
                while i < n and labels[i] == 1:
                    i += 1
                run_length = i - start
                run_lengths_vid.append(run_length)
            else:
                i += 1
        data[video]["run_lengths"] = run_lengths_vid
        run_lengths.extend(run_lengths_vid)

    all_run_lengths.extend(run_lengths)

    with open(output_path, "w") as f:
        json.dump(data, f, ensure_ascii = False, indent=2)

    return data

results_json = add_run_lengths_to_results(
    results_json,
    output_path=OUTPUT_JSON,
)

## Clean runs of short events

Three different threshold for minimum troph. run length are used:
- 200ms (10 frames)
- 1s (50 frames)
- 4s (200 frames)

The 200ms results are used as the baseline because any run shorter than this is very likely too short to be trophallaxis

In [None]:
import numpy as np

def remove_short_troph_runs(results, min_run_length, output_path):
    """
    Replace runs of 1s shorter than min_run_length with 0s.
    """

    results_cleaned = deepcopy(results)
    for video, data in results_cleaned.items():
        labels = data["pred_labels"]
        cleaned = labels.copy()
        n = len(labels)
        i = 0
        n_runs_removed = 0
        n_runs_removed_gte_10 = 0
        n_frames_removed = 0
        n_troph_runs = 0

        while i < n:
            if labels[i] == 1:
                start = i
                while i < n and labels[i] == 1:
                    i += 1
                run_length = i - start

                if run_length < min_run_length:
                    n_runs_removed += 1
                    n_frames_removed += run_length
                    if run_length >= 10:
                        n_runs_removed_gte_10 += 1

                    for j in range(start, i):
                        cleaned[j] = 0
                else:
                    n_troph_runs += 1
            else:
                i += 1
        
        results_cleaned[video]["pred_labels"] = cleaned
        results_cleaned[video]["n_runs_removed"] = n_runs_removed
        results_cleaned[video]["n_runs_removed_gte_10"] = n_runs_removed_gte_10
        results_cleaned[video]["n_frames_removed"] = n_frames_removed
        results_cleaned[video]["n_troph_runs"] = n_troph_runs
        results_cleaned[video]["run_lengths"] = [run_length for run_length in results_cleaned[video]["run_lengths"] \
                                                 if run_length >= min_run_length]
    
    with open(output_path, "w") as f:
        json.dump(results_cleaned, f, ensure_ascii = False, indent=2)
    return results_cleaned

JSON_200MS = "inference_results_all_clips_above_200ms.json"
JSON_1S = "inference_results_all_clips_above_1s.json"
JSON_4S = "inference_results_all_clips_above_4s.json"

# use this as the baseline as clips below 10 frames are too short
results_above_200ms = remove_short_troph_runs(results_json, 10, JSON_200MS)
results_above_1s = remove_short_troph_runs(results_json, 50, JSON_1S)
results_above_4s = remove_short_troph_runs(results_json, 200, JSON_4S)




## Compute statistics for the different experiment series

For each group the following stats are calculated/given
| **Stat** | **Description** |
| ---------- | --------------- |
| frames_orig_used | Frame count of the original videos that were used to extract the clips |
| total_frames | Total frames in all clips combined |
| rel_prop | Relative proportion: troph. frames in group / total frames of the original videos of the experiment group *  total frames of the original videos of the control group (hex) / troph. frames in control group (hex) |


Additionally the following stats are calculated for the different minimum run lengths of the trophallaxis (for 200ms, 1s and 4s). They are explained here only for the 1s variant but the 200ms and 4s variants are equivalent, just calculated with different thresholds for the run lengths
| **Stat** | **Description** |
| ---------- | --------------- |
| troph_frames_above_1s | Number of troph. frames remaining if all runs of trophallaxis below a duration of 1 sec. get removed |
| n_runs_removed_1s | Number of runs below a duration of 1 sec. that get removed |
| n_runs_removed_gte_10_1s | Number of runs between 0.2 sec. (10 frames) and 1 sec. that get removed |
| n_troph_runs_1s | Total count of troph. sequences of >= 1 sec. |
| rel_prop_1s | same as rel. prop stat but only for runs above 1 sec. |
|||
| mean_run_length_1s | Mean of troph. run lengths |
| stdev_run_length_1s | Standard deviation of troph. run lengths |
| variance_run_length_1s | Variance of troph. run lengths |
| median_run_length_1s | Median of troph. run lengths |

Mean, standard deviation, variance and median are also calculated for the total set.




In [72]:
import statistics

groups = ["hex", "OCI", "OLE"]
result_sets_tuples = [("200ms", results_above_200ms), ("1s", results_above_1s), ("4s", results_above_4s)]
stats = {}

for group in groups:
    stats[group] = {
        "total_frames": 0,
        "run_lengths": []
    }
    for time_str, result_set in result_sets_tuples:
        stats[group][f"troph_frames_above_{time_str}"] = 0
        stats[group][f"n_runs_removed_{time_str}"] = 0
        stats[group][f"n_runs_removed_gte_10_{time_str}"] = 0
        stats[group][f"n_troph_runs_{time_str}"] = 0

# frames from all original videos I used, summed up by group (first five minutes of a video were 
# never used and some of the 20 videos per group were not used)
stats["hex"]["frames_orig_used"] = 1087622      # 18 vids (missing 6, 13)
stats["OCI"]["frames_orig_used"] = 1257284      # 20 vids
stats["OLE"]["frames_orig_used"] = 777168       # 13 vids (missing 4, 9, 13, 17, 18, 19, 20)


def save_stats(stats, file):
    for group, group_stats in stats.items():
        if group == "all":
            sep = f" ----- Overall statistics -----"
        else:
            sep = f" ----- Group {group} statistics -----"
        file.write(sep + "\n")
        print(sep)

        for stat, value in group_stats.items():
            stat_str = f"{stat}: {value}"
            file.write(stat_str + "\n")
            print(stat_str)
        print("\n")


def update_group_stats(video):
    for group in groups:
        if video.startswith(group):
            # general stats
            stats[group]["total_frames"] += len(results_above_200ms[video]["pred_labels"])
            stats[group]["run_lengths"].extend(results_above_200ms[video]["run_lengths"])

            # stats about the troph cleaning
            for time_str, result_set in result_sets_tuples:
                result_set_video = result_set[video]
                stats[group][f"troph_frames_above_{time_str}"] += sum(result_set_video["pred_labels"])
                stats[group][f"n_runs_removed_{time_str}"] += result_set_video["n_runs_removed"]
                stats[group][f"n_runs_removed_gte_10_{time_str}"] += result_set_video["n_runs_removed_gte_10"]
                stats[group][f"n_troph_runs_{time_str}"] += result_set_video["n_troph_runs"]
            return
    raise ValueError("A video comes from none of the defined experiment series!")


def calc_run_length_stats():
    all_run_lengths_secs = [elem / 50 for elem in all_run_lengths if elem >= 10]
    stats["all"] = {
        "mean_run_length": statistics.mean(all_run_lengths_secs),
        "stdev_run_length": statistics.stdev(all_run_lengths_secs),
        "variance_run_length": statistics.variance(all_run_lengths_secs),
        "median_run_length": statistics.median(all_run_lengths_secs)
    }

    for group in groups:
        # convert run lengths to seconds
        group_run_lengths_secs_200ms = [elem / 50 for elem in stats[group]["run_lengths"]]
        stats[group]["mean_run_length_200ms"] = statistics.mean(group_run_lengths_secs_200ms)
        stats[group]["stdev_run_length_200ms"] = statistics.stdev(group_run_lengths_secs_200ms)
        stats[group]["variance_run_length_200ms"] = statistics.variance(group_run_lengths_secs_200ms)
        stats[group]["median_run_length_200ms"] = statistics.median(group_run_lengths_secs_200ms)

        group_run_lengths_secs_1s = [run_len for run_len in group_run_lengths_secs_200ms if run_len >= 1]
        stats[group]["mean_run_length_1s"] = statistics.mean(group_run_lengths_secs_1s)
        stats[group]["stdev_run_length_1s"] = statistics.stdev(group_run_lengths_secs_1s)
        stats[group]["variance_run_length_1s"] = statistics.variance(group_run_lengths_secs_1s)
        stats[group]["median_run_length_1s"] = statistics.median(group_run_lengths_secs_1s)

        group_run_lengths_secs_4s = [run_len for run_len in group_run_lengths_secs_200ms if run_len >= 4]
        stats[group]["mean_run_length_4s"] = statistics.mean(group_run_lengths_secs_4s)
        stats[group]["stdev_run_length_4s"] = statistics.stdev(group_run_lengths_secs_4s)
        stats[group]["variance_run_length_4s"] = statistics.variance(group_run_lengths_secs_4s)
        stats[group]["median_run_length_4s"] = statistics.median(group_run_lengths_secs_4s)


def calc_rel_prop():
    for group in groups:
        if group == "hex":
            for time_str, _ in result_sets_tuples:
                stats[group][f"rel_prop_{time_str}"] = 1
        else:
            for time_str, _ in result_sets_tuples:
                stats[group][f"rel_prop_{time_str}"] = stats[group][f"troph_frames_above_{time_str}"] / stats[group]["frames_orig_used"] \
                                                       * stats["hex"]["frames_orig_used"] / stats["hex"][f"troph_frames_above_{time_str}"]

for video in results_json.keys():
    update_group_stats(video)

calc_run_length_stats()
calc_rel_prop()

with open("feral_inference_stats.txt", "w") as f:
    save_stats(stats, f)
    
        




 ----- Group hex statistics -----
total_frames: 305864
run_lengths: [184, 162, 69, 215, 254, 96, 105, 145, 336, 58, 190, 168, 66, 32, 37, 199, 254, 176, 320, 260, 202, 392, 247, 248, 29, 138, 52, 79, 416, 16, 62, 70, 80, 127, 444, 480, 375, 270, 16, 192, 160, 686, 410, 187, 100, 288, 287, 586, 205, 67, 25, 251, 102, 155, 158, 281, 272, 208, 10, 184, 41, 14, 11, 631, 169, 67, 352, 123, 66, 54, 626, 155, 197, 325, 22, 196, 125, 109, 273, 506, 114, 302, 253, 131, 259, 450, 910, 105, 353, 311, 279, 40, 217, 48, 521, 519, 256, 818, 123, 295, 368, 315, 270, 285, 95, 58, 377, 74, 237, 45, 132, 439, 192, 272, 48, 521, 19, 64, 400, 248, 53, 307, 327, 196, 263, 135, 244, 234, 142, 118, 139, 368, 69, 39, 255, 97, 96, 290, 2295, 183, 318, 53, 273, 340, 734, 65, 80, 258, 28, 224, 537, 11, 134, 146, 82, 199, 414, 131, 238, 183, 112, 301, 66, 13, 115, 158, 213, 226, 69, 348, 209, 149, 211, 309, 71, 331, 48, 150, 27, 81, 182, 31, 426, 53, 240, 129, 10, 270, 121, 119, 191, 108, 448, 593, 65, 82, 263, 3