In [1]:
import os
import json
import random
from types import SimpleNamespace

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image

from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION

# imports modules for registration
from minigpt4.datasets.builders import *
from minigpt4.models import *
from minigpt4.processors import *
from minigpt4.runners import *
from minigpt4.tasks import *

  warn(f"Failed to load image Python extension: {e}")
  from .autonotebook import tqdm as notebook_tqdm


Setting ds_accelerator to cuda (auto detect)


In [2]:
seed = 2023
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True

In [3]:
args = SimpleNamespace()
args.cfg_path = 'eval_configs/minigpt4_eval.yaml'
args.gpu_id = 0
args.options = None
cfg = Config(args)

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))

Loading VIT
Loading VIT Done
Loading Q-Former
Loading Q-Former Done
Loading LLAMA


Loading checkpoint shards: 100%|██████████| 2/2 [01:09<00:00, 34.75s/it]


Loading LLAMA Done
Load 4 training prompts
Prompt Example 
###Human: <Img><ImageHere></Img> Could you describe the contents of this image for me? ###Assistant: 
Load BLIP2-LLM Checkpoint: /home/chengzhang/Multimodal-Quantization/MiniGPT-4/checkpoints/prerained_minigpt4_7b.pth


In [4]:
linear_modules: dict[str, torch.nn.Linear] = {}

# for i, block in enumerate(model.visual_encoder.blocks):
#     linear_modules[f'vit/{i}-qkv-proj'] = block.attn.qkv
#     linear_modules[f'vit/{i}-o-proj'] = block.attn.proj
#     linear_modules[f'vit/{i}-fc1'] = block.mlp.fc1
#     linear_modules[f'vit/{i}-fc2'] = block.mlp.fc2

# for i, layer in enumerate(model.Qformer.bert.encoder.layer):
#     linear_modules[f'q-former/{i}-self-q-proj'] = layer.attention.self.query
#     linear_modules[f'q-former/{i}-self-k-proj'] = layer.attention.self.key
#     linear_modules[f'q-former/{i}-self-v-proj'] = layer.attention.self.value
#     linear_modules[f'q-former/{i}-self-o-proj'] = layer.attention.output.dense
#     if hasattr(layer, 'crossattention'):
#         linear_modules[f'q-former/{i}-cross-q-proj'] = layer.crossattention.self.query
#         linear_modules[f'q-former/{i}-cross-k-proj'] = layer.crossattention.self.key
#         linear_modules[f'q-former/{i}-cross-v-proj'] = layer.crossattention.self.value
#         linear_modules[f'q-former/{i}-cross-o-proj'] = layer.crossattention.output.dense
#     linear_modules[f'q-former/{i}-fc1'] = layer.intermediate_query.dense
#     linear_modules[f'q-former/{i}-fc2'] = layer.output_query.dense

for i, layer in enumerate(model.llama_model.model.layers):
    linear_modules[f'llama-ori/{i}-q-proj'] = layer.self_attn.q_proj
    linear_modules[f'llama-ori/{i}-k-proj'] = layer.self_attn.k_proj
    linear_modules[f'llama-ori/{i}-v-proj'] = layer.self_attn.v_proj
    linear_modules[f'llama-ori/{i}-o-proj'] = layer.self_attn.o_proj
    linear_modules[f'llama-ori/{i}-gate-proj'] = layer.mlp.gate_proj
    linear_modules[f'llama-ori/{i}-down-proj'] = layer.mlp.down_proj
    linear_modules[f'llama-ori/{i}-up-proj'] = layer.mlp.up_proj

for name, module in linear_modules.items():
    module.unique_name = name

ln_modules: dict[str, torch.nn.Linear] = {}

for i, layer in enumerate(model.llama_model.model.layers):
    ln_modules[f'llama-ori/{i}-input-ln'] = layer.input_layernorm
    ln_modules[f'llama-ori/{i}-post-attn-ln'] = layer.post_attention_layernorm

for name, module in ln_modules.items():
    module.unique_name = name

In [5]:
def save_weights():
    for name, module in linear_modules.items():
        torch.save(
            module.weight,
            f'/home/chengzhang/Multimodal-Quantization/MiniGPT-4/snapshot/weights/{name}.pt',
        )

# save_weights()

In [6]:
with open('../datasets/OK-VQA/question/OpenEnded_mscoco_val2014_questions.json') as f:
    questions = json.loads(f.read())['questions']

hooks = []

In [7]:
def save_activations(hooks):

    for i in range(10):

        q = questions[i]
        question = q['question']
        image_id = q['image_id']
        question_id = q['question_id']
        image = Image.open(f'../datasets/OK-VQA/image/val2014/COCO_val2014_{str(image_id).zfill(12)}.jpg')
        image = chat.vis_processor(image).unsqueeze(0).to(torch.float16).to('cuda')

        act_folder = f'/home/chengzhang/Multimodal-Quantization/MiniGPT-4/snapshot/activations/{question_id}'
        os.makedirs(f'{act_folder}/vit', exist_ok=True)
        os.makedirs(f'{act_folder}/q-former', exist_ok=True)
        os.makedirs(f'{act_folder}/llama-ori', exist_ok=True)

        for hook in hooks:
            hook.remove()
        hooks = []

        def hook(m, input, output):
            torch.save(input, f'{act_folder}/{m.unique_name}.pt')

        for name, module in linear_modules.items():
            hooks.append(module.register_forward_hook(hook))

        for name, module in ln_modules.items():
            hooks.append(module.register_forward_hook(hook))

        model({'image': image, 'text_input': [question]})

save_activations(hooks)

In [8]:
def save_text_activations(hooks):

    for i in range(10):

        q = questions[i]
        question = q['question']
        image_id = q['image_id']
        question_id = q['question_id']

        act_folder = f'/home/chengzhang/Multimodal-Quantization/MiniGPT-4/snapshot/text-activations/{question_id}'
        os.makedirs(f'{act_folder}/llama-ori', exist_ok=True)

        for hook in hooks:
            hook.remove()
        hooks = []

        def hook(m, input, output):
            torch.save(input, f'{act_folder}/{m.unique_name}.pt')

        for name, module in linear_modules.items():
            hooks.append(module.register_forward_hook(hook))

        for name, module in ln_modules.items():
            hooks.append(module.register_forward_hook(hook))

        model.text_forward({'text_input': [question]})

save_text_activations(hooks)