In [1]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

# Define the SimCLR Model
class SimCLRModel(nn.Module):
    def __init__(self, base_model='resnet50', out_dim=128):
        super(SimCLRModel, self).__init__()
        
        # Load a pre-trained ResNet (or other model)
        self.encoder = getattr(models, base_model)(pretrained=True)
        
        # Remove the final fully connected layer and add a new projection head
        in_features = self.encoder.fc.in_features
        self.encoder.fc = nn.Identity()  # Remove the final fully connected layer
        
        # Projection head (typically a small MLP)
        self.projection_head = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Linear(512, out_dim)
        )

    def forward(self, x):
        # Step 1: Encode the input images using the base CNN (e.g., ResNet)
        h = self.encoder(x)  # Get feature representation from encoder
        
        # Step 2: Project the encoded features to a lower dimension using the projection head
        z = self.projection_head(h)
        
        # Return both embeddings and projected embeddings
        return h, F.normalize(z, dim=-1)

# Function to load and preprocess the image
def load_image(image_path, transform):
    # Load the image with PIL
    image = Image.open(image_path).convert("RGB")  # Convert to RGB (if grayscale or PNG with transparency)
    return transform(image)

# Define the image preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to the input size expected by ResNet (224x224)
    transforms.ToTensor(),  # Convert the image to a PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize using ImageNet mean and std
])


In [2]:
# Load and preprocess a PNG image
image_path = '../images/3597_blur_avg.png'  # Path to the PNG image
image_tensor = load_image(image_path, transform)
image_tensor = image_tensor.unsqueeze(0)  # Add a batch dimension (batch size = 1)

# Initialize the SimCLR model
model = SimCLRModel(base_model='resnet50', out_dim=128)

# Get embeddings from the image
with torch.no_grad():  # Disable gradient calculation for inference
    embeddings, _ = model(image_tensor)

# Print the shape of the embeddings
print("Embeddings shape:", embeddings.shape)  # Should print [1, 2048] for ResNet-50

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /Users/anthonyhsu/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:07<00:00, 14.1MB/s]


Embeddings shape: torch.Size([1, 2048])


In [3]:
# Load and preprocess a PNG image
image_path = '../images/dog.jpg'  # Path to the PNG image
image_tensor = load_image(image_path, transform)
image_tensor = image_tensor.unsqueeze(0)  # Add a batch dimension (batch size = 1)

# Initialize the SimCLR model
model = SimCLRModel(base_model='resnet50', out_dim=128)

# Get embeddings from the image
with torch.no_grad():  # Disable gradient calculation for inference
    embeddings2, _ = model(image_tensor)

# Print the shape of the embeddings
print("Embeddings shape:", embeddings2.shape)  # Should print [1, 2048] for ResNet-50

Embeddings shape: torch.Size([1, 2048])


In [5]:
embeddings == embeddings2

tensor([[False, False, False,  ..., False, False, False]])