In [46]:
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List

import torch

import transformers
import tokenizers

from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from torch.utils.data import Dataset
from llava.train.llava_trainer import LLaVATrainer

from llava import conversation as conversation_lib
from llava.model import *
from llava.mm_utils import tokenizer_image_token
import os
import torch.nn as nn
from PIL import Image
from torchvision import transforms
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [47]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

# model_path = "/datadrive_a/jihai/LLaVA/scripts/v1_5/checkpoints/llava-v1.5-7b-lora"

# tokenizer, model, image_processor, context_len = load_pretrained_model(
#     model_path=model_path,
#     model_base='/datadrive_a/jihai/azure_storage2/vigstandard_data/jihai/checkpoint/vicuna-7b-v1.5/vicuna-7b-v1.5',
#     model_name=get_model_name_from_path(model_path),
#     # device="cuda:2"
# )

In [48]:
class ModelArguments:
    model_name_or_path = "/datadrive_a/tmp/vicuna-7b-v1.5/vicuna-7b-v1.5"
    version = "v0"
    freeze_backbone = True
    tune_mm_mlp_adapter = False
    vision_tower = 'openai/clip-vit-large-patch14'
    vision_tower_path = '/datadrive_a/jihai/tmp'
    mm_vision_select_layer = -2  # default to the last layer
    pretrain_mm_mlp_adapter = None
    mm_projector_type = "linear"
    mm_use_im_start_end = False
    mm_use_im_patch_token = True
    mm_patch_merge_type = "flat"
    mm_vision_select_feature = "patch"
    vision_tower_gen = 'same'
    vision_tower_gen_path = None
    image_loss='mse'
    pretrained_mm_mlp_adapter=None
    mm_projector_head_output_size=None

# Usage example in Jupyter:
args = ModelArguments()
print(args.model_name_or_path)
print(args.mm_projector_type)

/datadrive_a/tmp/vicuna-7b-v1.5/vicuna-7b-v1.5
linear


In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    args.model_name_or_path,
    padding_side="right"
)
model = LlavaLlamaForCausalLM_ImgGen.from_pretrained(
    args.model_name_or_path,
)

You are using a model of type llama to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.82s/it]


In [13]:


model.get_model().initialize_vision_modules(
    model_args=args
)

openai/clip-vit-large-patch14 is already loaded, `load_model` called again, skipping.


In [52]:

@dataclass
class DataArguments:
    data_path: str = field(default='/datadrive_a/jihai/data/multimodalout/smart_watch_train.json',
                           metadata={"help": "Path to the training data."})
    lazy_preprocess: bool = True
    is_multimodal: bool = True
    multimodal_out: bool=True
    image_folder: Optional[str] = field(default='/datadrive_a/jihai/data/multimodalout/smart_watch_image_train')
    image_aspect_ratio: str = 'square'
    understanding_only: bool=False
    generation_only: bool=True
    image_shape: List[int] = field(
        default_factory=lambda: [3,224,224],
    )
    num_image_token: int = 256 #how many token will one image take up
    dataset: str = 'smartwatch'

data_args = DataArguments()
data_args.image_processor=nn.Identity()
data_args.image_processor_gen=nn.Identity()
print(data_args)

DataArguments(data_path='/datadrive_a/jihai/data/multimodalout/smart_watch_train.json', lazy_preprocess=True, is_multimodal=True, multimodal_out=True, image_folder='/datadrive_a/jihai/data/multimodalout/smart_watch_image_train', image_aspect_ratio='square', understanding_only=False, generation_only=True, image_shape=[3, 224, 224], num_image_token=256, dataset='smartwatch')


In [53]:

