In [57]:
import hydra
import re
import pyrootutils
import os
import torch

from omegaconf import OmegaConf
import json
from typing import Optional
import transformers
from PIL import Image
from torchvision.transforms.functional import InterpolationMode

BOI_TOKEN = '<img>'
EOI_TOKEN = '</img>'
IMG_TOKEN = '<img_{:05d}>'

sep = "\n"
user_token = "USER"
assistant_token = "ASSISTANT"

IMG_FLAG = '<image>'
NUM_IMG_TOKNES = 32
NUM_IMG_CODES = 8192
image_id_shift = 32000

def generate(tokenizer, input_tokens, generation_config, model):
    """Only for batch_size=1"""

    input_ids = tokenizer(input_tokens, add_special_tokens=False, return_tensors='pt').input_ids
    input_ids = input_ids.to(model.device)

    generate_ids = model.generate(
        input_ids=input_ids,
        generation_config=generation_config
    )
    generate_ids = generate_ids[0][input_ids.shape[1]:]
    
    return generate_ids

def replace_img_tags(input_text):
    img_pattern = re.compile(r'<img>(.*?)</img>', re.IGNORECASE)
    img_matches = img_pattern.findall(input_text)
    
    for i, match in enumerate(img_matches):
        replacement = f'<IMAGE>'
        input_text = input_text.replace(f'<img>{match}</img>', replacement)
    
    return input_text

In [2]:
import sys
sys.path.append("/lid/home/saydalie/multimodal_cot/SEED/")

from models.seed_llama_tokenizer import SeedLlamaTokenizer
from models.transforms import get_transform
from models.model_tools import get_pretrained_llama_causal_model

In [58]:
encoder_path="/lid/home/saydalie/multimodal_cot/SEED/checkpoints/seed-tokenizer-2/seed_quantizer.pt"

tokenizer = SeedLlamaTokenizer.from_pretrained(
    pretrained_model_name_or_path="AILab-CVC/seed-tokenizer-2",
    fp16=True,
    load_diffusion=True,
    encoder_url=encoder_path,
    padding_side = 'left',
    diffusion_path="stabilityai/stable-diffusion-2-1-unclip"
)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LlamaTokenizer'. 
The class this function is called from is 'SeedLlamaTokenizer'.
Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.e

missing keys:  511 unexpected keys: 146


Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████| 9/9 [00:02<00:00,  3.69it/s]


In [59]:
from types import MethodType

# Define a wrapper function that forces `skip_special_tokens=False`
batch_decode = tokenizer.batch_decode

def wrapped_batch_decode(self, *args, **kwargs):
    kwargs["skip_special_tokens"] = False
    return batch_decode(*args, **kwargs)

tokenizer.batch_decode = MethodType(wrapped_batch_decode, tokenizer)

# Load Dataset

In [14]:
import json

data_path = "/lid/home/saydalie/multimodal_cot/SEED/data/ReSQ/train_resq.json"
with open(data_path, "r") as file:
    data = json.load(file)['data']

In [15]:
prompt="""{bos}{user_token}: I now describe a scene and ask a question about it. First, think about the reasoning process using an interleaved combination of images and text. You should generate an image when necessary to support reasoning, then describe insights from the image. Finally, provide with the answer. The final answer is either one of {candidate_answers}. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process with images and text here </think> <answer> the final answer here </answer>.
{question}
{assistant_token}:"""

In [16]:
train_dataset = []

for d in data:
    story = d['story']
    questions = d['questions']
    
    for q in questions:
        question = q['question']
        answer = q['answer'][0]
        candidate_answers = q['candidate_answers']
        num_1st_context_sentences = q['num_1st_context_sentences']

        scene = ' '.join(story[:num_1st_context_sentences])
        train_dataset.append({
            'prompt': prompt.format(
                bos=tokenizer.bos_token,
                user_token=user_token,
                assistant_token=assistant_token,
                candidate_answers=candidate_answers,
                question=f"{scene} {question}"),
            'answer': answer
        })

