In [11]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import CocoDetection
import matplotlib.pyplot as plt
import torchvision.models as models

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [12]:
import torch
from torch import nn
import torchvision.models as models

class Backbone(nn.Module):
    def __init__(self, name="resnet50", pretrained=False, num_channels=256):
        super(Backbone, self).__init__()
        # Load a ResNet backbone
        resnet = getattr(models, name)(pretrained=pretrained)
        # Use layers until layer4 as the backbone
        self.body = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
        )
        # Reduce dimensionality from 2048 to the specified number of channels (256)
        self.conv = nn.Conv2d(2048, num_channels, kernel_size=1)

    def forward(self, x):
        features = self.body(x)  # Extract features from ResNet
        print(features.shape)
        features = self.conv(features)  # Reduce channels to 256
        print(features.shape)
        return features


# Instantiate the backbone
backbone = Backbone()
input_tensor = torch.randn(1, 3, 224, 224)  # Simulated input
features = backbone(input_tensor)
print("Backbone output shape:", features.shape)  # Expected: (1, 256, 7, 7)



torch.Size([1, 2048, 7, 7])
torch.Size([1, 256, 7, 7])
Backbone output shape: torch.Size([1, 256, 7, 7])


In [13]:
import math

class PositionalEncoding(nn.Module):
    def __init__(self, num_channels=256, height=50, width=50):
        super(PositionalEncoding, self).__init__()
        self.height = height
        self.width = width
        self.num_channels = num_channels

        # Create positional encodings
        pe = torch.zeros(num_channels, height, width)
        y_pos = torch.arange(0, height, dtype=torch.float32).unsqueeze(1).repeat(1, width)
        x_pos = torch.arange(0, width, dtype=torch.float32).unsqueeze(0).repeat(height, 1)

        div_term = torch.exp(torch.arange(0, num_channels, 2).float() * (-math.log(10000.0) / num_channels))
        pe[0::2, :, :] = torch.sin(y_pos.unsqueeze(0) * div_term[:, None, None])
        pe[1::2, :, :] = torch.cos(x_pos.unsqueeze(0) * div_term[:, None, None])

        self.register_buffer("pe", pe)

    def forward(self, x):
        """
        x: Input feature map of shape (batch_size, num_channels, height, width)
        """
        batch_size, num_channels, height, width = x.size()
        assert num_channels == self.num_channels, "Feature map channels must match positional encoding channels."
        assert height <= self.height and width <= self.width, "Feature map spatial size exceeds positional encoding."

        return x + self.pe[:, :height, :width]

pos_enc = PositionalEncoding()

# Test the Transformer
features = torch.randn(1, 256, 7, 7)  # Backbone feature map
pos_encoded_features = pos_enc(features)  # Add positional encoding

In [30]:
class Transformer(nn.Module):
    def __init__(self, d_model=256, num_queries=100, nhead=8, num_encoder_layers=6, num_decoder_layers=6):
        super(Transformer, self).__init__()
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead), num_layers=num_encoder_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead), num_layers=num_decoder_layers
        )
        self.query_embed = nn.Embedding(num_queries, d_model)  # Learnable queries (100 queries for 100 objects)

    def forward(self, src, mask=None):
        """
        src: Feature map from the backbone (flattened)
        mask: Optional mask for padding
        """
        feature_map = self.encoder(src, src_key_padding_mask=mask)  # Transformer encoder output. Feature map. 
        print(feature_map.shape)
        queries = self.query_embed.weight.unsqueeze(1).repeat(1, src.size(1), 1)  # Shape: (num_queries, batch, d_model)
        output = self.decoder(queries, feature_map)  # Decoder maps queries to final outputs
        return output

transformer = Transformer()
pos_encoded_features = torch.randn(1, 256, 50, 50)
flattened_features = pos_encoded_features.flatten(2).permute(2, 0, 1)  # (H*W, Batch, Channels)

output = transformer(flattened_features)  # Transformer output
print("Transformer output shape:", output.shape)  # (num_queries, batch, d_model)

torch.Size([2500, 1, 256])
Transformer output shape: torch.Size([100, 1, 256])


