In [1]:
import networkx as nx
import argparse
import os
import sys
import numpy as np
import json
import torch
import torchvision
from PIL import Image
import gc
sys.path.append('/root/autodl-tmp/GroundedSAM_src/Grounded-Segment-Anything')
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap

In [2]:
import pickle
def save_graph(graph, file_path):
    with open(file_path, 'wb') as f:
        pickle.dump(graph, f)

def load_image(image_path):
    # load image
    image_pil = Image.open(image_path).convert("RGB")  # load image

    transform = T.Compose(
        [
            # T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image, _ = transform(image_pil, None)  # 3, h, w
    return image_pil, image

def load_model(model_config_path, model_checkpoint_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print(load_res)
    _ = model.eval()
    return model


def get_grounding_output(model, image, caption, box_threshold, text_threshold,device="cpu"):
    caption = caption.lower()
    caption = caption.strip()
    if not caption.endswith("."):
        caption = caption + "."
    model = model.to(device)
    image = image.to(device)
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
    logits = outputs["pred_logits"].cpu().sigmoid()[0]  # (nq, 256)
    boxes = outputs["pred_boxes"].cpu()[0]  # (nq, 4)
    logits.shape[0]

    # filter output
    logits_filt = logits.clone()
    boxes_filt = boxes.clone()
    filt_mask = logits_filt.max(dim=1)[0] > box_threshold
    logits_filt = logits_filt[filt_mask]  # num_filt, 256
    boxes_filt = boxes_filt[filt_mask]  # num_filt, 4
    logits_filt.shape[0]

    # get phrase
    tokenlizer = model.tokenizer
    tokenized = tokenlizer(caption)
    # build pred
    pred_phrases = []
    scores = []
    for logit, box in zip(logits_filt, boxes_filt):
        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
        pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
        scores.append(logit.max().item())

    return boxes_filt, torch.Tensor(scores), pred_phrases

In [3]:
import mimetypes
import os
from io import BytesIO
from typing import Union
import cv2
import requests
import torch
import transformers
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from tqdm import tqdm
import sys

sys.path.append("/root/autodl-tmp/Otter/Otter/src/otter_ai")
os.chdir('/root/autodl-tmp/Otter/Otter/src/otter_ai')
sys.path.append("..")
from models.otter.modeling_otter import OtterForConditionalGeneration


# Disable warnings
requests.packages.urllib3.disable_warnings()

# ------------------- Utility Functions -------------------


def get_content_type(file_path):
    content_type, _ = mimetypes.guess_type(file_path)
    return content_type


# ------------------- Image and Video Handling Functions -------------------

def get_image(url: str) -> Union[Image.Image, list]:
    if "://" not in url:  # Local file
        content_type = get_content_type(url)
    else:  # Remote URL
        content_type = requests.head(url, stream=True, verify=False).headers.get("Content-Type")

    if "image" in content_type:
        if "://" not in url:  # Local file
            return Image.open(url)
        else:  # Remote URL
            return Image.open(requests.get(url, stream=True, verify=False).raw)
    else:
        raise ValueError("Invalid content type. Expected image or video.")


# ------------------- OTTER Prompt and Response Functions -------------------


def get_formatted_prompt(prompt: str, in_context_prompts: list = []) -> str:
    in_context_string = ""
    for in_context_prompt, in_context_answer in in_context_prompts:
        in_context_string += f"<image>User: {in_context_prompt} GPT:<answer> {in_context_answer}<|endofchunk|>"
    return f"{in_context_string}<image>User: {prompt} GPT:<answer>"


def get_response(image_list, prompt: str, model=None, image_processor=None, in_context_prompts: list = []) -> str:
    input_data = image_list

    if isinstance(input_data, Image.Image):
        vision_x = image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    elif isinstance(input_data, list):  # list of video frames
        vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    else:
        raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")

    print(get_formatted_prompt(prompt, in_context_prompts))
    
    lang_x = model.text_tokenizer(
        [
            get_formatted_prompt(prompt, in_context_prompts),
        ],
        return_tensors="pt",
    )
    bad_words_id = tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x["input_ids"].to(model.device),
        attention_mask=lang_x["attention_mask"].to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )
    parsed_output = (
        model.text_tokenizer.decode(generated_text[0])
        .split("<answer>")[-1]
        .lstrip()
        .rstrip()
        .split("<|endofchunk|>")[0]
        .lstrip()
        .rstrip()
        .lstrip('"')
        .rstrip('"')
    )
    return parsed_output

def get_response_test(image_list, prompt: str, model=None, image_processor=None) -> str:
    input_data = image_list

    if isinstance(input_data, Image.Image):
        vision_x = image_processor.preprocess([input_data], return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    elif isinstance(input_data, list):  # list of video frames
        vision_x = image_processor.preprocess(input_data, return_tensors="pt")["pixel_values"].unsqueeze(1).unsqueeze(0)
    else:
        raise ValueError("Invalid input data. Expected PIL Image or list of video frames.")

    lang_x = model.text_tokenizer(
        [
            prompt
        ],
        return_tensors="pt",
    )
    bad_words_id = tokenizer(["User:", "GPT1:", "GFT:", "GPT:"], add_special_tokens=False).input_ids
    generated_text = model.generate(
        vision_x=vision_x.to(model.device),
        lang_x=lang_x["input_ids"].to(model.device),
        attention_mask=lang_x["attention_mask"].to(model.device),
        max_new_tokens=512,
        num_beams=3,
        no_repeat_ngram_size=3,
        bad_words_ids=bad_words_id,
    )
    parsed_output = (
        model.text_tokenizer.decode(generated_text[0])
        .split("<answer>")[-1]
        .lstrip()
        .rstrip()
        .split("<|endofchunk|>")[0]
        .lstrip()
        .rstrip()
        .lstrip('"')
        .rstrip('"')
    )
    return parsed_output

In [10]:
import pickle
with open('/root/autodl-tmp/Decomposing-Complex-Visual-Textual-Tasks-into-Intermediate-Questions-for-Multi-Modal-COT/decomp_vqav2_1000.pkl', 'rb') as f2:
    decomp = pickle.load(f2)

decomp

{'What is this photo taken looking through?': '1. Is this a photo?\n2. What is the object being looked through in this photo?',
 'What position is this man playing?': '1. Is there a man in the image?\n2. Is the man playing a sport?\n3. What sport is the man playing?\n4. What position is the man playing?',
 'What color is the players shirt?': "1. Is there a player in the image?\n2. Is the player wearing a shirt?\n3. What color is the player's shirt?",
 'Is this man a professional baseball player?': '1. Is there a man in the image?\n2. Is the man engaged in a baseball-related activity?\n3. Does the man possess the physical attributes typically associated with professional baseball players?\n4. Does the man have any identifying features or accessories that suggest he is a professional baseball player?',
 'What color is the snow?': '1. Is there snow present?\n2. What is the color of the snow?',
 'What is the person doing?': "1. Is there a person in the image?\n2. What actions or movements 

In [11]:
!pwd

/root/autodl-tmp/Otter/Otter/src/otter_ai


### BLIP2 baseline

In [13]:
with open('/root/autodl-tmp/Decomposing-Complex-Visual-Textual-Tasks-into-Intermediate-Questions-for-Multi-Modal-COT/vqav2_anwers.pkl', 'rb') as f2:
    all_answers = pickle.load(f2)

In [16]:
os.path.isabs('/root/autodl-tmp/Decomposing-Complex-Visual-Textual-Tasks-into-Intermediate-Questions-for-Multi-Modal-COT/')

True

In [21]:
all_answers1000 = all_answers[:1000]
all_answers1000

[['net', 'netting', 'mesh'],
 ['pitcher', 'catcher'],
 ['orange'],
 ['yes', 'no'],
 ['white'],
 ['skiing'],
 ['red', 'red and white', 'black'],
 ['frisbee', 'white frisbee', 'frisbie', 'flying disc'],
 ['yes'],
 ['frisbee'],
 ['yes'],
 ['yes'],
 ['airplane',
  'plane trail',
  'contrail',
  'jet trail',
  'snow',
  'contrails',
  'jet stream',
  'air trail'],
 ['yes', 'no'],
 ['white and purple',
  'white',
  'purple and white',
  'white and lavender',
  'lavender and white'],
 ['brushing teeth', 'brushing', 'brushing their teeth', 'toothbrush'],
 ['yes'],
 ['no'],
 ['waiting', 'standing', 'walking', 'unknown', 'frowning', 'pouting'],
 ['yes', 'no', 'poorly'],
 ['no', 'yes'],
 ['no'],
 ['no', 'yes'],
 ['yes', 'no'],
 ['black white', 'black and white', 'white and gray', 'black white gray'],
 ['no', 'yes'],
 ['no'],
 ['yes'],
 ['skateboard'],
 ['1', '3', '2'],
 ['backwards', 'blue'],
 ['no'],
 ['green', 'green and black'],
 ['yes'],
 ['motorbike', 'motorcycle', 'dirt bike', 'motocross bi

In [26]:
import os
HOME_dino = '/root/autodl-tmp/Decomposing-Complex-Visual-Textual-Tasks-into-Intermediate-Questions-for-Multi-Modal-COT/GroundedSAM_src/Grounded-Segment-Anything/'
HOME_dino

'/root/autodl-tmp/Decomposing-Complex-Visual-Textual-Tasks-into-Intermediate-Questions-for-Multi-Modal-COT/GroundedSAM_src/Grounded-Segment-Anything/'

In [27]:
import os

CONFIG_PATH = os.path.join(HOME_dino, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
print(CONFIG_PATH, "; exist:", os.path.isfile(CONFIG_PATH))

/root/autodl-tmp/Decomposing-Complex-Visual-Textual-Tasks-into-Intermediate-Questions-for-Multi-Modal-COT/GroundedSAM_src/Grounded-Segment-Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py ; exist: True


In [28]:
import os

WEIGHTS_NAME = "groundingdino_swint_ogc.pth"

WEIGHTS_PATH = os.path.join(HOME_dino, "weights", WEIGHTS_NAME)

In [29]:
from groundingdino.util.inference import load_model, load_image, predict, annotate

model = load_model(CONFIG_PATH, WEIGHTS_PATH)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


final text_encoder_type: bert-base-uncased


In [31]:
def save_box(img_pth):


    
    image_pil, image = load_image(img_pth)
    tags="cat"
    return image_pil, image

image_path='/root/autodl-tmp/mmcot_test/n157115.jpg'
img_pil, img = save_box(image_path)

In [None]:
def grounding(image, tags, image_pil):
    box_threshold = 0.25
    text_threshold = 0.2
    iou_threshold = 0.38
    boxes_filt, scores, pred_phrases = get_grounding_output(
        model, image, tags, box_threshold, text_threshold, device=device
    )
    size = image_pil.size
    H, W = size[1], size[0]
    for i in range(boxes_filt.size(0)):
            boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
            boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
            boxes_filt[i][2:] += boxes_filt[i][:2]

    boxes_filt = boxes_filt.cpu()
    # use NMS to handle overlapped boxes
    # print(f"Before NMS: {boxes_filt.shape[0]} boxes")
    nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
    boxes_filt = boxes_filt[nms_idx]
    pred_phrases = [pred_phrases[idx] for idx in nms_idx]

In [None]:
def readimage(img_dir, )