In [None]:
import os
import cv2
import monai
import numpy as np
from PIL import Image
from tqdm import tqdm
from statistics import mean
from typing import Dict, List, Tuple
from matplotlib import pyplot as plt

import torch
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn.functional import threshold, normalize

from transformers import SamProcessor
from transformers import SamModel

from finetune import resume, calc_metrics, postprocess_masks, \
    get_point_prompt_bymask, \
    PSVDataset, SAMDataset

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
color_list: List = [
    # Yellow
    np.array([1, 1, 0, 0.3]),
    # Blue
    np.array([0, 0, 1, 0.3]),
    # Purple
    np.array([0.5, 0, 0.5, 0.3]),
    # Red
    np.array([1, 0, 0, 0.3]),
    # Green
    np.array([0, 1, 0, 0.3]),
    # Orange
    np.array([1, 0.5, 0, 0.3]),
    # Pink
    np.array([1, 0, 1, 0.3]),
    # Brown
    np.array([0.5, 0.25, 0, 0.3]),
    # Gray
    np.array([0.5, 0.5, 0.5, 0.3]),
    # Black
    np.array([0, 0, 0, 0.3]),
    # Teal
    np.array([0, 0.5, 0.5, 0.3]),
    # Navy
    np.array([0, 0, 0.5, 0.3])
]


def show_mask(mask, ax, mask_color=None, random_color=False):
    if mask_color is not None:
        color = mask_color
    else:
        if random_color:
            color = np.concatenate(
                [np.random.random(3), np.array([0.6])], axis=0)
        else:
            color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])

    if issubclass(type(mask), torch.Tensor):
        mask = mask.cpu().numpy()
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckp_path = 'work_dirs/sam_psv/sam_loss_full/'
resume_path = os.path.join(ckp_path, 'latest.pth')

'init model and processor'
model = SamModel.from_pretrained(
    "facebook/sam-vit-huge", mirror='tuna')
processor = SamProcessor.from_pretrained(
    "facebook/sam-vit-huge", mirror='tuna')
model, _, _ = resume(model, None, resume_path, device)
model.to(device)


In [None]:
model.eval()
iou = []
dsc = []
i = 0
i_max = 20
test_dataset = PSVDataset(split='test')
for idx in range(i_max):
    item = test_dataset[idx+20]
    image, gt_mask = item['image'], item['label']
    prompts = get_point_prompt_bymask(gt_mask)
    inputs = processor(image, **prompts, return_tensors="pt")
    inputs = inputs.to(device)
    with torch.no_grad():
        outputs = model(**inputs, multimask_output=False)
    print(outputs.keys())
    low_res_masks = outputs.pred_masks
    upscaled_masks = postprocess_masks(
        low_res_masks.squeeze(1),
        inputs["reshaped_input_sizes"][0].tolist(),
        inputs["original_sizes"][0].tolist())
    pred_mask = torch.sigmoid(upscaled_masks).squeeze().cpu().numpy()
    pred_mask = (pred_mask > 0.5).astype(np.uint8)
    kernel = np.ones((3, 3), np.uint8)
    # pred_mask = cv2.erode(pred_mask, kernel, iterations=1)
    # pred_mask = cv2.dilate(pred_mask, kernel, iterations=1)
    # pred_mask = cv2.dilate(pred_mask, kernel, iterations=3)
    # pred_mask = cv2.erode(pred_mask, kernel, iterations=3)
    
    fig, axes = plt.subplots(1, 3, figsize=(10, 30))
    axes[0].imshow(image)

    axes[1].imshow(image)
    show_mask(gt_mask, axes[0], mask_color=color_list[0])

    axes[2].imshow(image)
    show_mask(pred_mask, axes[1], mask_color=color_list[1])
    plt.show()


## Plot Accuracy Curve

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
accu = '21.73 41.20 68.45 72.60 74.59 75.96 76.94 77.66 78.28 78.59 79.10 79.45 79.75 79.85 80.10 80.32 80.53 80.74 80.76 80.77 80.81'
patn = r'(\d+\.\d+)'
accu = re.findall(patn, accu)
accu = [float(i) for i in accu]
epochs = np.arange(0, 105, 5)

plt.figure(figsize=(10, 10))
ax = plt.subplot(111)
ax.plot(epochs, accu, color="black")
for idx, y in enumerate(accu):
    plt.plot(epochs[:idx + 1], [y] * (idx + 1),
                "--", lw=1, color="green", alpha=0.7)
for idx, epoch in enumerate(epochs):
    plt.plot([epoch] * (idx + 1), accu[:idx + 1],
                "--", lw=1, color="red", alpha=0.7)
for idx, epoch in enumerate(epochs):
    plt.scatter(epoch, accu[idx], color="orange", s=50)
plt.ylim(min(accu), 82)
plt.xlim(0, 105)
plt.xticks(fontsize=14)
plt.yticks([min(accu)] + list(range(30, 90, 10)), fontsize=14)
plt.xlabel('epochs', fontsize=16)
plt.ylabel('mIoU', fontsize=16)
# plt.axis('off')
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()

plt.show()

In [None]:
import re
import numpy as np
import matplotlib.pyplot as plt
accu = '41.20 68.45 72.60 74.59 75.96 76.94 77.66 78.28 78.59 79.10 79.45 79.75 79.85 80.10 80.32 80.53 80.74 80.76 80.77 80.81'
patn = r'(\d+\.\d+)'
accu = re.findall(patn, accu)
accu = [float(i) for i in accu]
accu = list(np.array(accu[1:]) - np.array(accu[: -1])) + [0]
print(accu)
epochs = np.arange(0, 100, 5)

plt.figure(figsize=(10, 10))
ax = plt.subplot(111)
ax.plot(epochs, accu, color="black")
for idx, y in enumerate(accu):
    plt.plot(epochs[:idx + 1], [y] * (idx + 1),
                "--", lw=1, color="green", alpha=0.7)
for idx, epoch in enumerate(epochs):
    plt.plot([epoch, epoch], [0, accu[idx]],
                "--", lw=1, color="red", alpha=0.7)
for idx, epoch in enumerate(epochs):
    plt.scatter(epoch, accu[idx], color="orange", s=50)
plt.ylim(min(accu), 30)
plt.xlim(0, 100)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel('epochs', fontsize=16)
plt.ylabel('diff-mIoU', fontsize=16)
# plt.axis('off')
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.get_xaxis().tick_bottom()
ax.get_yaxis().tick_left()

plt.show()