### CLIP-GPT

In [None]:
import torch
import torch.nn as nn

from transformers import CLIPProcessor, CLIPModel, CLIPConfig
from transformers import GPT2Model
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CLIPGPTCLS(nn.Module):
    def __init__(self, num_class=18):
        super(CLIPGPTCLS, self).__init__()

        # prepare CLIP encoders (visual and text)
        config = CLIPConfig.from_pretrained("openai/clip-vit-base-patch32")
        model = CLIPModel(config).to(device)

        self.config = model.config  # 获取模型的配置信息, output_attentions, output_hidden_states
        self.text_model = model.text_model  # 获取文本编码模型
        self.vision_model = model.vision_model  # 获取图像编码模型
        self.visual_projection = nn.Linear(model.visual_projection.in_features, 768)  # 替换NN层的操作
        self.text_projection = nn.Linear(model.text_projection.in_features, 768)
        self.logit_scale = model.logit_scale  # 获取缩放因子，用于调整 logits

        # GPT2 decoder
        self.VCA_decoder = GPT2Model.from_pretrained('gpt2')

        # intermediate_layers
        self.intermediate_layer = nn.Linear(768, 512)
        self.LayerNorm = nn.BatchNorm1d(512)
        # self.LayerNorm = nn.LayerNorm(512)  # use this one if only one data point in each batch
        self.dropout = nn.Dropout(0.2)

        # classifier
        self.classifier = nn.Linear(in_features=512, out_features=num_class)

    def forward(
        self,
        input_ids=None,  # 非空, 文本id的数组
        pixel_values=None,  # 非空, 图像像素值的数组
        attention_mask=None,  # 非空, 注意力掩码
        position_ids=None,
        return_loss=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states if output_hidden_states is not None
                                else self.config.output_hidden_states)

        # use CLIP visual and text models to process data
        vision_outputs = self.vision_model(  # 使用视觉模型处理图像数据
            pixel_values=pixel_values,  # 图像的像素值数组
            output_attentions=output_attentions,  # True/False, 用来控制是否输出模型的注意力机制的细节
            output_hidden_states=output_hidden_states,  # True/False, 用来控制是否输出模型中间层的隐藏状态
            return_dict=return_dict,  # True/False, 被设置为True时, 模型的输出将被封装在一个字典中
        )
        text_outputs = self.text_model(  # 使用文本模型处理文本数据
            input_ids=input_ids,  # 文本ID的数组; 这个数组是模型输入文本的主要形式
            attention_mask=attention_mask,  # 0/1数组; 用于指示哪些部分的 input_ids 应该被模型考虑, 哪些部分是填充 (应该被模型忽略)
            position_ids=position_ids,  # 通常是整数数组, 用于表示输入中每个 token 的位置信息; 如果不提供, 模型通常会自动生成一个默认的位置编码
            output_attentions=output_attentions,  # True/False
            output_hidden_states=output_hidden_states,  # True/False
            return_dict=return_dict,  # True/False
        )

        # get visual and text embeddings
        image_embeds = vision_outputs[0].to(device)
        image_embeds = self.visual_projection(image_embeds)  # 对图像嵌入进行投影处理
        text_embeds = text_outputs[0].to(device)
        text_embeds = self.text_projection(text_embeds)

        batch_size = image_embeds.shape[0]  # 1
        visual_seq_len = image_embeds.shape[1]  # 50

        # get text and visual attention mask
        text_attention_mask = attention_mask.to(device)
        visual_attention_mask = torch.ones((batch_size, visual_seq_len), dtype=torch.float).to(device)

        # concatenate text and visual embeddings (text first)
        inputs_embeds = torch.cat((text_embeds, image_embeds), dim=1).to(device)  # 拼接2个512的embedding
        # concatenate text and visual attention mask (text first)
        inputs_attention_mask = torch.cat((text_attention_mask, visual_attention_mask), dim=1).to(device)

        # decode
        decoder_output = self.VCA_decoder(inputs_embeds=inputs_embeds, attention_mask=inputs_attention_mask)

        decoder_output = decoder_output.last_hidden_state.swapaxes(1, 2)
        decoder_output = F.adaptive_avg_pool1d(decoder_output, 1)
        decoder_output = decoder_output.swapaxes(1, 2).squeeze(1)

        # intermediate layers
        out = self.intermediate_layer(decoder_output)
        out = self.LayerNorm(out)
        out = self.dropout(out)

        # classifier
        out = self.classifier(out)
        return out