In [26]:
class DetectionHead(nn.Module):
    def __init__(self, d_model=256, num_classes=91):
        super(DetectionHead, self).__init__()
        
        # Classification head: Predicts class probabilities for each query
        self.class_head = nn.Linear(d_model, num_classes)
        
        # Bounding box head: Predicts [cx, cy, w, h] for each query
        self.bbox_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 4)  # Final output: 4 bounding box coordinates
        )
    
    def forward(self, decoder_outputs):
        """
        decoder_outputs: Tensor of shape (num_queries, batch_size, d_model)
        """
        # Classification predictions
        class_logits = self.class_head(decoder_outputs)  # (num_queries, batch_size, num_classes)
        
        # Bounding box predictions
        bbox_coords = self.bbox_head(decoder_outputs)   # (num_queries, batch_size, 4)
        bbox_coords = bbox_coords.sigmoid()  # Normalize box coordinates to [0, 1]
        
        return class_logits, bbox_coords

In [28]:
import torch
import torch.nn as nn

class DETR(nn.Module):
    def __init__(self, num_classes=91, num_queries=100, hidden_dim=256):
        """
        DETR Model integrating:
        - Backbone for feature extraction.
        - Positional Encoding.
        - Transformer Encoder-Decoder.
        - Detection Head for bounding box and class prediction.

        Args:
            num_classes (int): Number of object classes.
            num_queries (int): Number of object queries.
            hidden_dim (int): Transformer embedding dimension.
        """
        super(DETR, self).__init__()

        # Use the previously defined Backbone class
        self.backbone = Backbone(num_channels=hidden_dim)

        # Positional Encoding
        self.positional_encoding = PositionalEncoding(hidden_dim)

        # Transformer Encoder-Decoder
        self.transformer = Transformer(d_model=hidden_dim, num_queries=num_queries)

        # Detection Head (Classification + Bounding Box Regression)
        self.detection_head = DetectionHead(hidden_dim, num_classes)

        # Object Queries: Learnable embeddings
        self.query_embed = nn.Embedding(num_queries, hidden_dim)

    def forward(self, images):
        """
        Forward pass through DETR:
        1. Extract features using CNN backbone.
        2. Add positional encoding.
        3. Process features through Transformer encoder-decoder.
        4. Decode predictions into class probabilities and bounding boxes.

        Args:
            images (Tensor): Batch of images (B, C, H, W).

        Returns:
            dict: {'pred_logits': class probabilities, 'pred_boxes': bounding boxes}
        """
        # Extract features using Backbone
        features = self.backbone(images)  # Shape: (B, 256, H', W')

        # Add positional encoding
        pos_encoded_features = self.positional_encoding(features)

        # Flatten feature map for Transformer
        batch_size, _, height, width = features.shape
        src = pos_encoded_features.flatten(2).permute(2, 0, 1)  # Shape: (H'*W', B, hidden_dim)

        # Create object queries
        queries = self.query_embed.weight.unsqueeze(1).repeat(1, batch_size, 1)  # Shape: (num_queries, B, hidden_dim)

        # Pass through Transformer
        transformer_output = self.transformer(src, queries)  # Shape: (num_queries, B, hidden_dim)

        # Convert Transformer output into predictions
        class_logits, bbox_coords = self.detection_head(transformer_output)  # (num_queries, B, num_classes), (num_queries, B, 4)

        return {"pred_logits": class_logits.permute(1, 0, 2), "pred_boxes": bbox_coords.permute(1, 0, 2)}

# Instantiate the model
model = DETR(num_classes=91, num_queries=100)
# print(model)

In [16]:
# Define the model
d_model = 256
num_classes = 91
num_queries = 100

# Instantiate the Transformer and Detection Heads
transformer = Transformer(d_model=d_model)
detection_head = DetectionHead(d_model=d_model, num_classes=num_classes)

# Simulated input
features = torch.randn(1, 256, 50, 50)  # Backbone feature map
pos_encoded_features = pos_enc(features)
flattened_features = pos_encoded_features.flatten(2).permute(2, 0, 1)  # (H*W, batch, d_model)

# Pass through Transformer
decoder_output = transformer(flattened_features)  # (num_queries, batch, d_model)

