In [1]:
import torch
import torchvision.transforms as T

import PIL

import os
import random
import cv2
import numpy as np

# from matplotlib import pyplot as plt 
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from PIL import Image
import copy
import json

from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
from diffusers import StableDiffusionInpaintPipeline


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
apricot_dataset = "/project/trinity/datasets/apricot/pub/apricot-mask/data_mask_v2"
apricot_files = os.listdir(apricot_dataset)

In [3]:
annotations_file = "/project/trinity/datasets/APRICOT_ORIGINAL/Annotations/apricot_annotations_test.json"

with open(annotations_file, 'r') as file:
    data = json.load(file)

annotations_coco_91 = {}

for i in range(len(data['categories'])):
    annotations_coco_91[data['categories'][i]['id']] = data['categories'][i]['name']

In [5]:
def getRandomImageMasks(n=1, scale=1.0, shape=(512,512)):
    random_numbers = random.sample(range(0, len(apricot_files)-1), n)

    imgs = []
    masks = []
    all_bboxes = []
    all_labels = []
    for i in range(n):
        img_info = torch.load(os.path.join(apricot_dataset, apricot_files[random_numbers[i]]))
        img = np.squeeze(img_info['Image'])
        h, w, _ = img.shape
        # img = (255.0 * img).astype(np.uint8)

        # img = cv2.resize(img, (int(w * scale), int(h * scale)))
        img = cv2.resize(img, shape)
        img = torch.from_numpy(np.transpose(img, (2,0,1)))
        imgs.append(img)

        mask = np.squeeze(img_info['Mask'])
        mask = (255.0 * mask).astype(np.uint8)

        # mask = cv2.resize(mask, (int(w * scale), int(h * scale)))
        mask = cv2.resize(mask, shape)

        # Make a rectangular mask
        nonzero_indices = np.nonzero(mask)
        min_row = np.min(nonzero_indices[0])
        max_row = np.max(nonzero_indices[0])
        min_col = np.min(nonzero_indices[1])
        max_col = np.max(nonzero_indices[1])
        mask[min_row:max_row+1, min_col:max_col+1] = 255

        masks.append(mask)


        bounding_boxes = img_info['Annotations'][0]['boxes']
        bboxes = np.reshape(bounding_boxes, (bounding_boxes.shape[1], bounding_boxes.shape[2]))        
        bboxes[:, 2] = ((bboxes[:, 2] - bboxes[:, 0]) * img.shape[1]).astype(int)
        bboxes[:, 3] = ((bboxes[:, 3] - bboxes[:, 1]) * img.shape[2]).astype(int)
        bboxes[:, 0] = (img.shape[1] * bboxes[:, 0]).astype(int)
        bboxes[:, 1] = (img.shape[2] * bboxes[:, 1]).astype(int)

        all_bboxes.append(bboxes)

        labels = img_info['Annotations'][0]['labels']
        labels = np.reshape(labels, (labels.shape[1]))
        all_labels.append(labels)

    return imgs, masks, all_bboxes, all_labels

In [6]:
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16)
pipe.to("cuda:2")
prompt = ""

In [7]:
def getRepaintedImages(imgs, masks):
    repainted_images = []
    for i in range(len(imgs)):
        img = (255.0 * np.transpose(imgs[i].numpy(), (1,2,0))).astype(np.uint8)
        repainted_image = pipe(prompt=prompt, image=img, mask_image=masks[i]).images[0]
        repainted_image = np.transpose(np.array(repainted_image), (2,0,1))
        repainted_image = torch.from_numpy(repainted_image.astype(np.float32))
        repainted_image /= 255.0
        # repainted_image = repainted_image.double()
        repainted_images.append(repainted_image)
    
    return repainted_images

In [8]:
num_images = 20
adv_imgs, masks, all_bboxes, all_labels = getRandomImageMasks(n=num_images, scale=0.20, shape=(512,512))
repainted_imgs = getRepaintedImages(adv_imgs, masks)

100%|██████████| 50/50 [00:04<00:00, 10.68it/s]
100%|██████████| 50/50 [00:04<00:00, 10.87it/s]
100%|██████████| 50/50 [00:04<00:00, 10.84it/s]
100%|██████████| 50/50 [00:04<00:00, 10.80it/s]
100%|██████████| 50/50 [00:04<00:00, 10.76it/s]
100%|██████████| 50/50 [00:04<00:00, 10.71it/s]
100%|██████████| 50/50 [00:04<00:00, 10.69it/s]
100%|██████████| 50/50 [00:04<00:00, 10.68it/s]
100%|██████████| 50/50 [00:04<00:00, 10.62it/s]
100%|██████████| 50/50 [00:04<00:00, 10.58it/s]
100%|██████████| 50/50 [00:04<00:00, 10.59it/s]
100%|██████████| 50/50 [00:04<00:00, 10.58it/s]
100%|██████████| 50/50 [00:04<00:00, 10.56it/s]
100%|██████████| 50/50 [00:04<00:00, 10.54it/s]
100%|██████████| 50/50 [00:04<00:00, 10.53it/s]
100%|██████████| 50/50 [00:04<00:00, 10.54it/s]
100%|██████████| 50/50 [00:04<00:00, 10.52it/s]
100%|██████████| 50/50 [00:04<00:00, 10.52it/s]
100%|██████████| 50/50 [00:04<00:00, 10.49it/s]
100%|██████████| 50/50 [00:04<00:00, 10.50it/s]


