In [7]:
from os import path

import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import LazyConfig, instantiate
from diffusers import AutoPipelineForText2Image, DiffusionPipeline
from diffusers_interpret import StableDiffusionPipelineDetExplainer

from gill import layers

from torch.cuda import amp

In [8]:
model_path = '/mnt/workspace/model'
eva_id = 'zacbi2023/eva02'
sd_id = 'AI-ModelScope/stable-diffusion-v1-5'

In [9]:
# prepare eva
eva_base_path = path.join(model_path, eva_id)
eva_coco_config_rpath = 'projects/ViTDet/configs/eva2_o365_to_coco/eva2_o365_to_coco_cascade_mask_rcnn_vitdet_l_8attn_1536_lrd0p8.py'
eva_config_path = path.join(eva_base_path, eva_coco_config_rpath)

# replace with your eva02 weights path
eva_coco_weights_rpth = 'checkpoints/eva02_L_coco_seg_sys_o365.pth'
eva_weights_path = path.join(eva_base_path, eva_coco_weights_rpth)

custum_cfg = ['MODEL.RETINANET.SCORE_THRESH_TEST', 0.5,
                'MODEL.ROI_HEADS.SCORE_THRESH_TEST', 0.5,
                'MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH', 0.5,
                'DATASETS.TEST', [],
                'MODEL.WEIGHTS', eva_weights_path]
eva_cfg = LazyConfig.load(eva_config_path)
LazyConfig.apply_overrides(
    eva_cfg, [f"{key}={value}" for key, value in zip(custum_cfg[::2], custum_cfg[1::2])])

device = 'cuda'
eva = instantiate(eva_cfg.model).to(device)
DetectionCheckpointer(eva).load(eva_weights_path)
eva.eval()

Some model parameters or buffers are not found in the checkpoint:
[34mbackbone.net.blocks.0.attn.qkv.{bias, weight}[0m
[34mbackbone.net.blocks.0.attn.{rel_pos_h, rel_pos_w}[0m
[34mbackbone.net.blocks.0.mlp.fc1.{bias, weight}[0m
[34mbackbone.net.blocks.0.mlp.fc2.{bias, weight}[0m
[34mbackbone.net.blocks.1.attn.qkv.{bias, weight}[0m
[34mbackbone.net.blocks.1.attn.{rel_pos_h, rel_pos_w}[0m
[34mbackbone.net.blocks.1.mlp.fc1.{bias, weight}[0m
[34mbackbone.net.blocks.1.mlp.fc2.{bias, weight}[0m
[34mbackbone.net.blocks.10.attn.qkv.{bias, weight}[0m
[34mbackbone.net.blocks.10.attn.{rel_pos_h, rel_pos_w}[0m
[34mbackbone.net.blocks.10.mlp.fc1.{bias, weight}[0m
[34mbackbone.net.blocks.10.mlp.fc2.{bias, weight}[0m
[34mbackbone.net.blocks.11.attn.qkv.{bias, weight}[0m
[34mbackbone.net.blocks.11.attn.{rel_pos_h, rel_pos_w}[0m
[34mbackbone.net.blocks.11.mlp.fc1.{bias, weight}[0m
[34mbackbone.net.blocks.11.mlp.fc2.{bias, weight}[0m
[34mbackbone.net.blocks.12.attn.qkv.{

GeneralizedRCNN(
  (backbone): SimpleFeaturePyramid(
    (simfp_2): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
      (1): LayerNorm()
      (2): GELU(approximate='none')
      (3): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
      (4): Conv2d(
        256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
        (norm): LayerNorm()
      )
      (5): Conv2d(
        256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (norm): LayerNorm()
      )
    )
    (simfp_3): Sequential(
      (0): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
      (1): Conv2d(
        512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
        (norm): LayerNorm()
      )
      (2): Conv2d(
        256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
        (norm): LayerNorm()
      )
    )
    (simfp_4): Sequential(
      (0): Conv2d(
        1024, 256, kernel_size=(1, 1), stride=(1, 

In [10]:
torch_dtype=torch.bfloat16
sd_pipe = AutoPipelineForText2Image.from_pretrained(
    path.join(model_path, sd_id), torch_dtype=torch_dtype).to(device)
explainer = StableDiffusionPipelineDetExplainer(pipe=sd_pipe, det_model=eva)

Loading pipeline components...: 100%|██████████| 7/7 [00:02<00:00,  3.00it/s]


In [11]:
# llm生成的(batch_size, seq_len, hidden_dim)
raw_emb = torch.load('/mnt/workspace/data/tensor/raw_emb_tensor_cat_1.pt').to(torch_dtype)
raw_emb.requires_grad_(True)
# embedding img0-imge8
gen_prefix_embs = torch.load('/mnt/workspace/data/tensor/gen_prefix_embs_tensor_cat_1.pt').to(torch_dtype)
gen_prefix_embs.requires_grad_(True)

# gill_mapper: linear + Transformer + linear
gen_text_hidden_fcs = layers.GenTextHiddenFcs()
gill_state_dict = torch.load('/mnt/workspace/github/gill/checkpoints/gill_opt/pretrained_ckpt.pth.tar')

gen_text_hidden_fcs_state_dict = {}
for key, val in gill_state_dict['state_dict'].items():
    if 'gen_text_hidden_fcs' in key:
        prefix = 'gen_text_hidden_fcs' + key.split('gen_text_hidden_fcs')[1]
        gen_text_hidden_fcs_state_dict[prefix] = val
gen_text_hidden_fcs.load_state_dict(gen_text_hidden_fcs_state_dict)
gen_text_hidden_fcs.cuda()
gen_text_hidden_fcs.to(torch_dtype)



GenTextHiddenFcs(
  (gen_text_hidden_fcs): ModuleList(
    (0): TextFcLayer(
      (fc): Linear(in_features=4096, out_features=512, bias=True)
      (tfm): Transformer(
        (encoder): TransformerEncoder(
          (layers): ModuleList(
            (0-3): 4 x TransformerEncoderLayer(
              (self_attn): MultiheadAttention(
                (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
              )
              (linear1): Linear(in_features=512, out_features=2048, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (linear2): Linear(in_features=2048, out_features=512, bias=True)
              (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (dropout1): Dropout(p=0.0, inplace=False)
              (dropout2): Dropout(p=0.0, inplace=False)
            )
          )
          (norm): LayerNorm((512,), ep

In [12]:
gen_emb = gen_text_hidden_fcs.gen_text_hidden_fcs[0](raw_emb, gen_prefix_embs)

In [7]:
with torch.cuda.amp.autocast(dtype=torch.float16):
    output = explainer(
        prompt_embeds=gen_emb,
        num_inference_steps=50,
        target_cls_id=15,
        raw_embeds=raw_emb,
        n_last_diffusion_steps_to_consider_for_attributions=1
    )

  latents_shape = (batch_size, self.pipe.unet.in_channels, height // 8, width // 8)


  0%|          | 0/51 [00:00<?, ?it/s]

Calculating token attributions... 

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  x = F.conv2d(


Done!


In [9]:
output.token_attributions

[('0', 0.7846115),
 ('1', 1.3332276),
 ('2', 1.826821),
 ('3', 2.8268588),
 ('4', 0.3671959),
 ('5', 1.0050043),
 ('6', 1.2384505),
 ('7', 0.6856177)]