In [1]:
import torch
import torch.nn as nn
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration, Blip2VisionModel

raw_image = Image.open("./images/demo.jpg").convert('RGB')

question = "how many dogs are in the picture?"

2024-02-06 22:32:46.146499: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-02-06 22:32:46.174926: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-02-06 22:32:46.304982: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### BLIP2

In [None]:
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", cache_dir="./cache")

inputs = processor(raw_image, question, return_tensors="pt")

#### Conditional Generation model

In [None]:
generative_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float32, device_map="cpu", cache_dir='./cache')

In [None]:
generative_model.vision_model.save_pretrained("./blip2-opt-2.7b-vision")

- Push vision model to hub

In [None]:
generative_model.vision_model.push_to_hub("tmnam20/blip2-opt-2.7b-vision", private=True, token="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")

- Push processor to hub

In [None]:
processor.push_to_hub("tmnam20/blip2-opt-2.7b-vision", private=True, token="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")

#### Vision models

In [2]:
vision_processor = Blip2Processor.from_pretrained("tmnam20/blip2-opt-2.7b-vision",torch_dtype=torch.float32, cache_dir="./cache")
vision_model = Blip2VisionModel.from_pretrained("tmnam20/blip2-opt-2.7b-vision", device_map='cpu', torch_dtype=torch.float32, cache_dir="./cache")
vision_model = vision_model.eval()
vision_model

Blip2VisionModel(
  (embeddings): Blip2VisionEmbeddings(
    (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
  )
  (encoder): Blip2Encoder(
    (layers): ModuleList(
      (0-38): 39 x Blip2EncoderLayer(
        (self_attn): Blip2Attention(
          (dropout): Dropout(p=0.0, inplace=False)
          (qkv): Linear(in_features=1408, out_features=4224, bias=True)
          (projection): Linear(in_features=1408, out_features=1408, bias=True)
        )
        (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        (mlp): Blip2MLP(
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1408, out_features=6144, bias=True)
          (fc2): Linear(in_features=6144, out_features=1408, bias=True)
        )
        (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
      )
    )
  )
  (post_layernorm): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
)

In [3]:
inputs = vision_processor(raw_image, question, return_tensors="pt")

In [4]:
input_pixels = inputs['pixel_values']
with torch.no_grad():
    vision_embeddings = vision_model(input_pixels)
print(f'Input pixels shape: {input_pixels.shape}')
print(f'Vision embeddings shape: {vision_embeddings.last_hidden_state.shape}')

Input pixels shape: torch.Size([1, 3, 224, 224])
Vision embeddings shape: torch.Size([1, 257, 1408])


In [5]:
input_pixels = inputs['pixel_values']
print(f'Shape before extend: {input_pixels.shape}') # [batch_size, 3, height, width]
# extend to [batch_size, 1, 3, height, width]
input_pixels = input_pixels.unsqueeze(1)
print(f'Shape after extend: {input_pixels.shape}')
input_pixels = input_pixels.repeat(2, 4, 1, 1, 1)
print(f'Shape after repeat: {input_pixels.shape}')

Shape before extend: torch.Size([1, 3, 224, 224])
Shape after extend: torch.Size([1, 1, 3, 224, 224])
Shape after repeat: torch.Size([2, 4, 3, 224, 224])


- Test train the model

In [18]:
vision_model.train()

encoded = []
for batch_idx in range(input_pixels.size(0)):
    print(batch_idx)
    item = input_pixels[batch_idx].contiguous()
    encoded_item = vision_model(pixel_values=item)
    encoded.append(encoded_item.last_hidden_state)
encoded = torch.stack(encoded)

0
1


- Check backward compatibility of vision models

In [None]:
encoded.sum().backward()

In [None]:
print(f'Shape of encoded: {encoded.shape}')

In [None]:
# [batch_size, num_frames, num_pathces, hidden_size]
# get the first token in num_patches
encoded = encoded[:, :, 0, :]
encoded.shape

In [None]:
projection_layer = nn.Linear(vision_model.config.hidden_size, 5)

In [None]:
scaled_encoded_features = projection_layer(encoded)
print(f'Projected image features shape = {scaled_encoded_features.shape}')

### CLIP image features

In [2]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("images/demo.jpg")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)

Label probs: [[0.0123  0.962   0.02563]]


In [3]:
image_features.shape

torch.Size([1, 512])

### Llama 2

In [6]:
from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM

In [7]:
llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", cache="./cache")
llama_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="cpu", cache_dir="./cache")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Video Llama

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Optional, Tuple, Union

