In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

2.2. CLIP-based Semantic Segmentation
Recently, the Contrastive Language-Image Pretraining
(CLIP) model [51] has been adopted in semantic segmen-
tation tasks thanks to the generalized knowledge learned
from a large corpus of image-text pairs. Given the gener-
alization capability, a number of zero-shot/open-vocabulary
approaches [17, 22, 34, 38, 47, 52, 67, 70, 71] exploit CLIP to
segment the classes which are unseen during training. How-
ever, these methods still require mask annotations during
training. To minimize the annotation effort, CLIP has also
been adopted to improve unsupervised methods [24, 58, 81].
Nevertheless, the segmentation performance is still unsat-
isfactory and is not desired for further applications. On
the other hand, CLIP has also been utilized to benefit
WSSS [42, 49, 65, 69, 72]. These works mainly focus on
designing text prompts or prompt learning techniques for
the text encoder. However, they either consider only the
foreground class prompts, or rely on general background
prompts defined by additional manual efforts and heuris-
tic human knowledge. Moreover, such manually-defined
prompts may not fully exploit the knowledge in the CLIP
latent space. In contrast, with no need for any manual
efforts, our proposed SemPLeS framework automatically
learns prompts embedded with class-associated semantic
knowledge discovered from the CLIP latent space.

In [26]:
# Define model parameters
batch_size = 2
channels = 3
height = 64
width = 64
prompt_dim = 512

# Create dummy inputs
X_f = torch.randn(batch_size, channels, height, width)  # Foreground image
X_b = torch.randn(batch_size, channels, height, width)  # Background image
t_k = torch.randn(batch_size, prompt_dim)  # Text prompt embeddings
p_k = torch.randn(batch_size, prompt_dim)  # Learnable background prompts

print(X_f.shape)
print(X_b.shape)
print(t_k.shape)
print(p_k.shape)

# Define ContrastivePromptLearning class
class ContrastivePromptLearning(nn.Module):
    def __init__(self, image_encoder, text_encoder, temperature=0.07):
        super().__init__()
        self.E_I = image_encoder  # Image encoder
        self.E_T = text_encoder  # Text encoder
        self.temperature = temperature

    def forward(self, X_f, X_b, t_k, p_k):
        """
        Contrastive Prompt Learning to learn background prompts
        Args:
            X_f: Foreground image (B, C, H, W)
            X_b: Background image (B, C, H, W) 
            t_k: Text prompt embeddings (B, prompt_dim)
            p_k: Learnable background prompts (B, prompt_dim)
        """
        # Flatten images before passing to linear encoder
        batch_size = X_f.shape[0]
        X_f_flat = X_f.view(batch_size, -1)  # Flatten to (B, C*H*W)
        X_b_flat = X_b.view(batch_size, -1)  # Flatten to (B, C*H*W)
        
        # Get image embeddings using CLIP image encoder
        v_f = self.E_I(X_f_flat)  # Foreground image embedding
        v_b = self.E_I(X_b_flat)  # Background image embedding
        
        # Get text embeddings using CLIP text encoder
        u_t = self.E_T(t_k)  # Text prompt embedding
        u_b = self.E_T(p_k)  # Background prompt embedding

        # Compute cosine similarities scaled by temperature
        sim_f_t = F.cosine_similarity(v_f, u_t, dim=-1) / self.temperature
        sim_b_p = F.cosine_similarity(v_b, u_b, dim=-1) / self.temperature
        sim_f_p = F.cosine_similarity(v_f, u_b, dim=-1) / self.temperature
        sim_b_t = F.cosine_similarity(v_b, u_t, dim=-1) / self.temperature

        # Compute contrastive loss
        L_contrast = -torch.log(torch.exp(sim_f_t) / (torch.exp(sim_f_t) + torch.exp(sim_f_p))) \
                    -torch.log(torch.exp(sim_b_p) / (torch.exp(sim_b_p) + torch.exp(sim_b_t)))
        
        return L_contrast.mean()


torch.Size([2, 3, 64, 64])
torch.Size([2, 3, 64, 64])
torch.Size([2, 512])
torch.Size([2, 512])


In [18]:


# Initialize model 
model = ContrastivePromptLearning(image_encoder, text_encoder)

# Forward pass
loss = model(X_f, X_b, t_k, p_k)

# Check output
assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0  # scalar
assert not torch.isnan(loss)
assert not torch.isinf(loss)

print("ContrastivePromptLearning test passed!")

ContrastivePromptLearning test passed!


In [34]:
import PIL.Image
import requests
from torchvision import transforms
from pathlib import Path

# Create dummy encoders
image_encoder = nn.Linear(64 * 64 * 3, prompt_dim)  # Changed from 3 * 64 * 64 to match flattened input size
text_encoder = nn.Linear(prompt_dim, prompt_dim)

# Load and preprocess the image
image_path = Path("background.png")
image = PIL.Image.open(image_path)
print(image.shape)

# Define image transforms
preprocess = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

# Preprocess image
img_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension

# Create dummy background image by slightly perturbing original
background = img_tensor + 0.1 * torch.randn_like(img_tensor)

# Text prompt
text_prompt = "Photo of a train"
text_embedding = torch.randn(1, prompt_dim)  # Simulate text embedding
background_prompt = torch.randn(1, prompt_dim)  # Learnable background prompt

# Forward pass with real data
loss = model(img_tensor, background, text_embedding, background_prompt)

print(f"Loss on real data: {loss.item():.4f}")

AttributeError: 'PngImageFile' object has no attribute 'shape'