In [2]:
# CogVLM_Chat_Demo_with_LeGrad.ipynb

# This is a demo for using CogAgent and CogVLM in CLI with LeGrad functionality
# Make sure you have installed vicuna-7b-v1.5 tokenizer model (https://huggingface.co/lmsys/vicuna-7b-v1.5), full checkpoint of vicuna-7b-v1.5 LLM is not required.
# Strongly suggest to use GPU with bfloat16 support, otherwise, it will be slow.
# Mention that only one picture can be processed at one conversation, which means you can not replace or insert another picture during the conversation.

import argparse
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer

# LeGrad imports and utils (ensure these are defined as per your original code)
from legrad import hooked_resblock_forward, hooked_attention_forward, vit_dynamic_size_forward, min_max

# Simulate argparse
class Args:
    quant = None
    from_pretrained = "THUDM/cogagent-chat-hf"
    local_tokenizer = "lmsys/vicuna-7b-v1.5"
    fp16 = False
    bf16 = False

args = Args()
MODEL_PATH = args.from_pretrained
TOKENIZER_PATH = args.local_tokenizer
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
torch_type = torch.bfloat16 if args.bf16 else torch.float16

print(f"========Use torch type as:{torch_type} with device:{DEVICE}========\n\n")

if args.quant:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch_type,
        low_cpu_mem_usage=True,
        load_in_4bit=True,
        trust_remote_code=True
    ).eval()
else:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch_type,
        low_cpu_mem_usage=True,
        load_in_4bit=args.quant is not None,
        trust_remote_code=True
    ).to(DEVICE).eval()

# Define LeWrapper and LePreprocess classes
class LeWrapper(nn.Module):
    def __init__(self, model, layer_index=-2):
        super(LeWrapper, self).__init__()
        for attr in dir(model):
            if not attr.startswith('__'):
                setattr(self, attr, getattr(model, attr))
        self._activate_hooks(layer_index=layer_index)

    def _activate_hooks(self, layer_index):
        print('Activating necessary hooks and gradients ....')
        if isinstance(self.visual, VisionTransformer):
            self.visual.forward = types.MethodType(vit_dynamic_size_forward, self.visual)
            self.patch_size = self.visual.patch_size[0]
            self.starting_depth = layer_index if layer_index >= 0 else len(self.visual.transformer.resblocks) + layer_index
            if self.visual.attn_pool is None:
                self.model_type = 'clip'
                self._activate_self_attention_hooks()
            else:
                self.model_type = 'coca'
                self._activate_att_pool_hooks(layer_index=layer_index)
        elif isinstance(self.visual, TimmModel):
            self.visual.trunk.dynamic_img_size = True
            self.visual.trunk.patch_embed.dynamic_img_size = True
            self.visual.trunk.patch_embed.strict_img_size = False
            self.visual.trunk.patch_embed.flatten = False
            self.visual.trunk.patch_embed.output_fmt = 'NHWC'
            self.model_type = 'timm_siglip'
            self.patch_size = self.visual.trunk.patch_embed.patch_size[0]
            self.starting_depth = layer_index if layer_index >= 0 else len(self.visual.trunk.blocks) + layer_index
            self._activate_timm_attn_pool_hooks(layer_index=layer_index)
        else:
            raise ValueError("Model currently not supported, see legrad.list_pretrained() for a list of available models")
        print('Hooks and gradients activated!')

    def _activate_self_attention_hooks(self):
        for name, param in self.named_parameters():
            param.requires_grad = False
            if name.startswith('visual.transformer.resblocks'):
                depth = int(name.split('visual.transformer.resblocks.')[-1].split('.')[0])
                if depth >= self.starting_depth:
                    param.requires_grad = True
        for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)):
            self.visual.transformer.resblocks[layer].attn.forward = types.MethodType(hooked_attention_forward, self.visual.transformer.resblocks[layer].attn)
            self.visual.transformer.resblocks[layer].forward = types.MethodType(hooked_resblock_forward, self.visual.transformer.resblocks[layer])

    def compute_legrad(self, text_embedding, image=None, apply_correction=True):
        if 'clip' in self.model_type:
            return self.compute_legrad_clip(text_embedding, image)
        elif 'siglip' in self.model_type:
            return self.compute_legrad_siglip(text_embedding, image, apply_correction=apply_correction)
        elif 'coca' in self.model_type:
            return self.compute_legrad_coca(text_embedding, image)

    def compute_legrad_clip(self, text_embedding, image=None):
        num_prompts = text_embedding.shape[0]
        if image is not None:
            _ = self.encode_image(image)
        blocks_list = list(dict(self.visual.transformer.resblocks.named_children()).values())
        image_features_list = []
        for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)):
            intermediate_feat = self.visual.transformer.resblocks[layer].feat_post_mlp
            intermediate_feat = self.visual.ln_post(intermediate_feat.mean(dim=0)) @ self.visual.proj
            intermediate_feat = F.normalize(intermediate_feat, dim=-1)
            image_features_list.append(intermediate_feat)
        num_tokens = blocks_list[-1].feat_post_mlp.shape[0] - 1
        w = h = int(math.sqrt(num_tokens))
        accum_expl_map = 0
        for layer, (blk, img_feat) in enumerate(zip(blocks_list[self.starting_depth:], image_features_list)):
            self.visual.zero_grad()
            sim = text_embedding @ img_feat.transpose(-1, -2)
            one_hot = F.one_hot(torch.arange(0, num_prompts)).float().requires_grad_(True).to(text_embedding.device)
            one_hot = torch.sum(one_hot * sim)
            attn_map = blocks_list[self.starting_depth + layer].attn.attention_map
            grad = torch.autograd.grad(one_hot, [attn_map], retain_graph=True, create_graph=True)[0]
            grad = rearrange(grad, '(b h) n m -> b h n m', b=num_prompts)
            grad = torch.clamp(grad, min=0.)
            image_relevance = grad.mean(dim=1).mean(dim=1)[:, 1:]
            expl_map = rearrange(image_relevance, 'b (w h) -> 1 b w h', w=w, h=h)
            expl_map = F.interpolate(expl_map, scale_factor=self.patch_size, mode='bilinear')
            accum_expl_map += expl_map
        accum_expl_map = min_max(accum_expl_map)
        return accum_expl_map