### Surgical-GPT

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

from transformers import VisualBertConfig, GPT2Config, GPT2Tokenizer
from transformers import VisualBertModel, GPT2Model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SurgicalGPTCLS(nn.Module):
    def __init__(self, num_class=18, vis_pos_emb=None):
        super(SurgicalGPTCLS, self).__init__()

        # use default setting
        self.vis_pos_emb = vis_pos_emb

        # image processing
        self.img_feature_extractor = models.resnet18(pretrained=True)  # ResNet18
        self.img_feature_extractor.fc = nn.Sequential(*list(self.img_feature_extractor.fc.children())[:-1])

        # Visual embedding
        VB_config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")  # VisualBert
        VB_config.visual_embedding_dim = 512
        visualbert = VisualBertModel(config=VB_config)
        self.visual_embedder = visualbert.embeddings.visual_projection

        # Question embedding
        tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        question_embedder = GPT2Model.from_pretrained('gpt2')
        question_embedder.config.pad_token_id = tokenizer.eos_token
        self.question_embedder = question_embedder.wte  # word token embedding

        # GPT2 decoder
        self.VCA_decoder = GPT2Model.from_pretrained('gpt2')
        self.VCA_decoder.config.pad_token_id = tokenizer.eos_token

        # intermediate_layers
        self.intermediate_layer = nn.Linear(768, 512)  # (512+768)
        self.LayerNorm = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.1)

        # classifier
        self.classifier = nn.Linear(512, num_class)

    def forward(self, input, img):
        # image encoder features
        img_feature = self.img_feature_extractor(img)
        img_feature = torch.unsqueeze(img_feature, dim=1)

        # visual Embedding: id type 1, pos: zero / incremental
        visual_embeds = self.visual_embedder(img_feature)
        visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
        visual_attention_mask = visual_attention_mask.to(device)

        if self.vis_pos_emb == 'zeroes':
            visual_id_type = torch.ones(*visual_embeds.size()[:-1], dtype=torch.long, device=device)
            visual_position_id = torch.zeros(*visual_embeds.size()[:-1], dtype=torch.long, device=device)
        elif self.vis_pos_emb == 'pos':
            visual_id_type = torch.ones(*visual_embeds.size()[:-1], dtype=torch.long, device=device)
            visual_position_id = torch.arange(0, visual_embeds.size()[1])
            visual_position_id = torch.unsqueeze(visual_position_id, 0)
            visual_position_id = visual_position_id.repeat(visual_embeds.size()[0], 1)
            visual_position_id = visual_position_id.to(device)

        # question embedding: id type 0, pose incremental
        input['input_ids'] = input['input_ids'].to(device)
        input['attention_mask'] = input['attention_mask'].to(device)
        question_embeds = self.question_embedder(input['input_ids'])
        question_attention_mask = input['attention_mask']

        if self.vis_pos_emb == 'zeroes' or self.vis_pos_emb == 'pos':
            question_id_type = torch.zeros(*question_embeds.size()[:-1], dtype=torch.long, device=device)
            question_position_id = torch.arange(0, question_embeds.size()[1])
            question_position_id = torch.unsqueeze(question_position_id, 0)
            question_position_id = question_position_id.repeat(question_embeds.size()[0], 1)
            question_position_id = question_position_id.to(device)

        # question first
        inputs_embeds = torch.cat((question_embeds, visual_embeds), dim=1)
        attention_mask = torch.cat((question_attention_mask, visual_attention_mask), dim=1)

        if self.vis_pos_emb == 'zeroes' or self.vis_pos_emb == 'pos':
            token_type_ids = torch.cat((question_id_type, visual_id_type), dim=1)
            position_ids = torch.cat((question_position_id, visual_position_id), dim=1)

        # VCA_GPT2 decoder
        if self.vis_pos_emb == 'zeroes' or self.vis_pos_emb == 'pos':
            decoder_output = self.VCA_decoder(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
                                             position_ids=position_ids, token_type_ids=token_type_ids)
        else:
            decoder_output = self.VCA_decoder(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        decoder_output = decoder_output.last_hidden_state.swapaxes(1, 2)
        decoder_output = F.adaptive_avg_pool1d(decoder_output, 1)
        decoder_output = decoder_output.swapaxes(1, 2).squeeze(1)

        # intermediate layers
        out = self.intermediate_layer(decoder_output)
        out = self.LayerNorm(out)
        out = self.dropout(out)

        # classifier
        out = self.classifier(out)
        return out

### Surgical-LLaMA

In [None]:
from torch import nn
import torch.utils.data
from typing import Tuple

from transformers import VisualBertConfig
from transformers import VisualBertModel, SwinModel
from transformers import LlamaForSequenceClassification

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class MLP(nn.Module):
    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) - 1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

