In [1]:
import torch
import torchvision.transforms as T

import PIL

import os
import random
import cv2
import numpy as np

# from matplotlib import pyplot as plt 
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import copy
import json

from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn
from diffusers import StableDiffusionInpaintPipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
coco_data_folder = "/homes/e34960/SegmentAndComplete/adv_data/coco_random_patch_100/data"
coco_files = os.listdir(coco_data_folder)

In [3]:
annotations_file = "/project/trinity/datasets/APRICOT_ORIGINAL/Annotations/apricot_annotations_test.json"

with open(annotations_file, 'r') as file:
    data = json.load(file)

annotations_coco_91 = {}

for i in range(len(data['categories'])):
    annotations_coco_91[data['categories'][i]['id']] = data['categories'][i]['name']

In [4]:
def getRandomImageMasks(n=1, scale=1.0, shape=(512,512)):
    random_numbers = random.sample(range(0, len(coco_files)-1), n)

    adv_imgs = []
    orig_imgs = []
    masks = []
    all_bboxes = []
    all_labels = []
    for i in range(n):
        img_info = torch.load(os.path.join(coco_data_folder, coco_files[random_numbers[i]]))
        adv_img = np.squeeze(img_info['x_adv'])
        # adv_img = (255.0 * adv_img).astype(np.uint8)
        h, w, _ = adv_img.shape

        # adv_img = cv2.resize(adv_img, (int(w * scale), int(h * scale)))
        adv_img = cv2.resize(adv_img, shape)
        adv_img = torch.from_numpy(np.transpose(adv_img, (2,0,1)))
        adv_imgs.append(adv_img)

        orig_img = np.squeeze(img_info['x'])
        # orig_img = (255.0 * orig_img).astype(np.uint8)
        h, w, _ = orig_img.shape

        # orig_img = cv2.resize(orig_img, (int(w * scale), int(h * scale)))
        orig_img = cv2.resize(orig_img, shape)
        orig_img = torch.from_numpy(np.transpose(orig_img, (2,0,1)))
        orig_imgs.append(orig_img)


        mask = np.zeros((h, w))
        mask[img_info['ymin']:img_info['ymin']+100, img_info['xmin']:img_info['xmin']+100] = 255
        # mask = np.squeeze(img_info['Mask'])
        # mask = (255.0 * mask).astype(np.uint8)

        # # mask = cv2.resize(mask, (int(w * scale), int(h * scale)))
        mask = cv2.resize(mask, shape)

        # # Make a rectangular mask
        # nonzero_indices = np.nonzero(mask)
        # min_row = np.min(nonzero_indices[0])
        # max_row = np.max(nonzero_indices[0])
        # min_col = np.min(nonzero_indices[1])
        # max_col = np.max(nonzero_indices[1])
        # mask[min_row:max_row+1, min_col:max_col+1] = 255

        masks.append(mask)


        bounding_boxes = img_info['y']['boxes'].cpu().detach().numpy()
        bboxes = np.reshape(bounding_boxes, (bounding_boxes.shape[1], bounding_boxes.shape[2]))        
        bboxes[:, 2] = ((bboxes[:, 2] - bboxes[:, 0])*(shape[0]/w)).astype(int)
        bboxes[:, 3] = ((bboxes[:, 3] - bboxes[:, 1])*(shape[1]/h)).astype(int)
        bboxes[:, 0] = (shape[0] * bboxes[:, 0] / w).astype(int)
        bboxes[:, 1] = (shape[1] * bboxes[:, 1] / h).astype(int)

        all_bboxes.append(bboxes)

        labels = img_info['y']['labels'].cpu().detach().numpy()
        labels = np.reshape(labels, (labels.shape[1]))
        all_labels.append(labels)

    return adv_imgs, orig_imgs, masks, all_bboxes, all_labels

In [5]:
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16)
pipe.to("cuda:2")
prompt = ""