# Pass through Detection Head
class_logits, bbox_coords = detection_head(decoder_output)

print("Class logits shape:", class_logits.shape)  # (num_queries, batch, num_classes)
print("Bounding box coordinates shape:", bbox_coords.shape)  # (num_queries, batch, 4)

Class logits shape: torch.Size([100, 1, 91])
Bounding box coordinates shape: torch.Size([100, 1, 4])


In [17]:
# Function to convert bounding boxes from [cx, cy, w, h] to [x_min, y_min, x_max, y_max]
def box_cxcywh_to_xyxy(box):
    x_c, y_c, w, h = box.unbind(-1)
    x_min = x_c - 0.5 * w
    y_min = y_c - 0.5 * h
    x_max = x_c + 0.5 * w
    y_max = y_c + 0.5 * h
    return torch.stack([x_min, y_min, x_max, y_max], dim=-1)

def generalized_iou(pred_boxes, target_boxes):
    """
    Compute Generalized IoU between predicted and target boxes.
    """
    # Convert [cx, cy, w, h] to [x_min, y_min, x_max, y_max]
    pred_boxes = box_cxcywh_to_xyxy(pred_boxes)
    target_boxes = box_cxcywh_to_xyxy(target_boxes)

    # Compute areas
    pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1])
    target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1])

    # Compute intersection
    inter_min = torch.max(pred_boxes[:, None, :2], target_boxes[:, :2])  # (num_queries, num_targets, 2)
    inter_max = torch.min(pred_boxes[:, None, 2:], target_boxes[:, 2:])  # (num_queries, num_targets, 2)
    inter_wh = (inter_max - inter_min).clamp(min=0)  # (num_queries, num_targets, 2)
    inter_area = inter_wh[:, :, 0] * inter_wh[:, :, 1]

    # Compute union
    union_area = pred_area[:, None] + target_area - inter_area

    # IoU
    iou = inter_area / union_area

    # Compute enclosing box
    enclose_min = torch.min(pred_boxes[:, None, :2], target_boxes[:, :2])
    enclose_max = torch.max(pred_boxes[:, None, 2:], target_boxes[:, 2:])
    enclose_wh = (enclose_max - enclose_min).clamp(min=0)
    enclose_area = enclose_wh[:, :, 0] * enclose_wh[:, :, 1]

    # Generalized IoU
    giou = iou - (enclose_area - union_area) / enclose_area
    return giou

# Generate random predicted and ground truth boxes
pred_boxes = torch.rand(5, 4)
gt_boxes = torch.rand(5, 4)

# Convert to IoU-compatible format
pred_boxes_xyxy = box_cxcywh_to_xyxy(pred_boxes)
gt_boxes_xyxy = box_cxcywh_to_xyxy(gt_boxes)

# Compute Generalized IoU
giou = generalized_iou(pred_boxes_xyxy, gt_boxes_xyxy)

print("\nGeneralized IoU matrix between predictions and ground truth:")
print(giou)


Generalized IoU matrix between predictions and ground truth:
tensor([[-0.5531, -0.0137, -0.5323, -0.3554, -0.0014],
        [ 0.2663, -0.2661, -0.1653, -0.5216, -0.4692],
        [-0.1817,  0.0229, -0.1699, -0.4190, -0.2534],
        [ 0.0293, -0.0562, -0.2851, -0.5267, -0.1449],
        [ 0.2996,  0.0889,  0.2839, -0.2208, -0.1199]])


In [18]:
from scipy.optimize import linear_sum_assignment

