In [1]:
import time
import matplotlib.pyplot as plt
import numpy as np
import torch

import sys; sys.path.append("../src")
from cholec import \
    CholecExample, \
    CholecDataset, \
    get_llm_generated_answer, \
    isolate_individual_features, \
    distill_relevant_features, \
    calculate_expert_alignment_score

  from pandas.core import (


In [2]:
dataset = CholecDataset(split="test")
num_samples = 5
items = [dataset[i] for i in range(5)]

### Stage 0: Get LLM explanations

In [3]:
_t = time.time()
# Parallelized calls to the LLM
llm_answers = get_llm_generated_answer([item["image"] for item in items])
print(f"Elapsed: {time.time() - _t:.3f}")

Elapsed: 22.340


In [4]:
examples = [
    CholecExample(
        id=item["id"],
        image=item["image"],
        organ_masks=[item["organs"] == i for i in [0, 1, 2]],
        gonogo_masks=[item["gonogo"] == i for i in [0, 1, 2, 3]],
        llm_explanation = llm_ans
    )
    for (item, llm_ans) in zip(items, llm_answers)
]

### Stage 1: Atomic claim extraction

In [5]:
_t = time.time()
# Parallelized calls to get the list of list of claims
all_all_claims = isolate_individual_features([example.llm_explanation for example in examples])
print(f"Elapsed: {time.time() - _t:.3f}")

Elapsed: 8.948


In [6]:
for i in range(len(all_all_claims)):
    examples[i].all_claims = all_all_claims[i]

In [7]:
len(examples), len(all_all_claims), [len(ac) for ac in all_all_claims]

(5, 5, [14, 15, 1, 19, 15])

### Stage 2: Distill relevant claims

In [8]:
# There's several claims per example, so ful parallelization won't buy **that** much.
_t = time.time()
for example in examples:
    example.relevant_claims = distill_relevant_features(example.image, example.all_claims)
print(f"Elapsed: {time.time() - _t:.3f}")

Elapsed: 32.740


In [9]:
[len(ex.relevant_claims) for ex in examples]

[9, 9, 0, 15, 11]

### Stage 3: Calculate alignment scores

In [10]:
_t = time.time()
for i, ex in enumerate(examples):
    aligns_info = calculate_expert_alignment_score(ex.relevant_claims)
    for info in aligns_info:
        try:
            if all([k in info.keys() for k in ["Category ID", "Category", "Alignment", "Reasoning"]]):
                ex.alignment_categories.append((info["Category ID"], info["Category"]))
                ex.alignment_scores.append(info["Alignment"])
                ex.alignment_reasonings.append(info["Reasoning"])
    
    
                if len(ex.alignment_scores) > 0:
                    ex.final_alignment_score = sum(ex.alignment_scores) / len(ex.alignment_scores)
            raise ValueError(f"Insufficient keys: {info.keys()}")
        except Exception as e:
            print(f"\tExplanation {i}, with exception {e}")

    print(f"Elapsed {i+1}/{len(examples)}: {time.time() - _t:.3f}")

	Explanation 0, with exception Insufficient keys: dict_keys(['Category', 'Category ID', 'Alignment', 'Reasoning'])
	Explanation 0, with exception Insufficient keys: dict_keys(['Category', 'Category ID', 'Alignment', 'Reasoning'])
	Explanation 0, with exception Insufficient keys: dict_keys(['Category', 'Category ID', 'Alignment', 'Reasoning'])
	Explanation 0, with exception Insufficient keys: dict_keys(['Category', 'Category ID', 'Alignment', 'Reasoning'])
	Explanation 0, with exception Insufficient keys: dict_keys(['Category', 'Category ID', 'Alignment', 'Reasoning'])
	Explanation 0, with exception Insufficient keys: dict_keys(['Category', 'Category ID', 'Alignment', 'Reasoning'])
	Explanation 0, with exception Insufficient keys: dict_keys(['Category', 'Category ID', 'Alignment', 'Reasoning'])
	Explanation 0, with exception Insufficient keys: dict_keys(['Category', 'Category ID', 'Alignment', 'Reasoning'])
	Explanation 0, with exception Insufficient keys: dict_keys(['Category', 'Catego

In [11]:
for ex in examples:
    print(ex.final_alignment_score)

0.6888888888888889
0.7777777777777778
0
0.7200000000000001
0.7999999999999999