In [9]:
model = fasterrcnn_resnet50_fpn(pretrained=True)

In [10]:
model.eval()
adv_predictions = model(adv_imgs)

repaint_predictions = model(repainted_imgs)

In [1]:
for i in range(num_images):
    fig, axes = plt.subplots(1, 3, figsize=(12, 20))
    
    ####### GT IMAGE WITH GT BOUNDING BOXES
    img = (255.0 * np.transpose(adv_imgs[i].numpy(), (1,2,0))).astype(np.uint8)

    axes[0].imshow(img)
    axes[0].axis('off')
    axes[0].set_title('GT Bounding Boxes')

    for j in range(all_bboxes[i].shape[0]):
        if all_labels[i][j] != -10:
            rectangle = patches.Rectangle((all_bboxes[i][j][1], all_bboxes[i][j][0]), all_bboxes[i][j][3], all_bboxes[i][j][2], linewidth=1, edgecolor='r', facecolor='none')
            # cv2.putText(imgs[i], text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
            axes[0].add_patch(rectangle)
            axes[0].text(all_bboxes[i][j][1], all_bboxes[i][j][0], annotations_coco_91[all_labels[i][j]], fontsize=8, bbox=dict(facecolor='black', alpha=0.8, pad=1), color='white')

    ########## ADVERSARIAL BOUNDING BOXES ###############

    adversarial_bboxes = copy.deepcopy(adv_predictions[i]['boxes'].detach().numpy())
    adversarial_bboxes[:, 2] = (adversarial_bboxes[:, 2] - adversarial_bboxes[:, 0]).astype(int)
    adversarial_bboxes[:, 3] = (adversarial_bboxes[:, 3] - adversarial_bboxes[:, 1]).astype(int)
    adversarial_bboxes[:, 0] = (adversarial_bboxes[:, 0]).astype(int)
    adversarial_bboxes[:, 1] = (adversarial_bboxes[:, 1]).astype(int)

    adversarial_labels = adv_predictions[i]['labels'].detach().numpy()
    adversarial_scores = adv_predictions[i]['scores'].detach().numpy()

    axes[1].imshow(img)
    axes[1].axis('off')
    axes[1].set_title('Adversarial Bounding Boxes (FRCNN)')

    for j in range(adversarial_bboxes.shape[0]):
        if adversarial_scores[j] > 0.6:
            rectangle = patches.Rectangle((adversarial_bboxes[j][0], adversarial_bboxes[j][1]), adversarial_bboxes[j][2], adversarial_bboxes[j][3],linewidth=1, edgecolor='r', facecolor='none')
            # cv2.putText(imgs[i], text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
            axes[1].add_patch(rectangle)
            axes[1].text(adversarial_bboxes[j][0], adversarial_bboxes[j][1], annotations_coco_91[adversarial_labels[j]], fontsize=8, bbox=dict(facecolor='black', alpha=0.8, pad=1), color='white')


    ########## REPAINTED IMAGE BOUNDING BOXES ####################

    repainted_img = (255.0 * np.transpose(repainted_imgs[i].numpy(), (1,2,0))).astype(np.uint8)

    repainted_bboxes = copy.deepcopy(repaint_predictions[i]['boxes'].detach().numpy())

    # print(repainted_bboxes)
    repainted_bboxes[:, 2] = (repainted_bboxes[:, 2] - repainted_bboxes[:, 0]).astype(int)
    repainted_bboxes[:, 3] = (repainted_bboxes[:, 3] - repainted_bboxes[:, 1]).astype(int)
    repainted_bboxes[:, 0] = (repainted_bboxes[:, 0]).astype(int)
    repainted_bboxes[:, 1] = (repainted_bboxes[:, 1]).astype(int)
    # print(repainted_bboxes)

    repaint_labels = repaint_predictions[i]['labels'].detach().numpy()
    repaint_scores = repaint_predictions[i]['scores'].detach().numpy()

    axes[2].imshow(repainted_img)
    axes[2].axis('off')
    axes[2].set_title('Repainted Bounding Boxes')

    for j in range(repainted_bboxes.shape[0]):
        if repaint_scores[j] > 0.5:
            rectangle = patches.Rectangle((repainted_bboxes[j][0], repainted_bboxes[j][1]), repainted_bboxes[j][2], repainted_bboxes[j][3],linewidth=1, edgecolor='r', facecolor='none')
            # cv2.putText(imgs[i], text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
            axes[2].add_patch(rectangle)
            axes[2].text(repainted_bboxes[j][0], repainted_bboxes[j][1], annotations_coco_91[repaint_labels[j]], fontsize=8, bbox=dict(facecolor='black', alpha=0.8, pad=1), color='white')

    plt.tight_layout()
    plt.show()
    # display(fig)
    plt.close()