class HungarianMatcher(nn.Module):
    def __init__(self, cost_class=1, cost_bbox=1, cost_giou=1):
        super(HungarianMatcher, self).__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou

    @torch.no_grad()
    def forward(self, outputs, targets):
        """
        Perform Hungarian matching.
        outputs: Dictionary with 'pred_logits' and 'pred_boxes' from the detection head.
        targets: List of dictionaries with 'labels' and 'boxes' for each image.
        """
        pred_logits = outputs['pred_logits']  # (batch_size, num_queries, num_classes)
        pred_boxes = outputs['pred_boxes']    # (batch_size, num_queries, 4)

        batch_size = pred_logits.shape[0]
        indices = []

        for b in range(batch_size):
            # Extract predictions and ground truth for this batch
            logits = pred_logits[b]  # (num_queries, num_classes)
            boxes = pred_boxes[b]    # (num_queries, 4)
            tgt_labels = targets[b]['labels']  # (num_objects,)
            tgt_boxes = targets[b]['boxes']    # (num_objects, 4)

            # Compute classification cost (cross-entropy)
            cost_class = -logits[:, tgt_labels].softmax(dim=-1)

            # Compute bbox L1 cost
            cost_bbox = torch.cdist(boxes, tgt_boxes, p=1)

            # Compute Generalized IoU cost
            cost_giou = -generalized_iou(boxes, tgt_boxes)

            # Combine costs
            total_cost = (
                self.cost_class * cost_class +
                self.cost_bbox * cost_bbox +
                self.cost_giou * cost_giou
            )

            # Solve assignment problem
            matched_indices = linear_sum_assignment(total_cost.cpu().numpy())
            indices.append((torch.as_tensor(matched_indices[0]), torch.as_tensor(matched_indices[1])))

        return indices

class DETRLoss(nn.Module):
    def __init__(self, matcher, num_classes, weight_dict):
        super(DETRLoss, self).__init__()
        self.matcher = matcher
        self.num_classes = num_classes
        self.weight_dict = weight_dict
        self.bce = nn.CrossEntropyLoss()
        self.l1 = nn.L1Loss()

    def forward(self, outputs, targets):
        indices = self.matcher(outputs, targets)

        # Extract matched predictions and targets
        pred_logits = outputs['pred_logits']
        pred_boxes = outputs['pred_boxes']

        # Initialize losses
        loss_class = 0
        loss_bbox = 0
        loss_giou = 0

        for batch_idx, (pred_idx, target_idx) in enumerate(indices):
            # Classification loss
            target_classes = targets[batch_idx]['labels'][target_idx]
            loss_class += self.bce(pred_logits[batch_idx, pred_idx], target_classes)

            # Bounding box loss
            target_boxes = targets[batch_idx]['boxes'][target_idx]
            loss_bbox += self.l1(pred_boxes[batch_idx, pred_idx], target_boxes)
            loss_giou += 1 - generalized_iou(pred_boxes[batch_idx, pred_idx], target_boxes).mean()

        total_loss = (self.weight_dict['class'] * loss_class +
                      self.weight_dict['bbox'] * loss_bbox +
                      self.weight_dict['giou'] * loss_giou)
        return total_loss

In [20]:
import torch
import numpy as np

# Function to generate random bounding boxes and labels
def generate_synthetic_data(batch_size, num_queries, num_objects, num_classes):
    """
    Generates random predictions and ground truth for testing.
    
    Args:
        batch_size: Number of images in the batch.
        num_queries: Number of object queries (fixed at 100 in DETR).
        num_objects: Number of actual objects in the image.
        num_classes: Number of object classes.
    
    Returns:
        predictions: Dict containing 'pred_logits' and 'pred_boxes'.
        targets: List of dictionaries containing 'labels' and 'boxes' for each image.
    """
    predictions = {
        'pred_logits': torch.randn(batch_size, num_queries, num_classes),  # Random class scores
        'pred_boxes': torch.rand(batch_size, num_queries, 4)  # Random bounding boxes in [0,1]
    }
    
    targets = []
    for _ in range(batch_size):
        num_gt = np.random.randint(1, num_objects + 1)  # Random number of ground truth objects
        
        labels = torch.randint(1, num_classes, (num_gt,))  # Random class labels (excluding "no object" class)
        boxes = torch.rand(num_gt, 4)  # Random bounding boxes in [0,1]

        targets.append({'labels': labels, 'boxes': boxes})
    
    return predictions, targets

In [21]:
# Generate data
batch_size = 2
num_queries = 100
num_objects = 5
num_classes = 10

predictions, targets = generate_synthetic_data(batch_size, num_queries, num_objects, num_classes)

print(f"Generated {batch_size} synthetic samples with max {num_objects} objects per image.")

