In [3]:
"""
!pip install -q condacolab
import condacolab
condacolab.install()

!conda create -n llava python=3.10 -y
!conda run -n llava pip install torch==2.0.1 torchvision==0.15.2
!conda run -n llava pip install transformers==4.31.0
!conda run -n llava pip install tokenizers==0.13.3
!conda run -n llava pip install numpy==1.26.0
!conda run -n llava pip install accelerate==0.21.0

# Sccessfully created conda environment named llava, but the default environment is still python 3.11 without torch module, the default Colab environment
!python -c "import torch, transformers; print(torch.__version__, transformers.__version__)"

# We have to add !conda run -n llava to excute the script in llava environment. This is a bit uncomfortable.
!conda run -n llava python -c "import torch, transformers; print(torch.__version__, transformers.__version__)"
"""

'\n!pip install -q condacolab\nimport condacolab\ncondacolab.install()\n\n!conda create -n llava python=3.10 -y\n!conda run -n llava pip install torch==2.0.1 torchvision==0.15.2\n!conda run -n llava pip install transformers==4.31.0\n!conda run -n llava pip install tokenizers==0.13.3\n!conda run -n llava pip install numpy==1.26.0\n!conda run -n llava pip install accelerate==0.21.0\n\n# Sccessfully created conda environment named llava, but the default environment is still python 3.11 without torch module, the default Colab environment\n!python -c "import torch, transformers; print(torch.__version__, transformers.__version__)"\n\n# We have to add !conda run -n llava to excute the script in llava environment. This is a bit uncomfortable.\n!conda run -n llava python -c "import torch, transformers; print(torch.__version__, transformers.__version__)"\n'

In [None]:
import builtins
import io
import contextlib

def normalize_key(key: str) -> str:
    """printキーを実行時と同じ形に揃える"""
    return key.encode("utf-8").decode("unicode_escape").strip()

def traced_print_factory(original_print):
    store = {}

    def traced_print(*args, **kwargs):
        if args and isinstance(args[0], str) and len(args) > 1:
            key = normalize_key(args[0])
            buf = io.StringIO()
            tmp_kwargs = dict(kwargs)
            tmp_kwargs["file"] = buf
            original_print(*args[1:], **tmp_kwargs)
            value = buf.getvalue().rstrip()
            store[key] = value
        # 画面出力は常に行う
        original_print(*args, **kwargs)

    return traced_print, store


def run_and_capture(func, *args, **kwargs):
    buffer = io.StringIO()
    original_print = builtins.print
    traced_print, store = traced_print_factory(original_print)

    with contextlib.redirect_stdout(buffer):
        builtins.print = traced_print
        try:
            func(*args, **kwargs)
        finally:
            builtins.print = original_print

    logs = buffer.getvalue()
    return logs, store

def embed_print_outputs(code: str, mapping: dict[str, str]) -> str:
    """元のコードに print 出力を埋め込む"""
    new_lines = []
    for line in code.splitlines():
        stripped = line.strip()
        if stripped.startswith("print(") and stripped[6:].startswith('"'):
            if "," not in stripped:
                new_lines.append(line)
                continue

            try:
                raw_key = stripped.split('"', 2)[1]
                key = normalize_key(raw_key)
            except IndexError:
                key = None

            indent = line[:len(line) - len(line.lstrip())]  # インデント保持

            if key and key in mapping:
                value = mapping[key]
                if len(value) <= 40:
                    new_lines.append(f"{line}  # {value}")
                else:
                    new_lines.append(line)
                    new_lines.append(f'{indent}"""')
                    for vline in value.splitlines():
                        new_lines.append(f"{indent}{vline}")
                    new_lines.append(f'{indent}"""')
            else:
                new_lines.append(f"{line}  # not found")
        else:
            new_lines.append(line)
    return "\n".join(new_lines)

In [4]:
!python -c "import torch, transformers; print(torch.__version__, transformers.__version__)"

2.0.1+cu117 4.31.0


In [5]:
import transformers
import torch
from dataclasses import dataclass, field
from typing import Optional

In [6]:
@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    version: Optional[str] = field(default="v0")
    freeze_backbone: bool = field(default=False)
    tune_mm_mlp_adapter: bool = field(default=False)
    vision_tower: Optional[str] = field(default=None) # default to None
    mm_vision_select_layer: Optional[int] = field(default=-1)   # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    mm_projector_type: Optional[str] = field(default='linear')
    mm_use_im_start_end: bool = field(default=False)
    mm_use_im_patch_token: bool = field(default=True)
    mm_patch_merge_type: Optional[str] = field(default='flat')
    mm_vision_select_feature: Optional[str] = field(default="patch")

