In [None]:
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from torch.nn import functional as F
from transformers import BlipProcessor, BlipForConditionalGeneration
from scipy.ndimage import gaussian_filter

class CombinedTransformerVisualizer:
    def __init__(self, model, processor, device='cuda'):
        self.model = model
        self.processor = processor
        self.device = device
        self.gradients = None
        self.activations = None
        self.attention_maps = []
        
        # Target the output of the last transformer block
        self.target_layer = self.model.vision_model.encoder.layers[-1]
        
        def forward_hook(module, input, output):
            self.activations = output[0]
            
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]

        # Register hooks for GradCAM
        self.forward_handle = self.target_layer.register_forward_hook(forward_hook)
        self.backward_handle = self.target_layer.register_full_backward_hook(backward_hook)
        
        # Register hooks for attention maps on all encoder layers
        self.attention_hooks = []
        for layer in self.model.vision_model.encoder.layers:
            hook = layer.self_attn.register_forward_hook(
                lambda module, input, output: self._attention_hook(module, output)
            )
            self.attention_hooks.append(hook)

    def _attention_hook(self, module, output):
        """Hook to capture attention maps during forward pass."""
        if isinstance(output, tuple) and len(output) > 1:
            attention_weights = output[1]  # Usually, the attention weights are in the second element
            
            if attention_weights is not None:
                print("Captured attention weights:", attention_weights.shape)  # Debug print for attention weights
                self.attention_maps.append(attention_weights.detach())  # Save the attention weights
            else:
                print("Warning: Attention weights are None.")  # If no weights, print a warning
        else:
            print("Unexpected output format from attention layer:", output)

    def _perform_attention_rollout(self):
        """Perform attention rollout across all layers."""
        if not self.attention_maps:
            raise ValueError("No attention maps captured. Check if the forward pass was successful.")
            
        # Average attention heads per layer
        averaged_attentions = [attn.mean(dim=1) for attn in self.attention_maps]
        
        # Start with identity matrix
        batch_size, seq_len, _ = averaged_attentions[0].shape
        accumulated = torch.eye(seq_len).unsqueeze(0).to(self.device)
        accumulated = accumulated.repeat(batch_size, 1, 1)
        
        # Accumulate attention through layers
        for attn in averaged_attentions:
            accumulated = torch.bmm(attn, accumulated)
        
        # Get attention for tokens (exclude CLS token)
        rollout = accumulated[:, 0, 1:]
        
        return rollout

    def _process_attention_map(self, attention_map, image_size):
        """Process attention map for visualization."""
        attn = attention_map.cpu().numpy()
        grid_size = int(np.sqrt(attn.shape[-1]))
        attn = attn.reshape(grid_size, grid_size)
        attn = cv2.resize(attn, (image_size[0], image_size[1]))
        attn = gaussian_filter(attn, sigma=2)
        attn = (attn - attn.min()) / (attn.max() - attn.min() + 1e-8)
        return attn

    def apply_threshold(self, cam, threshold=0.2):
        """Apply threshold to focus on high attention regions."""
        cam[cam < threshold] = 0
        return cam

    def generate_visualizations(self, image_path, save_path=None):
        """Generate both GradCAM and attention visualizations."""
        # Load and process image
        image = Image.open(image_path).convert('RGB')
        original_size = image.size
        image_resized = image.resize((384, 384), Image.Resampling.LANCZOS)
        inputs = self.processor(images=image_resized, return_tensors="pt").to(self.device)
        
        # Clear previous attention maps
        self.attention_maps = []
        
        # Forward pass for attention maps
        with torch.no_grad():
            outputs = self.model.vision_model(
                inputs['pixel_values'],
                output_attentions=True,  # Ensure this is set to True
                return_dict=True
            )
        
        # Get attention rollout
        rollout = self._perform_attention_rollout()
        attention_map = self._process_attention_map(rollout[0], original_size)
        
        # Clear gradients and perform forward/backward pass for GradCAM
        self.model.zero_grad()
        outputs = self.model.vision_model(**inputs)
        target = outputs.last_hidden_state.mean(dim=1).sum()
        target.backward()
        
        if self.gradients is None or self.activations is None:
            print("Error: Gradients or activations are None.")
            return
        
        # Calculate GradCAM
        pooled_gradients = torch.mean(self.gradients, dim=1)
        cam = torch.zeros(self.activations.shape[1], dtype=self.activations.dtype).to(self.device)
        
        for i in range(1, self.activations.shape[1]):
            cam[i] = torch.sum(pooled_gradients[0] * self.activations[0, i])
        
        # Process GradCAM
        cam = F.relu(cam)
        cam = cam.detach().cpu().numpy()
        cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
        cam = self.apply_threshold(cam)
        
        # Reshape and resize GradCAM
        grid_size = int(np.sqrt(len(cam) - 1))
        cam_reshaped = cam[1:].reshape(grid_size, grid_size)
        cam_resized = cv2.resize(cam_reshaped, original_size)
        
        # Create attention overlay
        image_array = np.array(image)
        attention_heatmap = cv2.applyColorMap(np.uint8(attention_map * 255), cv2.COLORMAP_JET)
        attention_overlay = cv2.addWeighted(image_array, 0.7, attention_heatmap, 0.3, 0)
        
        # Create GradCAM overlay
        gradcam_heatmap = cv2.applyColorMap(np.uint8(cam_resized * 255), cv2.COLORMAP_JET)
        gradcam_overlay = cv2.addWeighted(image_array, 0.7, gradcam_heatmap, 0.3, 0)
        
        # Create GradCAM on Attention Overlay
        gradcam_on_attention = cv2.addWeighted(attention_overlay, 0.5, gradcam_overlay, 0.5, 0)
        
        # Visualize results
        plt.figure(figsize=(25, 5))
        
        plt.subplot(1, 6, 1)
        plt.imshow(image)
        plt.title('Original Image')
        plt.axis('off')
        
        plt.subplot(1, 6, 2)
        plt.imshow(cam_resized, cmap='jet')
        plt.title('GradCAM Heatmap')
        plt.axis('off')
        
        plt.subplot(1, 6, 3)
        plt.imshow(attention_map, cmap='jet')
        plt.title('Attention Map')
        plt.axis('off')
        
        plt.subplot(1, 6, 4)
        plt.imshow(cv2.cvtColor(gradcam_overlay, cv2.COLOR_BGR2RGB))
        plt.title('GradCAM Overlay')
        plt.axis('off')
        
        plt.subplot(1, 6, 5)
        plt.imshow(cv2.cvtColor(attention_overlay, cv2.COLOR_BGR2RGB))
        plt.title('Attention Overlay')
        plt.axis('off')
        
        plt.subplot(1, 6, 6)
        plt.imshow(cv2.cvtColor(gradcam_on_attention, cv2.COLOR_BGR2RGB))
        plt.title('GradCAM on Attention Overlay')
        plt.axis('off')
        
        if save_path:
            plt.savefig(save_path)
            plt.close()
        
        return {
            'gradcam': cam_resized,
            'attention_map': attention_map,
            'gradcam_overlay': gradcam_overlay,
            'attention_overlay': attention_overlay,
            'gradcam_on_attention': gradcam_on_attention
        }

    def __del__(self):
        """Clean up hooks."""
        self.forward_handle.remove()
        self.backward_handle.remove()
        for hook in self.attention_hooks:
            hook.remove()


