In [1]:
import warnings
warnings.filterwarnings("ignore")

from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, AutoImageProcessor, ResNetModel
import torch
from torchvision import transforms
from PIL import Image
import torch.nn as nn
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
import os
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INFO] Using device: {device}")

[INFO] Using device: cuda


### Adaptive Layer

In [2]:
class MultiStageAdapter(nn.Module):
    """
    Model that takes embeddings from 4 stages of ResNet and converts them into embeddings for Qwen.
    It uses 1x1 convolutions to project each stage to the desired output dimension, upsamples them to the
    same spatial size, concatenates them, and applies a fusion network. Finally, it flattens the spatial dimensions
    and applies a linear interpolation to get the target sequence length. A Layer Normalization is applied at the end.
    """

    def __init__(self, stage_channels=[256, 512, 1024, 2048], out_dim=2048, hidden_multiplier=2):
        """
        Constructor for MultiStageAdapter.

        Args:
            stage_channels (list): List of channel dimensions for each ResNet stage.
            out_dim (int): Desired output dimension for Qwen embeddings.
            hidden_multiplier (int): Multiplier for the hidden dimension in the fusion network.
        """

        super().__init__()

        # 1x1 convolutions to project each stage to out_dim
        self.projections = nn.ModuleList([
            nn.Conv2d(c, out_dim, kernel_size=1) for c in stage_channels
        ])

        # Fusion network, a small MLP with GELU activation between two linear layers implemented as 1x1 convolutions
        self.fusion = nn.Sequential(
            nn.Conv2d(out_dim * len(stage_channels), out_dim * hidden_multiplier, kernel_size=1),
            nn.GELU(),
            nn.Conv2d(out_dim * hidden_multiplier, out_dim, kernel_size=1)
        )
        
        # Final Layer Normalization
        self.final_norm = nn.LayerNorm(out_dim)


    def forward(self, stage0, stage1, stage2, stage3, target_seq_len=196):
        """
        Forward pass for MultiStageAdapter.

        Args:
            stage0, stage1, stage2, stage3 (torch.Tensor): Feature maps from the 4 ResNet stages.
            target_seq_len (int): Desired sequence length for the output embeddings.
        """

        # Get spatial dimensions from the last stage
        B, _, Ht, Wt = stage3.shape

        # Project each stage to out_dim and upsample to the size of the last stage
        proj_feats = []
        for feat, proj in zip([stage0, stage1, stage2, stage3], self.projections):
            x = proj(feat)
            x = F.interpolate(x, size=(Ht, Wt), mode='bilinear', align_corners=False)
            proj_feats.append(x)

        # Concatenate along the channel dimension and apply fusion network
        fused = torch.cat(proj_feats, dim=1)  # (B, out_dim*4, Ht, Wt)
        fused = self.fusion(fused)           # (B, out_dim, Ht, Wt)

        # Flatten spatial dimensions, interpolate to target sequence length and permute to (B, L, C)
        seq = fused.flatten(2)               # (B, out_dim, Ht*Wt)
        seq = F.interpolate(seq, size=target_seq_len, mode='linear', align_corners=False)
        seq = seq.permute(0, 2, 1)           # (B, L, C)
        
        # Apply final Layer Normalization
        seq = self.final_norm(seq)
        
        return seq

### Composite model

In [3]:
class CompositeModel(nn.Module):
    """
    Composite model that integrates ResNet for image feature extraction and MultiStageAdapter
    to convert these features into embeddings suitable for Qwen.
    """

    def __init__(self):
        """
        Constructor for CompositeModel.
        Initializes the ResNet model and the MultiStageAdapter.
        """

        super().__init__()
        self.resnet = ResNetModel.from_pretrained("microsoft/resnet-50")
        self.adapter = MultiStageAdapter()


    def forward(self, pixel_values, target_seq_len=196):
        """
        Forward pass for CompositeModel.
        It extracts features from ResNet and passes them through the MultiStageAdapter.

        Args:
            pixel_values (torch.Tensor): Input image tensor.
            target_seq_len (int): Desired sequence length for the output embeddings.

        Returns:
            torch.Tensor: Output embeddings suitable for Qwen.
        """

        intermediate_outputs = {}

        # Register hooks to capture outputs from each ResNet stage
        def get_hook(idx):
            def hook(module, input, output):
                intermediate_outputs[f"stage_{idx}"] = output
            return hook

        # Register hooks for each stage
        hooks = []
        for idx, stage in enumerate(self.resnet.encoder.stages):
            hooks.append(stage.register_forward_hook(get_hook(idx)))

        # Forward pass through ResNet to get intermediate features
        intermediate_outputs.clear()
        _ = self.resnet(pixel_values)

        # Remove hooks
        for h in hooks:
            h.remove()

        # Extract features from each stage
        stage0, stage1, stage2, stage3 = (
            intermediate_outputs["stage_0"],
            intermediate_outputs["stage_1"],
            intermediate_outputs["stage_2"],
            intermediate_outputs["stage_3"],
        )

        # Pass the features through the MultiStageAdapter
        projected = self.adapter(stage0, stage1, stage2, stage3, target_seq_len)
        
        return projected

