In [None]:
# Format the correct output
# !python -m spacy download en_core_web_sm
# %pip install google-generativeai
# import google.generativeai as genai
# from google.generativeai.types import HarmCategory, HarmBlockThreshold
import os
import sys
sys.path.append('../../')
import json 
import copy
import re
import time
import cv2
import sys
import argparse
import numpy as np
import torch
from pathlib import Path
from matplotlib import pyplot as plt
from typing import Any, Dict, List
from sam_segment import predict_masks_with_sam_prompts
from stable_diffusion_inpaint import fill_img_with_sd
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
    show_mask, show_points, get_clicked_point
from tasks.segmentation_task import SegmentationSample, SegmentationTask
from diffusers import StableDiffusionInpaintPipeline
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline
from PIL import Image
import io
from utils import load_img_to_array, save_array_to_img, format_img
import pandas as pd
from tasks.segmentation_task import SegmentationSample, SegmentationTask

# load pyarrow
import pandas as pd

vqa_path = "{}/data/VQAv2_arrows/vqav2_train.arrow".format(os.path.expanduser("~/SegmentationSubstitution")) 
# load the VQA dataset
data = pd.read_feather(vqa_path)
all_generated_files = os.listdir("../../results/vqa_removal_val/")
data.head()


vqa_data = {}
qid_img_q = {}
vqa_qid_obj_dir = 'vqav2_val_obj.txt'
output_dir = "../../results/vqa_removal_val"

for _,row in data.iterrows():
    img = row['image']
    for idx,qid in enumerate(row['question_id']):
        qid_img_q[qid] = {"img": img, "q": row['questions'][idx], "img_id": row['image_id']}


with open(vqa_qid_obj_dir, 'r') as f:
    for row in f:
        content = row.rstrip().split('\t')
        assert len(content) == 2
        qid = int(content[0])
        if qid not in qid_img_q:
            continue
        llm_res = json.loads(content[1])
        vqa_data[qid] = {
            'object': llm_res['object'],
            'q': qid_img_q[qid]['q'],
            'img': qid_img_q[qid]['img'],
            'new_label': llm_res['new_answer'],
            'qid': qid,
            'img_id': qid_img_q[qid]['img_id']
        }

### Load all the pipelines and models at once
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading SD model...")
infill_pipe = StableDiffusionInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
        torch_dtype=torch.float32,
    ).to(device)
print("Loading SAM model...")
detector_id = "IDEA-Research/grounding-dino-tiny"
segmenter_id = "facebook/sam-vit-base"

detector_pipe = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)
segmenter = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
processor = AutoProcessor.from_pretrained(segmenter_id)


In [None]:
import io
from PIL import Image
import matplotlib.pyplot as plt
import random
import random
from dataclasses import dataclass
from typing import Any, List, Dict, Optional, Union, Tuple

import cv2
import torch
import requests
import numpy as np
from PIL import Image
import plotly.express as px
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline

def get_raw_image(index):
    return Image.open(io.BytesIO(data['image'][index])).convert("RGB")

def convert_image_to_binary(image):
    with io.BytesIO() as output:
        # Save the image to the BytesIO object in JPEG format
        image.save(output, format='JPEG')
        # Get the binary data from the BytesIO object
        binary_data = output.getvalue()
    return binary_data

def get_question_from_file(file):
    image_id = file.split("_")[0]
    row_id = data[(data['image_id'] == int(image_id))].index[0]
    question_id = file.split("_")[1].split(".")[0]
    question = data['questions'][row_id][list(data['question_id'][row_id]).index(int(question_id))]
    return question, row_id

@dataclass
class BoundingBox:
    xmin: int
    ymin: int
    xmax: int
    ymax: int

    @property
    def xyxy(self) -> List[float]:
        return [self.xmin, self.ymin, self.xmax, self.ymax]

@dataclass
class DetectionResult:
    score: float
    label: str
    box: BoundingBox
    mask: Optional[np.array] = None

    @classmethod
    def from_dict(cls, detection_dict: Dict) -> 'DetectionResult':
        return cls(score=detection_dict['score'],
                   label=detection_dict['label'],
                   box=BoundingBox(xmin=detection_dict['box']['xmin'],
                                   ymin=detection_dict['box']['ymin'],
                                   xmax=detection_dict['box']['xmax'],
                                   ymax=detection_dict['box']['ymax']))
        