# Instantiate the Hungarian matcher
matcher = HungarianMatcher(cost_class=1, cost_bbox=1, cost_giou=1)

# Perform matching
matched_indices = matcher(predictions, targets)

# Print results
for batch_idx, (pred_idx, target_idx) in enumerate(matched_indices):
    print(f"\nImage {batch_idx}:")
    print(f"  Matched predictions (query indices): {pred_idx.tolist()}")
    print(f"  Matched ground truth objects (target indices): {target_idx.tolist()}")

# Define loss weightings
weight_dict = {'class': 1, 'bbox': 5, 'giou': 2}

# Instantiate loss function
detr_loss = DETRLoss(matcher, num_classes, weight_dict)

# Compute loss
loss = detr_loss(predictions, targets)

print("\nComputed DETR loss:", loss.item())

Generated 2 synthetic samples with max 5 objects per image.

Image 0:
  Matched predictions (query indices): [3]
  Matched ground truth objects (target indices): [0]

Image 1:
  Matched predictions (query indices): [4, 19, 44]
  Matched ground truth objects (target indices): [0, 1, 2]

Computed DETR loss: 9.935282707214355


## Downloading Dataset

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transforms
transform = transforms.Compose([
    transforms.Resize((800, 800)),
    transforms.ToTensor(),
])

# Download and load VOC dataset (trainval split)
train_dataset = datasets.VOCDetection(
    root='./data',
    year='2012',
    image_set='val',  # You can also use 'train' or 'val'
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# Example for running inference and visualizing
image, target = train_dataset[0]
print(image.shape)  # Example image shape (3, 800, 800)
print(target)  # Annotations (bounding boxes, labels, etc.)

In [None]:
def visualize_voc_sample(dataset, num_samples=3):
    """
    Visualizes a few images from the VOC dataset with their bounding boxes.
    
    Args:
        dataset: The VOC dataset.
        num_samples: Number of images to visualize.
    """
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 5))

    for i in range(num_samples):
        img, target = dataset[i]
        img = img.permute(1, 2, 0).numpy()  # Convert from Tensor to NumPy format

        # Get bounding boxes and labels
        boxes = target["annotation"]["object"]
        if isinstance(boxes, dict):  # If only one object, wrap it in a list
            boxes = [boxes]

        fig, ax = plt.subplots(1, figsize=(6, 6))
        ax.imshow(img)
        
        for obj in boxes:
            bbox = obj["bndbox"]
            x_min, y_min, x_max, y_max = map(int, [bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"]])
            label = obj["name"]

            # Draw bounding box
            rect = patches.Rectangle(
                (x_min, y_min), x_max - x_min, y_max - y_min,
                linewidth=2, edgecolor="r", facecolor="none"
            )
            ax.add_patch(rect)
            ax.text(x_min, y_min - 5, label, color="red", fontsize=12, fontweight="bold")

        plt.axis("off")
        plt.show()

# Visualize sample images
visualize_voc_sample(train_dataset)

In [None]:
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR

# Assume `DETRModel` is your DETR implementation (replace with your model)
model = DETR()  # Replace with your DETR model initialization
model.to("cuda" if torch.cuda.is_available() else "cpu")

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Learning rate scheduler
scheduler = StepLR(optimizer, step_size=5, gamma=0.7)

# Loss function (assuming it's implemented)
detr_loss = YourDETRLossFunction()  # Replace with actual DETR loss function

# Training function
def train_model(model, dataloader, optimizer, num_epochs=10):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.train()

    for epoch in range(num_epochs):
        total_loss = 0

        for images, targets in dataloader:
            images = images.to(device)

            # Convert VOC targets to DETR format (bounding boxes and labels)
            boxes, labels = [], []
            for target in targets:
                objects = target["annotation"]["object"]
                if isinstance(objects, dict):  # If only one object, wrap it in a list
                    objects = [objects]
                
                img_boxes = []
                img_labels = []
                for obj in objects:
                    bbox = obj["bndbox"]
                    x_min, y_min, x_max, y_max = map(int, [bbox["xmin"], bbox["ymin"], bbox["xmax"], bbox["ymax"]])
                    img_boxes.append([x_min, y_min, x_max - x_min, y_max - y_min])  # Convert to (x, y, w, h)
                    img_labels.append(int(obj["name"]))

                boxes.append(torch.tensor(img_boxes, dtype=torch.float32).to(device))
                labels.append(torch.tensor(img_labels, dtype=torch.int64).to(device))

            # Create DETR-style targets
            targets = [{"boxes": b, "labels": l} for b, l in zip(boxes, labels)]

            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)

            # Compute loss
            loss = detr_loss(outputs, targets)

            # Backpropagation
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}")