class SurgicalLlamaCLS(nn.Module):
    def __init__(self, num_class=59):
        super(SurgicalLlamaCLS, self).__init__()
        # image processing
        self.img_feature_extractor = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

        # Visual embedding
        VB_config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
        VB_config.visual_embedding_dim = 768
        visualbert = VisualBertModel(config=VB_config)
        self.visual_embedder = visualbert.embeddings.visual_projection

        # projection layer
        llm_embedding_size = self.question_embedder.weight.shape[1]  # 4096
        image_feature_size = 768
        self.image_project = MLP((image_feature_size, llm_embedding_size // 2, llm_embedding_size))

        # llama text encoder and decoder
        self.VCA_decoder = LlamaForSequenceClassification.from_pretrained("meta-llama/Llama-2-7b-hf", num_labels=num_class)
        self.VCA_decoder.config.pad_token_id = self.VCA_decoder.config.eos_token_id
        self.question_embedder = self.VCA_decoder.model.embed_tokens

    def forward(self, inputs, img):

        # get image features
        img = img.to(device)
        img_feature = self.img_feature_extractor(pixel_values=img)
        # print('img_feature: ', img_feature[0].shape)  # torch.Size([1, 49, 768])

        # visual embedding: input size 768, output size 4096
        visual_embeds = self.visual_embedder(img_feature[0])  # torch.Size([1, 49, 768])
        visual_embeds = self.image_project(visual_embeds)  # output: torch.Size([1, 49, 4096])
        visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float).to(device)  # [1, 49]

        # question embedding
        inputs['input_ids'] = inputs['input_ids'].to(device)
        inputs['attention_mask'] = inputs['attention_mask'].to(device)

        question_embeds = self.question_embedder(inputs['input_ids'])  # 4096
        question_attention_mask = inputs['attention_mask']

        # question first
        inputs_embeds = torch.cat((question_embeds, visual_embeds), dim=1)  # torch.Size([40, 74, 4096])
        attention_mask = torch.cat((question_attention_mask, visual_attention_mask), dim=1)

        # VCA_GPT2 decoder
        decoder_output = self.VCA_decoder(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        logits = decoder_output.logits  # torch.Size([40, 18])

        return logits


if __name__ == '__main__':
    from torchvision import transforms
    from PIL import Image
    from transformers import LlamaTokenizer
    from huggingface_hub import login

    login(token='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')

    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    image = Image.open("./cat.jpg")
    image = preprocess(image)
    img_tensor = image.unsqueeze(0)

    # 加载llama分词器
    tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
    tokenizer.pad_token = tokenizer.eos_token
    text = 'a photo of a cat'

    # 处理文本
    inputs = tokenizer(text, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    model = SurgicalLlamaCLS().to(device)

    result = model(inputs, img_tensor)
    print(result)