In [1]:
from transformers.models.mllama.modeling_mllama import MllamaVisionModel
from PIL import Image
from transformers import AutoProcessor
from astropy.table import Table
from tqdm.notebook import tqdm, trange
import os

In [2]:
# read all csv table
csv_data = Table.read("/mnt/data/CVPR2025/task1_data/test_no_classification.hdf5")
csv_data.keys()
print(csv_data["TARGETID"].shape, type(csv_data["TARGETID"]))

(21051,) <class 'astropy.table.column.Column'>


In [3]:
# load a vision model
checkpoint="/mnt/data/CVPR2025/task1_data/Llama-3.2-11B-Vision-Instruct"
model = MllamaVisionModel.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [4]:
model.cuda()
model.eval()

MllamaVisionModel(
  (patch_embedding): Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), padding=valid, bias=False)
  (gated_positional_embedding): MllamaPrecomputedPositionEmbedding(
    (tile_embedding): Embedding(9, 8197120)
  )
  (pre_tile_positional_embedding): MllamaPrecomputedAspectRatioEmbedding(
    (embedding): Embedding(9, 5120)
  )
  (post_tile_positional_embedding): MllamaPrecomputedAspectRatioEmbedding(
    (embedding): Embedding(9, 5120)
  )
  (layernorm_pre): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  (layernorm_post): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  (transformer): MllamaVisionEncoder(
    (layers): ModuleList(
      (0-31): 32 x MllamaVisionEncoderLayer(
        (self_attn): MllamaVisionSdpaAttention(
          (q_proj): Linear(in_features=1280, out_features=1280, bias=False)
          (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
          (v_proj): Linear(in_features=1280, out_features=1280, bias=False)

In [11]:
image_root_dir = "/mnt/data/CVPR2025/task1_data/images/images"
all_image_paths = []
for id_ in tqdm(csv_data["TARGETID"]):
    img_path = os.path.join(image_root_dir, "{}.png".format(id_))
    all_image_paths.append(img_path)
    # inputs = processor(images=test_image, return_tensors="pt")
    # with torch.no_grad():
    #     for k in inputs:
    #         inputs[k] = inputs[k].cuda()
    #     output = model(**inputs)
    #     all_features.append(output[0].cpu().squeeze(0).squeeze(1))

  0%|          | 0/21051 [00:00<?, ?it/s]

In [13]:
all_image_paths[:4]

['/mnt/data/CVPR2025/task1_data/images/images/39632936139492141.png',
 '/mnt/data/CVPR2025/task1_data/images/images/39632936139492779.png',
 '/mnt/data/CVPR2025/task1_data/images/images/39632936143685968.png',
 '/mnt/data/CVPR2025/task1_data/images/images/39632936143686134.png']

In [16]:
import torch
# test image inputs
test_image_paths = all_image_paths[:8]
test_images = [Image.open(x) for x in test_image_paths]
inputs = processor(images=test_images, return_tensors="pt")
with torch.no_grad():
    for k in inputs:
        inputs[k] = inputs[k].cuda()
    output = model(**inputs)
print(output[0].shape)

torch.Size([1, 8, 4, 1601, 7680])


In [9]:
output[0].device

device(type='cuda', index=0)

In [19]:
import numpy as np
# 初始化存储所有特征的列表
all_image_features = []

# 设置批次大小
batch_size = 8

# 将模型转换为 bf16
model = model.bfloat16()

# 计算需要处理的批次数
num_batches = len(all_image_paths) // batch_size
if len(all_image_paths) % batch_size != 0:
    num_batches += 1

# 使用tqdm显示处理进度
for i in tqdm(range(num_batches)):
    # 获取当前批次的图像路径
    start_idx = i * batch_size
    end_idx = min((i + 1) * batch_size, len(all_image_paths))
    batch_image_paths = all_image_paths[start_idx:end_idx]
    
    # 读取并处理图像
    batch_images = [Image.open(x) for x in batch_image_paths]
    inputs = processor(images=batch_images, return_tensors="pt")
    
    # 将输入移到GPU
    with torch.no_grad():
        for k in inputs:
            inputs[k] = inputs[k].cuda()
        
        # 模型推理
        output = model(**inputs)
        
        # 将输出移到CPU并转换为numpy数组
        features = output[0].cpu().float().squeeze(0).numpy()
        all_image_features.append(features)

# 合并所有特征
all_image_features = np.concatenate(all_image_features, axis=1)

print("处理完成!")
print(f"特征形状: {all_image_features.shape}")

  0%|          | 5/2632 [00:10<1:34:50,  2.17s/it]