In [6]:
def getRepaintedImages(imgs, masks):
    repainted_images = []
    for i in range(len(imgs)):
        img = (255.0 * np.transpose(imgs[i].numpy(), (1,2,0))).astype(np.uint8)
        repainted_image = pipe(prompt=prompt, image=img, mask_image=masks[i]).images[0]
        repainted_image = np.transpose(np.array(repainted_image), (2,0,1))
        repainted_image = torch.from_numpy(repainted_image.astype(np.float32))
        repainted_image /= 255.0
        # repainted_image = repainted_image.double()
        repainted_images.append(repainted_image)
    
    return repainted_images

In [7]:
num_images = 20
adv_imgs, orig_imgs, masks, all_bboxes, all_labels = getRandomImageMasks(n=num_images, scale=0.20, shape=(512,512))
repainted_imgs = getRepaintedImages(adv_imgs, masks)

100%|██████████| 50/50 [00:04<00:00, 10.80it/s]
100%|██████████| 50/50 [00:04<00:00, 10.85it/s]
100%|██████████| 50/50 [00:04<00:00, 10.78it/s]
100%|██████████| 50/50 [00:04<00:00, 10.76it/s]
100%|██████████| 50/50 [00:04<00:00, 10.70it/s]
100%|██████████| 50/50 [00:04<00:00, 10.70it/s]
100%|██████████| 50/50 [00:04<00:00, 10.68it/s]
100%|██████████| 50/50 [00:04<00:00, 10.61it/s]
100%|██████████| 50/50 [00:04<00:00, 10.61it/s]
100%|██████████| 50/50 [00:04<00:00, 10.60it/s]
100%|██████████| 50/50 [00:04<00:00, 10.58it/s]
100%|██████████| 50/50 [00:04<00:00, 10.57it/s]
100%|██████████| 50/50 [00:04<00:00, 10.58it/s]
100%|██████████| 50/50 [00:04<00:00, 10.55it/s]
100%|██████████| 50/50 [00:04<00:00, 10.55it/s]
100%|██████████| 50/50 [00:04<00:00, 10.53it/s]
100%|██████████| 50/50 [00:04<00:00, 10.53it/s]
100%|██████████| 50/50 [00:04<00:00, 10.53it/s]
100%|██████████| 50/50 [00:04<00:00, 10.53it/s]
100%|██████████| 50/50 [00:04<00:00, 10.52it/s]


In [8]:
model = fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
adv_predictions = model(adv_imgs)

repaint_predictions = model(repainted_imgs)

orig_predictions = model(orig_imgs)