def preprocess_v1_with_gen(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    has_image: bool = False,
    num_image_token: int = 6 #how many token will one image take up
) -> Dict:
    """Preprocess conversations for generation.
    original preprocess_v1() will return:
    {'input_ids': tensor([[    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
           322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
         29889,  3148,  1001, 29901, 29871,  -200, 29871,    13,  5618,   931,
           338,   372,   297,   278,  1967, 29973,   319,  1799,  9047, 13566,
         29901,   739,   338, 29871, 29896, 29896, 29901, 29906, 29945, 29901,
         29941, 29900,     2]]), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,   739,   338, 29871, 29896, 29896, 29901, 29906, 29945, 29901,
         29941, 29900,     2]])}
    when '-200' is in the target (answer) for generation, need to:
     1. replace it with the id of '<image>' as the indicator for generation.
     2. add num_of_image_tokens of '-100' after the id of '<image>' as image token holders
     3. add img_begin index to data_dict such that hidden_state[:,img_begin:img_begin+num_of_image_tokens] will be the image tokens
    
    
    """
    conv = conversation_lib.default_conversation.copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    # Tokenize conversations

    if has_image:
        print('has image')
        input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
        print(input_ids)
    else:
        input_ids = tokenizer(
            conversations,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ).input_ids

    targets = input_ids.clone()

    assert conv.sep_style == conversation_lib.SeparatorStyle.TWO

    # Mask targets
    sep = conv.sep + conv.roles[1] + ": "
    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        rounds = conversation.split(conv.sep2)
        cur_len = 1
        target[:cur_len] = IGNORE_INDEX
        for i, rou in enumerate(rounds):
            if rou == "":
                break

            parts = rou.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep

            if has_image:
                round_len = len(tokenizer_image_token(rou, tokenizer))
                instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
            else:
                round_len = len(tokenizer(rou).input_ids)
                instruction_len = len(tokenizer(parts[0]).input_ids) - 2

            if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
                round_len -= 1
                instruction_len -= 1

            target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

            cur_len += round_len
        target[cur_len:] = IGNORE_INDEX

        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_INDEX
                print(
                    f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                    f" (ignored)"
                )
    #-----------add preprocess for generation:
    img_token_start=None #mark whether there is image to generate, and the generated image position
    if DEFAULT_IMAGE_TOKEN in sources[0][1]['value']:
        #sp token '<image>'
        gen_indicator=torch.LongTensor([[529,  3027, 29958]])
        # 查找 -200 的位置
        positions = (input_ids == -200).nonzero(as_tuple=True)
        # 对 input_ids 进行处理
        for pos in zip(*positions):    
            # 在其后面插入 num_image_token 个 -100
            insert_position = pos[1]  # delete -200, insert gen_indicator and num_image_token -300. -300 need to be replaced by actual image token.
            img_token_start=insert_position+gen_indicator.size(1)
            input_ids = torch.cat((input_ids[:, :insert_position], 
                                    gen_indicator,
                                    torch.full((input_ids.size(0), num_image_token), -300, dtype=torch.long), 
                                    input_ids[:, insert_position+1:]), dim=1)

        # 对 labels 进行处理（相同的处理方式）
        for pos in zip(*positions):
            insert_position = pos[1]  # delete -200, insert gen_indicator and num_image_token -100
            targets = torch.cat((targets[:, :insert_position], 
                                gen_indicator,
                                torch.full((targets.size(0), num_image_token), -100, dtype=torch.long), 
                                targets[:, insert_position+1:]), dim=1)

    return dict(
        img_token_start=img_token_start,
        input_ids=input_ids,
        labels=targets,
    )

