In [1]:
import json
import os

from PIL import Image
import torch

import sys
sys.path.append("/lid/home/saydalie/multimodal_cot/Emu3")

from emu3.tokenizer import Emu3VisionVQModel, Emu3VisionVQImageProcessor

  from .autonotebook import tqdm as notebook_tqdm


# Image tokenization

In [2]:
def smart_resize(image, image_area: int = 512 * 512):
    w, h = image.size
    current_area = h * w
    target_ratio = (image_area / current_area) ** 0.5

    th = int(round(h * target_ratio))
    tw = int(round(w * target_ratio))

    image = image.resize((tw, th))
    return image

In [3]:
image_processor = Emu3VisionVQImageProcessor.from_pretrained("BAAI/Emu3-VisionTokenizer")
image_tokenizer = Emu3VisionVQModel.from_pretrained("BAAI/Emu3-VisionTokenizer", device_map="cuda:0")
image_tokenizer.eval()

Emu3VisionVQModel(
  (encoder): Emu3VisionVQEncoder(
    (conv_in): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0): Module(
        (block): ModuleList(
          (0-1): 2 x Emu3VisionVQResnetBlock(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (act): Emu3VisionVQActivation()
          )
        )
        (attn): ModuleList()
        (downsample): Emu3VisionVQDownsample(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (1): Module(
        (block): ModuleList(
          (0): Emu3VisionVQResnetBlock(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
        

In [55]:
prompt = "What is the missing number of the part denoted with a question mark?"

image = Image.open("/lid/home/saydalie/multimodal_cot/LLM-PuzzleTest/PuzzleVQA/data/eval/images/circle_size_number/circle_size_number_0009.png").convert("RGB")
image = smart_resize(image)

In [56]:
image = image_processor(image, return_tensors="pt")["pixel_values"]
image.shape

torch.Size([1, 3, 512, 512])

In [57]:
with torch.no_grad():
    image = image.cuda()
    token_ids = image_tokenizer.encode(image)

token_ids = token_ids.squeeze(0).cpu().numpy()

In [58]:
token_ids.shape

(64, 64)

In [None]:
data = {
    "images": token_ids,
    "texts": prompt
}

# Emu3FeatureDataset

In [4]:
from emu3.mllm import Emu3Config, Emu3Tokenizer, Emu3ForCausalLM
from emu3.train.datasets import Emu3FeatureDataset

tokenizer = Emu3Tokenizer.from_pretrained(
    "BAAI/Emu3-Stage1",
    model_max_length=5120,
    padding_side="right",
    use_fast=False,
)

In [5]:
from dataclasses import dataclass, field
from typing import Optional, List

class DataArguments:
    data_path: Optional[str] = None
    null_prompt_prob: float = 0.05
    apply_loss_on_only_vision: bool = True
    apply_loss_on_only_text: bool = False
    ignore_index: int = -100
    visual_token_pattern: str = "<|visual token {token_id:0>6d}|>"
    codebook_size: Optional[int] = 32768

args = DataArguments()

In [6]:
def format_image_prompt(image_tokens):
    h, w = image_tokens.shape
    imgstr = to_imgstr(image_tokens)

    image_prompt = (
        tokenizer.boi_token +
        f"{h}*{w}" +
        tokenizer.img_token +
        imgstr +
        tokenizer.eol_token +
        tokenizer.eof_token +
        tokenizer.eoi_token
    )

    return image_prompt

def to_imgstr(image_tokens):
    image_token_str = [
        [
            args.visual_token_pattern.format(token_id=token_id)
            for token_id in token_row
        ]
        for token_row in image_tokens
    ]
    image_row_str = ["".join(token_row) for token_row in image_token_str]
    imgstr = tokenizer.eol_token.join(image_row_str)
    return imgstr

In [8]:
tokenizer.encode(args.visual_token_pattern.format(token_id=0))[0], tokenizer.encode(args.visual_token_pattern.format(token_id=args.codebook_size - 1))[0]

(151854, 184621)

In [75]:
image_prompt = format_image_prompt(token_ids)
input = tokenizer.bos_token + prompt + image_prompt

input

'<|extra_203|>What is the missing number of the part denoted with a question mark?<|image start|>90*90<|image token|><|visual token 003122|><|visual token 013609|><|visual token 001510|><|visual token 013765|><|visual token 001510|><|visual token 013765|><|visual token 009197|><|visual token 013765|><|visual token 009197|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual token 006830|><|visual 

In [76]:
sample = tokenizer(
    input,
    padding="max_length",
    return_token_type_ids=False,
    return_tensors="pt",
)

In [77]:
sample['input_ids'].shape

torch.Size([1, 8215])

# Prepare data

In [24]:
def smart_resize(image, image_area: int = 720 * 720):
    w, h = image.size
    current_area = h * w
    target_ratio = (image_area / current_area) ** 0.5

    th = int(round(h * target_ratio))
    tw = int(round(w * target_ratio))

    image = image.resize((tw, th))
    return image

In [9]:
model_path = '/lid/home/saydalie/multimodal_cot/Emu3-models/Emu3-VisionTokenizer/snapshots/c81f916ad371289e205310a7539255e8a9396488'
data_path = '/lid/home/saydalie/multimodal_cot/LLM-PuzzleTest/PuzzleVQA/data/'
output_path = '/lid/home/saydalie/multimodal_cot/Emu3/data/'
image_area = 720 * 720

In [6]:
image_processor = Emu3VisionVQImageProcessor.from_pretrained(model_path)
image_tokenizer = Emu3VisionVQModel.from_pretrained(model_path, device_map="cuda:0")
image_tokenizer.eval()

Emu3VisionVQModel(
  (encoder): Emu3VisionVQEncoder(
    (conv_in): Conv2d(3, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0): Module(
        (block): ModuleList(
          (0-1): 2 x Emu3VisionVQResnetBlock(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
            (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (act): Emu3VisionVQActivation()
          )
        )
        (attn): ModuleList()
        (downsample): Emu3VisionVQDownsample(
          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (1): Module(
        (block): ModuleList(
          (0): Emu3VisionVQResnetBlock(
            (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
        

In [34]:
pattern_name = "circle_size_number"
input_path = os.path.join(data_path, f'train/{pattern_name}.json')

In [35]:
with open(input_path, 'r') as input_file:
    data = [json.loads(line) for line in input_file]

In [37]:
data[0]

{'image': 'images/circle_size_number/circle_size_number_0000.png',
 'question': 'What is the missing number of the part denoted with a question mark?',
 'options': [4, 5, 6, 1],
 'answer': '5',
 'caption': "There are 6 numbered circles with varying sizes arranged in a ring with number [2, 6, 2, 5, 6, '?'] in a clockwise order.",
 'explanation': 'We observe that the size of the circle is related to the number in the circle. The circle with the largest value 6 seems to be the biggest and the circle with the smallest value 2 seems to be the smallest. Thus, the pattern is that the larger the number the larger the circle.',
 'deduction': 'Based on the pattern that the larger the number the larger the circle, the missing number of the circle denoted with a question mark should be 5.'}

In [25]:
image_path = data_path + 'train/' + data[0]['image_options_path'][0]
image = Image.open(image_path).convert("RGB")
image = smart_resize(image, image_area)

image = image_processor(image, return_tensors="pt")["pixel_values"]

In [29]:
with torch.no_grad():
    image = image.cuda()
    token_ids = image_tokenizer.encode(image)

# token_ids = token_ids.squeeze(0).cpu().numpy()

In [33]:
token_ids.squeeze(0).cpu().numpy().shape

(90, 90)

In [45]:
token_ids[0][0][:10]

tensor([ 2567, 13609,  1510, 13765,  1510, 13765,  9197, 13765,  9197,  6830],
       device='cuda:0')

In [48]:
token_ids[0].T[0][:10]

tensor([ 2567, 11466,  4895,   924,  7780, 15325,  1082, 15961, 15961, 15961],
       device='cuda:0')

In [58]:
for img in torch.cat([token_ids, token_ids[0].T.unsqueeze(0)], dim=0):
    print(img.shape)

torch.Size([90, 90])
torch.Size([90, 90])


In [62]:
# Sample input
data = {
    "texts": "Here is an <image>. Another <image> appears. And one more <image>.",
    "images": [
        "<|image start|>90*90<|image token|><|visual token 003122|><|image end|>",
        "<|image start|>90*90<|image token|><|visual token 002435|><|image end|>",
        "<|image start|>90*90<|image token|><|visual token 003134|><|image end|>"
    ]
}

# Function to format image prompts (for demonstration)
def format_image_prompt(image_token):
    return f"{image_token}"

# Extract image prompts
image_prompts = [format_image_prompt(token) for token in data["images"]]

# Replace each <image> with the corresponding image prompt
prompt = data["texts"]
for image_prompt in image_prompts:
    prompt = prompt.replace("<image>", image_prompt, 1)

# Output result
print(prompt)

Here is an <|image start|>90*90<|image token|><|visual token 003122|><|image end|>. Another <|image start|>90*90<|image token|><|visual token 002435|><|image end|> appears. And one more <|image start|>90*90<|image token|><|visual token 003134|><|image end|>.


In [56]:
DATA_PATH = '/lid/home/saydalie/multimodal_cot/LLM-PuzzleTest/PuzzleVQA/data/'

with open(DATA_PATH + f'train/{pattern_name}.json') as f:
    input_data = json.load(f)

JSONDecodeError: Extra data: line 2 column 1 (char 784)