### Custom visual encoder

In [4]:
class CustomQwenVisualEncoder(Qwen2_5_VisionTransformerPretrainedModel):
    """
    Custom Visual Encoder for Qwen that can be modified as needed.
    Replace the forward method with an empty pass-through logic.
    """

    def __init__(self, output_dim, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        

    def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Forward pass for CustomQwenVisualEncoder.
        Currently, it acts as a pass-through, returning the input hidden_states unchanged.

        Args:
            hidden_states (torch.Tensor): Input tensor (typically the preprocessed image from the processor).
            grid_thw (torch.Tensor): Information about the grid of patches (may be used for final dimension).

        Returns:
            torch.Tensor: Unchanged input hidden_states.
        """

        return hidden_states

### Models settings and setup

In [5]:
# Model names and paths
qwen_model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
resnet_model_name = "microsoft/resnet-50"
composite_model_state_path = os.path.join("..", "models", "layernorm_best_model.pth")


# Processor Qwen
print("[INFO] Loading processors...")
tokenizer = AutoProcessor.from_pretrained(qwen_model_name, trust_remote_code=True).tokenizer
print("[INFO] Qwen tokenizer loaded.")


# Qwen Model
print("\n[INFO] Loading Qwen model...")
qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    qwen_model_name,
    torch_dtype=torch.float16,
)
qwen_model.to(device)
print("[INFO] Qwen model loaded.")


# Composite model
composite_model = CompositeModel().to(device).half()
composite_model.load_state_dict(torch.load(composite_model_state_path))


# Custom visual encoder
custom_visual_encoder = CustomQwenVisualEncoder(output_dim=qwen_model.config.hidden_size, config=qwen_model.config.vision_config)
custom_visual_encoder.eval()
qwen_model.model.visual = custom_visual_encoder
print("\n[INFO] Qwen visual encoder replaced with custom encoder.")


# Predefined inputs for Qwen model
qwen_inputs = torch.load(os.path.join("..", "assets", "qwen_inputs.pt"))

[INFO] Loading processors...


The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


[INFO] Qwen tokenizer loaded.

[INFO] Loading Qwen model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[INFO] Qwen model loaded.

[INFO] Qwen visual encoder replaced with custom encoder.


### Inference

In [6]:
# Image setup
image_path = os.path.join("..", "test_images", "satellite.jpg")
image = Image.open(image_path).convert("RGB")
W, H = image.size
scale = 384 / min(W, H)
new_W = int(W * scale)
new_H = int(H * scale)
image = image.resize((new_W, new_H), resample=Image.BICUBIC)
W, H = image.size
crop_size = min(W, H)
left = (W - crop_size) // 2
top = (H - crop_size) // 2
image = image.crop((left, top, left + crop_size, top + crop_size))


# Custom pixel values
transform = transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor()
    ])
resnet_inputs = transform(image).unsqueeze(0).to(device, torch.float16)

with torch.no_grad():
    custom_pixel_values = composite_model(resnet_inputs)
    custom_pixel_values = custom_pixel_values.squeeze(0)


# Generate inputs for Qwen from saved file and replace pixel_values
qwen_inputs['pixel_values'] = custom_pixel_values


# Inference
print("[INFO] Generating...")
with torch.no_grad():
    outputs = qwen_model.generate(
        **qwen_inputs,
        max_new_tokens=128,
    )

generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(qwen_inputs['input_ids'], outputs)
]
output_text = tokenizer.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("\n--- Output ---")
print(output_text[0])

[INFO] Generating...

--- Output ---
The image depicts a large, rectangular structure that appears to be a solar panel array. The panels are arranged in rows and columns, forming a grid-like pattern. Each panel is covered with photovoltaic cells designed to convert sunlight into electricity. The panels are mounted on a metal frame, which is supported by multiple vertical poles. The background shows a clear sky with no visible clouds, indicating good weather conditions for solar energy production.

The solar panels are positioned at an angle to maximize their exposure to the sun's rays, which is typical for solar farms or large-scale solar power installations. The ground around the panels is relatively flat, suggesting that the