In [17]:
from datasets import Dataset

train_dataset = Dataset.from_list(train_dataset)
train_dataset

Dataset({
    features: ['prompt', 'answer'],
    num_rows: 1008
})

# Load the model

In [39]:
model_name_or_path="/lid/home/saydalie/multimodal_cot/SEED/checkpoints/seed-llama-8b-sft-comm"

model = get_pretrained_llama_causal_model(
    pretrained_model_name_or_path=model_name_or_path,
    torch_dtype="bf16",
    low_cpu_mem_usage=True
)

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:03<00:00,  1.10s/it]


In [40]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    task_type='CAUSAL_LM',
    lora_dropout=0.05,
    target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'down_proj', 'up_proj'],
    modules_to_save=['embed_tokens', 'lm_head', 'input_layernorm', 'post_attention_layernorm', 'norm']
)

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

trainable params: 369,512,448 || all params: 7,175,053,312 || trainable%: 5.1500


# Reward Functions

In [23]:
import re

def image_reward_one(completion, **kwargs):
    imgage_pattern = re.compile(r'<img>(.*?)</img>', re.IGNORECASE)
    imgage_matches = imgage_pattern.findall(completion)
    
    if not imgage_matches:
        # no valid <img>...</img> pairs
        return 0.0

    if completion.count("<img>") != completion.count("</img>"):
        # some <img>...</img> pairs are valid, but not all
        return 0.5

    num_invalid_images = 0
    for match in imgage_matches:
        tokens = match.strip().split()
        if len(tokens) != 32:
            num_invalid_images += 1
    
    if num_invalid_images == 0:
        # all are valid images
        return 1.0
        
    if len(imgage_matches) > num_invalid_images:
        # at least one invalid image
        return 0.5

    # all are invalid images
    return 0.0

def image_reward_one_naive(completion, **kwargs):
    print(completion)
    imgage_pattern = re.compile(r'<img>(.*?)</img>', re.IGNORECASE)
    imgage_matches = imgage_pattern.findall(completion)
    
    if not imgage_matches:
        # no valid <img>...</img> pairs
        return 0.0

    return 1.0
    
def image_reward(completions, **kwargs):
    """Reward function that checks if the completion has a valid images."""
    return [image_reward_one_naive(completion) for completion in completions]

# test
completions = [
    '<think> <img> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_06429> <img_00680> </img> <think> <answer> the final answer here </answer>',
    '<think> <img> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_06429> <img_00680> </img> <img> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_06429> </img> <think> <answer> the final answer here </answer>',
    '<think> <img> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_06429> <img_00680> </img> <think> <answer> the final answer here </answer>',
    '<think> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <think> <answer> the final answer here </answer>',
    '<think> reasoning process with images and text here </think> <answer> the final answer '
]
image_reward(completions=completions)

<think> <img> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_06429> <img_00680> </img> <think> <answer> the final answer here </answer>
<think> <img> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_06429> <img_00680> </img> <img> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <img_00680> <img_04240> <img_06429> <img_00680> <img_02123> <im

[1.0, 1.0, 1.0, 0.0, 0.0]

In [24]:
import re

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r".*?<think>.*?</think>\s*<answer>.*?</answer>.*?"
    matches = [re.match(pattern, completion) for completion in completions]
    return [1.0 if match else 0.0 for match in matches]

# test
completions = [
    '<think> reasoning process with images and text here </think> <answer> the final answer here </answer>', 
    '<think> reasoning process with images and text here </think> <answer> the final answer ',
    ' <think> The white thing behind the tree is a car. </think><answer> Yes. </answer> </s><s>'
]
format_reward(completions=completions)

[1.0, 0.0, 1.0]

In [25]:
import re

def normalize_answer(answer):
    """Normalizes an answer by stripping whitespace, converting to lowercase, and removing punctuation."""
    return re.sub(r'[^a-zA-Z0-9]', '', answer).lower()
    
