In [1]:
import torch
from torch import nn
import pytorch_lightning as pl
from open_mmicl.lvlm_interface import FlamingoInterface
import omegaconf
import json
import yaml
import requests
from PIL import Image
import torch.nn.functional as F

def calculate_kl_divergence(logits1, logits2, epsilon=0.):
    # 将logits转换为概率分布，并添加一个小的正数以提高数值稳定性
    print(logits1, logits2)
    probs1 = F.softmax(logits1, dim=-1) + epsilon
    probs2 = F.softmax(logits2, dim=-1) + epsilon
    print(probs1, probs2)
    # 计算KL散度
    kl_div = F.kl_div(probs1.log(), probs2, reduction='batchmean')

    return kl_div


In [2]:
demo_image_one = Image.open(
    requests.get(
        "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
    ).raw
)

demo_image_two = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
        stream=True
    ).raw
)

query_image = Image.open(
    requests.get(
        "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", 
        stream=True
    ).raw
)

prompts = [
    [
        demo_image_one,
        "An image of two cats.",
        demo_image_two,
        "An image of a bathroom sink.",
        query_image,
        "An image of a table with food"
    ],
    [
        demo_image_two,
        "An image of a bathroom sink.",
        demo_image_one,
        "An image of two cats.",
        query_image,
        "An image of a table with food"
    ],
]



In [3]:
with open('./configs/lvlm/flamingo_3B.yaml', 'r')as f:
    data = yaml.load(f, Loader=yaml.FullLoader)

with open('./configs/task/caption.yaml', 'r')as f:
    task_data = yaml.load(f, Loader=yaml.FullLoader)

cfg = omegaconf.DictConfig(data)
device = 'cuda'
task_config = omegaconf.DictConfig(task_data)

In [4]:
cfg.flamingo_checkpoint_dir = '/home/pyz/checkpoint/openflamingo/OpenFlamingo-3B-vitl-mpt1b'
cfg.tokenizer_path = 'anas-awadalla/mpt-1b-redpajama-200b'
cfg.lang_encoder_path
task_config.template = cfg.caption_prompt_template
task_config.instruction = ''
task_config.column_token_map

{'single_caption': '<X>'}

In [5]:
llm = FlamingoInterface(
        lang_encoder_path=cfg.lang_encoder_path,
        tokenizer_path=cfg.tokenizer_path,
        flamingo_checkpoint_dir=cfg.flamingo_checkpoint_dir,
        cross_attn_every_n_layers=cfg.cross_attn_every_n_layers,
        hf_root=cfg.hf_root,
        precision='bf16',
        device=device,
        prompt_template=task_config.template,
        column_token_map=task_config.column_token_map,
        icd_join_char=cfg.icd_join_char,
        load_from_local=cfg.load_from_local,
        instruction=task_config.instruction,
        init_device=cfg.init_device,
        image_field=task_config.image_field,
        label_field=task_config.output_column,
    )