def annotate(image: Union[Image.Image, np.ndarray], detection_results: List[DetectionResult]) -> np.ndarray:
    # Convert PIL Image to OpenCV format
    image_cv2 = np.array(image) if isinstance(image, Image.Image) else image
    image_cv2 = cv2.cvtColor(image_cv2, cv2.COLOR_RGB2BGR)

    # Iterate over detections and add bounding boxes and masks
    for detection in detection_results:
        label = detection.label
        score = detection.score
        box = detection.box
        mask = detection.mask

        # Sample a random color for each detection
        color = np.random.randint(0, 256, size=3)

        # Draw bounding box
        cv2.rectangle(image_cv2, (box.xmin, box.ymin), (box.xmax, box.ymax), color.tolist(), 2)
        cv2.putText(image_cv2, f'{label}: {score:.2f}', (box.xmin, box.ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color.tolist(), 2)

        # If mask is available, apply it
        if mask is not None:
            # Convert mask to uint8
            mask_uint8 = (mask * 255).astype(np.uint8)
            contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(image_cv2, contours, -1, color.tolist(), 2)

    return cv2.cvtColor(image_cv2, cv2.COLOR_BGR2RGB)

def plot_detections(
    image: Union[Image.Image, np.ndarray],
    detections: List[DetectionResult],
    save_name: Optional[str] = None
) -> None:
    return annotate(image, detections)
    
def random_named_css_colors(num_colors: int) -> List[str]:
    """
    Returns a list of randomly selected named CSS colors.

    Args:
    - num_colors (int): Number of random colors to generate.

    Returns:
    - list: List of randomly selected named CSS colors.
    """
    # List of named CSS colors
    named_css_colors = [
        'aliceblue', 'antiquewhite', 'aqua', 'aquamarine', 'azure', 'beige', 'bisque', 'black', 'blanchedalmond',
        'blue', 'blueviolet', 'brown', 'burlywood', 'cadetblue', 'chartreuse', 'chocolate', 'coral', 'cornflowerblue',
        'cornsilk', 'crimson', 'cyan', 'darkblue', 'darkcyan', 'darkgoldenrod', 'darkgray', 'darkgreen', 'darkgrey',
        'darkkhaki', 'darkmagenta', 'darkolivegreen', 'darkorange', 'darkorchid', 'darkred', 'darksalmon', 'darkseagreen',
        'darkslateblue', 'darkslategray', 'darkslategrey', 'darkturquoise', 'darkviolet', 'deeppink', 'deepskyblue',
        'dimgray', 'dimgrey', 'dodgerblue', 'firebrick', 'floralwhite', 'forestgreen', 'fuchsia', 'gainsboro', 'ghostwhite',
        'gold', 'goldenrod', 'gray', 'green', 'greenyellow', 'grey', 'honeydew', 'hotpink', 'indianred', 'indigo', 'ivory',
        'khaki', 'lavender', 'lavenderblush', 'lawngreen', 'lemonchiffon', 'lightblue', 'lightcoral', 'lightcyan', 'lightgoldenrodyellow',
        'lightgray', 'lightgreen', 'lightgrey', 'lightpink', 'lightsalmon', 'lightseagreen', 'lightskyblue', 'lightslategray',
        'lightslategrey', 'lightsteelblue', 'lightyellow', 'lime', 'limegreen', 'linen', 'magenta', 'maroon', 'mediumaquamarine',
        'mediumblue', 'mediumorchid', 'mediumpurple', 'mediumseagreen', 'mediumslateblue', 'mediumspringgreen', 'mediumturquoise',
        'mediumvioletred', 'midnightblue', 'mintcream', 'mistyrose', 'moccasin', 'navajowhite', 'navy', 'oldlace', 'olive',
        'olivedrab', 'orange', 'orangered', 'orchid', 'palegoldenrod', 'palegreen', 'paleturquoise', 'palevioletred', 'papayawhip',
        'peachpuff', 'peru', 'pink', 'plum', 'powderblue', 'purple', 'rebeccapurple', 'red', 'rosybrown', 'royalblue', 'saddlebrown',
        'salmon', 'sandybrown', 'seagreen', 'seashell', 'sienna', 'silver', 'skyblue', 'slateblue', 'slategray', 'slategrey',
        'snow', 'springgreen', 'steelblue', 'tan', 'teal', 'thistle', 'tomato', 'turquoise', 'violet', 'wheat', 'white',
        'whitesmoke', 'yellow', 'yellowgreen'
    ]

    # Sample random named CSS colors
    return random.sample(named_css_colors, min(num_colors, len(named_css_colors)))

def plot_detections_plotly(
    image: np.ndarray,
    detections: List[DetectionResult],
    class_colors: Optional[Dict[str, str]] = None
) -> None:
    # If class_colors is not provided, generate random colors for each class
    if class_colors is None:
        num_detections = len(detections)
        colors = random_named_css_colors(num_detections)
        class_colors = {}
        for i in range(num_detections):
            class_colors[i] = colors[i]


    fig = px.imshow(image)

    # Add bounding boxes
    shapes = []
    annotations = []
    for idx, detection in enumerate(detections):
        label = detection.label
        box = detection.box
        score = detection.score
        mask = detection.mask

        polygon = mask_to_polygon(mask)

        fig.add_trace(go.Scatter(
            x=[point[0] for point in polygon] + [polygon[0][0]],
            y=[point[1] for point in polygon] + [polygon[0][1]],
            mode='lines',
            line=dict(color=class_colors[idx], width=2),
            fill='toself',
            name=f"{label}: {score:.2f}"
        ))

        xmin, ymin, xmax, ymax = box.xyxy
        shape = [
            dict(
                type="rect",
                xref="x", yref="y",
                x0=xmin, y0=ymin,
                x1=xmax, y1=ymax,
                line=dict(color=class_colors[idx])
            )
        ]
        annotation = [
            dict(
                x=(xmin+xmax) // 2, y=(ymin+ymax) // 2,
                xref="x", yref="y",
                text=f"{label}: {score:.2f}",
            )
        ]

        shapes.append(shape)
        annotations.append(annotation)

    # Update layout
    button_shapes = [dict(label="None",method="relayout",args=["shapes", []])]
    button_shapes = button_shapes + [
        dict(label=f"Detection {idx+1}",method="relayout",args=["shapes", shape]) for idx, shape in enumerate(shapes)
    ]
    button_shapes = button_shapes + [dict(label="All", method="relayout", args=["shapes", sum(shapes, [])])]

    fig.update_layout(
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        # margin=dict(l=0, r=0, t=0, b=0),
        showlegend=True,
        updatemenus=[
            dict(
                type="buttons",
                direction="up",
                buttons=button_shapes
            )
        ],
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1
        )
    )

    # Show plot
    fig.show()