class LazySupervisedDataset_SmartWatch(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str,
                 tokenizer: transformers.PreTrainedTokenizer,
                 data_args: DataArguments):
        super(LazySupervisedDataset_SmartWatch, self).__init__()
        list_data_dict = json.load(open(data_path, "r"))
        if data_args.generation_only:
            list_data_dict = [e for e in list_data_dict if e['task']=="generation"]
        if data_args.understanding_only:
            list_data_dict = [e for e in list_data_dict if (e['task']=="vqa" or e['task']=="caption")]

        #rank0_print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.list_data_dict = list_data_dict
        self.data_args = data_args

    def __len__(self):
        return len(self.list_data_dict)

    @property
    def modality_lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
            cur_len = cur_len if 'image' in sample else -cur_len
            length_list.append(cur_len)
        return length_list

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if 'image' in sources[0]:
            image_file = self.list_data_dict[i]['image']
            image_folder = self.data_args.image_folder
            if sources[0]['task']=='generation':
                processor = self.data_args.image_processor_gen
            else:
                processor=self.data_args.image_processor
            image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
            if self.data_args.image_aspect_ratio == 'pad':
                def expand2square(pil_img, background_color):
                    width, height = pil_img.size
                    if width == height:
                        return pil_img
                    elif width > height:
                        result = Image.new(pil_img.mode, (width, width), background_color)
                        result.paste(pil_img, (0, (width - height) // 2))
                        return result
                    else:
                        result = Image.new(pil_img.mode, (height, height), background_color)
                        result.paste(pil_img, ((height - width) // 2, 0))
                        return result
                image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
            transform = transforms.ToTensor()
            image = transform(image)
      
            #     image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            # else:
            #     image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]


            sources = copy.deepcopy([e["conversations"] for e in sources])
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])
        #bypass preprocess(), use default for llava1.5: preprocess_v1
        #print(('image' in self.list_data_dict[i]))
        data_dict = preprocess_v1_with_gen(
            sources,
            self.tokenizer,
            has_image=('image' in self.list_data_dict[i]),
            num_image_token=self.data_args.num_image_token
        )
        
        if isinstance(i, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0],
                             img_token_start=data_dict["img_token_start"])
            #print(f"data_dict[img_token_start]:{data_dict['img_token_start']}")

        # image exist in the data
        if 'image' in self.list_data_dict[i]:
            data_dict['image'] = image
            #print(image.shape)
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            # this will skip in my_prepare_inputs_labels_for_multimodal, not input to the model
            data_dict['image'] = torch.zeros(self.data_args.image_shape)
        return data_dict


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer
    data_args: DataArguments

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels, img_token_start = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels", "img_token_start"))

        #generation padding sample
        input_ids_pad=torch.ones((5+self.data_args.num_image_token,), dtype=torch.long)
        labels_pad=torch.ones((5+self.data_args.num_image_token,), dtype=torch.long)
        input_ids_pad[5:]=-300
        labels_pad[5:]=-100
        input_ids.append(input_ids_pad)
        labels.append(labels_pad)

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        
        img_token_start_pad=torch.ones((1,), dtype=torch.long)*5
        # input_ids = torch.cat((input_ids, input_ids_pad), dim=0)
        # labels = torch.cat((labels, labels_pad), dim=0)
        img_token_start.append(img_token_start_pad)
        image_pad=torch.zeros(self.data_args.image_shape, dtype=torch.float32,device=input_ids.device)

        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
            img_token_start=img_token_start
        )

        if 'image' in instances[0]:
            images = [instance['image'] for instance in instances]
            if all(x is not None and x.shape == images[0].shape for x in images):
                batch['images'] = torch.stack(images)
            else:
                batch['images'] = images
            batch['images']= torch.cat((batch['images'], image_pad.unsqueeze(0)), dim=0)
        #print(batch['images'].shape)
        return batch

def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
                                data_args) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    if data_args.dataset=='smartwatch':
        train_dataset = LazySupervisedDataset_SmartWatch(tokenizer=tokenizer,
                                    data_path=data_args.data_path,
                                    data_args=data_args)
    # elif data_args.dataset=='segment_digit':
    #     train_dataset = LazySupervisedDataset_ImgGen(tokenizer=tokenizer,
    #                                 data_path=data_args.data_path,
    #                                 data_args=data_args)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer,data_args=data_args)
    return dict(train_dataset=train_dataset,
                eval_dataset=None,
                data_collator=data_collator)

In [54]:
from torch.utils.data import DataLoader

# 创建数据模块
data_module = make_supervised_data_module(tokenizer, data_args)

# 提取 train_dataset 和 data_collator
train_dataset = data_module['train_dataset']
data_collator = data_module['data_collator']

# 创建 DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=1, collate_fn=data_collator)

# 获取一个batch的数据
data_iter=iter(train_dataloader)
batch = next(data_iter)

# 打印 batch 内容以进行调试
print(batch)