def analyze_image(model_path, image_path, device='cuda'):
    try:
        model = BlipForConditionalGeneration.from_pretrained(model_path)
        processor = BlipProcessor.from_pretrained(model_path)
        model.to(device)
        
        visualizer = CombinedTransformerVisualizer(model, processor, device)
        results = visualizer.generate_visualizations(
            image_path,
            save_path='combined_visualization.png'
        )
        print("Visualization saved as 'combined_visualization.png'")
        return results
    except Exception as e:
        print(f"Error analyzing image: {str(e)}")
        raise


if __name__ == "__main__":
    model_path = "./blip_radiology_finetuned_best"  # Ensure the model path is correct
    image_path = "radiology_images/image_552.jpg"  # Ensure the image path is correct
    analyze_image(model_path, image_path)


In [1]:
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from torch.nn import functional as F
from transformers import BlipProcessor, BlipForConditionalGeneration
from scipy.ndimage import gaussian_filter

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class CombinedTransformerVisualizer:
    def __init__(self, model, processor, device='cuda'):
        self.model = model
        self.processor = processor
        self.device = device
        self.gradients = None
        self.activations = None
        self.attention_maps = []
        
        # Target the output of the last transformer block
        self.target_layer = self.model.vision_model.encoder.layers[-1]
        
        def forward_hook(module, input, output):
            self.activations = output[0]
            
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]

        # Register hooks for GradCAM
        self.forward_handle = self.target_layer.register_forward_hook(forward_hook)
        self.backward_handle = self.target_layer.register_full_backward_hook(backward_hook)
        
        # Register hooks for attention maps on all encoder layers
        self.attention_hooks = []
        for layer in self.model.vision_model.encoder.layers:
            hook = layer.self_attn.register_forward_hook(
                lambda module, input, output: self._attention_hook(module, output)
            )
            self.attention_hooks.append(hook)

    def _attention_hook(self, module, output):
        """Hook to capture attention maps during forward pass."""
        if isinstance(output, tuple) and len(output) > 1:
            attention_weights = output[1]  # Usually, the attention weights are in the second element
            
            if attention_weights is not None:
                print("Captured attention weights:", attention_weights.shape)  # Debug print for attention weights
                self.attention_maps.append(attention_weights.detach())  # Save the attention weights
            else:
                print("Warning: Attention weights are None.")  # If no weights, print a warning
        else:
            print("Unexpected output format from attention layer:", output)

    def _perform_attention_rollout(self):
        """Perform attention rollout across all layers."""
        if not self.attention_maps:
            raise ValueError("No attention maps captured. Check if the forward pass was successful.")
            
        # Average attention heads per layer
        averaged_attentions = [attn.mean(dim=1) for attn in self.attention_maps]
        
        # Start with identity matrix
        batch_size, seq_len, _ = averaged_attentions[0].shape
        accumulated = torch.eye(seq_len).unsqueeze(0).to(self.device)
        accumulated = accumulated.repeat(batch_size, 1, 1)
        
        # Accumulate attention through layers
        for attn in averaged_attentions:
            accumulated = torch.bmm(attn, accumulated)
        
        # Get attention for tokens (exclude CLS token)
        rollout = accumulated[:, 0, 1:]
        
        return rollout

    def _process_attention_map(self, attention_map, image_size):
        """Process attention map for visualization."""
        attn = attention_map.cpu().numpy()
        grid_size = int(np.sqrt(attn.shape[-1]))
        attn = attn.reshape(grid_size, grid_size)
        attn = cv2.resize(attn, (image_size[0], image_size[1]))
        attn = gaussian_filter(attn, sigma=2)
        attn = (attn - attn.min()) / (attn.max() - attn.min() + 1e-8)
        return attn

    def apply_threshold(self, cam, threshold=0.2):
        """Apply threshold to focus on high attention regions."""
        cam[cam < threshold] = 0
        return cam

    def generate_visualizations(self, image_path, save_path=None):
        """Generate both GradCAM and attention visualizations."""
        # Load and process image
        image = Image.open(image_path).convert('RGB')
        original_size = image.size
        image_resized = image.resize((384, 384), Image.Resampling.LANCZOS)
        inputs = self.processor(images=image_resized, return_tensors="pt").to(self.device)
        
        # Clear previous attention maps
        self.attention_maps = []
        
        # Forward pass for attention maps
        with torch.no_grad():
            outputs = self.model.vision_model(
                inputs['pixel_values'],
                output_attentions=True,  # Ensure this is set to True
                return_dict=True
            )
        
        # Get attention rollout
        rollout = self._perform_attention_rollout()
        attention_map = self._process_attention_map(rollout[0], original_size)
        
        # Clear gradients and perform forward/backward pass for GradCAM
        self.model.zero_grad()
        outputs = self.model.vision_model(**inputs)
        target = outputs.last_hidden_state.mean(dim=1).sum()
        target.backward()
        
        if self.gradients is None or self.activations is None:
            print("Error: Gradients or activations are None.")
            return
        
        # Calculate GradCAM
        pooled_gradients = torch.mean(self.gradients, dim=1)
        cam = torch.zeros(self.activations.shape[1], dtype=self.activations.dtype).to(self.device)
        
        for i in range(1, self.activations.shape[1]):
            cam[i] = torch.sum(pooled_gradients[0] * self.activations[0, i])
        
        # Process GradCAM
        cam = F.relu(cam)
        cam = cam.detach().cpu().numpy()
        cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam) + 1e-8)
        cam = self.apply_threshold(cam)
        
        # Reshape and resize GradCAM
        grid_size = int(np.sqrt(len(cam) - 1))
        cam_reshaped = cam[1:].reshape(grid_size, grid_size)
        cam_resized = cv2.resize(cam_reshaped, original_size)
        
        # Create attention overlay
        image_array = np.array(image)
        attention_heatmap = cv2.applyColorMap(np.uint8(attention_map * 255), cv2.COLORMAP_JET)
        attention_overlay = cv2.addWeighted(image_array, 0.7, attention_heatmap, 0.3, 0)
        
        # Create GradCAM overlay
        gradcam_heatmap = cv2.applyColorMap(np.uint8(cam_resized * 255), cv2.COLORMAP_JET)
        gradcam_overlay = cv2.addWeighted(image_array, 0.7, gradcam_heatmap, 0.3, 0)
        
        # Create GradCAM on Attention Overlay
        gradcam_on_attention = cv2.addWeighted(attention_overlay, 0.5, gradcam_overlay, 0.5, 0)
        
        # Visualize results
        plt.figure(figsize=(25, 5))
        
        plt.subplot(1, 6, 1)
        plt.imshow(image)
        plt.title('Original Image')
        plt.axis('off')
        
        plt.subplot(1, 6, 2)
        plt.imshow(cam_resized, cmap='jet')
        plt.title('GradCAM Heatmap')
        plt.axis('off')
        
        plt.subplot(1, 6, 3)
        plt.imshow(attention_map, cmap='jet')
        plt.title('Attention Map')
        plt.axis('off')
        
        plt.subplot(1, 6, 4)
        plt.imshow(cv2.cvtColor(gradcam_overlay, cv2.COLOR_BGR2RGB))
        plt.title('GradCAM Overlay')
        plt.axis('off')
        
        plt.subplot(1, 6, 5)
        plt.imshow(cv2.cvtColor(attention_overlay, cv2.COLOR_BGR2RGB))
        plt.title('Attention Overlay')
        plt.axis('off')
        
        plt.subplot(1, 6, 6)
        plt.imshow(cv2.cvtColor(gradcam_on_attention, cv2.COLOR_BGR2RGB))
        plt.title('GradCAM on Attention Overlay')
        plt.axis('off')
        
        if save_path:
            plt.savefig(save_path)
            plt.close()
        
        return {
            'gradcam': cam_resized,
            'attention_map': attention_map,
            'gradcam_overlay': gradcam_overlay,
            'attention_overlay': attention_overlay,
            'gradcam_on_attention': gradcam_on_attention
        }

    def __del__(self):
        """Clean up hooks."""
        self.forward_handle.remove()
        self.backward_handle.remove()
        for hook in self.attention_hooks:
            hook.remove()


In [4]:
def analyze_image(model_path, image_path, device='cuda'):
    try:
        model = BlipForConditionalGeneration.from_pretrained(model_path)
        processor = BlipProcessor.from_pretrained(model_path)
        model.to(device)
        
        visualizer = CombinedTransformerVisualizer(model, processor, device)
        results = visualizer.generate_visualizations(
            image_path,
            save_path='combined_visualization.png'
        )
        print("Visualization saved as 'combined_visualization.png'")
        return results
    except Exception as e:
        print(f"Error analyzing image: {str(e)}")
        raise

In [5]:
if __name__ == "__main__":
    model_path = "../train/models/blip_radiology_finetuned"  
    image_path = "../train/radiology_images/image_1294.jpg"  
    analyze_image(model_path, image_path)

Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured attention weights: torch.Size([1, 16, 577, 577])
Captured atten