Using pad_token, but it is not set yet.
You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding dimension will be 50280. This might induce some performance reduction as *Tensor Cores* will not be available. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
[32m2024-01-19 16:05:47.316[0m | [1mINFO    [0m | [36mopen_mmicl.lvlm_interface.flamingo_interface[0m:[36mcreate_model_and_transforms[0m:[36m317[0m - [1mFlamingo model initialized with 1046992944 trainable parameters[0m


In [6]:
class FlamingoICLAdapter(nn.Module):
    def __init__(self, llm):
        super().__init__()
        self.llm = llm
        n = 50280
        self.adapter = nn.Linear(n, n)

        # 创建一个对角线元素为1的单位矩阵
        identity_matrix = torch.eye(n)

        # 将线性层的权重初始化为单位矩阵
        self.adapter.weight.data = identity_matrix

        # 可选：如果您还想将偏置初始化为0
        self.adapter.bias.data.zero_()
        self.adapter = self.adapter.cuda()
        
        self.loss_fct = torch.nn.CrossEntropyLoss(
                reduction='none', ignore_index=self.llm.pad_token_id
            )
        
    def forward(self, model_input, ice_token_length):
        bs, aug_num, seq_len = model_input['lang_x'].shape
        length = (model_input['lang_x'] != self.llm.pad_token_id).sum(-1)
        query_length = length - ice_token_length
        for key in model_input:
            model_input[key] = model_input[key].reshape(-1, *model_input[key].shape[2:]).to('cuda')
        model_input['vision_x'] = model_input['vision_x'].to('cuda', torch.bfloat16)
        
        with torch.no_grad():
            with self.llm.autocast_context:
                outputs = self.llm(model_input)
            # 计算ICE Seq的质量分数（Prob）
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = model_input['lang_x'][..., 1:].contiguous()
            loss = self.loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            loss = loss.view(shift_labels.size())
            mask_length = ice_token_length.reshape(-1)
            loss_mask = torch.zeros_like(shift_labels)  # [batch, seqlen]
            for i in range(len(loss_mask)):
                for j in range(mask_length[i] - 1, len(loss_mask[i])):
                    loss_mask[i][j] = 1
            loss = loss * loss_mask
            lens = (model_input['lang_x'] != self.llm.pad_token_id).sum(-1)
            lens -= torch.tensor(mask_length, device=lens.device)
            ce_loss = loss.sum(-1) / lens
            scores = (-ce_loss).exp()
            scores = scores.reshape(bs, aug_num)
            logits = outputs.logits.reshape(bs, aug_num, seq_len, -1)

        loss = 0.
        for i in range(bs):
            best_ice_seq_idx = torch.argmax(scores[i])
            query_end_idx = ice_token_length[i] - 1 + query_length[i]
            best_logits = logits[i, best_ice_seq_idx, ice_token_length[i][best_ice_seq_idx] - 1: query_end_idx[best_ice_seq_idx]].unsqueeze(0)
            neg_logits = [logits[i, j, ice_token_length[i][j] - 1: query_end_idx[j]] for j in range(aug_num) if j != best_ice_seq_idx]
            neg_logits = torch.stack(neg_logits, dim=0)
            with self.llm.autocast_context:
                adapter_logits = self.adapter(neg_logits)
            best_logits = best_logits.repeat(aug_num - 1, 1, 1)
            loss += calculate_kl_divergence(adapter_logits.reshape(-1, 50280), best_logits.reshape(-1, 50280))
        return loss
            

In [7]:
model = FlamingoICLAdapter(llm)

In [8]:
model_input = model.llm.prepare_input(prompts)
for key in model_input:
    model_input[key] = model_input[key].unsqueeze(0)

In [9]:
model(model_input, torch.tensor([[17, 17]]))

  lens -= torch.tensor(mask_length, device=lens.device)


tensor([[ 66.5000,  11.3125,  59.2500,  ...,  40.0000,  12.6875,  11.3750],
        [ 12.5000,  -4.5312,  14.0000,  ...,  27.6250,  -4.4062,  -4.5312],
        [108.0000,  51.0000, 111.5000,  ...,  81.5000,  53.7500,  50.7500],
        ...,
        [103.0000,  37.5000, 105.0000,  ...,  67.5000,  40.2500,  37.5000],
        [ 83.0000,  34.2500,  89.0000,  ...,  65.5000,  36.5000,  34.2500],
        [107.5000,  38.5000, 110.5000,  ...,  68.0000,  42.0000,  38.5000]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ReshapeAliasBackward0>) tensor([[ 65.5000,  15.1250,  57.5000,  ...,  44.5000,  16.2500,  15.1250],
        [ 11.7500,  -5.3750,  12.8125,  ...,  26.2500,  -5.3125,  -5.4062],
        [ 99.0000,  44.5000, 103.0000,  ...,  74.0000,  47.2500,  44.5000],
        ...,
        [103.5000,  38.0000, 105.0000,  ...,  68.0000,  40.7500,  38.0000],
        [ 86.0000,  35.5000,  92.0000,  ...,  66.5000,  37.7500,  35.5000],
        [108.5000,  39.2500, 111.5000,  ...,  68.0000,  42

tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)

In [10]:
model.llm.autocast_context

<torch.cuda.amp.autocast_mode.autocast at 0x7f86c2263430>

In [11]:
model_input['lang_x'] 

tensor([[50278,  1145,  2460,   273,   767, 16581,    15, 50277, 50278,  1145,
          2460,   273,   247, 15336, 16338,    15, 50277, 50278,  1145,  2460,
           273,   247,  2829,   342,  2739],
        [50278,  1145,  2460,   273,   247, 15336, 16338,    15, 50277, 50278,
          1145,  2460,   273,   767, 16581,    15, 50277, 50278,  1145,  2460,
           273,   247,  2829,   342,  2739]], device='cuda:0')

In [1]:
'<image>Output:A yellow pair of scissors next to cut up leafs.<|endofchunk|><image>Output:A baby is sitting down eating a banana.<|endofchunk|><image>Output:A man in a hat receiving a kiss from an orange haired woman.<|endofchunk|><image>Output:A person walking their dog on the beach<|endofchunk|><image>Output:A couple of men dong road construction next to a parking meter.<|endofchunk|><image>Output:A pile of leafy greens sitting on top of a table.<|endofchunk|><image>Output:Two girls sitting on a couch with a woman and one of them is brushing her hair.<|endofchunk|>'

'<image>Output:A yellow pair of scissors next to cut up leafs.<|endofchunk|><image>Output:A baby is sitting down eating a banana.<|endofchunk|><image>Output:A man in a hat receiving a kiss from an orange haired woman.<|endofchunk|><image>Output:A person walking their dog on the beach<|endofchunk|><image>Output:A couple of men dong road construction next to a parking meter.<|endofchunk|><image>Output:A pile of leafy greens sitting on top of a table.<|endofchunk|><image>Output:Two girls sitting on a couch with a woman and one of them is brushing her hair.<|endofchunk|>'

: 

In [2]:
import torch


d = torch.load('/home/pyz/code/ICD-LM/icl_adapter_result/model_cpk/caption/icl_lora/min_vl-epoch=0-train_loss=0.00000-val_loss=0.00000.ckpt')

In [4]:
d['state_dict']

OrderedDict([('interface.model.vision_encoder.class_embedding',
              tensor([ 0.0138,  0.2357, -0.1285,  ...,  0.0171, -0.3332, -0.2366],
                     device='cuda:0')),
             ('interface.model.vision_encoder.positional_embedding',
              tensor([[ 0.0019,  0.0479, -0.0149,  ...,  0.0005, -0.0558, -0.0460],
                      [ 0.0114, -0.0413,  0.0357,  ...,  0.0271, -0.0313, -0.0383],
                      [-0.0026, -0.0340, -0.0006,  ...,  0.0216, -0.0294, -0.0423],
                      ...,
                      [-0.0038, -0.0350, -0.0048,  ..., -0.0228, -0.0328, -0.0412],
                      [-0.0046, -0.0360, -0.0026,  ..., -0.0350, -0.0355, -0.0353],
                      [-0.0073, -0.0287, -0.0144,  ..., -0.0202, -0.0272, -0.0360]],
                     device='cuda:0')),
             ('interface.model.vision_encoder.proj',
              tensor([[ 0.0224, -0.0139, -0.0072,  ..., -0.0058, -0.0078,  0.0139],
                      [ 0.0186,  0.

In [5]:
import torch 
from torch import nn

adapter = torch.nn.Sequential(
                    nn.Linear(50280, 128),
                    nn.Linear(128, 50280)
                )


In [7]:
state_dict = {k.replace('adapter.', ''): v for k, v in d['state_dict'].items()}


adapter.load_state_dict(state_dict)

<All keys matched successfully>

: 