has image
tensor([[    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
           322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
         29889,  3148,  1001, 29901,  1815,   366,  5706,   263,  1967,   310,
           263, 15040,  6505, 29892,   373,   607,   278,  1857,   931,   338,
         29871, 29900, 29953, 29901, 29900, 29953, 29901, 29945, 29906, 29892,
           411, 11015,  1361,   297, 18350, 29899,  1182,   523, 29899,  1182,
           523, 29892, 23425, 11785, 29899,  4366, 29899,  1182,   523,  1473,
          1361, 29892,   411,  6501, 29899,  1182,   523, 29899, 26031,  8296,
          1426,  2955,  3239, 29892, 16384, 29871, 29945, 29941, 29995,  3081,
          3233, 29889,   319,  1799,  9047, 13566, 29901,   450,  1967,   338,
          5759, 29889, 29871,    13,  -200,     2]])
{'input_ids': tensor([[    1,   319, 13563,  1546,   263, 12758,  14

In [55]:
#batch = next(data_iter)
print(batch)
for s in batch['img_token_start']:
    if s is not None:
        s.cuda()
print(batch['img_token_start'])
print(batch['images'].dtype)

{'input_ids': tensor([[    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
           322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
         29889,  3148,  1001, 29901,  1815,   366,  5706,   263,  1967,   310,
           263, 15040,  6505, 29892,   373,   607,   278,  1857,   931,   338,
         29871, 29900, 29953, 29901, 29900, 29953, 29901, 29945, 29906, 29892,
           411, 11015,  1361,   297, 18350, 29899,  1182,   523, 29899,  1182,
           523, 29892, 23425, 11785, 29899,  4366, 29899,  1182,   523,  1473,
          1361, 29892,   411,  6501, 29899,  1182,   523, 29899, 26031,  8296,
          1426,  2955,  3239, 29892, 16384, 29871, 29945, 29941, 29995,  3081,
          3233, 29889,   319,  1799,  9047, 13566, 29901,   450,  1967,   338,
          5759, 29889, 29871,    13,   529,  3027, 29958,  -300,  -300,  -300,
          -300,  -300,  -300,  -300,  

In [7]:
batch['images']=torch.zeros(batch['images'].shape)

In [10]:
print(model.get_model().vision_tower.A)

tensor([[-0.8065, -0.5087],
        [-0.1101,  0.3153]])


In [18]:
model.get_model().cuda()
(
    input_ids,
    position_ids,
    attention_mask,
    past_key_values,
    inputs_embeds,
    labels
) = model.my_prepare_inputs_labels_for_multimodal(
    input_ids=batch['input_ids'].cuda(),
    position_ids=None,
    attention_mask=batch['attention_mask'].cuda(),
    past_key_values=None,
    labels=batch['labels'].cuda(),
    images=batch['images'].cuda(),
    img_start_token=batch['img_token_start'],
    image_sizes=None
)

In [21]:
model=model.float()
model.cuda()

LlavaLlamaForCausalLM_ImgGen(
  (model): LlavaLlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): Lla

In [22]:
import torch.nn.functional as F
output_attentions=None
output_hidden_states=None
output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions
output_hidden_states = (
    output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states
)
return_dict=None
use_cache=None
return_dict = return_dict if return_dict is not None else model.config.use_return_dict

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = model.model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_values=past_key_values,
    inputs_embeds=inputs_embeds,
    use_cache=use_cache,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
)

hidden_states = outputs[0]
if model.config.pretraining_tp > 1:
    lm_head_slices = model.lm_head.weight.split(model.vocab_size // model.config.pretraining_tp, dim=0)
    logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(model.config.pretraining_tp)]
    logits = torch.cat(logits, dim=-1)
else:
    logits = model.lm_head(hidden_states)

logits = logits.float()

loss = None



In [23]:
img_token_start=batch['img_token_start']

In [27]:
from torch.nn import CrossEntropyLoss, MSELoss
if labels is not None:
    # 将 img_token_start 转为张量，如果是 None 则用 -1 占位
    start_positions = torch.tensor([pos if pos is not None else -1 for pos in img_token_start], dtype=torch.long, device=inputs_embeds.device)

    print(f"start_positions: {start_positions}")
    # torch.save(start_positions, "/datadrive_a/jihai/tmp/start_positions.pt")
    
    # 创建一个掩码，标记有效的起始位置
    valid_mask = start_positions >= 0

    # 筛选出有效的 batch 索引和对应的开始位置
    batch_indices = valid_mask.nonzero(as_tuple=True)[0]
    start_indices = start_positions[valid_mask]

    # 使用高级索引从 inputs_embeds 中抽取 img_token_start 到 img_token_start + seq_len 的部分
    img_row_indices = batch_indices.unsqueeze(1)
    img_col_indices = start_indices.unsqueeze(1) + torch.arange(6,device=start_indices.device).unsqueeze(0)
    print(f"img_col_indices: {img_col_indices}")
    print(f"img_row_indices: {img_row_indices}")
    #img_embed_targets = inputs_embeds[img_row_indices, img_col_indices].detach()
    images=batch['images'].cuda()
    print(images.shape)
    img_targets=images[img_row_indices].squeeze()
    print(f"img_targets: {img_targets}")
    print(img_targets.shape)
    
    #-------only for debug
    hidden_states=hidden_states.cuda(0)
    # print(hidden_states)
    # print(f"hidden_states_shape: {hidden_states.shape}")
    # print(f"labels_shape:{labels.shape}")
    # print(f"inputs_embeds_shape:{inputs_embeds.shape}")

    #get img embedding
    img_col_indices-=1
    img_embed_outputs = hidden_states[img_row_indices, img_col_indices]
    print(img_embed_outputs.shape)

    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss(ignore_index=-100)
    shift_logits = shift_logits.view(-1, model.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)
    loss_gen_fn = MSELoss()
    loss_gen=loss_gen_fn(model.im_head(img_embed_outputs), img_targets)
    print(f"img_embed_outputs: {img_embed_outputs}")
    print(f"loss_img: {loss_gen}")
    if torch.isnan(loss_gen): 
        loss_gen = torch.tensor(0.0, device=loss.device)
    
    print(f"loss_language: {loss}")

# if not return_dict:
#     output = (logits,) + outputs[1:]
#     return (loss,) + output if loss is not None else output

start_positions: tensor([82, 81, -1], device='cuda:0')
img_col_indices: tensor([[82, 83, 84, 85, 86, 87],
        [81, 82, 83, 84, 85, 86]], device='cuda:0')
img_row_indices: tensor([[0],
        [1]], device='cuda:0')
torch.Size([3, 6, 7])
img_targets: tensor([[[0.6684, 0.6684, 0.0000, 0.6684, 0.6684, 0.0000, 0.6684],
         [0.6684, 0.6684, 0.6684, 0.6684, 0.0000, 0.0000, 0.6684],
         [0.6684, 0.6684, 0.0000, 0.6684, 0.6684, 0.0000, 0.6684],
         [0.6684, 0.6684, 0.6684, 0.6684, 0.0000, 0.6684, 0.6684],
         [0.0000, 0.6684, 0.6684, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.6684, 0.6684, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.6114, 0.6114, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6114, 0.6114, 0.6114, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.6114, 0.6114, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6114, 0.6114, 0.6114, 0.6114, 0.0000, 0.0000, 0.6114],
         [0.6114, 0.0000, 0.6114, 0.6114, 0.0000, 0.6114, 0.6114],
       

In [26]:

print(os.environ.get("CUDA_LAUNCH_BLOCKING"))  # 输出 "1" 表示已成功设置
print(input_ids)
print(inputs_embeds.dtype)
print(batch['images'].dtype)
print(batch['labels'].dtype)
print(labels.shape)
print(labels)

1
None
torch.float32
torch.float32
torch.int64
torch.Size([1, 81])
tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           450,  1967,  3697,   263,  2479,  5183, 29871, 29896, 29941, 29901,
         29900, 29946, 29901, 29896, 29900,   297,  3708,   552,  2927, 29889,
             2]], device='cuda:0')


In [28]:
print(model.device)
model=model.float()

model(input_ids=batch['input_ids'].cuda(), attention_mask=batch['attention_mask'].cuda(), labels=batch['labels'].cuda(), img_token_start=batch['img_token_start'],images=batch['images'].cuda())

cuda:0


RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:3)

In [7]:
batch = next(data_iter)

# 打印 batch 内容以进行调试
print(batch)

{'input_ids': tensor([[    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
           322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
         29889,  3148,  1001, 29901,  3251,   403,   385,  1967,  6445, 29871,
         29900, 29896, 29901, 29945, 29896, 29901, 29906, 29941,   773,  3708,
           552, 13340, 29889,   319,  1799,  9047, 13566, 29901,   530,  1967,
           310, 29871, 29900, 29896, 29901, 29945, 29896, 29901, 29906, 29941,
           411,  3694,   297,  3708,   552,   338,  4318, 29889, 29871,    13,
           529,  3027, 29958,  -300,  -300,  -300,  -300,  -300,  -300,     2]]), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -

In [25]:
print(batch['images'].sum())

tensor(0.)


In [None]:
conversations=['<image>']
print(tokenizer(
        conversations,
        return_tensors="pt",
        padding="longest",
        max_length=tokenizer.model_max_length,
        truncation=True,
        ))
# reversed=tokenizer.batch_decode(data_dict['input_ids'][:,:35], skip_special_tokens=True)
# print(reversed)
# reversed=tokenizer.batch_decode(data_dict['input_ids'][:,36:], skip_special_tokens=True)
# print(reversed)
# reversed=tokenizer.batch_decode(data_dict['labels'][:,-12:], skip_special_tokens=True)
# print(reversed)
reversed=tokenizer.batch_decode(torch.LongTensor([[   529,  3027, 29958]]), skip_special_tokens=False)
print(reversed)

{'input_ids': tensor([[    1,   529,  3027, 29958]]), 'attention_mask': tensor([[1, 1, 1, 1]])}


In [16]:
image_features=torch.zeros(1,6,4096)+1
input_ids=torch.LongTensor([[    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
           322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
         29889,  3148,  1001, 29901,  3529,  1510,   278,  1967,   310, 29871,
         29896, 29896, 29901, 29906, 29945, 29901, 29941, 29900, 29889,   319,
          1799,  9047, 13566, 29901,  2266,   338,   278,  1967, 29889,   529,
          3027, 29958,  -300,  -300,  -300,  -300,  -300,  -300,     2]])
labels=torch.LongTensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  2266,   338,   278,  1967, 29889,   529,
          3027, 29958,  -100,  -100,  -100,  -100,  -100,  -100,     2]])
attention_mask=input_ids.ne(tokenizer.pad_token_id)
print(tokenizer.pad_token_id)
print(attention_mask)
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
print(position_ids)

0
tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True]])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68])