def mask_to_polygon(mask: np.ndarray) -> List[List[int]]:
    # Find contours in the binary mask
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Find the contour with the largest area
    largest_contour = max(contours, key=cv2.contourArea)

    # Extract the vertices of the contour
    polygon = largest_contour.reshape(-1, 2).tolist()

    return polygon

def polygon_to_mask(polygon: List[Tuple[int, int]], image_shape: Tuple[int, int]) -> np.ndarray:
    """
    Convert a polygon to a segmentation mask.

    Args:
    - polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
    - image_shape (tuple): Shape of the image (height, width) for the mask.

    Returns:
    - np.ndarray: Segmentation mask with the polygon filled.
    """
    # Create an empty mask
    mask = np.zeros(image_shape, dtype=np.uint8)

    # Convert polygon to an array of points
    pts = np.array(polygon, dtype=np.int32)

    # Fill the polygon with white color (255)
    cv2.fillPoly(mask, [pts], color=(255,))

    return mask

def load_image(image_str: str) -> Image.Image:
    if image_str.startswith("http"):
        image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB")
    else:
        image = Image.open(image_str).convert("RGB")

    return image

def get_boxes(results: DetectionResult) -> List[List[List[float]]]:
    boxes = []
    for result in results:
        xyxy = result.box.xyxy
        boxes.append(xyxy)

    return [boxes]

def refine_masks(masks: torch.BoolTensor, polygon_refinement: bool = False) -> List[np.ndarray]:
    masks = masks.cpu().float()
    masks = masks.permute(0, 2, 3, 1)
    masks = masks.mean(axis=-1)
    masks = (masks > 0).int()
    masks = masks.numpy().astype(np.uint8)
    masks = list(masks)

    if polygon_refinement:
        for idx, mask in enumerate(masks):
            shape = mask.shape
            polygon = mask_to_polygon(mask)
            mask = polygon_to_mask(polygon, shape)
            masks[idx] = mask

    return masks