class LlamaVideo(nn.Module):

    def __init__(
        self,
        vision_encoder,
        tokenizer,
        language_model,
        freeze_vision_encoder=True,
        freeze_language_model=True,
    ):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.tokenizer = tokenizer
        self.lm = language_model
        
        vision_encoder_output_dim = vision_encoder.config.hidden_size
        lm_output_dim = language_model.config.hidden_size
        self.vision_projection = nn.Linear(vision_encoder_output_dim, lm_output_dim)
        if freeze_vision_encoder:
            self.freeze_vision_encoder()
        if freeze_language_model:
            self.freeze_language_model()
        
    def freeze_vision_encoder(self):
        for param in self.vision_encoder.parameters():
            param.requires_grad = False
        print("Freeze vision encoder")
        
    def freeze_language_model(self):
        for param in self.lm.parameters():
            param.requires_grad = False
        print("Freeze language model")
    
    def forward(
        self, 
        frames: torch.Tensor,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        """Forward pass of VideoLlama

        Args:
            frames (torch.Tensor): Input in shape (batch_size, seq_len, channels, width, height) or (batch_size, seq_len, features)
            tokenized_input (dict): input dictionary of the prompts
        """
        if len(frames.size()) == 5: # [batch_size, seq_length, channels, width, height]
            frame_features = []
            for batch_idx in range(frames.size(0)):
                cur_frames = frames[batch_idx].contiguous()
                frame_features.append(
                    self.vision_encoder(cur_frames).last_hidden_state[:, 0, :]
                )
            frame_features = torch.stack(frame_features)
            frame_features = self.vision_projection(frame_features)
        elif len(frames.size()) == 3: # [batch_size, seq_length, features]
            frame_features = self.vision_projection(frames)
        else:
            raise ValueError("Invalid input shape. Must be either (batch_size, seq_len, channels, width, height) or (batch_size, seq_len, features)")
        
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("Only one of input_ids or inputs_embeds can be populated")
        
        if input_ids is not None:
            embedding_layer = self.lm.get_input_embeddings()
            text_features = embedding_layer(input_ids)
            input_features = torch.cat([frame_features, text_features], dim=1)
        elif inputs_embeds is not None:
            text_features = inputs_embeds
            input_features = torch.cat([frame_features, text_features], dim=1)
        else:
            input_features = frame_features

        video_feature_len = frame_features.size(1)
        if labels is not None:
            padding = torch.full((labels.size(0), video_feature_len), -100, dtype=torch.long, device=labels.device)
            labels = torch.cat([padding, labels], dim=1)
        attention_mask = torch.cat([torch.ones((input_features.size(0), video_feature_len), dtype=torch.long, device=input_features.device), attention_mask], dim=1)

        outputs = self.lm(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=input_features,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        return outputs

    @staticmethod
    def from_pretrained(
        vision_encoder_name_or_path: str,
        language_model_name_or_path: str,
        vision_projection_name_or_path: str,
        **kwargs
    ):
        vision_encoder = Blip2VisionModel.from_pretrained(vision_encoder_name_or_path, **kwargs)
        language_model = LlamaForCausalLM.from_pretrained(language_model_name_or_path, **kwargs)
        model = LlamaVideo(vision_encoder, language_model, **kwargs)
        if vision_projection_name_or_path is not None:
            model.vision_projection.load_state_dict(torch.load(vision_projection_name_or_path))
        return model

    def generate(
        self,
        video,
        input_tokenized,
        use_nucleus_sampling=False,
        num_beams=4,
        max_length=256,
        min_length=1,
        top_p=0.9,
        repetition_penalty=1.0,
        length_penalty=1.0,
        num_captions=1,
        temperature=1,
    ):
        if len(video.size()) == 5:
            frame_features = []
            for batch_idx in range(video.size(0)):
                item = video[batch_idx].contiguous()
                encoded_item = self.vision_encoder(pixel_values=item)
                frame_features.append(encoded_item.last_hidden_state[:, 0, :])
            frame_features = torch.stack(frame_features)
            frame_features = self.vision_projection(frame_features)
        elif len(video.size()) == 3:
            frame_features = self.vision_projection(video)
        else:
            raise ValueError("Invalid input shape. Must be either (batch_size, seq_len, channels, width, height) or (batch_size, seq_len, features)")
        
        video_feature_len = frame_features.size(1)
        attention_mask = torch.ones((frame_features.size(0), video_feature_len), dtype=torch.long, device=frame_features.device)
        
        outputs = self.lm.generate(
            input_ids=None,
            inputs_embeds=frame_features,
            attention_mask=attention_mask,
            do_sample=use_nucleus_sampling,
            top_p=top_p,
            temperature=temperature,
            num_beams=num_beams,
            max_new_tokens=max_length,
            min_length=min_length,
            repetition_penalty=repetition_penalty,
            length_penalty=length_penalty,
            num_return_sequences=num_captions,
        )
        output_text = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        return output_text

In [9]:
llama_tokenizer.pad_token = llama_tokenizer.eos_token

In [10]:
question1 = "how many dogs are in the picture?"
question2 = "what color is the dog?"

model_inputs = llama_tokenizer(
    [question1, question2], 
    return_tensors="pt",
    padding=True,
)
model_inputs['labels'] = model_inputs['input_ids'].clone()
model_inputs


{'input_ids': tensor([[    1,   920,  1784, 26361,   526,   297,   278,  7623, 29973],
        [    1,   825,  2927,   338,   278, 11203, 29973,     2,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0]]), 'labels': tensor([[    1,   920,  1784, 26361,   526,   297,   278,  7623, 29973],
        [    1,   825,  2927,   338,   278, 11203, 29973,     2,     2]])}

In [18]:
llama_video = LlamaVideo(
    vision_encoder=vision_model,
    tokenizer=llama_tokenizer,
    language_model=llama_model
)

Freeze vision encoder
Freeze language model


In [12]:
tmp_output = llama_video(
    frames=input_pixels,
    **model_inputs
)

In [13]:
tmp_output.keys()

odict_keys(['loss', 'logits', 'past_key_values'])

In [14]:
tmp_output.logits.size(), tmp_output.loss

(torch.Size([2, 13, 32000]), tensor(6.0217, grad_fn=<NllLossBackward0>))

In [15]:
llama_video.lm.hf_device_map

{'': device(type='cpu')}

In [19]:
generated = llama_video.generate(
    video=input_pixels,
    input_tokenized=model_inputs,
    use_nucleus_sampling=False,
    num_beams=4,
    max_length=4,
    min_length=1,
    top_p=0.9,
    repetition_penalty=1.0,
    length_penalty=1.0,
    num_captions=1,
    temperature=1,
)



In [20]:
generated

['OOOO', 'OOOO']

: 