# Run training
train_model(model, train_loader, optimizer, num_epochs=10)

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision.transforms import functional as F

def visualize_attention(model, image, feature_map, decoder_attention_weights, query_idx):
    """
    Visualizes the attention map for a specific query on an input image.
    
    Args:
        model: The DETR model.
        image: The input image (C, H, W).
        feature_map: The encoded feature map from the backbone.
        decoder_attention_weights: The cross-attention weights from the decoder.
        query_idx: Index of the query to visualize.
    """
    # Extract attention map for the given query
    # Attention weights: (num_queries, num_heads, H * W)
    query_attention = decoder_attention_weights[0, query_idx].mean(dim=0)  # Average across heads
    query_attention = query_attention.view(feature_map.shape[2], feature_map.shape[3])  # Reshape to (H, W)
    
    # Normalize attention for visualization
    query_attention = query_attention.detach().cpu().numpy()
    query_attention = (query_attention - query_attention.min()) / (query_attention.max() - query_attention.min())

    # Upsample attention to input image size
    attention_resized = F.resize(torch.tensor(query_attention), size=(image.shape[1], image.shape[2]), interpolation=F.InterpolationMode.BILINEAR)

    # Convert image to NumPy for visualization
    image_np = image.permute(1, 2, 0).cpu().numpy()
    image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())  # Normalize to [0, 1]

    # Overlay attention map on the image
    plt.figure(figsize=(8, 8))
    plt.imshow(image_np)
    plt.imshow(attention_resized.numpy(), alpha=0.5, cmap='jet')  # Overlay with transparency
    plt.title(f"Attention Map for Query {query_idx}")
    plt.axis('off')
    plt.show()


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision.transforms import functional as F

def visualize_self_attention(feature_map, attention_weights, spatial_idx, input_image):
    """
    Visualizes self-attention for a specific spatial position in the feature map.
    
    Args:
        feature_map: The feature map (batch_size, C, H, W).
        attention_weights: Self-attention weights (H * W, H * W).
        spatial_idx: The index of the spatial position to visualize (row * W + col).
        input_image: The input image (C, H, W).
    """
    # Feature map spatial dimensions
    H, W = feature_map.shape[2], feature_map.shape[3]
    
    # Extract attention for the chosen spatial position
    attention = attention_weights[spatial_idx]  # (H * W,)
    attention = attention.view(H, W)  # Reshape to feature map shape
    
    # Normalize for visualization
    attention = attention.detach().cpu().numpy()
    attention = (attention - attention.min()) / (attention.max() - attention.min())
    
    # Upsample attention to input image size
    attention_resized = F.resize(torch.tensor(attention), size=(input_image.shape[1], input_image.shape[2]), interpolation=F.InterpolationMode.BILINEAR)
    
    # Convert input image to NumPy
    input_image_np = input_image.permute(1, 2, 0).cpu().numpy()
    input_image_np = (input_image_np - input_image_np.min()) / (input_image_np.max() - input_image_np.min())
    
    # Plot the attention map
    plt.figure(figsize=(8, 8))
    plt.imshow(input_image_np)
    plt.imshow(attention_resized.numpy(), alpha=0.5, cmap='jet')  # Overlay attention map
    plt.title(f"Self-Attention for Position {spatial_idx} (Feature Map)")
    plt.axis('off')
    plt.show()

# Example usage
# Suppose `feature_map` is (1, C, H, W), `attention_weights` is (H*W, H*W), and `input_image` is (C, H_img, W_img)
# spatial_idx = 50  # Choose a spatial index
# visualize_self_attention(feature_map, attention_weights, spatial_idx, input_image)