def detect(
    image: Image.Image,
    labels: List[str],
    threshold: float = 0.3,
    detector_id: Optional[str] = None
) -> List[Dict[str, Any]]:
    """
    Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    detector_id = detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny"
    object_detector = pipeline(model=detector_id, task="zero-shot-object-detection", device=device)

    labels = [label if label.endswith(".") else label+"." for label in labels]

    results = object_detector(image,  candidate_labels=labels, threshold=threshold)
    results = [DetectionResult.from_dict(result) for result in results]

    return results

def segment(
    image: Image.Image,
    detection_results: List[Dict[str, Any]],
    polygon_refinement: bool = False,
    segmenter_id: Optional[str] = None
) -> List[DetectionResult]:
    """
    Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base"

    segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device)
    processor = AutoProcessor.from_pretrained(segmenter_id)

    boxes = get_boxes(detection_results)
    inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to(device)

    outputs = segmentator(**inputs)
    masks = processor.post_process_masks(
        masks=outputs.pred_masks,
        original_sizes=inputs.original_sizes,
        reshaped_input_sizes=inputs.reshaped_input_sizes
    )[0]

    masks = refine_masks(masks, polygon_refinement)

    for detection_result, mask in zip(detection_results, masks):
        detection_result.mask = mask

    return detection_results

def grounded_segmentation(
    image: Union[Image.Image, str],
    labels: List[str],
    threshold: float = 0.3,
    polygon_refinement: bool = False,
    detector_id: Optional[str] = None,
    segmenter_id: Optional[str] = None
) -> Tuple[np.ndarray, List[DetectionResult]]:
    if isinstance(image, str):
        image = load_image(image)

    detections = detect(image, labels, threshold, detector_id)
    detections = segment(image, detections, polygon_refinement, segmenter_id)

    return np.array(image), detections

files = files = ['480275_480275003.jpeg', '262565_262565000.jpeg', '397587_397587000.jpeg', '266579_266579001.jpeg']

In [None]:
file = files[3]
qid = file.split("_")[1].split(".")[0]
image_id = file.split("_")[0]
perturbed_path = f"../../results/vqa_removal_val/{file}"
question, row_id = get_question_from_file(file)

print(question)
print(file)
raise ValueError
original_image = get_raw_image(row_id)
perturbed_image = Image.open(perturbed_path)
object = vqa_data[int(qid)]['object']


In [None]:
image_array, detections = grounded_segmentation(
    image=original_image,
    labels=[object],
    threshold=0.3,
    polygon_refinement=True,
    detector_id=detector_id,
    segmenter_id=segmenter_id
)

masked_original_image = Image.fromarray(plot_detections(image_array, detections))
plt.imshow(masked_original_image)
perturbed_image

In [None]:
masked_original_image.save(f"examples/{object}_original.png")
perturbed_image.save(f"examples/{object}_removed.png")


In [None]:
# VQA generations

# TODO: also render masks object detection side by side

counterfactuals = [
    '526197_526197006.jpeg', # no doughnuts - GPT tricked?
    '528030_528030002.jpeg', # no necktie
    '395717_395717003.jpeg', # no boatsman
    '263961_263961000.jpeg', # no bike
    '525119_525119000.jpeg', # no food on the plate
    '437205_437205008.jpeg', # no bananas
    '396903_396903001.jpeg', # no plane
    '397734_397734009.jpeg', # no hotdog
    '132415_132415003.jpeg', # no fork
    '22461_22461002.jpeg', # no cereal
    '264568_264568000.jpeg', # no cook
    '262565_262565000.jpeg', # no bat - GPT tricked
    '632_632008.jpeg', # no mirror - GPT tricked
    '262509_262509005.jpeg', # no boat
    '131089_131089004.jpeg', # no bat kid
    '264375_264375000.jpeg', # no streetlight - GPT tricked?
    '5418_5418000.jpeg', # no giraffes
    '397587_397587000.jpeg', # no tie - GPT tricked
    '529105_529105000.jpeg', # no horse - GPT tricked
    '266579_266579001.jpeg', # no bird - GPT tricked
    '267664_267664002.jpeg', # no chair
    '480275_480275003.jpeg', # no bananas - GPT tricked
    '393523_393523001.jpeg', # no bridge
    '133100_133100001.jpeg', # no zebra - GPT tricked
    '134689_134689006.jpeg', # no giraffes
    '4175_4175000.jpeg', # no server - GPT tricked
    '267664_267664001.jpeg', # no cat
    '526711_526711002.jpeg', # no aircraft
    '5385_5385010.jpeg', # no bat - GPT tricked for variant of question "What is he holding?", but not original question
    '502766_502766000.jpeg', # no sheepdog. GPT not tricked despite context cues
    '5352_5352026.jpeg', # no beer, GPT tricked
    '393809_393809008.jpeg', # GPT tricked
    '133343_133343002.jpeg', # no sunglasses / eyes
    '43816_43816016.jpeg', # no mitt - GPT tricked
    'How many different types of animals are on the field?', # no animals - GPT tricked
    '526580_526580007.jpeg', # no jeans - GPT mini tricked
]