@dataclass
class DataArguments:
    data_path: str = field(default=None,
                           metadata={"help": "Path to the training data."})
    lazy_preprocess: bool = False
    is_multimodal: bool = False
    image_folder: Optional[str] = field(default=None)
    image_aspect_ratio: str = 'square'

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    remove_unused_columns: bool = field(default=False)
    freeze_mm_mlp_adapter: bool = field(default=False)
    mpt_attn_impl: Optional[str] = field(default="triton")
    model_max_length: int = field(
        default=512,
        metadata={
            "help":
            "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    double_quant: bool = field(
        default=True,
        metadata={"help": "Compress the quantization statistics through double quantization."}
    )
    quant_type: str = field(
        default="nf4",
        metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
    )
    bits: int = field(
        default=16,
        metadata={"help": "How many bits to use."}
    )
    lora_enable: bool = False
    lora_r: int = 64
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    lora_weight_path: str = ""
    lora_bias: str = "none"
    mm_projector_lr: Optional[float] = None
    group_by_modality_length: bool = field(default=False)


  warn("The installed version of bitsandbytes was compiled without GPU support. "


/opt/venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32
[2025-09-23 07:57:55,357] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


  import pkg_resources
  import pkg_resources


In [7]:
from transformers import HfArgumentParser

args_dict = {
    #"deepspeed": "./scripts/zero2.json",
    "model_name_or_path": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "version": "plain",
    "data_path": "/workspaces/LLaVA/blip_laion_cc_sbu_1.json",
    "image_folder": "/workspaces/LLaVA/images/",
    "vision_tower": "openai/clip-vit-large-patch14-336",
    "mm_projector_type": "mlp2x_gelu",
    "tune_mm_mlp_adapter": True,
    "mm_vision_select_layer": -2,
    "mm_use_im_start_end": False,
    "mm_use_im_patch_token": False,
    "bf16": True,
    "output_dir": "./checkpoints/llava-TinyLlama-1.1B-Chat-v1.0",

    # TrainingArguments 相当
    "num_train_epochs": 1,
    "per_device_train_batch_size": 1,
    "per_device_eval_batch_size": 1,
    "gradient_accumulation_steps": 1,
    "evaluation_strategy": "no",
    "save_strategy": "steps",
    "save_steps": 1,
    "save_total_limit": 1,
    "learning_rate": 1e-3,
    "weight_decay": 0.0, # I don't know why 0.0
    "warmup_ratio": 0.03,
    "lr_scheduler_type": "cosine",
    "logging_steps": 1,
    "tf32": False, # switched from True for TinyLlama
    "model_max_length": 2048,
    "gradient_checkpointing": True,
    "dataloader_num_workers": 2,
    "lazy_preprocess": True,
    "report_to": "none",
}

In [8]:
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_dict(args_dict)
print("model_args\n", model_args)
print("data_args\n", data_args)
print("training_args\n", training_args)

model_args
 ModelArguments(model_name_or_path='TinyLlama/TinyLlama-1.1B-Chat-v1.0', version='plain', freeze_backbone=False, tune_mm_mlp_adapter=True, vision_tower='openai/clip-vit-large-patch14-336', mm_vision_select_layer=-2, pretrain_mm_mlp_adapter=None, mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_patch_merge_type='flat', mm_vision_select_feature='patch')
data_args
 DataArguments(data_path='/workspaces/LLaVA/blip_laion_cc_sbu_1.json', lazy_preprocess=True, is_multimodal=False, image_folder='/workspaces/LLaVA/images/', image_aspect_ratio='square')
training_args
 TrainingArguments(
_n_gpu=0,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=True,
bf16_full_eval=False,
bits=16,
cache_dir=None,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=2,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=No

In [9]:
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"

In [10]:
local_rank = training_args.local_rank
print("local_rank\n", local_rank)
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
print("compute_dtype\n", compute_dtype)
bnb_model_from_pretrained_args = {} # bitsandbytes
print("bnb_model_from_pretrained_args\n", bnb_model_from_pretrained_args)

local_rank
 0
compute_dtype
 torch.bfloat16
bnb_model_from_pretrained_args
 {}


In [11]:
"""
from transformers import CLIPModel

normal_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14-336")
print("normal_clip_model\n", normal_clip_model)
"""

'\nfrom transformers import CLIPModel\n\nnormal_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14-336")\nprint("normal_clip_model\n", normal_clip_model)\n'

In [12]:
"""
from transformers import CLIPImageProcessor

image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
print("image_processor\n", image_processor)
"""

'\nfrom transformers import CLIPImageProcessor\n\nimage_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")\nprint("image_processor\n", image_processor)\n'

In [13]:
"""
from PIL import Image
import requests
from io import BytesIO
from transformers import CLIPImageProcessor
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt

# 画像 URL
url = "https://llava-vl.github.io/static/images/view.jpg"

# 画像を取得
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")

# 前処理
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
processed = processor(img, return_tensors="pt")

# tensor: shape (1, 3, H, W), 値は正規化済み
pix = processed["pixel_values"][0]

# 正規化を戻す
mean = torch.tensor(processor.image_mean).unsqueeze(1).unsqueeze(2)
std = torch.tensor(processor.image_std).unsqueeze(1).unsqueeze(2)
pix = pix * std + mean

# 0-1 範囲にクリップ
pix = pix.clamp(0.0, 1.0)

# 画像生成
to_pil = T.ToPILImage()
img_processed = to_pil(pix)

# ==== Colab 上で可視化 ====
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(img)
axes[0].set_title("Original")
axes[0].axis("off")

axes[1].imshow(img_processed)
axes[1].set_title("Processed (normalized etc.)")
axes[1].axis("off")

plt.show()
"""

'\nfrom PIL import Image\nimport requests\nfrom io import BytesIO\nfrom transformers import CLIPImageProcessor\nimport torch\nimport torchvision.transforms as T\nimport matplotlib.pyplot as plt\n\n# 画像 URL\nurl = "https://llava-vl.github.io/static/images/view.jpg"\n\n# 画像を取得\nresponse = requests.get(url)\nimg = Image.open(BytesIO(response.content)).convert("RGB")\n\n# 前処理\nprocessor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")\nprocessed = processor(img, return_tensors="pt")\n\n# tensor: shape (1, 3, H, W), 値は正規化済み\npix = processed["pixel_values"][0]\n\n# 正規化を戻す\nmean = torch.tensor(processor.image_mean).unsqueeze(1).unsqueeze(2)\nstd = torch.tensor(processor.image_std).unsqueeze(1).unsqueeze(2)\npix = pix * std + mean\n\n# 0-1 範囲にクリップ\npix = pix.clamp(0.0, 1.0)\n\n# 画像生成\nto_pil = T.ToPILImage()\nimg_processed = to_pil(pix)\n\n# ==== Colab 上で可視化 ====\nfig, axes = plt.subplots(1, 2, figsize=(12, 6))\n\naxes[0].imshow(img)\naxes[0].set_title("Original")\naxe

In [14]:
"""
from transformers import CLIPVisionModel

clip_vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")
print("clip_vision_tower\n", clip_vision_tower)
"""

'\nfrom transformers import CLIPVisionModel\n\nclip_vision_tower = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14-336")\nprint("clip_vision_tower\n", clip_vision_tower)\n'

In [15]:
"""
config_clip_vision_tower = clip_vision_tower.config
print("config_clip_vision_tower\n", config_clip_vision_tower)
"""

'\nconfig_clip_vision_tower = clip_vision_tower.config\nprint("config_clip_vision_tower\n", config_clip_vision_tower)\n'

In [16]:
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
import torch.nn as nn
# __init__
# load_model

# result = CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
class CLIPVisionTower(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        # result = CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
        print("current file path", "llava/llava/model/multimodal_encoder/clip_encoder.py")
        print("def CLIPVisionTower.__init__(self, vision_tower, args, delay_load=False)")
        print("self\n", type(self))
        print("vision_tower\n", vision_tower) # openai/clip-vit-large-patch14-336
        print("args\n", args) # ModelArguments(model_name_or_path='lmsys/vicuna-7b-v1.5', version='plain', freeze_backbone=False, tune_mm_mlp_adapter=True, vision_tower='openai/clip-vit-large-patch14-336', mm_vision_select_layer=-2, pretrain_mm_mlp_adapter=None, mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_patch_merge_type='flat', mm_vision_select_feature='patch')
        print("delay_load\n", delay_load) # False
        super().__init__()

        self.is_loaded = False

        print("self.is_loaded\n", self.is_loaded) # False

        self.vision_tower_name = vision_tower
        print("self.vision_tower_name\n", self.vision_tower_name) # openai/clip-vit-large-patch14-336
        self.select_layer = args.mm_vision_select_layer
        print("self.select_layer\n", self.select_layer) # -2
        self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
        print("self.select_feature\n", self.select_feature) # patch

        print(f"【COND】 delay_load={delay_load}")
        if not delay_load:
            # 【ENTER】
            print("【ENTER】if not delay_load:")
            self.load_model()
        elif getattr(args, 'unfreeze_mm_vision_tower', False):
            print("【ENTER】elif getattr(args, 'unfreeze_mm_vision_tower', False):")
            self.load_model()
            print("【EXIT】elif getattr(args, 'unfreeze_mm_vision_tower', False):")
        else:
            print("【ENTER】else of if not delay_load/elif getattr(args, 'unfreeze_mm_vision_tower', False):")
            self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
            print("self.cfg_only\n", self.cfg_only)
            print("【EXIT】else of if not delay_load/elif getattr(args, 'unfreeze_mm_vision_tower', False):")


    def load_model(self):

        print("current file path", "llava/llava/model/multimodal_encoder/clip_encoder.py")
        print("def CLIPVisionTower.load_model(self)")
        print("self\n", type(self))
        print("self.vision_tower_name\n", self.vision_tower_name) # openai/clip-vit-large-patch14-336
        self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
        print("self.image_processor\n", self.image_processor)
        """
        CLIPImageProcessor {
        "crop_size": {
            "height": 336,
            "width": 336
        },
        "do_center_crop": true,
        "do_convert_rgb": true,
        "do_normalize": true,
        "do_rescale": true,
        "do_resize": true,
        "feature_extractor_type": "CLIPFeatureExtractor",
        "image_mean": [
            0.48145466,
            0.4578275,
            0.40821073
        ],
        "image_processor_type": "CLIPImageProcessor",
        "image_std": [
            0.26862954,
            0.26130258,
            0.27577711
        ],
        "resample": 3,
        "rescale_factor": 0.00392156862745098,
        "size": {
            "shortest_edge": 336
        }
        }
        """
        self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
        print("self.vision_tower\n", self.vision_tower)
        """
        CLIPVisionModel(
        (vision_model): CLIPVisionTransformer(
            (embeddings): CLIPVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
            (position_embedding): Embedding(577, 1024)
            )
            (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (encoder): CLIPEncoder(
            (layers): ModuleList(
                (0-23): 24 x CLIPEncoderLayer(
                (self_attn): CLIPAttention(
                    (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                )
                (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (mlp): CLIPMLP(
                    (activation_fn): QuickGELUActivation()
                    (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                    (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                )
                (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                )
            )
            )
            (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
        )
        """
        self.vision_tower.requires_grad_(False)

        self.is_loaded = True
        print("self.is_loaded\n", self.is_loaded) # True

In [17]:
import os

def build_vision_tower(vision_tower_cfg, **kwargs):
    # vision_tower = build_vision_tower(model_args)
    print("current file path", "llava/llava/model/multimodal_encoder/builder.py")
    print("def build_vision_tower(vision_tower_cfg, **kwargs)")
    print("vision_tower_cfg\n", vision_tower_cfg) # ModelArguments(model_name_or_path='lmsys/vicuna-7b-v1.5', version='plain', freeze_backbone=False, tune_mm_mlp_adapter=True, vision_tower='openai/clip-vit-large-patch14-336', mm_vision_select_layer=-2, pretrain_mm_mlp_adapter=None, mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_patch_merge_type='flat', mm_vision_select_feature='patch')
    print("kwargs\n", kwargs) # {}
    vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
    print("vision_tower from vision_tower_cfg\n", vision_tower) # openai/clip-vit-large-patch14-336
    # ローカルに存在しない場合はFalse。存在する場合の例: /ubuntu/home/user/model/openai/clip-vit-large-patch14-336
    is_absolute_path_exists = os.path.exists(vision_tower)
    print("is_absolute_path_exists\n", is_absolute_path_exists) # False
    print(f"【COND】 is_absolute_path_exists={is_absolute_path_exists} vision_tower={vision_tower}") # is_absolute_path_exists=False vision_tower=openai/clip-vit-large-patch14-336
    if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
        # 【ENTER】
        print("【ENTER】if is_absolute_path_exists or vision_tower.startswith('openai') or vision_tower.startswith('laion') or 'ShareGPT4V' in vision_tower:")
        result = CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
        print("result (return)\n", result) # CLIPVisionTowerクラスのselfに登録されたモジュール一覧を出力する
        print("【EXIT】if is_absolute_path_exists or vision_tower.startswith('openai') or vision_tower.startswith('laion') or 'ShareGPT4V' in vision_tower:")
        return result

    print("print(risk): print(vision_tower) disabled for safety")
    raise ValueError(f'Unknown vision tower: {vision_tower}')

In [18]:
"""
build_vision_tower(model_args)
"""

'\nbuild_vision_tower(model_args)\n'

In [19]:
import re

def build_vision_projector(config, delay_load=False, **kwargs):

    print("current file path", "llava/llava/model/multimodal_projector/builder.py")
    print("def build_vision_projector(config, delay_load=False, **kwargs)")
    print("config\n", config)
    """
    config
    LlavaConfig {
    "_name_or_path": "lmsys/vicuna-7b-v1.5",
    "architectures": [
        "LlamaForCausalLM"
    ],
    "bos_token_id": 1,
    "eos_token_id": 2,
    "hidden_act": "silu",
    "hidden_size": 4096,
    "initializer_range": 0.02,
    "intermediate_size": 11008,
    "max_position_embeddings": 4096,
    "mm_hidden_size": 1024,
    "mm_patch_merge_type": "flat",
    "mm_projector_type": "mlp2x_gelu",
    "mm_vision_select_feature": "patch",
    "mm_vision_select_layer": -2,
    "mm_vision_tower": "openai/clip-vit-large-patch14-336",
    "model_type": "llava_llama",
    "num_attention_heads": 32,
    "num_hidden_layers": 32,
    "num_key_value_heads": 32,
    "pad_token_id": 0,
    "pretraining_tp": 1,
    "rms_norm_eps": 1e-05,
    "rope_scaling": null,
    "tie_word_embeddings": false,
    "torch_dtype": "float16",
    "transformers_version": "4.31.0",
    "use_cache": false,
    "use_mm_proj": true,
    "vocab_size": 32000
    }
    """
    print("delay_load\n", delay_load) # False
    print("kwargs\n", kwargs) # {}
    projector_type = getattr(config, 'mm_projector_type', 'linear')
    print("projector_type from config\n", projector_type) # mlp2x_gelu

    print("【COND】 projector_type\n", projector_type) # mlp2x_gelu
    if projector_type == 'linear':
      pass

    mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
    print("【COND】mlp_gelu_match\n", mlp_gelu_match)
    if mlp_gelu_match:
        #【ENTER】if mlp_gelu_match:
        print("【ENTER】if mlp_gelu_match:")
        mlp_depth = int(mlp_gelu_match.group(1))
        print("mlp_depth from mlp_gelu_match.group(1)\n", mlp_depth)
        modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
        print("modules after first Linear\n", modules)
        for _ in range(1, mlp_depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
        print("modules before Sequential\n", modules)
        result = nn.Sequential(*modules) # * はリストをアンパックして引数に展開する
        print("result (return)\n", result)
        """
        Sequential(
        (0): Linear(in_features=1024, out_features=4096, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=4096, out_features=4096, bias=True)
        )
        """
        print("【EXIT】if mlp_gelu_match:")
        return result

    print("【COND】projector_type\n", projector_type)
    if projector_type == 'identity':
      pass

    print("print(risk): print(projector_type) disabled for safety")
    raise ValueError(f'Unknown projector type: {projector_type}')

In [20]:
# LlavaMetaModel
# __init__
# get_vision_tower
# initialize_vision_modules
# unpad_image

class LlavaMetaModel:

    def __init__(self, config):

        print("current file path", "llava/model/llava_arch.py")
        print("LlavaMetaModel.__init__(self, config)")
        print("config\n", config)
        # LlamaModelの__init_を呼び出す
        super(LlavaMetaModel, self).__init__(config)

        print(f"【COND】 mm_vision_tower={hasattr(config, 'mm_vision_tower')}")
        if hasattr(config, "mm_vision_tower"):
            print("【ENTER】if hasattr(config, 'mm_vision_tower'):")
            self.vision_tower = build_vision_tower(config, delay_load=True)
            print("self.vision_tower\n", self.vision_tower)
            self.mm_projector = build_vision_projector(config)
            print("self.mm_projector\n", self.mm_projector)

            print("self.config.mm_patch_merge_type\n", self.config.mm_patch_merge_type)
            print(f"【COND】 unpad_in_mm_patch_merge_type={'unpad' in getattr(config, 'mm_patch_merge_type', '')}")
            if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
              pass

In [21]:
from transformers import LlamaConfig, LlamaModel

class LlavaConfig(LlamaConfig):
    model_type = "llava_llama"


class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
    config_class = LlavaConfig

    def __init__(self, config: LlamaConfig):

        print("current file path", "llava/llava/model/language_model/llava_llama.py")
        print("def LlavaLlamaModel.__init__(self, config: LlamaConfig)")
        print("self\n", type(self))
        print("config\n", config)
        super(LlavaLlamaModel, self).__init__(config)

In [22]:
# LlavaMetaForCausalLM
# get_vision_tower
# encode_images
# prepare_inputs_labels_for_multimodal
# initialize_vision_tokenizer

class LlavaMetaForCausalLM:

    def get_vision_tower(self):
        print("current file path", "llava/model/llava_arch.py")
        print("class LlavaMetaForCausalLM(ABC).get_vision_tower(self)")
        result = self.get_model().get_vision_tower()
        print("LlavaMetaForCausalLM(ABC).get_vision_tower(self) result (return)\n", result)
        """
        CLIPVisionTower(
        (vision_tower): CLIPVisionModel(
            (vision_model): CLIPVisionTransformer(
            (embeddings): CLIPVisionEmbeddings(
                (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
                (position_embedding): Embedding(577, 1024)
            )
            (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (encoder): CLIPEncoder(
                (layers): ModuleList(
                (0-23): 24 x CLIPEncoderLayer(
                    (self_attn): CLIPAttention(
                    (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                    )
                    (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                    (mlp): CLIPMLP(
                    (activation_fn): QuickGELUActivation()
                    (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                    (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                    )
                    (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                )
                )
            )
            (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
        )
        )
        """
        return result

In [23]:
from typing import List, Optional, Tuple, Union
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import LlamaForCausalLM

class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
    config_class = LlavaConfig

    def __init__(self, config):

        print("current file path", "llava/llava/model/language_model/llava_llama.py")
        print("def LlavaLlamaForCausalLM.__init__(self, config)")
        print("self\n", type(self))
        # config は https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json
        print("config\n", config)
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)
        # LlavaLlamaModelの初期化あと、LlavaMetaModelの初期化も呼ばれる。
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        print("self.model\n", self.model)
        """
        self.model
        LlavaLlamaModel(
        (embed_tokens): Embedding(32000, 4096, padding_idx=0)
        (layers): ModuleList(
            (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
                (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
                (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
                (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
                (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
                (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
                (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
                (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
                (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
                (act_fn): SiLUActivation()
            )
            (input_layernorm): LlamaRMSNorm()
            (post_attention_layernorm): LlamaRMSNorm()
            )
        )
        (norm): LlamaRMSNorm()
        )
        """
        print("self.pretraining_tp\n", self.pretraining_tp) # 1
        print("self.vocab_size\n", self.vocab_size) # 32_000
        print("self.lm_head\n", self.lm_head) # Linear(in_features=4096, out_features=32000, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

In [24]:
"""
from transformers import AutoConfig

# 公式 LLaMA-2-7B の config をロード
llama_config = AutoConfig.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

print(llama_config)
"""

'\nfrom transformers import AutoConfig\n\n# 公式 LLaMA-2-7B の config をロード\nllama_config = AutoConfig.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")\n\nprint(llama_config)\n'

In [25]:
"""
from transformers import AutoConfig

# まず config.json をロードして Config クラスを自動判別
config = AutoConfig.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir
)

print("model_args.model_name_or_path\n", model_args.model_name_or_path)
print("training_args.cache_dir\n", training_args.cache_dir)
print("")
print("Loaded config:\n", config)
"""

'\nfrom transformers import AutoConfig\n\n# まず config.json をロードして Config クラスを自動判別\nconfig = AutoConfig.from_pretrained(\n    model_args.model_name_or_path,\n    cache_dir=training_args.cache_dir\n)\n\nprint("model_args.model_name_or_path\n", model_args.model_name_or_path)\nprint("training_args.cache_dir\n", training_args.cache_dir)\nprint("")\nprint("Loaded config:\n", config)\n'

In [26]:
import inspect
print(inspect.getattr_static(LlavaLlamaModel, "__init__"))

<function LlavaLlamaModel.__init__ at 0x7efff1124040>


In [27]:
def print_mro(cls):
    print(f"MRO for {cls.__name__}:\n")
    for i, c in enumerate(cls.mro()):
        print(f"{i:2d}: {c.__module__}.{c.__name__}")

print_mro(LlavaLlamaModel)

MRO for LlavaLlamaModel:

 0: __main__.LlavaLlamaModel
 1: __main__.LlavaMetaModel
 2: transformers.models.llama.modeling_llama.LlamaModel
 3: transformers.models.llama.modeling_llama.LlamaPreTrainedModel
 4: transformers.modeling_utils.PreTrainedModel
 5: torch.nn.modules.module.Module
 6: transformers.modeling_utils.ModuleUtilsMixin
 7: transformers.generation.utils.GenerationMixin
 8: transformers.utils.hub.PushToHubMixin
 9: builtins.object


In [28]:
print_mro(LlavaMetaModel)

MRO for LlavaMetaModel:

 0: __main__.LlavaMetaModel
 1: builtins.object


In [29]:
print_mro(LlavaLlamaForCausalLM)

MRO for LlavaLlamaForCausalLM:

 0: __main__.LlavaLlamaForCausalLM
 1: transformers.models.llama.modeling_llama.LlamaForCausalLM
 2: transformers.models.llama.modeling_llama.LlamaPreTrainedModel
 3: transformers.modeling_utils.PreTrainedModel
 4: torch.nn.modules.module.Module
 5: transformers.modeling_utils.ModuleUtilsMixin
 6: transformers.generation.utils.GenerationMixin
 7: transformers.utils.hub.PushToHubMixin
 8: __main__.LlavaMetaForCausalLM
 9: builtins.object


In [30]:
model = LlavaLlamaForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
    **bnb_model_from_pretrained_args
)

You are using a model of type llama to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.


current file path llava/llava/model/language_model/llava_llama.py
def LlavaLlamaForCausalLM.__init__(self, config)
self
 <class '__main__.LlavaLlamaForCausalLM'>
config
 LlavaConfig {
  "_name_or_path": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 2048,
  "model_type": "llava_llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 22,
  "num_key_value_heads": 4,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.31.0",
  "use_cache": true,
  "vocab_size": 32000
}

current file path llava/llava/model/language_model/llava_llama.py
def LlavaLlamaModel.__init__(self, config: LlamaConfig)
sel

Some weights of LlavaLlamaForCausalLM were not initialized from the model checkpoint at TinyLlama/TinyLlama-1.1B-Chat-v1.0 and are newly initialized: ['model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotary_emb.inv_freq', 'model.layers.4.self_attn.rotary_emb.inv_freq', 'model.layers.16.self_attn.rotary_emb.inv_freq', 'model.layers.5.self_attn.rotary_emb.inv_freq', 'model.layers.13.self_attn.rotary_emb.inv_freq', 'model.layers.6.self_attn.rotary_emb.inv_freq', 'model.layers.15.self_attn.rotary_emb.inv_freq', 'model.layers.10.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_a

In [31]:
print("model\n", model)

model
 LlavaLlamaForCausalLM(
  (model): LlavaLlamaModel(
    (embed_tokens): Embedding(32000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm)

In [32]:
model.enable_input_require_grads()

In [33]:
import dataclasses
from typing import List
from enum import auto, Enum

class SeparatorStyle(Enum):
    """Different separator style."""
    SINGLE = auto()
    TWO = auto()
    MPT = auto()
    PLAIN = auto()
    LLAMA_2 = auto()

@dataclasses.dataclass
class Conversation:
    """A class that keeps all conversation history."""
    system: str
    roles: List[str]
    messages: List[List[str]]
    offset: int
    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
    sep: str = "###"
    sep2: str = None
    version: str = "Unknown"

    skip_next: bool = False


conv_llava_plain = Conversation(
    system="",
    roles=("", ""),
    messages=(
    ),
    offset=0,
    sep_style=SeparatorStyle.PLAIN,
    sep="\n",
)


conv_templates = {
    "plain": conv_llava_plain,
}

In [34]:
import inspect
print(inspect.getattr_static(LlamaForCausalLM, "from_pretrained"))
print(inspect.getattr_static(LlamaForCausalLM, "enable_input_require_grads"))

<classmethod(<function PreTrainedModel.from_pretrained at 0x7efff13c7f40>)>
<function PreTrainedModel.enable_input_require_grads at 0x7efff13c70a0>


In [35]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=False,
)

In [36]:
print("pad_token:", tokenizer.pad_token)
print("pad_token_id:", tokenizer.pad_token_id)
print("unk_token:", tokenizer.unk_token)
print("unk_token_id:", tokenizer.unk_token_id)
print("tokenizer\n", tokenizer)

pad_token: </s>
pad_token_id: 2
unk_token: <unk>
unk_token_id: 0
tokenizer
 LlamaTokenizer(name_or_path='TinyLlama/TinyLlama-1.1B-Chat-v1.0', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'pad_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False)}, clean_up_tokenization_spaces=False)


In [37]:
tokenizer.pad_token = tokenizer.unk_token

In [38]:
print("pad_token:", tokenizer.pad_token)
print("pad_token_id:", tokenizer.pad_token_id)
print("unk_token:", tokenizer.unk_token)
print("unk_token_id:", tokenizer.unk_token_id)
print("tokenizer\n", tokenizer)

pad_token: <unk>
pad_token_id: 0
unk_token: <unk>
unk_token_id: 0
tokenizer
 LlamaTokenizer(name_or_path='TinyLlama/TinyLlama-1.1B-Chat-v1.0', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'pad_token': '<unk>'}, clean_up_tokenization_spaces=False)


In [39]:
default_conversation = conv_templates[model_args.version]
print("default_conversation\n", default_conversation)

default_conversation
 Conversation(system='', roles=('', ''), messages=(), offset=0, sep_style=<SeparatorStyle.PLAIN: 4>, sep='\n', sep2=None, version='Unknown', skip_next=False)


In [40]:
print("model_args.vision_tower\n", model_args.vision_tower)

model_args.vision_tower
 openai/clip-vit-large-patch14-336


In [41]:
def get_model(self):

    print("current file path", "llava/llava/model/language_model/llava_llama.py")
    print("def LlavaLlamaForCausalLM.get_model(self)")
    print("self\n", type(self))
    print("self.model (return)\n", self.model)
    return self.model

In [42]:
LlavaLlamaForCausalLM.get_model = get_model

In [43]:
initial_model = model.get_model()

current file path llava/llava/model/language_model/llava_llama.py
def LlavaLlamaForCausalLM.get_model(self)
self
 <class '__main__.LlavaLlamaForCausalLM'>
self.model (return)
 LlavaLlamaModel(
  (embed_tokens): Embedding(32000, 2048, padding_idx=0)
  (layers): ModuleList(
    (0-21): 22 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2048, out_features=256, bias=False)
        (v_proj): Linear(in_features=2048, out_features=256, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
        (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
        (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
        (act_fn): SiLUActivation()
      )
      (input_l

In [44]:
def config(self):

    print("current file path", "llava/llava/model/multimodal_encoder/clip_encoder.py")
    print("def CLIPVisionTower.config(self)")
    print("self\n", type(self))
    print("self.is_loaded\n", self.is_loaded) # True
    print(f"【COND】 is_loaded={self.is_loaded}")
    if self.is_loaded:
        # 【ENTER】
        print("【ENTER】if self.is_loaded:")
        result = self.vision_tower.config
        print("result (return)\n", type(result))
        print("【EXIT】if self.is_loaded:")
    else:
      pass
    print("result (return)\n", result)
    """
    CLIPVisionConfig {
    "_name_or_path": "openai/clip-vit-large-patch14-336",
    "attention_dropout": 0.0,
    "dropout": 0.0,
    "hidden_act": "quick_gelu",
    "hidden_size": 1024,
    "image_size": 336,
    "initializer_factor": 1.0,
    "initializer_range": 0.02,
    "intermediate_size": 4096,
    "layer_norm_eps": 1e-05,
    "model_type": "clip_vision_model",
    "num_attention_heads": 16,
    "num_channels": 3,
    "num_hidden_layers": 24,
    "patch_size": 14,
    "projection_dim": 768,
    "transformers_version": "4.31.0"
    }
    """
    return result

In [45]:
def hidden_size(self):

    print("current file path", "llava/llava/model/multimodal_encoder/clip_encoder.py")
    print("def CLIPVisionTower.hidden_size(self)")
    print("self\n", type(self))
    result = self.config.hidden_size
    print("result (return), self.config.hidden_size\n", result) # 1024
    return result

In [46]:
CLIPVisionTower.config = property(config)

In [47]:
CLIPVisionTower.hidden_size = property(hidden_size)

In [48]:
def initialize_vision_modules(self, model_args, fsdp=None):

  print("current file path", "llava/model/llava_arch.py")
  print("def initialize_vision_modules(self, model_args, fsdp=None)")
  print("model_args\n", model_args) #  ModelArguments(model_name_or_path='lmsys/vicuna-7b-v1.5', version='plain', freeze_backbone=False, tune_mm_mlp_adapter=True, vision_tower='openai/clip-vit-large-patch14-336', mm_vision_select_layer=-2, pretrain_mm_mlp_adapter=None, mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_patch_merge_type='flat', mm_vision_select_feature='patch')
  print("fsdp\n", fsdp) # []
  vision_tower = model_args.vision_tower
  print("vision_tower from model_args\n", vision_tower) # openai/clip-vit-large-patch14-336
  mm_vision_select_layer = model_args.mm_vision_select_layer
  print("mm_vision_select_layer from model_args\n", mm_vision_select_layer) # -2
  mm_vision_select_feature = model_args.mm_vision_select_feature
  print("mm_vision_select_feature from model_args\n", mm_vision_select_feature) # patch
  pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
  print("pretrain_mm_mlp_adapter from model_args\n", pretrain_mm_mlp_adapter) # None
  mm_patch_merge_type = model_args.mm_patch_merge_type
  # 下記はself.config.mm_vision_towerに関するもの。self.vision_towerは依然としてNone
  self.config.mm_vision_tower = vision_tower
  print("self.config.mm_vision_tower\n", self.config.mm_vision_tower) # None

  print("【COND】 self.get_vision_tower()\n", self.get_vision_tower()) # None
  print(f"【COND】 get_vision_tower_is_None={self.get_vision_tower() is None}")
  if self.get_vision_tower() is None:
      #【ENTER】self.vision_tower, self.get_vision_towerはNoneなのでこの分岐に入る。
      print("【ENTER】if self.get_vision_tower() is None:")
      print("[ENTER] self.get_vision_tower() is None")
      # build_vision_tower(model_args) はちょっと奥の依存関係が深い
      vision_tower = build_vision_tower(model_args)
      print("vision_tower after build_vision_tower\n", vision_tower)
      """
      CLIPVisionTower(
      (vision_tower): CLIPVisionModel(
      (vision_model): CLIPVisionTransformer(
          (embeddings): CLIPVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
          (position_embedding): Embedding(577, 1024)
          )
          (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (encoder): CLIPEncoder(
          (layers): ModuleList(
              (0-23): 24 x CLIPEncoderLayer(
              (self_attn): CLIPAttention(
                  (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
              )
              (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): CLIPMLP(
                  (activation_fn): QuickGELUActivation()
                  (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                  (fc2): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              )
          )
          )
          (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      )
      )
      """
      # 分散学習(FSDP)を使うかどうか. 今回は [] 空のリストとなるので、Noneではないが、len(fsdp) == 0
      print("【COND】 fsdp\n", fsdp) # []
      print(f"【COND】 fsdp_is_not_None={fsdp is not None} len_fsdp={len(fsdp) if fsdp is not None else 'N/A'}") # fsdp_is_not_None=True len_fsdp=0
      if fsdp is not None and len(fsdp) > 0:
        pass
      else:
          # 【ENTER】else of if fsdp is not None and len(fsdp) > 0:
          print("【COND】 else_fsdp_is_not_None_and_len_fsdp_gt_0=True")
          print("【ENTER】else of if fsdp is not None and len(fsdp) > 0:")
          self.vision_tower = vision_tower
          print("self.vision_tower\n", self.vision_tower)
          """
          CLIPVisionTower(
          (vision_tower): CLIPVisionModel(
              (vision_model): CLIPVisionTransformer(
              (embeddings): CLIPVisionEmbeddings(
                  (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
                  (position_embedding): Embedding(577, 1024)
              )
              (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (encoder): CLIPEncoder(
                  (layers): ModuleList(
                  (0-23): 24 x CLIPEncoderLayer(
                      (self_attn): CLIPAttention(
                      (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                      (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                      (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                      (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                      )
                      (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                      (mlp): CLIPMLP(
                      (activation_fn): QuickGELUActivation()
                      (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                      (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                      )
                      (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                  )
                  )
              )
              (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              )
          )
          )
          """
          print("【EXIT】else of if fsdp is not None and len(fsdp) > 0:")

      print("【EXIT】if self.get_vision_tower() is None:")
  else:
    pass

  self.config.use_mm_proj = True
  print("self.config.use_mm_proj set to True") # True
  self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
  print("self.config.mm_projector_type\n", self.config.mm_projector_type) # mlp2x_gelu
  self.config.mm_hidden_size = vision_tower.hidden_size
  print("self.config.mm_hidden_size\n", self.config.mm_hidden_size) # 1024
  self.config.mm_vision_select_layer = mm_vision_select_layer
  print("self.config.mm_vision_select_layer\n", self.config.mm_vision_select_layer) # -2
  self.config.mm_vision_select_feature = mm_vision_select_feature
  print("self.config.mm_vision_select_feature\n", self.config.mm_vision_select_feature) # patch
  self.config.mm_patch_merge_type = mm_patch_merge_type
  print("self.config.mm_patch_merge_type\n", self.config.mm_patch_merge_type) # flat

  # mm_projector_is_None=True
  print(f"【COND】 mm_projector_is_None={getattr(self, 'mm_projector', None) is None}")
  if getattr(self, 'mm_projector', None) is None:
      # 【ENTER】
      print("【ENTER】if getattr(self, 'mm_projector', None) is None:")
      self.mm_projector = build_vision_projector(self.config)
      """
      Sequential(
        (0): Linear(in_features=1024, out_features=2048, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=2048, out_features=2048, bias=True)
      )
      """
      print("self.mm_projector after build_vision_projector\n", self.mm_projector)
      print("mm_patch_merge_type\n", mm_patch_merge_type) # flat
      print(f"【COND】 unpad_in_mm_patch_merge_type={'unpad' in mm_patch_merge_type}")
      if 'unpad' in mm_patch_merge_type:
        pass
      print("【EXIT】if getattr(self, 'mm_projector', None) is None:")
  else:
    pass

  print(f"【COND】 pretrain_mm_mlp_adapter_is_not_None={pretrain_mm_mlp_adapter is not None}")
  if pretrain_mm_mlp_adapter is not None:
    pass

In [49]:
LlavaMetaModel.initialize_vision_modules = initialize_vision_modules

In [50]:
def get_vision_tower(self):

    print("current file path", "llava/model/llava_arch.py")
    print("def get_vision_tower(self)")
    vision_tower = getattr(self, 'vision_tower', None)
    print("vision_tower (raw)\n", vision_tower)
    """
    CLIPVisionTower(
    (vision_tower): CLIPVisionModel(
        (vision_model): CLIPVisionTransformer(
        (embeddings): CLIPVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
            (position_embedding): Embedding(577, 1024)
        )
        (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder): CLIPEncoder(
            (layers): ModuleList(
            (0-23): 24 x CLIPEncoderLayer(
                (self_attn): CLIPAttention(
                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                )
                (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (mlp): CLIPMLP(
                (activation_fn): QuickGELUActivation()
                (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                )
                (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            )
        )
        (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
    )
    )
    """
    print("type(vision_tower)\n", type(vision_tower))
    print(f"【COND】 type_vision_tower_is_list={type(vision_tower) is list}")  # False
    if type(vision_tower) is list:
        # 【SKIP】
        print("【ENTER】if type(vision_tower) is list:")
        vision_tower = vision_tower[0]
        print("【EXIT】if type(vision_tower) is list:")
    print("vision_tower (return)\n", vision_tower)
    """
    vision_tower (return)
    CLIPVisionTower(
    (vision_tower): CLIPVisionModel(
        (vision_model): CLIPVisionTransformer(
        (embeddings): CLIPVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
            (position_embedding): Embedding(577, 1024)
        )
        (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (encoder): CLIPEncoder(
            (layers): ModuleList(
            (0-23): 24 x CLIPEncoderLayer(
                (self_attn): CLIPAttention(
                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                )
                (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (mlp): CLIPMLP(
                (activation_fn): QuickGELUActivation()
                (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                )
                (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            )
        )
        (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
    )
    )
    """
    return vision_tower

In [51]:
LlavaMetaModel.get_vision_tower = get_vision_tower

In [52]:
initial_model.initialize_vision_modules(
    model_args=model_args,
    fsdp=training_args.fsdp
)

current file path llava/model/llava_arch.py
def initialize_vision_modules(self, model_args, fsdp=None)
model_args
 ModelArguments(model_name_or_path='TinyLlama/TinyLlama-1.1B-Chat-v1.0', version='plain', freeze_backbone=False, tune_mm_mlp_adapter=True, vision_tower='openai/clip-vit-large-patch14-336', mm_vision_select_layer=-2, pretrain_mm_mlp_adapter=None, mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_patch_merge_type='flat', mm_vision_select_feature='patch')
fsdp
 []
vision_tower from model_args
 openai/clip-vit-large-patch14-336
mm_vision_select_layer from model_args
 -2
mm_vision_select_feature from model_args
 patch
pretrain_mm_mlp_adapter from model_args
 None
self.config.mm_vision_tower
 openai/clip-vit-large-patch14-336
current file path llava/model/llava_arch.py
def get_vision_tower(self)
vision_tower (raw)
 None
type(vision_tower)
 <class 'NoneType'>
【COND】 type_vision_tower_is_list=False
vision_tower (return)
 None
【COND】 self.get

self.vision_tower
 CLIPVisionModel(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (position_embedding): Embedding(577, 1024)
    )
    (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-23): 24 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=1024, out_f

In [53]:
vision_tower = model.get_vision_tower()
print("vision_tower\n", vision_tower)
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)

data_args.image_processor = vision_tower.image_processor
print("data_args.image_processor\n", data_args.image_processor)
data_args.is_multimodal = True
print("data_args.is_multimodal\n", data_args.is_multimodal) # True

model.config.image_aspect_ratio = data_args.image_aspect_ratio
print("model.config.image_aspect_ratio\n", model.config.image_aspect_ratio) # square
model.config.tokenizer_padding_side = tokenizer.padding_side
print("model.config.tokenizer_padding_side\n", model.config.tokenizer_padding_side) # right
model.config.tokenizer_model_max_length = tokenizer.model_max_length
print("model.config.tokenizer_model_max_length\n", model.config.tokenizer_model_max_length) # 2048

current file path llava/model/llava_arch.py
class LlavaMetaForCausalLM(ABC).get_vision_tower(self)
current file path llava/llava/model/language_model/llava_llama.py
def LlavaLlamaForCausalLM.get_model(self)
self
 <class '__main__.LlavaLlamaForCausalLM'>
self.model (return)
 LlavaLlamaModel(
  (embed_tokens): Embedding(32000, 2048, padding_idx=0)
  (layers): ModuleList(
    (0-21): 22 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2048, out_features=256, bias=False)
        (v_proj): Linear(in_features=2048, out_features=256, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
        (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
        (down_proj): Linear(in_feat

In [54]:
model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
print(f"【COND】 tune_mm_mlp_adapter={model_args.tune_mm_mlp_adapter}") # True
if model_args.tune_mm_mlp_adapter:
    # 【ENTER】 tune_mm_mlp_adapter=True なので、この分岐に入る
    print("【ENTER】if model_args.tune_mm_mlp_adapter:")
    # モデル全体の全パラメータを「学習不可（requires_grad=False）」にする
    # これで通常の重みは全て凍結される
    model.requires_grad_(False)
    for p in model.get_model().mm_projector.parameters():
        # mm_projector（画像特徴量→テキスト特徴量への変換層）の全パラメータだけを「学習可能（requires_grad=True）」に戻す
        # これで mm_projector のみ学習されることになる
        print("model.get_model().mm_projector.parameters()", model.get_model().mm_projector.parameters())
        p.requires_grad = True
    print("【EXIT】if model_args.tune_mm_mlp_adapter:")

【COND】 tune_mm_mlp_adapter=True
【ENTER】if model_args.tune_mm_mlp_adapter:
current file path llava/llava/model/language_model/llava_llama.py
def LlavaLlamaForCausalLM.get_model(self)
self
 <class '__main__.LlavaLlamaForCausalLM'>
self.model (return)
 LlavaLlamaModel(
  (embed_tokens): Embedding(32000, 2048, padding_idx=0)
  (layers): ModuleList(
    (0-21): 22 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2048, out_features=256, bias=False)
        (v_proj): Linear(in_features=2048, out_features=256, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
        (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
        (down_proj): Linear(in_features=5632, out_features=2

In [55]:
model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
print(f"【COND】 freeze_mm_mlp_adapter={training_args.freeze_mm_mlp_adapter}") # False
if training_args.freeze_mm_mlp_adapter:
  pass

print(f"【COND】 bits={training_args.bits}") # 16
if training_args.bits in [4, 8]:
  pass

【COND】 freeze_mm_mlp_adapter=False
【COND】 bits=16


In [56]:
def initialize_vision_tokenizer(self, model_args, tokenizer):
    print("current file path", "llava/model/llava_arch.py")
    print("def initialize_vision_tokenizer(self, model_args, tokenizer)")
    print("model_args\n", model_args) # ModelArguments(model_name_or_path='lmsys/vicuna-7b-v1.5', version='plain', freeze_backbone=False, tune_mm_mlp_adapter=True, vision_tower='openai/clip-vit-large-patch14-336', mm_vision_select_layer=-2, pretrain_mm_mlp_adapter=None, mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_patch_merge_type='flat', mm_vision_select_feature='patch')
    print("tokenizer\n", tokenizer) # LlamaTokenizer(name_or_path='lmsys/vicuna-7b-v1.5', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'pad_token': '<unk>'}, clean_up_tokenization_spaces=False)

    print(f"【COND】 mm_use_im_patch_token={model_args.mm_use_im_patch_token}") # False
    if model_args.mm_use_im_patch_token:
      pass

    if model_args.mm_use_im_start_end: # False
      pass

    elif model_args.mm_use_im_patch_token: # False
      pass

In [57]:
LlavaLlamaForCausalLM.initialize_vision_tokenizer = initialize_vision_tokenizer

In [58]:
model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
print("model_args.mm_use_im_start_end", model_args.mm_use_im_start_end)
model.config.mm_projector_lr = training_args.mm_projector_lr
print("training_args.mm_projector_lr", training_args.mm_projector_lr)
training_args.use_im_start_end = model_args.mm_use_im_start_end
print("training_args.use_im_start_end", training_args.use_im_start_end)
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
print("model_args.mm_use_im_patch_token", model_args.mm_use_im_patch_token)
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
print("【EXIT】if model_args.vision_tower is not None:")

model_args.mm_use_im_start_end False
training_args.mm_projector_lr None
training_args.use_im_start_end False
model_args.mm_use_im_patch_token False
current file path llava/model/llava_arch.py
def initialize_vision_tokenizer(self, model_args, tokenizer)
model_args
 ModelArguments(model_name_or_path='TinyLlama/TinyLlama-1.1B-Chat-v1.0', version='plain', freeze_backbone=False, tune_mm_mlp_adapter=True, vision_tower='openai/clip-vit-large-patch14-336', mm_vision_select_layer=-2, pretrain_mm_mlp_adapter=None, mm_projector_type='mlp2x_gelu', mm_use_im_start_end=False, mm_use_im_patch_token=False, mm_patch_merge_type='flat', mm_vision_select_feature='patch')
tokenizer
 LlamaTokenizer(name_or_path='TinyLlama/TinyLlama-1.1B-Chat-v1.0', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, l

In [59]:
def rank0_print(*args):

    print("current file path", "llava/train/train.py")
    print("def rank0_print(*args)")
    print("args\n", args) # ('Formatting inputs...Skip in lazy mode',)
    if local_rank == 0:
        print(*args)

In [60]:
from torch.utils.data import Dataset
import json

class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str,
                 tokenizer: transformers.PreTrainedTokenizer,
                 data_args: DataArguments):

        print("current file path", "llava/train/train.py")
        print("def LazySupervisedDataset.__init__(self, data_path, tokenizer, data_args)")
        print("data_path\n", data_path) # /content/LLaVA/blip_laion_cc_sbu_1.json
        print("tokenizer\n", type(tokenizer)) # <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>
        print("data_args\n", data_args) # DataArguments(data_path='/content/LLaVA/blip_laion_cc_sbu_1.json', lazy_preprocess=True, is_multimodal=True, image_folder='/content/LLaVA/images', image_aspect_ratio='square')
        super(LazySupervisedDataset, self).__init__()
        list_data_dict = json.load(open(data_path, "r"))
        # 今回は1サンプルだけなのでprintしても危険ではない
        print("list_data_dict", list_data_dict)

        rank0_print("Formatting inputs...Skip in lazy mode") # Formatting inputs...Skip in lazy mode
        self.tokenizer = tokenizer
        print("self.tokenizer\n", self.tokenizer)
        self.list_data_dict = list_data_dict
        print("self.list_data_dict\n", self.list_data_dict)
        self.data_args = data_args
        print("self.data_args\n", self.data_args)

In [61]:
def __len__(self):

    print("current file path", "llava/train/train.py")
    print("def LazySupervisedDataset.__len__(self)")
    return len(self.list_data_dict)

In [62]:
LazySupervisedDataset.__len__ = __len__

In [63]:
from typing import Sequence
from typing import Dict

# Trainer > def _get_dataloader > dataloader_params = {..."collate_fn": data_collator,...}
# self.accelerator.prepare(DataLoader(dataset, **dataloader_params)) で呼ばれる

@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

        print("current file path", "llava/train/train.py")
        print("def DataCollatorForSupervisedDataset.__call__(self, instances)")
        print("instances\n", instances)
        #  [(torch.Size([24]), torch.Size([24]), torch.Size([3, 336, 336]))]
        print("shape of each instance's input_ids and labels, and images(if any):", [(x['input_ids'].shape, x['labels'].shape, x.get('image', None).shape if 'image' in x else None) for x in instances])
        # データローダーが None を返すことがあるので、Noneのサンプルを除外。
        instances = [x for x in instances if x is not None]
        # input_idsとlabelsのそれぞれについてリストを作成。タプルをつくる。
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))
        # input_idsはtokenizerのpad_token_id(0)でパディング
        print("self.tokenizer.pad_token_id\n", self.tokenizer.pad_token_id)
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        # labelsはIGNORE_INDEX(-100)でパディング
        print("IGNORE_INDEX\n", IGNORE_INDEX)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        print("input_ids.shape (after pad_sequence and truncate)\n", input_ids.shape)
        print("input_ids (after pad_sequence and truncate)\n", input_ids)
        labels = labels[:, :self.tokenizer.model_max_length]
        print("labels.shape (after pad_sequence and truncate)\n", labels.shape)
        print("labels (after pad_sequence and truncate)\n", labels)
        # .ne() は "not equal" → pad_token_id(=0) じゃない部分を 1、pad 部分を 0 にする。モデルが pad 部分を読まないように制御するマスクです。
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )

        if 'image' in instances[0]:
            images = [instance['image'] for instance in instances]
            if all(x is not None and x.shape == images[0].shape for x in images):
                batch['images'] = torch.stack(images)
            else:
                batch['images'] = images
            print("batch['images'].shape\n", batch['images'].shape)
        
        print("batch (return)\n", batch)
        print("shape of each batch's input_ids and labels, and images(if any):", [(batch['input_ids'].shape, batch['labels'].shape, batch.get('images', None).shape if 'images' in batch else None)])
        return batch

In [64]:
from typing import Dict

def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
                                data_args) -> Dict:

    print("current file path", "llava/train/train.py")
    print("def make_supervised_data_module(tokenizer, data_args)")
    print("tokenizer\n", type(tokenizer))
    print("data_args\n", data_args) #  DataArguments(data_path='/content/LLaVA/blip_laion_cc_sbu_1.json', lazy_preprocess=True, is_multimodal=True, image_folder='/content/LLaVA/images', image_aspect_ratio='square')
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
                                data_path=data_args.data_path,
                                data_args=data_args)
    print("train_dataset\n", train_dataset) # <llava.train.train.LazySupervisedDataset object at 0x7ed6341f4880>
    print("len(train_dataset)\n", len(train_dataset)) # 1
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    print("data_collator\n", data_collator) # DataCollatorForSupervisedDataset(tokenizer=LlamaTokenizer(name_or_path='lmsys/vicuna-7b-v1.5', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'pad_token': '<unk>'}, clean_up_tokenization_spaces=False))
    result = dict(train_dataset=train_dataset,
                  eval_dataset=None,
                  data_collator=data_collator)
    print("def make_supervised_data_module: result (return)\n", result) # {'train_dataset': <llava.train.train.LazySupervisedDataset object at 0x7ed6341f4880>, 'eval_dataset': None, 'data_collator': DataCollatorForSupervisedDataset(tokenizer=LlamaTokenizer(name_or_path='lmsys/vicuna-7b-v1.5', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'pad_token': '<unk>'}, clean_up_tokenization_spaces=False))}
    return result

In [65]:
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
print("data_module\n", data_module)

current file path llava/train/train.py
def make_supervised_data_module(tokenizer, data_args)
tokenizer
 <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>
data_args
 DataArguments(data_path='/workspaces/LLaVA/blip_laion_cc_sbu_1.json', lazy_preprocess=True, is_multimodal=True, image_folder='/workspaces/LLaVA/images/', image_aspect_ratio='square')
current file path llava/train/train.py
def LazySupervisedDataset.__init__(self, data_path, tokenizer, data_args)
data_path
 /workspaces/LLaVA/blip_laion_cc_sbu_1.json
tokenizer
 <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>
data_args
 DataArguments(data_path='/workspaces/LLaVA/blip_laion_cc_sbu_1.json', lazy_preprocess=True, is_multimodal=True, image_folder='/workspaces/LLaVA/images/', image_aspect_ratio='square')
list_data_dict [{'id': '000406392', 'image': 'GCC_train_000406392.jpg', 'conversations': [{'from': 'human', 'value': 'Give a brief description of the image.\n<image>'}, {'from': 'gpt', 'valu

In [66]:
from transformers import Trainer
from transformers.trainer import (
    is_sagemaker_mp_enabled,
    get_parameter_names,
    has_length,
    ALL_LAYERNORM_LAYERS,
    ShardedDDPOption,
    logger,
)
    

In [67]:
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):

    print("current file path", "llava/mm_utils.py")
    print("def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None)")
    print("prompt\n", prompt) # <image>the divine queen in her elaborate masks canvas print featuring the face and hands of a woman with red hair
    print("tokenizer\n", tokenizer) #  LlamaTokenizer(name_or_path='lmsys/vicuna-7b-v1.5', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'pad_token': '<unk>'}, clean_up_tokenization_spaces=False)
    print("image_token_index\n", image_token_index) # -200
    print("return_tensors\n", return_tensors) # pt
    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]

    def insert_separator(X, sep):
        return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]

    input_ids = []
    offset = 0
    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
        offset = 1
        input_ids.append(prompt_chunks[0][0])

    for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
        input_ids.extend(x[offset:])

    if return_tensors is not None:
        if return_tensors == 'pt':
            return torch.tensor(input_ids, dtype=torch.long)
        raise ValueError(f'Unsupported tensor type: {return_tensors}')
    print("input_ids (return)\n", input_ids)
    return input_ids

In [68]:
import copy

def preprocess_plain(
    sources: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:

    print("current file path", "llava/train/train.py")
    print("def preprocess_plain(sources, tokenizer)")
    print("sources\n", sources) # [[{'from': 'human', 'value': '<image>\nGive a brief description of the image.'}, {'from': 'gpt', 'value': 'the divine queen in her elaborate masks canvas print featuring the face and hands of a woman with red hair'}]]
    print("tokenizer\n", type(tokenizer)) # <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>
    # add end signal and concatenate together
    conversations = []
    print("conversations initial\n", conversations) # []
    for source in sources:
        print("source current loop\n", source) 
        assert len(source) == 2
        assert DEFAULT_IMAGE_TOKEN in source[0]['value']
        source[0]['value'] = DEFAULT_IMAGE_TOKEN
        conversation = source[0]['value'] + source[1]['value'] + default_conversation.sep
        print("conversation current loop\n", conversation)
        conversations.append(conversation)
    print("conversations (final)\n", conversations) #  ['<image>the divine queen in her elaborate masks canvas print featuring the face and hands of a woman with red hair\n']
    # tokenize conversations
    input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
    print("input_ids\n", input_ids) # [tensor([    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879, 10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114, 411,  2654, 11315,    13])]
    for idx, tensor in enumerate(input_ids):
        if hasattr(tensor, 'shape'):
            print(f"input_ids[{idx}].shape\n", tensor.shape) # torch.Size([24])
    targets = copy.deepcopy(input_ids)
    print("targets\n", targets) # [tensor([    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879, 10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114, 411,  2654, 11315,    13])]
    for idx, tensor in enumerate(targets):
        if hasattr(tensor, 'shape'):
            print(f"targets[{idx}].shape\n", tensor.shape) # torch.Size([24])
    print("sources\n", sources) # [[{'from': 'human', 'value': '<image>'}, {'from': 'gpt', 'value': 'the divine queen in her elaborate masks canvas print featuring the face and hands of a woman with red hair'}]]
    for target, source in zip(targets, sources):
        tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) # prompt <image>
        target[:tokenized_len] = IGNORE_INDEX

    print("input_ids (return)\n", input_ids) # [tensor([    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879, 10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114, 411,  2654, 11315,    13])]
    print("targets (return)\n", targets) #  [tensor([ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879, 10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114, 411,  2654, 11315,    13])]
    return dict(input_ids=input_ids, labels=targets)

In [69]:
def _add_speaker_and_signal(header, source, get_conversation=True):

    print("current file path", "llava/train/train.py")
    print("def _add_speaker_and_signal(header, source, get_conversation=True)")
    print("header\n", header)
    print("source\n", source)
    print("get_conversation\n", get_conversation)
    """Add speaker and start/end signal on each round."""
    BEGIN_SIGNAL = "### "
    END_SIGNAL = "\n"
    conversation = header
    for sentence in source:
        from_str = sentence["from"]
        if from_str.lower() == "human":
            from_str = default_conversation.roles[0]
        elif from_str.lower() == "gpt":
            from_str = default_conversation.roles[1]
        else:
            from_str = 'unknown'
        sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
                             sentence["value"] + END_SIGNAL)
        if get_conversation:
            conversation += sentence["value"]
    conversation += BEGIN_SIGNAL
    return conversation

In [70]:
def _tokenize_fn(strings: Sequence[str],
                 tokenizer: transformers.PreTrainedTokenizer) -> Dict:

    print("current file path", "llava/train/train.py")
    print("def _tokenize_fn(strings, tokenizer)")
    print("strings\n", strings)
    print("tokenizer\n", type(tokenizer))
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        ) for text in strings
    ]
    input_ids = labels = [
        tokenized.input_ids[0] for tokenized in tokenized_list
    ]
    for idx, tensor in enumerate(input_ids):
        if hasattr(tensor, 'shape'):
            print(f"input_ids[{idx}].shape\n", tensor.shape)
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
        for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )

In [71]:
def _mask_targets(target, tokenized_lens, speakers):

    print("current file path", "llava/train/train.py")
    print("def _mask_targets(target, tokenized_lens, speakers)")
    print("target\n", target)
    print("tokenized_lens\n", tokenized_lens)
    print("speakers\n", speakers)
    # cur_idx = 0
    cur_idx = tokenized_lens[0]
    tokenized_lens = tokenized_lens[1:]
    target[:cur_idx] = IGNORE_INDEX
    for tokenized_len, speaker in zip(tokenized_lens, speakers):
        if speaker == "human":
            target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
        cur_idx += tokenized_len

In [72]:
def preprocess(
    sources: Sequence[str],
    tokenizer: transformers.PreTrainedTokenizer,
    has_image: bool = False
) -> Dict:

    print("current file path", "llava/train/train.py")
    print("def preprocess(sources, tokenizer, has_image=False)")
    print("sources\n", sources) # [[{'from': 'human', 'value': '<image>\nGive a brief description of the image.'}, {'from': 'gpt', 'value': 'the divine queen in her elaborate masks canvas print featuring the face and hands of a woman with red hair'}]]
    print("tokenizer\n", type(tokenizer)) # <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>
    print("has_image\n", has_image) # True
    """
    Given a list of sources, each is a conversation list. This transform:
    1. Add signal '### ' at the beginning each sentence, with end signal '\n';
    2. Concatenate conversations together;
    3. Tokenize the concatenated conversation;
    4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
    """
    if default_conversation.sep_style == SeparatorStyle.PLAIN:
        return preprocess_plain(sources, tokenizer) # True
    # add end signal and concatenate together
    conversations = []
    for source in sources:
        header = f"{default_conversation.system}\n\n"
        conversation = _add_speaker_and_signal(header, source)
        conversations.append(conversation)
    # tokenize conversations
    def get_tokenize_len(prompts):
        return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]

    if has_image:
        input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
        for idx, tensor in enumerate(input_ids):
            if hasattr(tensor, 'shape'):
                print(f"input_ids[{idx}].shape\n", tensor.shape)
    else:
        conversations_tokenized = _tokenize_fn(conversations, tokenizer)
        input_ids = conversations_tokenized["input_ids"]

    targets = copy.deepcopy(input_ids)
    if isinstance(targets, list):
        for idx, tensor in enumerate(targets):
            if hasattr(tensor, 'shape'):
                print(f"targets[{idx}].shape\n", tensor.shape)
    elif hasattr(targets, 'shape'):
        print("targets.shape\n", targets.shape)
    for target, source in zip(targets, sources):
        if has_image:
            tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
        else:
            tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
        speakers = [sentence["from"] for sentence in source]
        _mask_targets(target, tokenized_lens, speakers)

    print("return dict(input_ids=input_ids, labels=targets)\n", dict(input_ids=input_ids, labels=targets))
    return dict(input_ids=input_ids, labels=targets)

In [73]:
def preprocess_multimodal(
    sources: Sequence[str],
    data_args: DataArguments
) -> Dict:

    print("current file path", "llava/train/train.py")
    print("def preprocess_multimodal(sources, data_args)")
    print("sources\n", sources) # [[{'from': 'human', 'value': 'Give a brief description of the image.\n<image>'}, {'from': 'gpt', 'value': 'the divine queen in her elaborate masks canvas print featuring the face and hands of a woman with red hair'}]]
    print("data_args\n", data_args) # DataArguments(data_path='/content/LLaVA/blip_laion_cc_sbu_1.json', lazy_preprocess=True, is_multimodal=True, image_folder='/content/LLaVA/images', image_aspect_ratio='square')
    is_multimodal = data_args.is_multimodal 
    print("is_multimodal\n", is_multimodal) # True
    if not is_multimodal:
        pass

    for source in sources:
        print("source current loop\n", source)
        for sentence in source:
            print("sentence current loop\n", sentence)
            print("【COND】 if DEFAULT_IMAGE_TOKEN in sentence['value']:", DEFAULT_IMAGE_TOKEN in sentence['value'])
            print("sentence['value']\n", sentence['value'])
            print("DEFAULT_IMAGE_TOKEN\n", DEFAULT_IMAGE_TOKEN)
            if DEFAULT_IMAGE_TOKEN in sentence['value']:
                print("【ENTER】if DEFAULT_IMAGE_TOKEN in sentence['value']:")
                sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
                sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
                sentence['value'] = sentence['value'].strip()
                if "mmtag" in default_conversation.version:
                    sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '<Image>' + DEFAULT_IMAGE_TOKEN + '</Image>')
            replace_token = DEFAULT_IMAGE_TOKEN
            if data_args.mm_use_im_start_end:
                replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
            sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
    print("sources (final return)\n", sources)
    return sources

In [74]:
import copy
from PIL import Image

# Trainer > def _get_dataloader > dataloader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params))
def __getitem__(self, i) -> Dict[str, torch.Tensor]:

    print("current file path", "llava/train/train.py")
    print("def LazySupervisedDataset.__getitem__(self, i)")
    print("i\n", i) # 0
    sources = self.list_data_dict[i]
    print("sources\n", sources)
    print("【COND】 isinstance(i, int):", isinstance(i, int))
    if isinstance(i, int):
        print("【ENTER】if isinstance(i, int):")
        sources = [sources]
        print("sources (after)\n", sources)
        print("【EXIT】if isinstance(i, int):")
    assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
    print("【COND】 'image' in sources[0]:", 'image' in sources[0])
    if 'image' in sources[0]:
        print("【ENTER】if 'image' in sources[0]:")
        image_file = self.list_data_dict[i]['image']
        print("image_file\n", image_file)
        image_folder = self.data_args.image_folder
        print("image_folder\n", image_folder)
        processor = self.data_args.image_processor
        print("processor\n", processor)
        image_path = os.path.join(image_folder, image_file)
        print("image_path\n", image_path)
        try:
            print("Trying to open image...")
            image = Image.open(image_path).convert('RGB')
            print("Image opened successfully.")
        except Exception as e:
            print(f"Error opening image: {e}")
            # 画像がなければこのサンプルはスキップ
            print("Skipping this sample due to image loading error.")
            return None 
        print("【COND】 self.data_args.image_aspect_ratio", self.data_args.image_aspect_ratio) # square
        if self.data_args.image_aspect_ratio == 'pad':
            pass
        else:
            print("【ENTER】else (self.data_args.image_aspect_ratio != 'pad')")
            print("image (before)\n", image)
            image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            print("image (after processor.preprocess)\n", image)
        print("sources (before preprocess_multimodal)\n", sources)
        sources = preprocess_multimodal(
            copy.deepcopy([e["conversations"] for e in sources]),
            self.data_args)
        print("sources (after preprocess_multimodal)\n", sources)
    else:
        pass

    print("Calling preprocess...")
    data_dict = preprocess(
        sources,
        self.tokenizer,
        has_image=('image' in self.list_data_dict[i]))
    print("data_dict (after preprocess)\n", data_dict)
    print("【COND】 isinstance(i, int):", isinstance(i, int))
    if isinstance(i, int):
        data_dict = dict(input_ids=data_dict["input_ids"][0],
                            labels=data_dict["labels"][0])

    # image exist in the data
    if 'image' in self.list_data_dict[i]:
        data_dict['image'] = image
    elif self.data_args.is_multimodal:
        # image does not exist in the data, but the model is multimodal
        crop_size = self.data_args.image_processor.crop_size
        data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
    return data_dict

In [75]:
LazySupervisedDataset.__getitem__ = __getitem__

In [76]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=False,
)

print("tokenizer\n", tokenizer)

tokenizer
 LlamaTokenizer(name_or_path='TinyLlama/TinyLlama-1.1B-Chat-v1.0', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'pad_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False)}, clean_up_tokenization_spaces=False)




In [77]:
train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args)

current file path llava/train/train.py
def LazySupervisedDataset.__init__(self, data_path, tokenizer, data_args)
data_path
 /workspaces/LLaVA/blip_laion_cc_sbu_1.json
tokenizer
 <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>
data_args
 DataArguments(data_path='/workspaces/LLaVA/blip_laion_cc_sbu_1.json', lazy_preprocess=True, is_multimodal=True, image_folder='/workspaces/LLaVA/images/', image_aspect_ratio='square')
list_data_dict [{'id': '000406392', 'image': 'GCC_train_000406392.jpg', 'conversations': [{'from': 'human', 'value': 'Give a brief description of the image.\n<image>'}, {'from': 'gpt', 'value': 'the divine queen in her elaborate masks canvas print featuring the face and hands of a woman with red hair'}]}]
current file path llava/train/train.py
def rank0_print(*args)
args
 ('Formatting inputs...Skip in lazy mode',)
Formatting inputs...Skip in lazy mode
self.tokenizer
 LlamaTokenizer(name_or_path='TinyLlama/TinyLlama-1.1B-Chat-v1.0', vocab_size=32000, mo

In [78]:
sample_data_dict = train_dataset.__getitem__(0)
print("sample_data_dict\n", sample_data_dict)

current file path llava/train/train.py
def LazySupervisedDataset.__getitem__(self, i)
i
 0
sources
 {'id': '000406392', 'image': 'GCC_train_000406392.jpg', 'conversations': [{'from': 'human', 'value': 'Give a brief description of the image.\n<image>'}, {'from': 'gpt', 'value': 'the divine queen in her elaborate masks canvas print featuring the face and hands of a woman with red hair'}]}
【COND】 isinstance(i, int): True
【ENTER】if isinstance(i, int):
sources (after)
 [{'id': '000406392', 'image': 'GCC_train_000406392.jpg', 'conversations': [{'from': 'human', 'value': 'Give a brief description of the image.\n<image>'}, {'from': 'gpt', 'value': 'the divine queen in her elaborate masks canvas print featuring the face and hands of a woman with red hair'}]}]
【EXIT】if isinstance(i, int):
【COND】 'image' in sources[0]: True
【ENTER】if 'image' in sources[0]:
image_file
 GCC_train_000406392.jpg
image_folder
 /workspaces/LLaVA/images/
processor
 CLIPImageProcessor {
  "crop_size": {
    "height": 336

In [79]:
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

In [80]:
instances = [sample_data_dict]
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
batch = data_collator(instances)
print("batch\n", batch)

current file path llava/train/train.py
def DataCollatorForSupervisedDataset.__call__(self, instances)
instances
 [{'input_ids': tensor([    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
        10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
          411,  2654, 11315,    13]), 'labels': tensor([ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
        10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
          411,  2654, 11315,    13]), 'image': tensor([[[ 0.0325,  0.0325,  0.0325,  ..., -0.7120, -0.3616, -0.1280],
         [ 0.0325,  0.0325,  0.0325,  ..., -0.3908, -0.1718, -0.0259],
         [ 0.0325,  0.0325,  0.0325,  ..., -0.0113,  0.0471,  0.0909],
         ...,
         [-1.0331, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],
         [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],
         [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623]],

        [[ 0.3190,  0.3

In [81]:
images = batch['images']
print("images shape\n", images.shape) # torch.Size([1, 3, 336, 336])
print("images\n", images)

images shape
 torch.Size([1, 3, 336, 336])
images
 tensor([[[[ 0.0325,  0.0325,  0.0325,  ..., -0.7120, -0.3616, -0.1280],
          [ 0.0325,  0.0325,  0.0325,  ..., -0.3908, -0.1718, -0.0259],
          [ 0.0325,  0.0325,  0.0325,  ..., -0.0113,  0.0471,  0.0909],
          ...,
          [-1.0331, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],
          [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],
          [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623]],

         [[ 0.3190,  0.3190,  0.3190,  ..., -0.3864, -0.0112,  0.2139],
          [ 0.3190,  0.3190,  0.3190,  ..., -0.0712,  0.1539,  0.3190],
          [ 0.3190,  0.3190,  0.3190,  ...,  0.2890,  0.3640,  0.4390],
          ...,
          [-1.0167, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017],
          [-1.0317, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017],
          [-1.0317, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017]],

         [[ 0.9656,  0.9656,  0.9656,  ...,  0.0982

In [82]:
def encode_images(self, images):
    print("current file path", "llava/model/llava_arch.py")
    print("def LlavaMetaForCausalLM(ABC).encode_images(self, images)")
    print("images\n", images)
    image_features = self.get_model().get_vision_tower()(images)
    image_features = self.get_model().mm_projector(image_features)
    print("image_features (return) shape\n", image_features.shape)
    print("image_features (return)\n", image_features)
    return image_features

In [83]:
LlavaMetaForCausalLM.encode_images = encode_images

In [84]:
# image_forward_outs から、指定した層の特徴量 (B, 577, 1024) を取り出したのち、パッチ特徴量 (B, 576, 1024) のみを返す。
def feature_select(self, image_forward_outs):

    print("current file path", "llava/llava/model/multimodal_encoder/clip_encoder.py")
    print("def CLIPVisionTower.feature_select(self, image_forward_outs)")
    print("image_forward_outs\n", image_forward_outs) # 24層のtuple
    image_features = image_forward_outs.hidden_states[self.select_layer]
    print("image_features (after select_layer)\n", type(image_features))
    if hasattr(image_features, 'shape'):
        print("image_features.shape\n", image_features.shape) # torch.Size([1, 577, 1024])
    print(f"【COND】 select_feature={self.select_feature}") # patch
    if self.select_feature == 'patch':
        print("【ENTER】if self.select_feature == 'patch':")
        print("original image_features\n", image_features)
        """
        tensor([[[ 0.2236,  0.2432, -0.5938,  ...,  0.4863, -0.5273, -0.2041],
                [-0.0469, -0.1836, -0.0273,  ...,  0.3535,  0.3750,  0.3047],
                [-0.2598,  1.1484,  0.4844,  ...,  0.4961, -0.1719, -0.5117],
                ...,
                [ 1.7188,  0.9688,  0.8828,  ..., -0.2441, -0.8672,  1.3047],
                [ 0.7891, -0.3984,  0.6797,  ..., -0.3594, -0.9922,  0.3164],
                [ 1.5000,  0.6250,  0.3672,  ..., -0.5469, -0.4902,  0.9766]]],
            device='cuda:0', dtype=torch.bfloat16)
        """
        image_features = image_features[:, 1:]
        print("after process\n", image_features)
        """
        tensor([[[-0.0469, -0.1836, -0.0273,  ...,  0.3535,  0.3750,  0.3047],
                [-0.2598,  1.1484,  0.4844,  ...,  0.4961, -0.1719, -0.5117],
                [ 1.0625, -0.0635, -0.3730,  ...,  0.0220,  0.0820,  0.4805],
                ...,
                [ 1.7188,  0.9688,  0.8828,  ..., -0.2441, -0.8672,  1.3047],
                [ 0.7891, -0.3984,  0.6797,  ..., -0.3594, -0.9922,  0.3164],
                [ 1.5000,  0.6250,  0.3672,  ..., -0.5469, -0.4902,  0.9766]]],
            device='cuda:0', dtype=torch.bfloat16)
        """
        print("【EXIT】if self.select_feature == 'patch':")
    elif self.select_feature == 'cls_patch':
        pass
    else:
        pass
    print("selected image_feature shape\n", image_features.shape) 
    print("image_features (return)\n", image_features)
    """
    image_features (return)
    tensor([[[-0.0469, -0.1836, -0.0273,  ...,  0.3535,  0.3750,  0.3047],
            [-0.2598,  1.1484,  0.4844,  ...,  0.4961, -0.1719, -0.5117],
            [ 1.0625, -0.0635, -0.3730,  ...,  0.0220,  0.0820,  0.4805],
            ...,
            [ 1.7188,  0.9688,  0.8828,  ..., -0.2441, -0.8672,  1.3047],
            [ 0.7891, -0.3984,  0.6797,  ..., -0.3594, -0.9922,  0.3164],
            [ 1.5000,  0.6250,  0.3672,  ..., -0.5469, -0.4902,  0.9766]]],
        device='cuda:0', dtype=torch.bfloat16)
    """
    if hasattr(image_features, 'shape'):
        print("image_features.shape\n", image_features.shape) # torch.Size([1, 576, 1024])
    return image_features


In [85]:
CLIPVisionTower.feature_select = feature_select

In [86]:
@torch.no_grad() 
def forward(self, images):

    print("current file path", "llava/llava/model/multimodal_encoder/clip_encoder.py")
    print("def CLIPVisionTower.forward(self, images)")
    print("images shape\n", images.shape) # torch.Size([1, 3, 336, 336])
    print("images\n", images)
    
    if hasattr(images, 'shape'):
        print("images.shape\n", images.shape) # torch.Size([1, 3, 336, 336])
    print(f"【COND】 type_images_is_list={type(images) is list}") # False
    if type(images) is list:
        pass
    else:
        # 【ENTER】
        print("【ENTER】else (type(images) is not list):")
        print("original images\n", images)
        image_forward_outs = self.vision_tower(images.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype), output_hidden_states=True)
        print("after process image_forward_outs\n", type(image_forward_outs)) # 24層のtuple
        image_features = self.feature_select(image_forward_outs).to(images.dtype)
        print("after process image_features\n", type(image_features)) # <class 'torch.Tensor'>
        print("【EXIT】else (type(images) is not list):")

    print("image_features (return)\n", image_features)
    """
    image_features (return)
    tensor([[[-0.0469, -0.1836, -0.0273,  ...,  0.3535,  0.3750,  0.3047],
            [-0.2598,  1.1484,  0.4844,  ...,  0.4961, -0.1719, -0.5117],
            [ 1.0625, -0.0635, -0.3730,  ...,  0.0220,  0.0820,  0.4805],
            ...,
            [ 1.7188,  0.9688,  0.8828,  ..., -0.2441, -0.8672,  1.3047],
            [ 0.7891, -0.3984,  0.6797,  ..., -0.3594, -0.9922,  0.3164],
            [ 1.5000,  0.6250,  0.3672,  ..., -0.5469, -0.4902,  0.9766]]],
        device='cuda:0', dtype=torch.bfloat16)
    """
    if hasattr(image_features, 'shape'):
        print("image_features.shape\n", image_features.shape) # 
    return image_features

In [87]:
CLIPVisionTower.forward = forward

In [None]:
image_features = model.encode_images(images)

current file path llava/model/llava_arch.py
def LlavaMetaForCausalLM(ABC).encode_images(self, images)
images
 tensor([[[[ 0.0325,  0.0325,  0.0325,  ..., -0.7120, -0.3616, -0.1280],
          [ 0.0325,  0.0325,  0.0325,  ..., -0.3908, -0.1718, -0.0259],
          [ 0.0325,  0.0325,  0.0325,  ..., -0.0113,  0.0471,  0.0909],
          ...,
          [-1.0331, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],
          [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],
          [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623]],

         [[ 0.3190,  0.3190,  0.3190,  ..., -0.3864, -0.0112,  0.2139],
          [ 0.3190,  0.3190,  0.3190,  ..., -0.0712,  0.1539,  0.3190],
          [ 0.3190,  0.3190,  0.3190,  ...,  0.2890,  0.3640,  0.4390],
          ...,
          [-1.0167, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017],
          [-1.0317, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017],
          [-1.0317, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0

In [None]:
def prepare_inputs_labels_for_multimodal(
    self, input_ids, position_ids, attention_mask, past_key_values, labels,
    images, image_sizes=None
):
    print("current file path", "llava/model/llava_arch.py")
    """
    llava/llava/model/language_model/llava_llama.py
    """
    print("def LlavaMetaForCausalLM(ABC).prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None)")  # not found
    print("input_ids\n", input_ids)
    """
    tensor([[    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
             10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
               411,  2654, 11315,    13]])
    """

    print("position_ids\n", position_ids)  # None
    print("attention_mask\n", attention_mask)
    """
    tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True]])
    """

    print("past_key_values\n", past_key_values)  # None
    print("labels\n", labels)
    """
    tensor([[ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
             10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
               411,  2654, 11315,    13]])
    """

    print("images\n", images)
    """
    tensor([[[[ 0.0325,  0.0325,  0.0325,  ..., -0.7120, -0.3616, -0.1280],
              [ 0.0325,  0.0325,  0.0325,  ..., -0.3908, -0.1718, -0.0259],
              [ 0.0325,  0.0325,  0.0325,  ..., -0.0113,  0.0471,  0.0909],
              ...,
              [-1.0331, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],
              [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],
              [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623]],
    
             [[ 0.3190,  0.3190,  0.3190,  ..., -0.3864, -0.0112,  0.2139],
              [ 0.3190,  0.3190,  0.3190,  ..., -0.0712,  0.1539,  0.3190],
              [ 0.3190,  0.3190,  0.3190,  ...,  0.2890,  0.3640,  0.4390],
              ...,
              [-1.0167, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017],
              [-1.0317, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017],
              [-1.0317, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017]],
    
             [[ 0.9656,  0.9656,  0.9656,  ...,  0.0982,  0.4537,  0.6670],
              [ 0.9656,  0.9656,  0.9656,  ...,  0.3968,  0.6101,  0.7523],
              [ 0.9656,  0.9656,  0.9656,  ...,  0.7523,  0.8092,  0.8377],
              ...,
              [-0.3711, -0.3853, -0.3995,  ..., -0.4279, -0.4279, -0.4279],
              [-0.3711, -0.3711, -0.3853,  ..., -0.4279, -0.4279, -0.4279],
              [-0.3853, -0.3711, -0.3711,  ..., -0.4279, -0.4279, -0.4279]]]])
    """

    print("image_sizes\n", image_sizes)  # None
    vision_tower = self.get_vision_tower()
    print("vision_tower\n", vision_tower)
    """
    CLIPVisionTower(
      (vision_tower): CLIPVisionModel(
        (vision_model): CLIPVisionTransformer(
          (embeddings): CLIPVisionEmbeddings(
            (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
            (position_embedding): Embedding(577, 1024)
          )
          (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (encoder): CLIPEncoder(
            (layers): ModuleList(
              (0-23): 24 x CLIPEncoderLayer(
                (self_attn): CLIPAttention(
                  (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
                  (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
                )
                (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (mlp): CLIPMLP(
                  (activation_fn): QuickGELUActivation()
                  (fc1): Linear(in_features=1024, out_features=4096, bias=True)
                  (fc2): Linear(in_features=4096, out_features=1024, bias=True)
                )
                (layer_norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              )
            )
          )
          (post_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    """

    print(f"【COND】 vision_tower_is_None={vision_tower is None} images_is_None={images is None} input_ids_shape_1_eq_1={input_ids.shape[1] == 1}")
    if vision_tower is None or images is None or input_ids.shape[1] == 1:
        pass

    print("【COND】type(images)\n", type(images))  # <class 'torch.Tensor'>
    print("【COND】images.ndim\n", images.ndim)  # 4
    if type(images) is list or images.ndim == 5:
        pass
    else:
        # 【ENTER】
        print("【ENTER】else of if type(images) is list or images.ndim == 5:")
        image_features = self.encode_images(images)
        print("image_features after encode_images shape \n", image_features.shape)  # torch.Size([1, 576, 2048])
        print("image_features after encode_images\n", image_features)
        """
        tensor([[[-0.1943,  0.1157, -0.0747,  ...,  0.0027, -0.1691, -0.3439],
                 [ 0.0437,  0.1717, -0.0998,  ...,  0.0930, -0.1386, -0.0731],
                 [-0.0505,  0.1592, -0.0982,  ...,  0.0866, -0.1123, -0.2177],
                 ...,
                 [-0.0182,  0.0850, -0.0556,  ...,  0.0622, -0.1969,  0.0129],
                 [-0.0651,  0.0586, -0.1218,  ..., -0.0614, -0.1158, -0.0104],
                 [ 0.0863,  0.0081, -0.1651,  ..., -0.2040, -0.0455,  0.0618]]],
               grad_fn=<ViewBackward0>)
        """
        print("【EXIT】else of if type(images) is list or images.ndim == 5:")

    # TODO: image start / end is not implemented here to support pretraining.
    if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
        print("【ENTER】if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):")  # not found
        raise NotImplementedError

    # Let's just add dummy tensors if they do not exist,
    # it is a headache to deal with None all the time.
    # But it is not ideal, and if you have a better idea,
    # please open an issue / submit a PR, thanks.

    print("labels before\n", labels)
    """
    tensor([[ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
             10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
               411,  2654, 11315,    13]])
    """

    print("position_ids before\n", position_ids)  # None

    print("attention_mask before\n", attention_mask)
    """
    tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True]])
    """

    _labels = labels
    _position_ids = position_ids
    _attention_mask = attention_mask
    if attention_mask is None:
        pass
    else:
        # 【ENTER】
        print("【ENTER】else of if attention_mask is None:")
        attention_mask = attention_mask.bool()
 
        print("attention_mask（after）shape \n", attention_mask.shape)  # torch.Size([1, 24])
        print("attention_mask (after)\n", attention_mask)
        """
        tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True]])
        """
        print("【EXIT】else of if attention_mask is None:")
    if position_ids is None:
        print("【ENTER】if position_ids is None:")
        position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)

        print("position_ids (after) shape \n", position_ids.shape)  # torch.Size([24])
        print("position_ids (after)\n", position_ids)
        """
        tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
                18, 19, 20, 21, 22, 23])
        """
        print("【EXIT】if position_ids is None:")
    print(f"【COND】 labels_is_None={labels is None}")
    if labels is None:
        pass

    # remove the padding using attention_mask -- FIXME
    _input_ids = input_ids
    input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
    labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
    print("input_ids after removing padding\n", input_ids)
    """
    [tensor([    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
            10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
              411,  2654, 11315,    13])]
    """

    print("labels after removing padding\n", labels)
    """
    [tensor([ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
            10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
              411,  2654, 11315,    13])]
    """


    new_input_embeds = []
    new_labels = []
    cur_image_idx = 0
    for batch_idx, cur_input_ids in enumerate(input_ids):
        print("cur_input_ids shape\n", cur_input_ids.shape)   # torch.Size([24])
        print("cur_input_ids\n", cur_input_ids)
        """
        tensor([    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
                10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
                  411,  2654, 11315,    13])
        """
        num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
        print("【COND】num_images:", num_images)  # tensor(1)
        if num_images == 0:
            print("【ENTER】if num_images == 0:")
            cur_image_features = image_features[cur_image_idx]
            cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
            cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
            new_input_embeds.append(cur_input_embeds)
            new_labels.append(labels[batch_idx])
            cur_image_idx += 1
            print("【EXIT】if num_images == 0:")
            continue

        image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
        print("image_token_indices\n", image_token_indices)  # [-1, 1, 24]
        print("len image_token_indices", len(image_token_indices))   # 3
        cur_input_ids_noim = []
        cur_labels = labels[batch_idx]
        print("cur_labels\n", cur_labels)
        """
        tensor([ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
                10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
                  411,  2654, 11315,    13])
        """
        cur_labels_noim = []
        for i in range(len(image_token_indices) - 1): # 2回ループ。1回目 START から IMAGE_TOKEN_INDEXの手前まで、2回目はIMAGE_TOKEN_INDEX より先から 最後まで
            cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
            cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
        print("cur_input_ids_noim (after)\n", cur_input_ids_noim)
        """
        [tensor([1]), tensor([  278, 25616, 26624,   297,   902, 19430, 11105, 29879, 10508,  1596,
                23425,   278,  3700,   322,  6567,   310,   263,  6114,   411,  2654,
                11315,    13])]
        """
        print("cur_labels_noim (after) \n", cur_labels_noim)
        """
        [tensor([-100]), tensor([  278, 25616, 26624,   297,   902, 19430, 11105, 29879, 10508,  1596,
                23425,   278,  3700,   322,  6567,   310,   263,  6114,   411,  2654,
                11315,    13])]
        """
        split_sizes = [x.shape[0] for x in cur_labels_noim]
        print("split_sizes\n", split_sizes)  # [1, 22]
        cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
        print("cur_input_embeds shape\n", cur_input_embeds.shape)  # torch.Size([23, 2048])
        print("cur_input_embeds\n", cur_input_embeds)
        """
        tensor([[-1.0910e-03,  1.9302e-03, -1.6632e-03,  ...,  1.9932e-04,
                 -6.5231e-04, -4.9973e-04],
                [ 7.0801e-03,  1.0452e-03,  6.0425e-03,  ...,  3.9673e-03,
                  1.2817e-03, -1.1215e-03],
                [-2.2949e-02, -2.6226e-05,  6.8359e-03,  ..., -2.4658e-02,
                 -9.4604e-03,  1.5869e-02],
                ...,
                [ 2.1240e-02, -2.2705e-02, -1.4221e-02,  ..., -2.8229e-03,
                 -8.3618e-03, -9.4604e-03],
                [ 3.7079e-03, -3.6011e-03,  9.0332e-03,  ..., -1.3672e-02,
                 -2.5177e-03, -8.0566e-03],
                [-6.3705e-04, -1.0605e-03, -1.1841e-02,  ...,  2.1935e-04,
                 -7.3242e-04,  2.7924e-03]], requires_grad=True)
        """
        cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
        print("cur_input_embeds_no_im\n", cur_input_embeds_no_im)
        """
        (tensor([[-0.0011,  0.0019, -0.0017,  ...,  0.0002, -0.0007, -0.0005]],
               grad_fn=<SplitWithSizesBackward0>), tensor([[ 7.0801e-03,  1.0452e-03,  6.0425e-03,  ...,  3.9673e-03,
                  1.2817e-03, -1.1215e-03],
                [-2.2949e-02, -2.6226e-05,  6.8359e-03,  ..., -2.4658e-02,
                 -9.4604e-03,  1.5869e-02],
                [-2.3499e-03,  1.4893e-02, -2.0447e-03,  ..., -8.6060e-03,
                  2.3193e-03,  3.0670e-03],
                ...,
                [ 2.1240e-02, -2.2705e-02, -1.4221e-02,  ..., -2.8229e-03,
                 -8.3618e-03, -9.4604e-03],
                [ 3.7079e-03, -3.6011e-03,  9.0332e-03,  ..., -1.3672e-02,
                 -2.5177e-03, -8.0566e-03],
                [-6.3705e-04, -1.0605e-03, -1.1841e-02,  ...,  2.1935e-04,
                 -7.3242e-04,  2.7924e-03]], grad_fn=<SplitWithSizesBackward0>))
        """
        cur_new_input_embeds = []
        cur_new_labels = []

        for i in range(num_images + 1):
            cur_new_input_embeds.append(cur_input_embeds_no_im[i])
            cur_new_labels.append(cur_labels_noim[i])
            print(f"【COND】 i={i} num_images={num_images}")
            if i < num_images:
                print("【ENTER】if i < num_images:")
                cur_image_features = image_features[cur_image_idx]
                cur_image_idx += 1
                cur_new_input_embeds.append(cur_image_features)
                cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
                print("【EXIT】if i < num_images:")

        print("cur_new_input_embeds (before cat) shape\n", [x.shape for x in cur_new_input_embeds])
        """
        [torch.Size([1, 2048]), torch.Size([576, 2048]), torch.Size([22, 2048])]
        """
        print("cur_new_input_embeds (before cat)\n", cur_new_input_embeds)
        """
        [tensor([[-0.0011,  0.0019, -0.0017,  ...,  0.0002, -0.0007, -0.0005]],
               grad_fn=<SplitWithSizesBackward0>), tensor([[-0.1943,  0.1157, -0.0747,  ...,  0.0027, -0.1691, -0.3439],
                [ 0.0437,  0.1717, -0.0998,  ...,  0.0930, -0.1386, -0.0731],
                [-0.0505,  0.1592, -0.0982,  ...,  0.0866, -0.1123, -0.2177],
                ...,
                [-0.0182,  0.0850, -0.0556,  ...,  0.0622, -0.1969,  0.0129],
                [-0.0651,  0.0586, -0.1218,  ..., -0.0614, -0.1158, -0.0104],
                [ 0.0863,  0.0081, -0.1651,  ..., -0.2040, -0.0455,  0.0618]],
               grad_fn=<SelectBackward0>), tensor([[ 7.0801e-03,  1.0452e-03,  6.0425e-03,  ...,  3.9673e-03,
                  1.2817e-03, -1.1215e-03],
                [-2.2949e-02, -2.6226e-05,  6.8359e-03,  ..., -2.4658e-02,
                 -9.4604e-03,  1.5869e-02],
                [-2.3499e-03,  1.4893e-02, -2.0447e-03,  ..., -8.6060e-03,
                  2.3193e-03,  3.0670e-03],
                ...,
                [ 2.1240e-02, -2.2705e-02, -1.4221e-02,  ..., -2.8229e-03,
                 -8.3618e-03, -9.4604e-03],
                [ 3.7079e-03, -3.6011e-03,  9.0332e-03,  ..., -1.3672e-02,
                 -2.5177e-03, -8.0566e-03],
                [-6.3705e-04, -1.0605e-03, -1.1841e-02,  ...,  2.1935e-04,
                 -7.3242e-04,  2.7924e-03]], grad_fn=<SplitWithSizesBackward0>)]
        """

        print("cur_new_labels (before cat) shape\n", [x.shape for x in cur_new_labels])
        """
        [torch.Size([1]), torch.Size([576]), torch.Size([22])]
        """
        print("cur_new_labels (before cat)\n", cur_new_labels)
        """
        [tensor([-100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
                -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]), tensor([  278, 25616, 26624,   297,   902, 19430, 11105, 29879, 10508,  1596,
                23425,   278,  3700,   322,  6567,   310,   263,  6114,   411,  2654,
                11315,    13])]
        """
        cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]

        cur_new_input_embeds = torch.cat(cur_new_input_embeds)
        cur_new_labels = torch.cat(cur_new_labels)

        print("cur_new_input_embeds (after cat) shape\n", cur_new_input_embeds.shape)  # torch.Size([599, 2048])
        print("cur_new_input_embeds (after cat)\n", cur_new_input_embeds)
        """
        tensor([[-1.0910e-03,  1.9302e-03, -1.6632e-03,  ...,  1.9932e-04,
                 -6.5231e-04, -4.9973e-04],
                [-1.9428e-01,  1.1569e-01, -7.4740e-02,  ...,  2.6653e-03,
                 -1.6907e-01, -3.4387e-01],
                [ 4.3680e-02,  1.7172e-01, -9.9813e-02,  ...,  9.3004e-02,
                 -1.3859e-01, -7.3106e-02],
                ...,
                [ 2.1240e-02, -2.2705e-02, -1.4221e-02,  ..., -2.8229e-03,
                 -8.3618e-03, -9.4604e-03],
                [ 3.7079e-03, -3.6011e-03,  9.0332e-03,  ..., -1.3672e-02,
                 -2.5177e-03, -8.0566e-03],
                [-6.3705e-04, -1.0605e-03, -1.1841e-02,  ...,  2.1935e-04,
                 -7.3242e-04,  2.7924e-03]], grad_fn=<CatBackward0>)
        """

        print("cur_new_labels (after cat) shape\n", cur_new_labels.shape)  # torch.Size([599])
        print("cur_new_labels (after cat)\n", cur_new_labels)
        """
        tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,   278, 25616, 26624,
                  297,   902, 19430, 11105, 29879, 10508,  1596, 23425,   278,  3700,
                  322,  6567,   310,   263,  6114,   411,  2654, 11315,    13])
        """

        new_input_embeds.append(cur_new_input_embeds)
        new_labels.append(cur_new_labels)
        print("new_input_embeds (so far) shape\n", [x.shape for x in new_input_embeds])  # [torch.Size([599, 2048])]
        print("new_input_embeds (so far)\n", new_input_embeds)
        """
        [tensor([[-1.0910e-03,  1.9302e-03, -1.6632e-03,  ...,  1.9932e-04,
                 -6.5231e-04, -4.9973e-04],
                [-1.9428e-01,  1.1569e-01, -7.4740e-02,  ...,  2.6653e-03,
                 -1.6907e-01, -3.4387e-01],
                [ 4.3680e-02,  1.7172e-01, -9.9813e-02,  ...,  9.3004e-02,
                 -1.3859e-01, -7.3106e-02],
                ...,
                [ 2.1240e-02, -2.2705e-02, -1.4221e-02,  ..., -2.8229e-03,
                 -8.3618e-03, -9.4604e-03],
                [ 3.7079e-03, -3.6011e-03,  9.0332e-03,  ..., -1.3672e-02,
                 -2.5177e-03, -8.0566e-03],
                [-6.3705e-04, -1.0605e-03, -1.1841e-02,  ...,  2.1935e-04,
                 -7.3242e-04,  2.7924e-03]], grad_fn=<CatBackward0>)]
        """

        print("new_labels (so far) shape\n", [x.shape for x in new_labels])  # [torch.Size([599])]
        print("new_labels (so far)\n", new_labels)
        """
        [tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                 -100,  -100,  -100,  -100,  -100,  -100,  -100,   278, 25616, 26624,
                  297,   902, 19430, 11105, 29879, 10508,  1596, 23425,   278,  3700,
                  322,  6567,   310,   263,  6114,   411,  2654, 11315,    13])]
        """

    # Truncate sequences to max length as image embeddings can make the sequence longer
    tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
    print(f"【COND】 tokenizer_model_max_length_is_not_None={tokenizer_model_max_length is not None}")
    if tokenizer_model_max_length is not None:
        print("【ENTER】if tokenizer_model_max_length is not None:")
        new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
        new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
        print("【EXIT】if tokenizer_model_max_length is not None:")

    # Combine them
    max_len = max(x.shape[0] for x in new_input_embeds)
    print("max_len\n", max_len)  # 599
    batch_size = len(new_input_embeds)
    print("batch_size\n", batch_size)  # 1

    new_input_embeds_padded = []
    new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
    print("new_labels_padded (before) shape\n", new_labels_padded.shape)  # torch.Size([1, 599])
    print("new_labels_padded (before)\n", new_labels_padded)
    """
    tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
             -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]])
    """
    attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
    print("attention_mask (before) shape\n", attention_mask.shape)  # torch.Size([1, 599])
    print("attention_mask (before)\n", attention_mask)
    """
    tensor([[False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False, False,
             False, False, False, False, False, False, False, False, False]])
    """
    position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
    print("position_ids (before) shape\n", position_ids.shape)  # torch.Size([1, 599])
    print("position_ids (before)\n", position_ids)
    """
    tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
             0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
    """

    for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
        cur_len = cur_new_embed.shape[0]
        print(f"【COND】 padding_side={getattr(self.config, 'tokenizer_padding_side', 'right')} cur_len={cur_len} max_len={max_len}")
        if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
            pass
        else:
            print("【ENTER】else (padding_side != 'left'):")
            #【ENTER】
            new_input_embeds_padded.append(torch.cat((
                cur_new_embed,
                torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
            ), dim=0))
            if cur_len > 0:
                # :cur_len に、代入
                new_labels_padded[i, :cur_len] = cur_new_labels 
                attention_mask[i, :cur_len] = True
                position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
            print("new_input_embeds_padded (so far) shape\n", [x.shape for x in new_input_embeds_padded])  # [torch.Size([599, 2048])]
            print("new_input_embeds_padded (so far)\n", new_input_embeds_padded)
            """
            [tensor([[-1.0910e-03,  1.9302e-03, -1.6632e-03,  ...,  1.9932e-04,
                     -6.5231e-04, -4.9973e-04],
                    [-1.9428e-01,  1.1569e-01, -7.4740e-02,  ...,  2.6653e-03,
                     -1.6907e-01, -3.4387e-01],
                    [ 4.3680e-02,  1.7172e-01, -9.9813e-02,  ...,  9.3004e-02,
                     -1.3859e-01, -7.3106e-02],
                    ...,
                    [ 2.1240e-02, -2.2705e-02, -1.4221e-02,  ..., -2.8229e-03,
                     -8.3618e-03, -9.4604e-03],
                    [ 3.7079e-03, -3.6011e-03,  9.0332e-03,  ..., -1.3672e-02,
                     -2.5177e-03, -8.0566e-03],
                    [-6.3705e-04, -1.0605e-03, -1.1841e-02,  ...,  2.1935e-04,
                     -7.3242e-04,  2.7924e-03]], grad_fn=<CatBackward0>)]
            """

            print("new_labels_padded (so far) shape\n", new_labels_padded.shape)  # torch.Size([1, 599])
            print("new_labels_padded (so far)\n", new_labels_padded)
            """
            tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                      -100,  -100,  -100,  -100,  -100,  -100,  -100,   278, 25616, 26624,
                       297,   902, 19430, 11105, 29879, 10508,  1596, 23425,   278,  3700,
                       322,  6567,   310,   263,  6114,   411,  2654, 11315,    13]])
            """

            print("attention_mask (so far) shape\n", attention_mask.shape)  # torch.Size([1, 599])
            print("attention_mask (so far)\n", attention_mask)
            """
            tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True, True,
                     True, True, True, True, True, True, True, True, True, True, True]])
            """

            print("position_ids (so far) shape\n", position_ids.shape)  # torch.Size([1, 599])
            print("position_ids (so far)\n", position_ids)
            """
            tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
                      14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
                      28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
                      42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
                      56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
                      70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
                      84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
                      98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
                     112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
                     126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
                     140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
                     154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
                     168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181,
                     182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195,
                     196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,
                     210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
                     224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
                     238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251,
                     252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265,
                     266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279,
                     280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293,
                     294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307,
                     308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321,
                     322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335,
                     336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349,
                     350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363,
                     364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377,
                     378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391,
                     392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405,
                     406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419,
                     420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433,
                     434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447,
                     448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461,
                     462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475,
                     476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489,
                     490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503,
                     504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517,
                     518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531,
                     532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545,
                     546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559,
                     560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573,
                     574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587,
                     588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598]])
            """
            print("【EXIT】else (padding_side != 'left'):")

    new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
    print("new_input_embeds (after) shape\n", new_input_embeds.shape)  # torch.Size([1, 599, 2048])
    print("new_input_embeds (after)\n", new_input_embeds)
    """
    tensor([[[-1.0910e-03,  1.9302e-03, -1.6632e-03,  ...,  1.9932e-04,
              -6.5231e-04, -4.9973e-04],
             [-1.9428e-01,  1.1569e-01, -7.4740e-02,  ...,  2.6653e-03,
              -1.6907e-01, -3.4387e-01],
             [ 4.3680e-02,  1.7172e-01, -9.9813e-02,  ...,  9.3004e-02,
              -1.3859e-01, -7.3106e-02],
             ...,
             [ 2.1240e-02, -2.2705e-02, -1.4221e-02,  ..., -2.8229e-03,
              -8.3618e-03, -9.4604e-03],
             [ 3.7079e-03, -3.6011e-03,  9.0332e-03,  ..., -1.3672e-02,
              -2.5177e-03, -8.0566e-03],
             [-6.3705e-04, -1.0605e-03, -1.1841e-02,  ...,  2.1935e-04,
              -7.3242e-04,  2.7924e-03]]], grad_fn=<StackBackward0>)
    """

    print(f"【COND】 _labels_is_None={_labels is None}") 
    if _labels is None:
        #【SKIP】
        print("【ENTER】if _labels is None:")
        new_labels = None
        print("【EXIT】if _labels is None:")
    else:
        # 【ENTER】
        print("【ENTER】else of if _labels is None:")
        new_labels = new_labels_padded
        print("new_labels (after)\n", new_labels)
        """
        tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
                  -100,  -100,  -100,  -100,  -100,  -100,  -100,   278, 25616, 26624,
                   297,   902, 19430, 11105, 29879, 10508,  1596, 23425,   278,  3700,
                   322,  6567,   310,   263,  6114,   411,  2654, 11315,    13]])
        """
        print("【EXIT】else of if _labels is None:")

    print(f"【COND】 _attention_mask_is_None={_attention_mask is None}") 
    if _attention_mask is None:
        # 【SKIP】
        print("【ENTER】if _attention_mask is None:")
        attention_mask = None
        print("【EXIT】if _attention_mask is None:")
    else:
        # 【ENTER】
        print("【ENTER】else of if _attention_mask is None:")
        attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
        print("attention_mask (after)2\n", attention_mask)
        """
        tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True, True,
                 True, True, True, True, True, True, True, True, True, True, True]])
        """

    print(f"【COND】 _position_ids_is_None={_position_ids is None}")
    if _position_ids is None:
        print("【ENTER】if _position_ids is None:")
        position_ids = None
        print("【EXIT】if _position_ids is None:")

    print("position_ids (return)\n", position_ids)  # None
    print("attention_mask (return)\n", attention_mask)
    """
    tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True, True,
             True, True, True, True, True, True, True, True, True, True, True]])
    """
    print("past_key_values (return)\n", past_key_values)  # None
    print("new_input_embeds (return)\n", new_input_embeds)
    """
    tensor([[[-1.0910e-03,  1.9302e-03, -1.6632e-03,  ...,  1.9932e-04,
              -6.5231e-04, -4.9973e-04],
             [-1.9428e-01,  1.1569e-01, -7.4740e-02,  ...,  2.6653e-03,
              -1.6907e-01, -3.4387e-01],
             [ 4.3680e-02,  1.7172e-01, -9.9813e-02,  ...,  9.3004e-02,
              -1.3859e-01, -7.3106e-02],
             ...,
             [ 2.1240e-02, -2.2705e-02, -1.4221e-02,  ..., -2.8229e-03,
              -8.3618e-03, -9.4604e-03],
             [ 3.7079e-03, -3.6011e-03,  9.0332e-03,  ..., -1.3672e-02,
              -2.5177e-03, -8.0566e-03],
             [-6.3705e-04, -1.0605e-03, -1.1841e-02,  ...,  2.1935e-04,
              -7.3242e-04,  2.7924e-03]]], grad_fn=<StackBackward0>)
    """
    print("new_labels (return)\n", new_labels)
    """
    tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
              -100,  -100,  -100,  -100,  -100,  -100,  -100,   278, 25616, 26624,
               297,   902, 19430, 11105, 29879, 10508,  1596, 23425,   278,  3700,
               322,  6567,   310,   263,  6114,   411,  2654, 11315,    13]])
    """
    return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels

In [None]:
LlavaMetaForCausalLM.prepare_inputs_labels_for_multimodal = prepare_inputs_labels_for_multimodal

In [None]:
print("model_device\n", model.device) # cpu

input_ids = batch['input_ids'].to(device=model.device)
print("input_ids shape\n", input_ids.shape) # torch.Size([1, 24])
print("input_ids\n", input_ids)
"""
tensor([[    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
         10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
           411,  2654, 11315,    13]])
"""

attention_mask = batch['attention_mask'].to(device=model.device)
print("attention_mask shape\n", attention_mask.shape) # torch.Size([1, 24])
print("attention_mask\n", attention_mask)
"""
tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True]])
"""

labels = batch['labels'].to(device=model.device)
print("labels shape\n", labels.shape) # torch.Size([1, 3, 336, 336])
print("labels\n", labels)
"""
tensor([[ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
         10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
           411,  2654, 11315,    13]])
"""

images = batch['images'].to(device=model.device)
print("images shape\n", images.shape)
print("images\n", images)
"""
tensor([[[[ 0.0325,  0.0325,  0.0325,  ..., -0.7120, -0.3616, -0.1280],
          [ 0.0325,  0.0325,  0.0325,  ..., -0.3908, -0.1718, -0.0259],
          [ 0.0325,  0.0325,  0.0325,  ..., -0.0113,  0.0471,  0.0909],
          ...,
          [-1.0331, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],
          [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],
          [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623]],

         [[ 0.3190,  0.3190,  0.3190,  ..., -0.3864, -0.0112,  0.2139],
          [ 0.3190,  0.3190,  0.3190,  ..., -0.0712,  0.1539,  0.3190],
          [ 0.3190,  0.3190,  0.3190,  ...,  0.2890,  0.3640,  0.4390],
          ...,
          [-1.0167, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017],
          [-1.0317, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017],
          [-1.0317, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017]],

         [[ 0.9656,  0.9656,  0.9656,  ...,  0.0982,  0.4537,  0.6670],
          [ 0.9656,  0.9656,  0.9656,  ...,  0.3968,  0.6101,  0.7523],
          [ 0.9656,  0.9656,  0.9656,  ...,  0.7523,  0.8092,  0.8377],
          ...,
          [-0.3711, -0.3853, -0.3995,  ..., -0.4279, -0.4279, -0.4279],
          [-0.3711, -0.3711, -0.3853,  ..., -0.4279, -0.4279, -0.4279],
          [-0.3853, -0.3711, -0.3711,  ..., -0.4279, -0.4279, -0.4279]]]])
"""


model_device
 cpu
input_ids shape
 torch.Size([1, 24])
input_ids
 tensor([[    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
         10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
           411,  2654, 11315,    13]])
attention_mask shape
 torch.Size([1, 24])
attention_mask
 tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True]])
labels shape
 torch.Size([1, 24])
labels
 tensor([[ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
         10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
           411,  2654, 11315,    13]])
images shape
 torch.Size([1, 3, 336, 336])
images
 tensor([[[[ 0.0325,  0.0325,  0.0325,  ..., -0.7120, -0.3616, -0.1280],
          [ 0.0325,  0.0325,  0.0325,  ..., -0.3908, -0.1718, -0.0259],
          [ 0.0325,  0.0325,  0.0325,  ..., -0.0113,  0.0471,  0.0909],
         

'\ntensor([[[[ 0.0325,  0.0325,  0.0325,  ..., -0.7120, -0.3616, -0.1280],\n          [ 0.0325,  0.0325,  0.0325,  ..., -0.3908, -0.1718, -0.0259],\n          [ 0.0325,  0.0325,  0.0325,  ..., -0.0113,  0.0471,  0.0909],\n          ...,\n          [-1.0331, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],\n          [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623],\n          [-1.0477, -1.0331, -1.0331,  ..., -1.0623, -1.0623, -1.0623]],\n\n         [[ 0.3190,  0.3190,  0.3190,  ..., -0.3864, -0.0112,  0.2139],\n          [ 0.3190,  0.3190,  0.3190,  ..., -0.0712,  0.1539,  0.3190],\n          [ 0.3190,  0.3190,  0.3190,  ...,  0.2890,  0.3640,  0.4390],\n          ...,\n          [-1.0167, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017],\n          [-1.0317, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017],\n          [-1.0317, -1.0167, -1.0167,  ..., -1.0017, -1.0017, -1.0017]],\n\n         [[ 0.9656,  0.9656,  0.9656,  ...,  0.0982,  0.4537,  0.6670],\n          

In [None]:
position_ids = None
past_key_values = None
image_sizes = None

In [None]:
"""
from print_factory.print_factory import run_and_capture, embed_print_outputs

logs, mapping = run_and_capture(
    model.prepare_inputs_labels_for_multimodal,
    input_ids,
    position_ids,
    attention_mask,
    past_key_values,
    labels,
    images,
    image_sizes
)

with open("print_factory/original_code.py") as f:
    code = f.read()

new_code = embed_print_outputs(code, mapping)

with open("print_factory/after_code.py", "w") as f:
    f.write(new_code)

print("done")
"""

done


In [None]:
(
    input_ids,
    position_ids,
    attention_mask,
    past_key_values,
    inputs_embeds,
    labels
) = model.prepare_inputs_labels_for_multimodal(
    input_ids,
    position_ids,
    attention_mask,
    past_key_values,
    labels,
    images,
    image_sizes
)

current file path llava/model/llava_arch.py
def LlavaMetaForCausalLM(ABC).prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None)
input_ids
 tensor([[    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
         10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
           411,  2654, 11315,    13]])
position_ids
 None
attention_mask
 tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True]])
past_key_values
 None
labels
 tensor([[ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
         10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
           411,  2654, 11315,    13]])
images
 tensor([[[[ 0.0325,  0.0325,  0.0325,  ..., -0.7120, -0.3616, -0.1280],
          [ 0.0325,  0.0325,  0.0325,  ..., -0.3908, -0.1718, -0.0259],
     

after process image_forward_outs
 <class 'transformers.modeling_outputs.BaseModelOutputWithPooling'>
current file path llava/llava/model/multimodal_encoder/clip_encoder.py
def CLIPVisionTower.feature_select(self, image_forward_outs)
image_forward_outs
 BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.0879, -0.1445, -0.2949,  ...,  0.3906, -0.7188,  0.0356],
         [ 0.4629,  0.0078,  0.0859,  ..., -0.0815,  0.1113, -0.3906],
         [-0.2188,  1.0938,  1.0469,  ...,  0.2119, -0.3477, -0.8477],
         ...,
         [ 1.7500,  1.0859,  1.6719,  ..., -0.2891, -0.6797,  1.1016],
         [ 0.9141, -0.1050,  1.3594,  ..., -0.4141, -1.2031,  0.4043],
         [ 1.4531,  1.1641,  1.0156,  ..., -0.3867, -0.6094,  0.6719]]],
       dtype=torch.bfloat16), pooler_output=tensor([[ 0.1738, -0.1426, -0.6133,  ...,  0.9453, -1.4844,  0.0967]],
       dtype=torch.bfloat16), hidden_states=(tensor([[[ 0.0342, -0.0408, -0.1670,  ...,  0.3203, -0.1475, -0.0201],
         [-0.1187, -0.0493, -

In [None]:
def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    images: Optional[torch.FloatTensor] = None,
    image_sizes: Optional[List[List[int]]] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:

    print("current file path", "llava/llava/model/language_model/llava_llama.py")
    print("def LlavaLlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, images, image_sizes, return_dict)")
    print("input_ids\n", input_ids)
    """
    tensor([[    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
            10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
            411,  2654, 11315,    13]], device='cuda:0')        
    """
    if hasattr(input_ids, 'shape'):
        print("input_ids.shape\n", input_ids.shape) # torch.Size([1, 24])
    print("attention_mask\n", attention_mask)
    """
    tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
            True, True, True, True, True, True, True, True, True, True, True, True]],
        device='cuda:0')
    """
    print("position_ids\n", position_ids) # None
    print("past_key_values\n", past_key_values) # None
    print("inputs_embeds\n", inputs_embeds) # None
    if hasattr(inputs_embeds, 'shape'):
        print("inputs_embeds.shape\n", inputs_embeds.shape)
    print("labels\n", labels)
    """
    tensor([[ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
            10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
            411,  2654, 11315,    13]], device='cuda:0')
    """
    print("use_cache\n", use_cache) # None
    print("output_attentions\n", output_attentions) # None
    print("output_hidden_states\n", output_hidden_states) # None
    print("images\n", images)
    """
    tensor([[[[ 0.0325,  0.0325,  0.0325,  ..., -0.7109, -0.3613, -0.1279],
            [ 0.0325,  0.0325,  0.0325,  ..., -0.3906, -0.1719, -0.0259],
            [ 0.0325,  0.0325,  0.0325,  ..., -0.0112,  0.0471,  0.0908],
            ...,
            [-1.0312, -1.0312, -1.0312,  ..., -1.0625, -1.0625, -1.0625],
            [-1.0469, -1.0312, -1.0312,  ..., -1.0625, -1.0625, -1.0625],
            [-1.0469, -1.0312, -1.0312,  ..., -1.0625, -1.0625, -1.0625]],

            [[ 0.3184,  0.3184,  0.3184,  ..., -0.3867, -0.0112,  0.2139],
            [ 0.3184,  0.3184,  0.3184,  ..., -0.0713,  0.1543,  0.3184],
            [ 0.3184,  0.3184,  0.3184,  ...,  0.2891,  0.3633,  0.4395],
            ...,
            [-1.0156, -1.0156, -1.0156,  ..., -1.0000, -1.0000, -1.0000],
            [-1.0312, -1.0156, -1.0156,  ..., -1.0000, -1.0000, -1.0000],
            [-1.0312, -1.0156, -1.0156,  ..., -1.0000, -1.0000, -1.0000]],

            [[ 0.9648,  0.9648,  0.9648,  ...,  0.0981,  0.4531,  0.6680],
            [ 0.9648,  0.9648,  0.9648,  ...,  0.3965,  0.6094,  0.7539],
            [ 0.9648,  0.9648,  0.9648,  ...,  0.7539,  0.8086,  0.8359],
            ...,
            [-0.3711, -0.3848, -0.4004,  ..., -0.4277, -0.4277, -0.4277],
            [-0.3711, -0.3711, -0.3848,  ..., -0.4277, -0.4277, -0.4277],
            [-0.3848, -0.3711, -0.3711,  ..., -0.4277, -0.4277, -0.4277]]]],
        device='cuda:0', dtype=torch.bfloat16)
    """
    if hasattr(images, 'shape'):
        print("images.shape\n", images.shape) # torch.Size([1, 3, 336, 336])
    print("image_sizes\n", image_sizes) # None
    print("return_dict\n", return_dict) # None

    print(f"【COND】 inputs_embeds_is_None={inputs_embeds is None}") # True
    if inputs_embeds is None:
        # 【ENTER】
        print("【ENTER】if inputs_embeds is None:")
        (
            input_ids,
            position_ids,
            attention_mask,
            past_key_values,
            inputs_embeds,
            labels
        ) = self.prepare_inputs_labels_for_multimodal(
            input_ids,
            position_ids,
            attention_mask,
            past_key_values,
            labels,
            images,
            image_sizes
        )
        print("【EXIT】if inputs_embeds is None:")

    print("input_ids (after prepare_inputs_labels_for_multimodal)\n", input_ids)

    print("position_ids (after prepare_inputs_labels_for_multimodal)\n", position_ids)

    print("attention_mask shape (after prepare_inputs_labels_for_multimodal)\n", attention_mask.shape)
    print("attention_mask (after prepare_inputs_labels_for_multimodal)\n", attention_mask)


    print("past_key_values (after prepare_inputs_labels_for_multimodal)\n", past_key_values)

    print("inputs_embeds shape (after prepare_inputs_labels_for_multimodal)\n", None if inputs_embeds is None else inputs_embeds.shape)
    print("inputs_embeds (after prepare_inputs_labels_for_multimodal)\n", inputs_embeds)

    print("labels shape (after prepare_inputs_labels_for_multimodal)\n", labels.shape)
    print("labels (after prepare_inputs_labels_for_multimodal)\n", labels)

    #  LlamaForCausalLM.forward(self, ...)で明示
    # Trainer > def train > def inner_training_loop > def training_step > model(**inputs) > model.forward
    result = LlamaForCausalLM.forward(
        self,
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        labels=labels,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict
    )
    print("Return of def LlavaLlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, images, image_sizes, return_dict)")
    #print("result of LlavaLlamaForCausalLM.forward (return)\n", result)
    print("logits tensor shape  LlavaLlamaForCausalLM.forward\n", result.logits.shape)
    print("logits tensor (first 10 tokens)  LlavaLlamaForCausalLM.forward\n", result.logits[0, :10, :])
    print("loss (return)  LlavaLlamaForCausalLM.forward \n", result.loss)
    """
    CausalLMOutputWithPast(loss=tensor(7.3094, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[[  0.8242,   0.1855,  -0.7031,  ...,   1.6719,   2.6719,   1.1875],
            [ -9.0000,  -2.1562,   8.9375,  ...,  -6.4375,  -6.9688,  -5.9375],
            [-12.4375,  -7.8750,   3.5625,  ..., -10.2500, -10.3750, -11.1875],
            ...,
            [ -6.7812,  -3.1406,   4.2188,  ...,  -4.6562,  -3.5312,  -4.8750],
            [ -7.5312,  -4.7188,   4.1562,  ...,  -4.6250,  -4.5625,  -5.5000],
            [ -4.3438,  -0.9023,   2.0625,  ...,  -3.5312,  -4.0625,  -2.5469]]],
        device='cuda:0', grad_fn=<ToCopyBackward0>), past_key_values=None, hidden_states=None, attentions=None)
    """
    return result

NameError: name 'torch' is not defined

In [None]:
LlavaLlamaForCausalLM.forward = forward

NameError: name 'forward' is not defined

In [None]:
inputs = {
    "input_ids": batch['input_ids'].to(device=model.device),
    "attention_mask": batch['attention_mask'].to(device=model.device),
    "labels": batch['labels'].to(device=model.device),
    "images": batch['images'].to(device=model.device),
    "position_ids": None,
    "past_key_values": None,
    "image_sizes": None,
}

In [None]:
outputs = model(**inputs)

current file path llava/llava/model/language_model/llava_llama.py
def LlavaLlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, images, image_sizes, return_dict)
input_ids
 tensor([[    1,  -200,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
         10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
           411,  2654, 11315,    13]])
input_ids.shape
 torch.Size([1, 24])
attention_mask
 tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True]])
position_ids
 None
past_key_values
 None
inputs_embeds
 None
labels
 tensor([[ -100,  -100,   278, 25616, 26624,   297,   902, 19430, 11105, 29879,
         10508,  1596, 23425,   278,  3700,   322,  6567,   310,   263,  6114,
           411,  2654, 11315,    13]])
use_cache
 None
output_attentions
 N

In [None]:
logs, mapping = run_and_capture(
    model.forward,
    **inputs
)

with open("print_factory/original_code.py") as f:
    code = f.read()

new_code = embed_print_outputs(code, mapping)

with open("print_factory/after_code.py", "w") as f:
    f.write(new_code)

print("done")

ModuleNotFoundError: No module named 'print_factory.print_factory'; 'print_factory' is not a package

In [None]:
def maybe_zero_3(param, ignore_status=False, name=None):

    print("current file path", "llava/train/llava_trainer.py")
    print("def maybe_zero_3(param, ignore_status=False, name=None)")
    print("param maybe_zero_3\n", param)
    print("ignore_status maybe_zero_3\n", ignore_status)
    print("name maybe_zero_3\n", name)
    from deepspeed import zero
    from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
    print(f"【COND】 hasattr_ds_id={hasattr(param, 'ds_id')}")
    if hasattr(param, "ds_id"):
        print("【ENTER】if hasattr(param, 'ds_id'):")
        print(f"【COND】 ds_status={getattr(param, 'ds_status', None)}, ignore_status={ignore_status}")
        if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
            print("【ENTER】if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:")
            print(f"【COND】 ignore_status={ignore_status}")
            if not ignore_status:
                print("【ENTER】if not ignore_status:")
                print(name, 'no ignore status')
                print("【EXIT】if not ignore_status:")
            print("【EXIT】if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:")
        with zero.GatheredParameters([param]):
            param = param.data.detach().cpu().clone()
            print("param (after GatheredParameters)\n", param)
        print("【EXIT】if hasattr(param, 'ds_id'):")
    else:
        print("【ENTER】else (not hasattr(param, 'ds_id')):")
        param = param.detach().cpu().clone()
        print("param (after else)\n", param)
        print("【EXIT】else (not hasattr(param, 'ds_id')):")
    print("param (def maybe_zero_3 at llava_trainer.py return)\n", param)
    return param

In [None]:
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):

    print("current file path", "llava/train/llava_trainer.py")
    print("def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match)")
    print("named_params get_mm_adapter_state_maybe_zero_3\n", named_params)
    print("keys_to_match get_mm_adapter_state_maybe_zero_3\n", keys_to_match)
    to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
    to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
    print("to_return def get_mm_adapter_state_maybe_zero_3 \n", to_return)
    return to_return

In [None]:
def _save_checkpoint(self, model, trial, metrics=None):

    print("current file path", "llava/train/llava_trainer.py")
    print("def _save_checkpoint(self, model, trial, metrics=None)")
    print("self\n", self) # <llava.train.llava_trainer.LLaVATrainer object at 0x7ed6341f4490>
    print("model\n", model)

    print("trial\n", trial) # None
    print("metrics\n", metrics) # None
    print(f"【COND】 tune_mm_mlp_adapter={getattr(self.args, 'tune_mm_mlp_adapter', False)}") # True
    if getattr(self.args, 'tune_mm_mlp_adapter', False):
        # 【ENTER】
        print("【ENTER】if getattr(self.args, 'tune_mm_mlp_adapter', False):")
        from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
        print("checkpoint_folder = f\"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}\"\n", checkpoint_folder)

        run_dir = self._get_output_dir(trial=trial)
        print("run_dir = self._get_output_dir(trial=trial)", run_dir)
        output_dir = os.path.join(run_dir, checkpoint_folder)
        print("output_dir = os.path.join(run_dir, checkpoint_folder)", output_dir)

        # Only save Adapter
        keys_to_match = ['mm_projector', 'vision_resampler']
        print(f"【COND】 use_im_start_end={getattr(self.args, 'use_im_start_end', False)}") # False
        if getattr(self.args, "use_im_start_end", False):
            pass

        weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)

        print(f"【COND】 local_rank={self.args.local_rank}") # 0
        if self.args.local_rank == 0 or self.args.local_rank == -1:
            # 【ENTER】
            print("【ENTER】if self.args.local_rank == 0 or self.args.local_rank == -1:")
            self.model.config.save_pretrained(output_dir)
            torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
            print("【EXIT】if self.args.local_rank == 0 or self.args.local_rank == -1:")
        print("【EXIT】if getattr(self.args, 'tune_mm_mlp_adapter', False):")
    else:
        pass

In [None]:
def train():

    print("current file path", "llava/train/train.py")
    print("def train()")
    global local_rank

    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments))
    print("original parser\n", parser)
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    print("model_args\n", model_args)
    print("data_args\n", data_args)
    print("training_args\n", training_args)
    local_rank = training_args.local_rank
    print("local_rank\n", local_rank)
    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
    print("compute_dtype\n", compute_dtype)
    bnb_model_from_pretrained_args = {}
    print("bnb_model_from_pretrained_args\n", bnb_model_from_pretrained_args)
    # 【SKIP】bfloat16 なので 以下の if 文はスキップされる
    print(f"【COND】 bits={training_args.bits}")
    if training_args.bits in [4, 8]:
      pass

    print(f"【COND】 vision_tower={model_args.vision_tower}")
    # 【ENTER】 vision_tower=openai/clip-vit-large-patch14-336 なので、この分岐に入る
    if model_args.vision_tower is not None:
        print("【ENTER】if model_args.vision_tower is not None:")
        print(f"【COND】 mpt_in_model_name_or_path={'mpt' in model_args.model_name_or_path}")
        #【SKIP】model_args.model_name_or_path に mptは含まれていないので、この分岐はskipされる
        if 'mpt' in model_args.model_name_or_path:
          pass

        #【ENTER】 model_args.model_name_or_path に mptは含まれていないので、この分岐に入る
        else:
            print("【COND】 not_mpt_in_model_name_or_path={'mpt' not in model_args.model_name_or_path}")
            print("【ENTER】else of if 'mpt' in model_args.model_name_or_path:")
            # PreTrainedModel.from_pretrained
            model = LlavaLlamaForCausalLM.from_pretrained(
                model_args.model_name_or_path,
                cache_dir=training_args.cache_dir,
                **bnb_model_from_pretrained_args
            )
            print("model defined as LlavaLlamaForCausalLM \n", model)
            print("【EXIT】else of if 'mpt' in model_args.model_name_or_path:")
        print("【EXIT】if model_args.vision_tower is not None:")
    # 【SKIP】 vision_tower=clip-vit-large-patch14-336 なので、この分岐には入らない
    else:
      pass

    print(f"【COND】 freeze_backbone={model_args.freeze_backbone}")
    # 【SKIP】 freeze_backbone=False なので、この分岐はskipされる
    if model_args.freeze_backbone:
        pass

    # 【SKIP】 bfloat16 なので 以下の if 文はスキップされる
    print(f"【COND】 bits={training_args.bits}")
    if training_args.bits in [4, 8]:
      pass

    print(f"【COND】 gradient_checkpointing={training_args.gradient_checkpointing}")
    # 【ENTER】 gradient_checkpointing=True なので、この分岐に入る
    if training_args.gradient_checkpointing:
        print("【ENTER】if training_args.gradient_checkpointing:")
        print(f"【COND】 has_enable_input_require_grads={hasattr(model, 'enable_input_require_grads')}")
        # 【ENTER】 model に enable_input_require_grads メソッドがあるので、この分岐に入る
        if hasattr(model, "enable_input_require_grads"):
            print("【ENTER】if hasattr(model, 'enable_input_require_grads'):")
            # PreTrainedModel.enable_input_require_grads
            # 元々 全ての重みについて True
            model.enable_input_require_grads()
            print("【EXIT】if hasattr(model, 'enable_input_require_grads'):")
        # 【SKIP】 model に enable_input_require_grads メソッドがあるので、この分岐はskipされる
        else:
          pass

        print("【EXIT】if training_args.gradient_checkpointing:")

    print(f"【COND】 lora_enable={training_args.lora_enable}")
    # 【SKIP】 lora_enable=False なので、この分岐はskipされる
    if training_args.lora_enable:
      pass

    print(f"【COND】 mpt_in_model_name_or_path={'mpt' in model_args.model_name_or_path}")
    # 【SKIP】model_args.model_name_or_path に mptは含まれていないので、この分岐はskipされる
    if 'mpt' in model_args.model_name_or_path:
      pass

    #【ENTER】 model_args.model_name_or_path に mptは含まれていないので、この分岐に入る
    else:
        print("【COND】 not_mpt_in_model_name_or_path={'mpt' not in model_args.model_name_or_path}")
        print("【ENTER】else of if 'mpt' in model_args.model_name_or_path:")
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=False,
        )
        print("tokenizer defined by AutoTokenizer.from_pretrained \n", tokenizer)
        print("【EXIT】else of if 'mpt' in model_args.model_name_or_path:")

    print(f"【COND】 version={model_args.version}")
    # 【SKIP】 version=plain なので、この分岐はskipされる
    if model_args.version == "v0":
      pass

    # 【SKIP】 version=plain なので、この分岐はskipされる
    elif model_args.version == "v0.5":
      pass
    # 【ENTER】 version=plain なので、この分岐に入る
    else:
        print("【ENTER】else of if model_args.version == 'v0' and elif 'v0.5':")
        tokenizer.pad_token = tokenizer.unk_token
        print(f"【COND】 version_in_conv_templates={model_args.version in conv_templates}")
        # 【ENTER】 model_args.version=plain は conversation_lib.conv_templates に含まれている（"plain": conv_llava_plain）ので、この分岐に入る
        if model_args.version in conv_templates:
            print("【ENTER】if model_args.version in conversation_lib.conv_templates:")
            default_conversation = conv_templates[model_args.version]
            print(f"conversation_lib.default_conversation set to {model_args.version}")
            print("【EXIT】if model_args.version in conversation_lib.conv_templates:")
        # 【SKIP】 model_args.version=plain は conversation_lib.conv_templates に含まれているので、この分岐はskipされる
        else:
          pass
        print("【EXIT】else of if model_args.version == 'v0' and elif 'v0.5':")

    print(f"【COND】 vision_tower={model_args.vision_tower}")
    # 【ENTER】 vision_tower=openai/clip-vit-large-patch14-336 なので、この分岐に入る
    if model_args.vision_tower is not None:
        print("【ENTER】if model_args.vision_tower is not None:")
        model.get_model().initialize_vision_modules(
            model_args=model_args,
            fsdp=training_args.fsdp
        )

        vision_tower = model.get_vision_tower()
        vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)

        data_args.image_processor = vision_tower.image_processor
        data_args.is_multimodal = True

        model.config.image_aspect_ratio = data_args.image_aspect_ratio
        model.config.tokenizer_padding_side = tokenizer.padding_side
        model.config.tokenizer_model_max_length = tokenizer.model_max_length

        model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
        print(f"【COND】 tune_mm_mlp_adapter={model_args.tune_mm_mlp_adapter}") # True
        if model_args.tune_mm_mlp_adapter:
            # 【ENTER】 tune_mm_mlp_adapter=True なので、この分岐に入る
            print("【ENTER】if model_args.tune_mm_mlp_adapter:")
            # モデル全体の全パラメータを「学習不可（requires_grad=False）」にする
            # これで通常の重みは全て凍結される
            model.requires_grad_(False)
            for p in model.get_model().mm_projector.parameters():
                # mm_projector（画像特徴量→テキスト特徴量への変換層）の全パラメータだけを「学習可能（requires_grad=True）」に戻す
                # これで mm_projector のみ学習されることになる
                print("model.get_model().mm_projector.parameters()", model.get_model().mm_projector.parameters())
                p.requires_grad = True
            print("【EXIT】if model_args.tune_mm_mlp_adapter:")

        model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
        print(f"【COND】 freeze_mm_mlp_adapter={training_args.freeze_mm_mlp_adapter}") # False
        if training_args.freeze_mm_mlp_adapter:
          pass

        print(f"【COND】 bits={training_args.bits}") # 16
        if training_args.bits in [4, 8]:
          pass

        model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
        print("model_args.mm_use_im_start_end", model_args.mm_use_im_start_end)
        model.config.mm_projector_lr = training_args.mm_projector_lr
        print("training_args.mm_projector_lr", training_args.mm_projector_lr)
        training_args.use_im_start_end = model_args.mm_use_im_start_end
        print("training_args.use_im_start_end", training_args.use_im_start_end)
        model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
        print("model_args.mm_use_im_patch_token", model_args.mm_use_im_patch_token)
        model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
        print("【EXIT】if model_args.vision_tower is not None:")

    print(f"【COND】 bits={training_args.bits}") # 16
    if training_args.bits in [4, 8]:
        pass

    data_module = make_supervised_data_module(tokenizer=tokenizer,
                                              data_args=data_args)
    print("data_module\n", data_module) # {'train_dataset': <llava.train.train.LazySupervisedDataset object at 0x7ed6341f4880>, 'eval_dataset': None, 'data_collator': DataCollatorForSupervisedDataset(tokenizer=LlamaTokenizer(name_or_path='lmsys/vicuna-7b-v1.5', vocab_size=32000, model_max_length=2048, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False), 'pad_token': '<unk>'}, clean_up_tokenization_spaces=False))}

    trainer = LLaVATrainer(model=model,
                    tokenizer=tokenizer,
                    args=training_args,
                    **data_module)
    print("trainer\n", trainer) # <llava.train.llava_trainer.LLaVATrainer object at 0x7ed6341f4490>

    print("【COND】list(pathlib.Path(training_args.output_dir).glob('checkpoint-*'))\n", list(pathlib.Path(training_args.output_dir).glob("checkpoint-*"))) # [PosixPath('checkpoints/llava-v1.5-7b-pretrain/checkpoint-250'), PosixPath('checkpoints/llava-v1.5-7b-pretrain/checkpoint-1')]
    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        # 【ENTER】
        print("【ENTER】if list(pathlib.Path(training_args.output_dir).glob(checkpoint-*)):")
        trainer.train(resume_from_checkpoint=False)
        print("【EXIT】if list(pathlib.Path(training_args.output_dir).glob(checkpoint-*)):")
    else:
        print("【ENTER】else of if list(pathlib.Path(training_args.output_dir).glob(checkpoint-*)):")
        trainer.train()
        print("【EXIT】else of if list(pathlib.Path(training_args.output_dir).glob(checkpoint-*)):")
    trainer.save_state()

    model.config.use_cache = True
    print("model.config.use_cache = True", model.config.use_cache) # True

    print(f"【COND】lora_enable={training_args.lora_enable}") # False
    if training_args.lora_enable:
      pass
    else:
        # 【ENTER】
        print("【ENTER】else of if training_args.lora_enable:")
        print("trainer", trainer) # <class 'llava.train.llava_trainer.LLaVATrainer'>
        safe_save_model_for_hf_trainer(trainer=trainer,
                                       output_dir=training_args.output_dir)
        print("【EXIT】else of if training_args.lora_enable:")