def accuracy_reward(completions, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    answers = kwargs["answer"]
    rewards = []
    for completion, correct_answer in zip(completions, answers):
        match = re.search(r"<answer>(.*?)</answer>", completion, re.DOTALL)
        if match and normalize_answer(match.group(1)) == normalize_answer(correct_answer):
            rewards.append(1.0)
        else:
            rewards.append(0.0)
    return rewards

# test
completions = [
    '<think> reasoning process with images and text here </think> <answer>   yes. </answer>', 
    '<think> reasoning process with images and text here </think> <answer> Yes </answer>', 
    '<think> reasoning process with images and text here </think> <answer> the final answer '
]
answer = [
    'Yes', 
    'No',
    'No'
]
accuracy_reward(completions=completions, answer=answer)

[1.0, 0.0, 0.0]

# GRPO

In [41]:
from trl import GRPOConfig

# Configure training arguments using GRPOConfig
# asser (num_gpus * batch_size) % num_generations == 0
training_args = GRPOConfig(
    output_dir='seed-llama-8b-GRPO-test',
    num_train_epochs=10,
    learning_rate=5e-7,
    remove_unused_columns=False,  # to access the `answer` column in accuracy_reward
    eval_strategy="no",
    weight_decay=5e-2,
    warmup_ratio=0.01,
    lr_scheduler_type="cosine",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    torch_empty_cache_steps=1,
    gradient_checkpointing=True,
    dataloader_num_workers=4,
    bf16=True,
    fp16=False,
    adam_beta1=0.9,
    adam_beta2=0.95,
    adam_epsilon=1e-5,
    # deepspeed=DEEPSPEED_CONFIG,
    # # Inference Optimization
    # Parameters that control the data preprocessing
    temperature=0.9,
    max_completion_length=512,
    num_generations=2,
    max_prompt_length=256,
    # Parameters related to reporting and saving
    save_strategy="steps",
    save_steps=0.05,
    logging_strategy='steps',
    logging_steps=1,
    log_level='warning',
    logging_nan_inf_filter="no",
    push_to_hub=False,
    report_to="none"
)

In [42]:
# from trl import GRPOTrainer

# class GRPOTrainerCustom(GRPOTrainer):
#     def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
#         if return_outputs:
#             raise ValueError("The GRPOTrainer does not support returning outputs")
#         # Compute the per-token log probabilities for the model

#         prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
#         completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
#         input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
#         attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
#         logits_to_keep = completion_ids.size(1)  # we only need to compute the logits for the completion tokens

#         per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
#         print('per_token_logps')
#         print(per_token_logps.detach())
#         print()

#         # Compute the KL divergence between the model and the reference model
#         ref_per_token_logps = inputs["ref_per_token_logps"]
#         print('ref_per_token_logps')
#         print(ref_per_token_logps)
#         print()
#         per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
#         print('per_token_kl')
#         print(per_token_kl)
#         print()

#         # x - x.detach() allows for preserving gradients from x
#         advantages = inputs["advantages"]
#         print('rewards/image_reward')
#         print(self._metrics["rewards/image_reward"][-1])
#         print()
#         print('advantages')
#         print(advantages)
#         print()
#         per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
#         per_token_loss = -(per_token_loss - self.beta * per_token_kl)
#         loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
#         print("-"*20)

#         # Log the metrics
#         completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
#         self._metrics["completion_length"].append(completion_length)

#         mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
#         self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

#         return loss

In [43]:
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[image_reward],
    args=training_args,
    train_dataset=train_dataset
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [44]:
trainer.train()

You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.
  return fn(*args, **kwargs)


 <think> The black object is a clock. The grey thing is most likely the bricks of a building. These objects are in front of the wall, which is a portion of the sky. <answer> Yes.</think></s><s>
 <think> <yes>, the black thing is the item next to the grey thing (desk).</think> </s><s><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>
The reasoning process and answer: <img> <img_03650> <img_02315> <img_07845> <img_04504> <img_00680> <img_01283> <img_01579> <img_00441> <img_02250> <img_02315> <img_02157> <img_01579> <img_01579> <img_02728> <img_07232> <img_07652> <img_02157> <img_00851> <img_02157> <img_00441> <img_01579> <img_01579> <img_02728> <img_02157> <img_05293> <img_07100> <img_00441> <img_00441> <img_02728> <img_07100> <img_01251> <img_00441> </img> A boy's blue and orange sweater is near him. <img> <img_03735> <img_00680> <img_02399> <img_07312> <img_00680> <img_01283> <img_01579> <img_00441> <img_01579> <img_02157> <im

KeyboardInterrupt: 

# Generate

In [18]:
from transformers.generation.configuration_utils import GenerationConfig

generation_config = GenerationConfig(
    temperature=0.9,
    num_beams=1,
    max_new_tokens=512,
    top_p=0.7,
    do_sample=True
)

In [19]:
prompt="""{bos}{user_token}: I now describe a scene and ask a question about it. First, think about the reasoning process using an interleaved combination of images and text. You should generate an image when necessary to support reasoning, then describe insights from the image before proceeding with more text-based deductions. Finally, provide with the answer. The final asnwer is either one of ['Yes', 'No']. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process with images and text here </think> <answer> the final answer here </answer>.
{question}
{assistant_token}:"""
scene = "A white hostel with a large tree to its left and a gravel road in front of the wooden entrance."
question = "Is the white thing behind the tree?"

input_tokens = prompt.format(
    bos=tokenizer.bos_token,
    user_token=user_token,
    assistant_token=assistant_token,
    question=f"{scene} {question}"
)

print(input_tokens)

<s>USER: I now describe a scene and ask a question about it. First, think about the reasoning process using an interleaved combination of images and text. You should generate an image when necessary to support reasoning, then describe insights from the image before proceeding with more text-based deductions. Finally, provide with the answer. The final asnwer is either one of ['Yes', 'No']. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process with images and text here </think> <answer> the final answer here </answer>.
A white hostel with a large tree to its left and a gravel road in front of the wooden entrance. Is the white thing behind the tree?
ASSISTANT:


In [46]:
device = 'cuda'
model.to(device)
print(model.device)

cuda:0


In [51]:
generate_ids = generate(tokenizer, input_tokens, generation_config, model)
generate_ids.shape

torch.Size([80])

### Decode

In [52]:
# decode text
output_text = tokenizer.batch_decode(generate_ids.unsqueeze(0), skip_special_tokens=True)
# output_text = replace_img_tags(output_text)
output_text

['<think> Yes. </think> The white thing is a building. The tree is behind it, and the gravel road is in front of it. This means the white thing is behind the tree. <img> <img_06310> <img_00680> <img_00285> <img_00285> <img_00680> <img_01966> <img_00285> <img_01966> <img_06075> <img_02157> <img_02157> <img_00285> <img_06075> <img_02076> <img_06075> <img_01716> <img_02157> <img_01966> <img_02157> <img_01716> <img_02076> <img_00285> <img_06075> <img_02157> <img_01716> <img_01716> <img_01966> <img_01966> <img_06075> <img_06075> <img_01716> <img_01428> </img> </s><s>']

In [1]:
# decode images
boi_list = torch.where(generate_ids == tokenizer(BOI_TOKEN, add_special_tokens=False).input_ids[0])[0]
eoi_list = torch.where(generate_ids == tokenizer(EOI_TOKEN, add_special_tokens=False).input_ids[0])[0]

for boi_index, eoi_index in zip(boi_list, eoi_list):
    image_ids = (generate_ids[boi_index+1:eoi_index] - image_id_shift).reshape(1,-1)
    image = tokenizer.decode_image(image_ids)[0]
    image.show()