id = 11
file = counterfactuals[-6]
perturbed_path = f"../../results/vqa_removal/{file}"
question, row_id = get_question_from_file(file)
print(question)
Image.open(perturbed_path)

In [None]:
sys.path.append('../')
from segmentation_task import *

original_image = get_raw_image(row_id)
sample = SegmentationSample(question, None, None, image=np.array(original_image), segment_prompt="baseball bat")
perturbed_image = sample.substitution("broomstick", False)
perturbed_image

In [None]:
data['questions'][:20]

In [None]:
from segment import *

def perturb_vqav2_image(segment_prompt, inpaint_prompt, index):

        image_source, _, image_mask = get_frames_from_prompt("", segment_prompt, model, get_raw_image(index))

        # General perturbation: inpaint random replacement of same type of object
        # Note: sometimes this borks and just removes the object totally (particularly for small objects relative to rest of image)
        image_perturbed = inpaint_mask(inpaint_prompt, image_source, image_mask)
        return image_perturbed

In [None]:
from tqdm import tqdm
from importlib import reload
reload(webqa)
import random
random.seed(0)

# new df with same schema as data that will be filled with perturbed 
vqa_path = "../data/VQAv2_arrows/vqav2_val.arrow"
data = pd.read_feather(vqa_path)
perturbed_data = copy.deepcopy(data)
dataset_root = vqa_path + ".perturbed"
os.makedirs(dataset_root, exist_ok=True)          
count = 1
for index in tqdm(range(len(data))):
    qa_list = list(zip(data['questions'][index], [x[0] for x in data['answers'][index]]))
    for i, (question, answer) in enumerate(qa_list):
        if (answer.lower() in ['yes', 'no'] or question.lower().startswith(('what color', 'how many'))):
            count += 1
count
            # object_noun = extract_object(question)
            # if answer.lower() in ['yes', 'no']:
            #     qcate = 'yesno'
            #     if 'yes' in answer.lower():
            #         # remove and set answer to no
            #         infill_prompt = "blank.png"
            #         rand_answer = 'no'
            #     else:
            #         # add and set answer to yes
            #         infill_prompt = object_noun
            #         rand_answer = 'yes'
            # else:
            #     # question.lower().startswith(('what color', 'how many')):
            #     if question.lower().startswith('what color'):
            #         qcate = 'color'
            #     else:
            #         qcate = 'number'      
            #     rand_answer = random.choice(webqa.domain_dict_gen[qcate])
            #     infill_prompt = rand_answer + ' ' + object_noun

#             print(question, answer, infill_prompt, object_noun)
#             perturbed_image = perturb_vqav2_image(object_noun, infill_prompt, index)
#             new_row = copy.deepcopy(data.iloc[index])
#             new_row['image'] = convert_image_to_binary(perturbed_image)
#             new_row['answers'] = [rand_answer] 
#             new_row['questions'] = [question]
#             perturbed_data = pd.concat([perturbed_data, pd.DataFrame([new_row])], ignore_index=True)

# train_table = pa.Table.from_pandas(perturbed_data)
  
# # save perturbed_data as new pyarrow file 
# with pa.OSFile(f"{dataset_root}/rand_augmented.arrow", "wb") as sink:
#         with pa.RecordBatchFileWriter(sink, train_table.schema) as writer:
#             writer.write_table(train_table)

# # remove original data
# perturbed_data_only = perturbed_data.drop(data.index)
# perturbed_table_only = pa.Table.from_pandas(perturbed_data_only)
# with pa.OSFile(f"{dataset_root}/rand_only.arrow", "wb") as sink:        
#         with pa.RecordBatchFileWriter(sink, perturbed_table_only.schema) as writer:
#             writer.write_table(perturbed_table_only)