In [8]:
import torch
import gzip
import numpy as np
import pandas as pd
import sys

sys.path.append("..")
from functions.prompts import Prompt
from functions.modified_predictor import modifiedPredictor
from functions.pipeline import *
from tqdm import tqdm
import pickle
from torchmetrics.classification import (
    BinaryF1Score,
    BinaryAccuracy,
    BinaryJaccardIndex,
)

In [2]:
with open("/Users/lisa/Documents/Master/sam-lab/ACDC/prompts2.pickle", "rb") as f:
    prompts_dict = pickle.load(f)

prompts_dict[1][0] gives coordinates and labels for the first image for batch size 50

In [7]:
dataloader = get_batch("../ACDC", 50, debug=False)
mp = modifiedPredictor()

In [20]:
iterations = 15

In [19]:
re = Results("/Users/lisa/Documents/Master/sam-lab/results", "error_sampling_0")

for i, batch in enumerate(dataloader):
    scores = {}
    embeddings, ground_truths = batch

    # load already computed masks
    with gzip.open(
        f"/Users/lisa/Documents/Master/sam-lab/ACDC/masks2/batch_{i}.npy.gz", "rb"
    ) as f:
        # Load the NumPy array from the file
        masks = torch.tensor(np.load(f))

    # load already computed 1 foreground point prompts (randomly sampled)
    prompts_class_1 = prompts_dict[1][
        i : i + 50
    ]  # contains 50 prompts per image for 50 images
    prompts_class_2 = prompts_dict[2][i : i + 50]
    prompts_class_3 = prompts_dict[3][i : i + 50]

    for j in tqdm(range(50)):
        embedding = embeddings[j]
        ground_truth = ground_truths[j]

        # create Prompts classes and give masks to compute error maps
        pr1 = Prompt(1, ground_truth, prompts_class_1[j][0], prompts_class_1[j][1])
        pr1.give_masks(masks[j])
        pr2 = Prompt(2, ground_truth, prompts_class_2[j][0], prompts_class_2[j][1])
        pr2.give_masks(masks[j])
        pr3 = Prompt(3, ground_truth, prompts_class_3[j][0], prompts_class_3[j][1])
        pr3.give_masks(masks[j])

        # loop through iterations
        for iteration in range(iterations):
            pr1.add_point_to_prompts()
            pr2.add_point_to_prompts()
            pr3.add_point_to_prompts()

            pr1_sam = pr1.get_prompts_sam()
            pr2_sam = pr2.get_prompts_sam()
            pr3_sam = pr3.get_prompts_sam()

            # count how many points are foreground and background in each prompt (for result storing)
            pr1_n_f = (pr1_sam[1] == 1).sum(dim=1, keepdim=True)
            pr1_n_b = (pr1_sam[1] == 0).sum(dim=1, keepdim=True)
            pr2_n_f = (pr2_sam[1] == 1).sum(dim=1, keepdim=True)
            pr2_n_b = (pr2_sam[1] == 0).sum(dim=1, keepdim=True)
            pr3_n_f = (pr3_sam[1] == 1).sum(dim=1, keepdim=True)
            pr3_n_b = (pr3_sam[1] == 0).sum(dim=1, keepdim=True)

            # give prompts to the predictor and generate multiclass masks
            logit_class_1 = mp.predict(embedding, pr1_sam[0], pr1_sam[1])
            logit_class_2 = mp.predict(embedding, pr2_sam[0], pr2_sam[1])
            logit_class_3 = mp.predict(embedding, pr3_sam[0], pr3_sam[1])

            logit_stack = torch.cat(
                [logit_class_1, logit_class_2, logit_class_3], dim=1
            )
            new_masks = multiclass_prob_batched(logit_stack, hard_labels=True)

            # evaluate the new masks with BinaryF1Score, BinaryAccuracy, BinaryJaccardIndex,
            metrics = {}
            for c in [1, 2, 3]:
                metrics[c] = {}
                preds = torch.where(new_masks == c, 1, 0)
                preds = preds.squeeze(1)
                metrics[c]["IOU"] = []
                metrics[c]["Dice"] = []
                metrics[c]["Accuracy"] = []
                for pred in preds:
                    targets = torch.where(ground_truth == c, 1, 0)
                    metric = BinaryJaccardIndex()
                    metrics[c]["IOU"].append(metric(pred, targets))
                    metric = BinaryF1Score()
                    metrics[c]["Dice"].append(metric(pred, targets))
                    metric = BinaryAccuracy()
                    metrics[c]["Accuracy"].append(metric(pred, targets))

            # give new masks to prompts classes
            pr1.give_masks(new_masks)
            pr2.give_masks(new_masks)
            pr3.give_masks(new_masks)

            # write down all results in the results class
            results = []
            for batch_entry in range(50):
                result = {
                    "image_id": 50 * i + j,
                    "f_points_class_1": int(pr1_n_f[batch_entry]),
                    "f_points_class_2": int(pr2_n_f[batch_entry]),
                    "f_points_class_3": int(pr3_n_f[batch_entry]),
                    "b_points_class_1": int(pr1_n_b[batch_entry]),
                    "b_points_class_2": int(pr2_n_b[batch_entry]),
                    "b_points_class_3": int(pr3_n_b[batch_entry]),
                    "dice_class_1": round(float(metrics[1]["Dice"][batch_entry]), 3),
                    "dice_class_2": round(float(metrics[2]["Dice"][batch_entry]), 3),
                    "dice_class_3": round(float(metrics[3]["Dice"][batch_entry]), 3),
                    "IOU_class_1": round(float(metrics[1]["IOU"][batch_entry]), 3),
                    "IOU_class_2": round(float(metrics[2]["IOU"][batch_entry]), 3),
                    "IOU_class_3": round(float(metrics[3]["IOU"][batch_entry]), 3),
                    "accuracy_class_1": round(
                        float(metrics[1]["Accuracy"][batch_entry]), 3
                    ),
                    "accuracy_class_2": round(
                        float(metrics[2]["Accuracy"][batch_entry]), 3
                    ),
                    "accuracy_class_3": round(
                        float(metrics[3]["Accuracy"][batch_entry]), 3
                    ),
                }
                results.append(result)
            re.append_row(results)
    break

  4%|▍         | 2/50 [03:31<1:24:25, 105.53s/it]


KeyboardInterrupt: 