In [5]:
import argparse
import os
import copy

import numpy as np
import json
import torch
from PIL import Image, ImageDraw, ImageFont

# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap

# segment anything
from segment_anything import build_sam, SamPredictor 
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [43]:
def load_image(image_path):
    # load image
    image_pil = Image.open(image_path).convert("RGB")  # load image

    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image, _ = transform(image_pil, None)  # 3, h, w
    return image_pil, image


def load_model(model_config_path, model_checkpoint_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print(load_res)
    _ = model.eval()
    return model


def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
    caption = caption.lower()
    caption = caption.strip()
    if not caption.endswith("."):
        caption = caption + "."
    model = model.to(device)
    image = image.to(device)
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
    logits = outputs["pred_logits"].cpu().sigmoid()[0]  # (nq, 256)
    boxes = outputs["pred_boxes"].cpu()[0]  # (nq, 4)
    logits.shape[0]

    # filter output
    logits_filt = logits.clone()
    boxes_filt = boxes.clone()
    filt_mask = logits_filt.max(dim=1)[0] > box_threshold
    logits_filt = logits_filt[filt_mask]  # num_filt, 256
    boxes_filt = boxes_filt[filt_mask]  # num_filt, 4
    logits_filt.shape[0]

    # get phrase
    tokenlizer = model.tokenizer
    tokenized = tokenlizer(caption)
    # build pred
    pred_phrases = []
    for logit, box in zip(logits_filt, boxes_filt):
        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
        if with_logits:
            pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
        else:
            pred_phrases.append(pred_phrase)

    return boxes_filt, pred_phrases

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_box(box, ax, label):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 
    ax.text(x0, y0, label)


def save_mask_data(output_dir, mask_list, box_list, label_list):
    value = 0  # 0 for background

    mask_img = torch.zeros(mask_list.shape[-2:])
    for idx, mask in enumerate(mask_list):
        mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
    plt.figure(figsize=(10, 10))
    plt.imshow(mask_img.numpy())
    plt.axis('off')
    plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)

    json_data = [{
        'value': value,
        'label': 'background'
    }]
    for label, box in zip(label_list, box_list):
        value += 1
        name, logit = label.split('(')
        logit = logit[:-1] # the last is ')'
        json_data.append({
            'value': value,
            'label': name,
            'logit': float(logit),
            'box': box.numpy().tolist(),
        })
    with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
        json.dump(json_data, f)

In [5]:
# initialize SAM
sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)

size = image_pil.size
H, W = size[1], size[0]
for i in range(boxes_filt.size(0)):
    boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
    boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
    boxes_filt[i][2:] += boxes_filt[i][:2]

boxes_filt = boxes_filt.cpu()
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
    point_coords = None,
    point_labels = None,
    boxes = transformed_boxes,
    multimask_output = False,
)

In [6]:
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)

In [7]:
print(len(masks))
print(masks[0].keys())

122
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])


In [None]:
class RelateAnything(torch.nn.Module):
    self.sam()
    self.pair_net()

    def forward(image):
        masks_out, feat_out = self.sam(image)
        pair_out = self.pair_net(feat_out)
        # pair_out = [(int, int, float), (int, int, float), ...]
        rel_feat = clip(mask_i + mask_j) - clip(mask_i) - clip(mask_j)
        # image level feature
        rel_class = argmax(rel_feat, predefined_relation_dict)
        return mask_i, rel_class, mask_j

image = "demo.png"
model = RelateAnything()
output = model(image)
# output = [(int, int, int), (int, int, int), ...]
text_output = process_output(output)
relation_caption = ChatGPT().generate(output)

In [50]:
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
import PIL

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

image = PIL.Image.open("assets/demo2.jpg")

inputs = processor(image, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.97s/it]


a corgi dog running in the grass


In [52]:
import sys
sys.path.append("../")
import sng_parser
from pprint import pprint
import pickle
import time

# official example
graph = sng_parser.parse(generated_text)
pprint(graph)
sng_parser.tprint(graph)  # we provide a tabular visualization of the graph.

{'entities': [{'head': 'corgi dog',
               'lemma_head': 'corgi dog',
               'lemma_span': 'a corgi dog',
               'modifiers': [{'dep': 'det', 'lemma_span': 'a', 'span': 'a'}],
               'span': 'a corgi dog',
               'span_bounds': (0, 3),
               'type': 'unknown'},
              {'head': 'grass',
               'lemma_head': 'grass',
               'lemma_span': 'the grass',
               'modifiers': [{'dep': 'det',
                              'lemma_span': 'the',
                              'span': 'the'}],
               'span': 'the grass',
               'span_bounds': (5, 7),
               'type': 'unknown'}],
 'relations': [{'lemma_relation': 'run in',
                'object': 1,
                'relation': 'running in',
                'subject': 0}]}
Entities:
+-----------+-------------+-------------+
| Head      | Span        | Modifiers   |
|-----------+-------------+-------------|
| corgi dog | a corgi dog | a           |


In [53]:
# sng_parser.tprint(graph)  # we provide a tabular visualization of the graph.
# entities_data = [
#     [e['head'].lower(), e['span'].lower(), ','.join([ x['span'].lower() for x in e['modifiers'] ])]
#     for e in graph['entities']
# ]

entities = graph['entities']
relations_data = [
    [
        entities[rel['subject']]['head'].lower(),
        rel['relation'].lower(),
        entities[rel['object']]['head'].lower()
    ]
    for rel in graph['relations']
]
for item in entities:
    print(item['head'])
# print(relations_data)

corgi dog
grass


In [63]:
# run grounding dino model
boxes_filt = []
pred_phrases = []
for item in entities:
    print(item['head'])
    current_boxes_filt, current_pred_phrases = get_grounding_output(
        model, image, item['head'], box_threshold, text_threshold, device=device
    )
    # boxes_filt.extend(current_boxes_filt)
    # pred_phrases.extend(current_pred_phrases)


corgi dog