class LePreprocess(nn.Module):
    def __init__(self, preprocess, image_size):
        super(LePreprocess, self).__init__()
        self.transform = Compose(
            [
                Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
                preprocess.transforms[-3],
                preprocess.transforms[-2],
                preprocess.transforms[-1],
            ]
        )
    def forward(self, image):
        return self.transform(image)

# Wrapping the CogVLM model with LeGrad functionality
model = LeWrapper(model)

# Function to chat and compute LeGrad
def chat_with_model(image_path, queries):
    if image_path:
        image = Image.open(image_path).convert('RGB')
    else:
        image = None
    
    history = []
    text_only_first_query = image is None

    for query in queries:
        if query == "clear":
            history.clear()
            continue

        if image is None:
            if text_only_first_query:
                query = text_only_template.format(query)
                text_only_first_query = False
            else:
                old_prompt = ''
                for old_query, response in history:
                    old_prompt += old_query + " " + response + "\n"
                query = old_prompt + f"USER: {query} ASSISTANT:"

        if image is None:
            input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, template_version='base')
        else:
            input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])

        inputs = {
            'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
            'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
            'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
            'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]] if image is not None else None,
        }
        if 'cross_images' in input_by_model and input_by_model['cross_images']:
            inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]

        # add any transformers params here.
        gen_kwargs = {"max_length": 2048, "do_sample": False}  # "temperature": 0.9
        with torch.no_grad():
            outputs = model.generate(**inputs, **gen_kwargs)
            outputs = outputs[:, inputs['input_ids'].shape[1]:]
            response = tokenizer.decode(outputs[0])
            response = response.split("</s>")[0]
            print("\nCog:", response)
        history.append((query, response))
        
        # Compute LeGrad explanation map
        text_embedding = model.encode_text(tokenizer([query]).to(DEVICE), normalize=True)
        explainability_map = model.compute_legrad_clip(image=image, text_embedding=text_embedding)
        visualize(heatmaps=explainability_map, image=image)

