<a href="https://colab.research.google.com/github/atrichaudhuri/endovis18-vqa/blob/main/checking_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/lalithjets/SurgicalGPT.git

Cloning into 'SurgicalGPT'...
remote: Enumerating objects: 89, done.[K
remote: Counting objects: 100% (89/89), done.[K
remote: Compressing objects: 100% (52/52), done.[K
remote: Total 89 (delta 42), reused 82 (delta 35), pack-reused 0[K
Receiving objects: 100% (89/89), 749.26 KiB | 1.17 MiB/s, done.
Resolving deltas: 100% (42/42), done.


In [2]:
# Installing required dependencies
!pip -q install transformers
!pip -q install accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m66.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m102.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m79.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.23.0-py3-none-any.whl (258 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.1/258.1 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.23.0


In [3]:
# Testing by loading the model only
# Type 1
# import math
# import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import transformers
# from timm.models import create_model
from transformers import  VisualBertConfig, GPT2Config
from transformers import VisualBertModel, GPT2Model, ViTModel, SwinModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EFVLEGPT2SwinClassification(nn.Module):
    def __init__(self, num_class = 12, model_subver = 'v0', vis_pos_emb = None):
        super(EFVLEGPT2SwinClassification, self).__init__()
        '''
        v0: visual embedding : Default patch1 + embedding form VB + GPT2 decoder
        v1: visual embedding : Default patch1 + GPT2 decoder
        '''

        self.sub_ver = model_subver
        self.vis_pos_emb = vis_pos_emb

        ## image processing
        self.img_feature_extractor = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")

        ## Visual_embedding
        # visual bert embedding
        if self.sub_ver == "v0":
            VB_config = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre")
            VB_config.visual_embedding_dim = 768
            visualbert = VisualBertModel(config=VB_config)
            self.visual_embedder = visualbert.embeddings.visual_projection

        ## Question_embedding
        question_embedder = model = transformers.LlamaForCausalLM.from_pretrained('chaoyi-wu/PMC_LLAMA_7B' ,offload_folder="offload", offload_state_dict = True, device_map="auto")
        self.question_embedder = question_embedder.wte

        ## GPT2 visual_cotext_aware_decoder
        self.VCAdecoder = model = transformers.LlamaForCausalLM.from_pretrained('chaoyi-wu/PMC_LLAMA_7B' ,offload_folder="offload", offload_state_dict = True, device_map="auto")
        # for name,child in self.VCAdecoder.h.named_children():
        #     child.mlp._modules['act'] = SwiGLU()
        # self.VCAdecoder.wte = nn.Embedding(2, 768)  # later used for position embedding

        ## intermediate_layers
        self.intermediate_layer = nn.Linear(768, 512)  #(512+768)
        self.LayerNorm = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(0.1)

        ## classifier
        self.classifier = nn.Linear(512, num_class)

    def forward(self, input, img):

        ## image encoder features
        img['pixel_values'] = img['pixel_values'].to(device)
        img_feature = self.img_feature_extractor(**img)


        ## visual Embedding : id type 1, pos: zero / incremental
        if self.sub_ver == 'v0':
            visual_embeds = self.visual_embedder(img_feature[0])
        elif self.sub_ver == 'v1':
            visual_embeds = img_feature[0]

        # if self.training is True:
        #     sigma = torch.tensor(np.random.uniform(0, 0.4, visual_embeds.size()[1])).to(device)
        #     smoothing_layer = get_gaussian_filter(ksize=3, sigma=sigma, channels=visual_embeds.size()[1]).to(device)
        #     visual_embeds = smoothing_layer(visual_embeds)

        visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)
        visual_attention_mask = visual_attention_mask.to(device)

        if self.vis_pos_emb == 'zeroes':
            visual_id_type = torch.ones(*visual_embeds.size()[:-1], dtype=torch.long, device=device)
            visual_position_id = torch.zeros(*visual_embeds.size()[:-1], dtype=torch.long, device=device)
        elif self.vis_pos_emb == 'pos':
            visual_id_type = torch.ones(*visual_embeds.size()[:-1], dtype=torch.long, device=device)
            visual_position_id = torch.arange(0,visual_embeds.size()[1])
            visual_position_id = torch.unsqueeze(visual_position_id,0)
            visual_position_id = visual_position_id.repeat(visual_embeds.size()[0], 1)
            visual_position_id = visual_position_id.to(device)


        ## question embedding: id type 0, pose incremental
        input['input_ids'] = input['input_ids'].to(device)
        input['attention_mask'] = input['attention_mask'].to(device)

        question_embeds = self.question_embedder(input['input_ids'])
        question_attention_mask = input['attention_mask']

        if self.vis_pos_emb == 'zeroes' or self.vis_pos_emb == 'pos':
            question_id_type = torch.zeros(*question_embeds.size()[:-1], dtype=torch.long, device=device)
            question_position_id = torch.arange(0,question_embeds.size()[1])
            question_position_id = torch.unsqueeze(question_position_id,0)
            question_position_id = question_position_id.repeat(question_embeds.size()[0], 1)
            question_position_id = question_position_id.to(device)


        ## combine visual and question embeds
        ## vision first
        # inputs_embeds = torch.cat((visual_embeds, question_embeds), dim=1)
        # attention_mask = torch.cat((visual_attention_mask, question_attention_mask), dim=1)

        # if self.vis_pos_emb == 'zeroes' or self.vis_pos_emb == 'pos':
        #     token_type_ids = torch.cat((visual_id_type, question_id_type), dim=1)
        #     position_ids = torch.cat((visual_position_id, question_position_id), dim=1)

        ## question first
        inputs_embeds = torch.cat((question_embeds, visual_embeds), dim=1)
        attention_mask = torch.cat((question_attention_mask, visual_attention_mask), dim=1)

        if self.vis_pos_emb == 'zeroes' or self.vis_pos_emb == 'pos':
            token_type_ids = torch.cat((question_id_type, visual_id_type), dim=1)
            position_ids = torch.cat((question_position_id, visual_position_id), dim=1)

        ## VCA_GPT2 decoder
        if self.vis_pos_emb == 'zeroes' or self.vis_pos_emb == 'pos':
            decoder_output = self.VCAdecoder(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids = position_ids, token_type_ids = token_type_ids)
        else:
            decoder_output = self.VCAdecoder(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        decoder_output = decoder_output.last_hidden_state.swapaxes(1,2)
        decoder_output = F.adaptive_avg_pool1d(decoder_output,1)
        decoder_output = decoder_output.swapaxes(1,2).squeeze(1)

        ## intermediate layers
        out =self.intermediate_layer(decoder_output)
        out = self.LayerNorm(out)
        out = self.dropout(out)

        ## classifier
        out = self.classifier(out)
        # print(out.size())
        return out

In [None]:
model = EFVLEGPT2SwinClassification(num_class = 12, model_subver = 'v0', vis_pos_emb = None)

Downloading (…)lve/main/config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]