In [1]:
for i in range(num_images):
    fig, axes = plt.subplots(1, 4, figsize=(12, 20))

    img = (255.0 * np.transpose(orig_imgs[i].numpy(), (1,2,0))).astype(np.uint8)

    axes[0].imshow(img)
    axes[0].axis('off')
    axes[0].set_title('GT Bounding Boxes')

    for j in range(all_bboxes[i].shape[0]):
        if all_labels[i][j] != -10:
            rectangle = patches.Rectangle((all_bboxes[i][j][0], all_bboxes[i][j][1]), all_bboxes[i][j][2], all_bboxes[i][j][3], linewidth=1, edgecolor='r', facecolor='none')
            # cv2.putText(imgs[i], text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
            axes[0].add_patch(rectangle)
            axes[0].text(all_bboxes[i][j][0], all_bboxes[i][j][1], annotations_coco_91[all_labels[i][j]], fontsize=8, bbox=dict(facecolor='black', alpha=0.8, pad=1), color='white')
    

    adv_img = (255.0 * np.transpose(adv_imgs[i].numpy(), (1,2,0))).astype(np.uint8)
    adversarial_labels = adv_predictions[i]['labels'].detach().numpy()
    adversarial_scores = adv_predictions[i]['scores'].detach().numpy()

    adversarial_bboxes = copy.deepcopy(adv_predictions[i]['boxes'].detach().numpy())
    adversarial_bboxes[:, 2] = (adversarial_bboxes[:, 2] - adversarial_bboxes[:, 0]).astype(int)
    adversarial_bboxes[:, 3] = (adversarial_bboxes[:, 3] - adversarial_bboxes[:, 1]).astype(int)
    adversarial_bboxes[:, 0] = (adversarial_bboxes[:, 0]).astype(int)
    adversarial_bboxes[:, 1] = (adversarial_bboxes[:, 1]).astype(int)

    axes[1].imshow(adv_img)
    axes[1].axis('off')
    axes[1].set_title('Adversarial Bounding Boxes')
    for j in range(adversarial_bboxes.shape[0]):
        if adversarial_scores[j] > 0.5:
            rectangle = patches.Rectangle((adversarial_bboxes[j][0], adversarial_bboxes[j][1]), adversarial_bboxes[j][2], adversarial_bboxes[j][3],linewidth=1, edgecolor='r', facecolor='none')
            # cv2.putText(imgs[i], text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
            axes[1].add_patch(rectangle)
            axes[1].text(adversarial_bboxes[j][0], adversarial_bboxes[j][1], annotations_coco_91[adversarial_labels[j]], fontsize=8, bbox=dict(facecolor='black', alpha=0.8, pad=1), color='white')


    repaint_img = (255.0 * np.transpose(repainted_imgs[i].numpy(), (1,2,0))).astype(np.uint8)
    repaint_labels = repaint_predictions[i]['labels'].detach().numpy()
    repaint_scores = repaint_predictions[i]['scores'].detach().numpy()

    repaint_bboxes = copy.deepcopy(repaint_predictions[i]['boxes'].detach().numpy())
    repaint_bboxes[:, 2] = (repaint_bboxes[:, 2] - repaint_bboxes[:, 0]).astype(int)
    repaint_bboxes[:, 3] = (repaint_bboxes[:, 3] - repaint_bboxes[:, 1]).astype(int)
    repaint_bboxes[:, 0] = (repaint_bboxes[:, 0]).astype(int)
    repaint_bboxes[:, 1] = (repaint_bboxes[:, 1]).astype(int)

    axes[2].imshow(repaint_img)
    axes[2].axis('off')
    axes[2].set_title('Repainted Bounding Boxes')
    for j in range(repaint_bboxes.shape[0]):
        if repaint_scores[j] > 0.5:
            rectangle = patches.Rectangle((repaint_bboxes[j][0], repaint_bboxes[j][1]), repaint_bboxes[j][2], repaint_bboxes[j][3],linewidth=1, edgecolor='r', facecolor='none')
            # cv2.putText(imgs[i], text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
            axes[2].add_patch(rectangle)
            axes[2].text(repaint_bboxes[j][0], repaint_bboxes[j][1], annotations_coco_91[repaint_labels[j]], fontsize=8, bbox=dict(facecolor='black', alpha=0.8, pad=1), color='white')


    orig_img = (255.0 * np.transpose(orig_imgs[i].numpy(), (1,2,0))).astype(np.uint8)
    orig_labels = orig_predictions[i]['labels'].detach().numpy()
    orig_scores = orig_predictions[i]['scores'].detach().numpy()

    orig_bboxes = copy.deepcopy(orig_predictions[i]['boxes'].detach().numpy())
    orig_bboxes[:, 2] = (orig_bboxes[:, 2] - orig_bboxes[:, 0]).astype(int)
    orig_bboxes[:, 3] = (orig_bboxes[:, 3] - orig_bboxes[:, 1]).astype(int)
    orig_bboxes[:, 0] = (orig_bboxes[:, 0]).astype(int)
    orig_bboxes[:, 1] = (orig_bboxes[:, 1]).astype(int)

    axes[3].imshow(orig_img)
    axes[3].axis('off')
    axes[3].set_title('Vanilla Detector Bounding Boxes')
    for j in range(orig_bboxes.shape[0]):
        if orig_scores[j] > 0.5:
            rectangle = patches.Rectangle((orig_bboxes[j][0], orig_bboxes[j][1]), orig_bboxes[j][2], orig_bboxes[j][3],linewidth=1, edgecolor='r', facecolor='none')
            # cv2.putText(imgs[i], text, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
            axes[3].add_patch(rectangle)
            axes[3].text(orig_bboxes[j][0], orig_bboxes[j][1], annotations_coco_91[orig_labels[j]], fontsize=8, bbox=dict(facecolor='black', alpha=0.8, pad=1), color='white')



    plt.tight_layout()
    plt.show()
    # display(fig)
    plt.close()