# Example usage
image_path = "path/to/your/image.jpg"  # Provide path to your image
queries = [
    "Hello, how are you?",
    "Can you tell me more about this image?",
    "What objects can you identify?"
]

chat_with_model(image_path, queries)




Downloading shards:   0%|          | 0/8 [00:00<?, ?it/s]

model-00001-of-00008.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

KeyboardInterrupt: 

In [1]:
import torch
from transformers import AutoModelForCausalLM

# Load the model
model = AutoModelForCausalLM.from_pretrained("THUDM/cogagent-chat-hf")

# Explore the attributes
dir(model)

Please 'pip install apex'


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_assisted_decoding',
 '_auto_class',
 '_autoset_attn_implementation',
 '_backward_compatibility_gradient_checkpointing',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_beam_sample',
 '_beam_search',
 '_buffers',
 '_call_impl',
 '_check_and_enable_flash_attn_2',
 '_check_and_enable_sdpa',
 '_compiled_call_impl',
 '_constrained_beam_search',
 '_contrastive_search',
 '_convert_head_mask_to_5d',
 '_copy_lm_head_original_to_resized',
 '_create_repo',
 '_dispatch_accelerate_model',
 '_expand_inputs_for_generation',
 

In [2]:
# Print the top-level submodules
for name, module in model.named_children():
    print(f"Submodule: {name} - {module}")


Submodule: model - CogAgentModel(
  (embed_tokens): Embedding(32000, 4096, padding_idx=0)
  (layers): ModuleList(
    (0-31): 32 x CogAgentDecoderLayer(
      (self_attn): VisionExpertAttention(
        (rotary_emb): RotaryEmbedding()
        (vision_expert_query_key_value): Linear(in_features=4096, out_features=12288, bias=False)
        (vision_expert_dense): Linear(in_features=4096, out_features=4096, bias=False)
        (language_expert_query_key_value): Linear(in_features=4096, out_features=12288, bias=False)
        (language_expert_dense): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (cross_attn): CrossAttention(
        (query): Linear(in_features=4096, out_features=1024, bias=False)
        (key_value): Linear(in_features=1024, out_features=2048, bias=False)
        (dense): Linear(in_features=1024, out_features=4096, bias=False)
      )
      (mlp): VisionExpertMLP(
        (language_mlp): MLP(
          (gate_proj): Linear(in_features=4096, out_featu

In [1]:
import torch
from transformers import AutoModelForCausalLM

# Load the CogVLM model
model = AutoModelForCausalLM.from_pretrained("THUDM/cogagent-chat-hf")

# Print the top-level submodules to identify possible visual components
print("Top-level submodules:")
for name, module in model.named_children():
    print(f"Submodule: {name} - {module}")

# Drill down into specific submodules if necessary
for name, module in model.named_children():
    print(f"\nInspecting submodule: {name}")
    for sub_name, sub_module in module.named_children():
        print(f"  Submodule: {sub_name} - {sub_module}")

Please 'pip install apex'


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Top-level submodules:
Submodule: model - CogAgentModel(
  (embed_tokens): Embedding(32000, 4096, padding_idx=0)
  (layers): ModuleList(
    (0-31): 32 x CogAgentDecoderLayer(
      (self_attn): VisionExpertAttention(
        (rotary_emb): RotaryEmbedding()
        (vision_expert_query_key_value): Linear(in_features=4096, out_features=12288, bias=False)
        (vision_expert_dense): Linear(in_features=4096, out_features=4096, bias=False)
        (language_expert_query_key_value): Linear(in_features=4096, out_features=12288, bias=False)
        (language_expert_dense): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (cross_attn): CrossAttention(
        (query): Linear(in_features=4096, out_features=1024, bias=False)
        (key_value): Linear(in_features=1024, out_features=2048, bias=False)
        (dense): Linear(in_features=1024, out_features=4096, bias=False)
      )
      (mlp): VisionExpertMLP(
        (language_mlp): MLP(
          (gate_proj): Linear(in_fe