In [1]:
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import json

import sys; sys.path.append("../src")
from cholec import \
    CholecExample, \
    CholecDataset, \
    get_llm_generated_answer, \
    isolate_individual_features, \
    distill_relevant_features, \
    calculate_expert_alignment_scores

  from pandas.core import (


In [2]:
dataset = CholecDataset(split="test")
num_samples = 3
items = [dataset[i] for i in range(num_samples)]

### Stage 0: Get LLM explanations

In [3]:
_t = time.time()
# Parallelized calls to the LLM
llm_answers = get_llm_generated_answer([item["image"] for item in items])
print(f"Elapsed: {time.time() - _t:.3f}")

Elapsed: 15.269


In [4]:
examples = [
    CholecExample(
        id=item["id"],
        image=item["image"],
        organ_masks=[item["organs"] == i for i in [0, 1, 2]],
        gonogo_masks=[item["gonogo"] == i for i in [0, 1, 2, 3]],
        llm_explanation = llm_ans
    )
    for (item, llm_ans) in zip(items, llm_answers)
]

### Stage 1: Atomic claim extraction

In [5]:
_t = time.time()
# Parallelized calls to get the list of list of claims
all_all_claims = isolate_individual_features([example.llm_explanation for example in examples])
print(f"Elapsed: {time.time() - _t:.3f}")

Elapsed: 4.357


In [6]:
for i in range(len(all_all_claims)):
    examples[i].all_claims = all_all_claims[i]

In [7]:
len(examples), len(all_all_claims), [len(ac) for ac in all_all_claims]

(3, 3, [12, 13, 17])

### Stage 2: Distill relevant claims

In [8]:
# There's several claims per example, so ful parallelization won't buy **that** much.
_t = time.time()
for example in examples:
    example.relevant_claims = distill_relevant_features(example.image, example.all_claims)
print(f"Elapsed: {time.time() - _t:.3f}")

Elapsed: 12.271


In [9]:
[len(ex.relevant_claims) for ex in examples]

[5, 5, 4]

### Stage 3: Calculate alignment scores

In [10]:
_t = time.time()
for i, ex in enumerate(examples):
    align_infos = calculate_expert_alignment_scores(ex.relevant_claims)

    if len(align_infos) > 0:
        final_score = sum(info["Alignment"] for info in align_infos) / len(align_infos)
    else:
        final_score = 0.0

    ex.alignable_claims = [info["Claim"] for info in align_infos]
    ex.aligned_category_ids = [info["Category ID"] for info in align_infos]
    ex.alignment_scores = [info["Alignment"] for info in align_infos]
    ex.alignment_reasonings = [info["Reasoning"] for info in align_infos]
    ex.final_alignment_score = final_score
    print(f"Elapsed {i+1}/{len(examples)}: {time.time() - _t:.3f}, score {final_score:.3f}")

Elapsed 1/3: 2.035, score 0.840
Elapsed 2/3: 4.380, score 0.720
Elapsed 3/3: 6.081, score 0.525


### Save to file

In [11]:
with open(f"_dump/cholec_dump.json", "w") as f:
    json.dump([ex.to_dict() for ex in examples], f, indent=2)