In [1]:
import warnings
warnings.filterwarnings("ignore")

from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer
import torch
from torchvision import transforms
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
import os

from models.MultiStage_FtResnet import CompositeModel
from util import preprocess_image

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device}")

[INFO] Using device: cuda


### Custom visual encoder

In [2]:
class CustomQwenVisualEncoder(Qwen2_5_VisionTransformerPretrainedModel):
    """
    Custom Visual Encoder for Qwen that can be modified as needed.
    Replace the forward method with an empty pass-through logic.
    """

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        

    def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Forward pass for CustomQwenVisualEncoder.
        Currently, it acts as a pass-through, returning the input hidden_states unchanged.

        Args:
            hidden_states (torch.Tensor): Input tensor (typically the preprocessed image from the processor).
            grid_thw (torch.Tensor): Information about the grid of patches (may be used for final dimension).

        Returns:
            torch.Tensor: Unchanged input hidden_states.
        """

        return hidden_states

### Models settings and setup

In [3]:
# Model name and paths
qwen_model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
qwen_processor_folder = os.path.join("..", "assets", "qwen_processor")


# Composite model state path
composite_model_state_path = os.path.join("..", "models", "layernorm_best_model.pth")
composite_model_state_dict = torch.load(composite_model_state_path, weights_only=False)
#composite_model_state_path = os.path.join("..", "models", "composite_model_checkpoints", "best_model_epoch_16_loss_2.6507.pth")
#composite_model_state_dict = torch.load(composite_model_state_path, weights_only=False)['model_state_dict']


# Processor Qwen
print("[INFO] Loading Qwen tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(qwen_processor_folder, trust_remote_code=True)
print("[INFO] Qwen tokenizer loaded.")


# Qwen Model
print("\n[INFO] Loading Qwen model...")
qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    qwen_model_name,
    torch_dtype=torch.float16,
)
qwen_model.to(device)
print("[INFO] Qwen model loaded.")


# Composite model
composite_model = CompositeModel().to(device).half()
composite_model.load_state_dict(composite_model_state_dict)


# Custom visual encoder
custom_visual_encoder = CustomQwenVisualEncoder(config=qwen_model.config.vision_config)
custom_visual_encoder.eval()
qwen_model.model.visual = custom_visual_encoder
print("\n[INFO] Qwen visual encoder replaced with custom encoder.")


# Predefined inputs for Qwen model
qwen_inputs = torch.load(os.path.join("..", "assets", "qwen_inputs.pt"))

[INFO] Loading Qwen tokenizer...
[INFO] Qwen tokenizer loaded.

[INFO] Loading Qwen model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[INFO] Qwen model loaded.

[INFO] Qwen visual encoder replaced with custom encoder.


### Inference

In [4]:
# Image setup
image_size = 384
image_path = os.path.join("..", "test_images", "satellite.jpg")
image = preprocess_image(image_path, image_size)


# Custom pixel values
transform = transforms.Compose([transforms.ToTensor(),])
resnet_inputs = transform(image).unsqueeze(0).to(device, torch.float16)

with torch.no_grad():
    custom_pixel_values = composite_model(resnet_inputs)
    custom_pixel_values = custom_pixel_values.squeeze(0)


# Generate inputs for Qwen from saved file and replace pixel_values
qwen_inputs['pixel_values'] = custom_pixel_values


# Inference
print("[INFO] Generating...")
with torch.no_grad():
    outputs = qwen_model.generate(
        **qwen_inputs,
        max_new_tokens=128,
    )

generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(qwen_inputs['input_ids'], outputs)
]
output_text = tokenizer.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("\n--- Output ---")
print(output_text[0])

[INFO] Generating...

--- Output ---
The image depicts a large, rectangular structure that appears to be a solar panel array. The panels are arranged in rows and columns, forming a grid-like pattern. Each panel is covered with photovoltaic cells designed to convert sunlight into electricity. The panels are mounted on a metal frame, which is supported by multiple vertical poles. The background shows a clear sky with no visible clouds, indicating good weather conditions for solar energy production.

The solar panels are positioned at an angle to maximize their exposure to the sun's rays, which is typical for solar farms or large-scale solar power installations. The ground around the panels is relatively flat, suggesting that the