In [17]:
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
print(input_ids)
print(labels)

[tensor([    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
        21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
          322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
        29889,  3148,  1001, 29901,  3529,  1510,   278,  1967,   310, 29871,
        29896, 29896, 29901, 29906, 29945, 29901, 29941, 29900, 29889,   319,
         1799,  9047, 13566, 29901,  2266,   338,   278,  1967, 29889,   529,
         3027, 29958,  -300,  -300,  -300,  -300,  -300,  -300,     2])]
[tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  2266,   338,   278,  1967, 298

In [18]:
new_input_embeds = []
new_labels = []
cur_image_idx = 0
img_start_token=torch.LongTensor([62])
for batch_idx, cur_input_ids in enumerate(input_ids):
    print(cur_input_ids)
    num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
    print(num_images)
    if num_images == 0:
        cur_image_features = image_features[cur_image_idx]
        if img_start_token[batch_idx] is not None: #for generation process
            cur_input_ids[cur_input_ids == -300] = 0 #replace the image token with 0 so that embedding can process
            print(cur_input_ids)
            cur_input_embeds = torch.randn(cur_input_ids.size(0), 4096)
            print(cur_input_embeds[:,0])
            cur_input_embeds[img_start_token[batch_idx]:img_start_token[batch_idx]+cur_image_features.shape[0]] =cur_image_features
            print(cur_input_embeds[:,0])
        else:
            cur_input_embeds_1 = model().embed_tokens(cur_input_ids)
            cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
        new_input_embeds.append(cur_input_embeds)
        new_labels.append(labels[batch_idx])
        cur_image_idx += 1

tensor([    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
        21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
          322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
        29889,  3148,  1001, 29901,  3529,  1510,   278,  1967,   310, 29871,
        29896, 29896, 29901, 29906, 29945, 29901, 29941, 29900, 29889,   319,
         1799,  9047, 13566, 29901,  2266,   338,   278,  1967, 29889,   529,
         3027, 29958,  -300,  -300,  -300,  -300,  -300,  -300,     2])
tensor(0)
tensor([    1,   319, 13563,  1546,   263, 12758,  1404,   322,   385, 23116,
        21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
          322,  1248,   568,  6089,   304,   278,  1404, 29915, 29879,  5155,
        29889,  3148,  1001, 29901,  3529,  1510,   278,  1967,   310, 29871,
        29896, 29896, 29901, 29906, 29945, 29901, 29941, 29900, 29889,   319,
         1799,  9047, 13566, 29901,  2266,   338,   278,  19