In [1]:
from datasets import load_dataset
import torch
import torch.nn as nn
import random
import open_clip
import os
import numpy as np
import torch
import diffusers 
import torchvision.transforms as transforms
from PIL import Image
from huggingface_hub import hf_hub_download
from models import CLIPVisionTower
from transformers.optimization import Adafactor, AdafactorSchedule
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F

In [2]:
hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="OmniMistral-v1_1/projection.pt", local_dir='./')
hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="models.py", local_dir='./')

'./models.py'

In [3]:
!nvidia-smi

Sun May 12 20:52:19 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:4B:00.0 Off |                    0 |
| N/A   31C    P0              68W / 400W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [4]:
DEVICE = torch.device("cuda:0")

In [5]:
# dataset_anygpt = load_dataset("zhanjun/AnyGPT-data-image")

In [6]:
dataset_science = load_dataset("cnut1648/ScienceQA-LLAVA")

In [7]:
dataset_science['validation'][0]

{'id': 'validation-0',
 'image': None,
 'conversations': [{'from': 'human',
   'value': "Context: N/A\nQuestion: What does the verbal irony in this text suggest?\nAccording to Mr. Herrera's kids, his snoring is as quiet as a jackhammer.\nOptions: (A) The snoring is loud. (B) The snoring occurs in bursts."},
  {'from': 'gpt', 'value': 'The answer is A.'}],
 'question': "What does the verbal irony in this text suggest?\nAccording to Mr. Herrera's kids, his snoring is as quiet as a jackhammer.",
 'context': 'N/A',
 'choice': '(A) The snoring is loud. (B) The snoring occurs in bursts.',
 'answer': 'A',
 'lecture': 'Figures of speech are words or phrases that use language in a nonliteral or unusual way. They can make writing more expressive.\nVerbal irony involves saying one thing but implying something very different. People often use verbal irony when they are being sarcastic.\nOlivia seems thrilled that her car keeps breaking down.\nEach breakdown is as enjoyable as a punch to the face.'

In [8]:
clip, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k')
clip.to(DEVICE);
clip.visual.output_tokens = True

def preprocess_image(image_path, preprocess_function=preprocess_train):
    image = Image.open(image_path).convert("RGB")
    return preprocess_function(image)

def encode_image(path=None, image=None, preprocess_function=preprocess_train):
    if path:
        image_tensor = preprocess_image(path, preprocess_function).to(DEVICE)
        with torch.no_grad():
            hidden_states = clip.visual(image_tensor[None])[1].to(DEVICE).squeeze(0)
        return hidden_states
    
    if image:
        image_tensor = preprocess_function(image).to(DEVICE)
        with torch.no_grad():
            hidden_states = clip.visual(image_tensor[None])[1].to(DEVICE).squeeze(0)
        return hidden_states

In [9]:
model = AutoModelForCausalLM.from_pretrained("AIRI-Institute/OmniFusion", subfolder="OmniMistral-v1_1/tuned-model", torch_dtype=torch.bfloat16, device_map=DEVICE)

projection = torch.load("OmniMistral-v1_1/projection.pt", map_location=DEVICE)
tokenizer = AutoTokenizer.from_pretrained("AIRI-Institute/OmniFusion", subfolder="OmniMistral-v1_1/tokenizer", use_fast=False)

clip = CLIPVisionTower("openai/clip-vit-large-patch14-336")
clip.load_model()
clip = clip.to(device=DEVICE, dtype=torch.bfloat16)



Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [10]:
choice='validation'
i  = random.randint(0, len(dataset_science[choice])-1)
image = None
while not image:
    i  = random.randint(0, len(dataset_science[choice])-1)
    item = dataset_science[choice][i]
    image = item['image']
conversation = item['conversations']
if type(conversation) == str:
    conversation = ast.literal_eval(conversation)
query = conversation[0]['value'].replace('\n<image>', '').replace('\n', ' ')
answer = item['solution'].replace('\n', ' ').strip() + '</s>'
with torch.no_grad():
    image_features = clip.image_processor(image, return_tensors='pt')
    image_embedding = clip(image_features['pixel_values']).to(device=DEVICE, dtype=torch.bfloat16)

    caption_ids = list(tokenizer.encode(answer, add_special_tokens=False))
    caption_ids = torch.LongTensor(caption_ids).to(DEVICE)
    user_query_ids = tokenizer.encode(query, add_special_tokens=False, return_tensors="pt").to(DEVICE)

In [11]:
vision_embs = image_embedding
projected_vision_embs = projection(vision_embs)
projected_vision_embs = projected_vision_embs
text_embeddings = model.model.embed_tokens(caption_ids)
prompt_embeddings = model.model.embed_tokens(user_query_ids)
bs = projected_vision_embs.shape[0]


embeddings1 = torch.cat([
            projected_vision_embs,
            prompt_embeddings
            ],
            dim=1)

embeddings2 = text_embeddings.repeat(bs,1,1)
embeddings = torch.cat([embeddings1, embeddings2], dim=1)

In [12]:
caption_ids.shape, user_query_ids.shape

(torch.Size([111]), torch.Size([1, 119]))

In [13]:
text_embeddings.shape, prompt_embeddings.shape,prompt_embeddings.repeat(bs,1,1).shape, projected_vision_embs.shape

(torch.Size([111, 4096]),
 torch.Size([1, 119, 4096]),
 torch.Size([1, 119, 4096]),
 torch.Size([1, 576, 4096]))

In [14]:
embeddings1.shape, embeddings2.shape, embeddings2.repeat(bs,1,1).shape

(torch.Size([1, 695, 4096]),
 torch.Size([1, 111, 4096]),
 torch.Size([1, 111, 4096]))

In [16]:
# bad_words_ids = tokenizer(["\n", "</s>", ":"], add_special_tokens=False).input_ids + [[13]]
# gen_params = {
#         "do_sample": False,
#         "max_new_tokens": 50,
#         "early_stopping": True,
#         "num_beams": 3,
#         "repetition_penalty": 1.0,
#         "remove_invalid_values": True,
#         "eos_token_id": 2,
#         "pad_token_id": 2,
#         "forced_eos_token_id": 2,
#         "use_cache": True,
#         "no_repeat_ngram_size": 4,
#         "bad_words_ids": bad_words_ids,
#         "num_return_sequences": 1,
#     }
# out = model.generate(inputs_embeds=embeddings, **gen_params)

In [17]:
out.shape, out[:, 1:].shape

(torch.Size([1, 9]), torch.Size([1, 8]))

In [15]:
embeddings[None, ...].shape

torch.Size([1, 1, 806, 4096])

In [16]:
class VAELoss(nn.Module):
    def __init__(self, λ=1.):
        super().__init__()
        self.λ = λ
        self.reconstruction_loss = nn.BCELoss()
    
    def forward(self, output, target, vq_loss):
        reconst_loss = self.reconstruction_loss(output, target)
        
        loss = reconst_loss + self.λ * vq_loss
        return {"loss": loss, "reconstruction loss": reconst_loss, "VQ loss": vq_loss}
loss_vae = VAELoss()

In [17]:
model_ae = diffusers.VQModel(1, 1).to(DEVICE)

In [18]:
# out = model_ae(embeddings[None, ...].float())
h = model_ae.encode(embeddings[None, ...].float()).latents
_, vq_loss, _ = model_ae.quantize(h)
# out = model_ae.decode(h).sample
# loss = loss_vae(out, embeddings[None, ...].float(), vq_loss)

In [19]:
vq_loss

tensor(0.1261, device='cuda:0', grad_fn=<AddBackward0>)

In [20]:
out = model_ae.decode(h).sample

In [21]:
out

tensor([[[[ 2.2227,  1.3053,  0.8109,  ..., -1.0408, -1.2240,  1.0533],
          [ 0.1085, -0.2438, -1.9387,  ..., -0.5151, -4.7525,  1.7800],
          [-2.6457,  2.7697, -3.5529,  ...,  0.6989,  1.3130,  2.7379],
          ...,
          [-3.6787,  2.6983,  1.2140,  ...,  0.2783, -3.1924, -2.7261],
          [-2.7598,  1.8086,  4.5208,  ...,  0.2744, -1.6552,  1.8046],
          [-1.5706,  3.2711,  2.4797,  ...,  1.2999,  1.6929,  1.0233]]]],
       device='cuda:0', grad_fn=<ConvolutionBackward0>)

In [23]:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)
target.shape

torch.Size([3, 5])

In [25]:
reconstruction_loss = nn.CrossEntropyLoss()
reconst_loss = reconstruction_loss(out, embeddings[None, ...].float())
loss = reconst_loss + vq_loss

In [27]:
out.shape

torch.Size([1, 1, 743, 4096])

In [36]:
m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn((1,1,743,4096), requires_grad=True)
target = torch.rand((1,1,743,4096), requires_grad=False)
output = loss(m(input), target)



In [32]:
m(input).shape, target.shape

(torch.Size([3, 2]), torch.Size([3, 2]))

In [39]:
output

tensor(0.8058, grad_fn=<BinaryCrossEntropyBackward0>)

In [26]:
loss


tensor(0.1261, device='cuda:0', grad_fn=<AddBackward0>)