In [None]:
import torch
from lightly.models import SimCLR
from lightly.data import LightlyDataset
from lightly.models.modules.heads import SimCLRProjectionHead
from torchvision import transforms
from PIL import Image

# Define a SimCLR model using a ResNet backbone
class SimCLRModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet50(pretrained=False)
        self.backbone = torch.nn.Sequential(*list(self.backbone.children())[:-1])  # Remove final layer
        self.projection_head = SimCLRProjectionHead(2048, 2048, 128)

    def forward(self, x):
        features = self.backbone(x).squeeze()
        projections = self.projection_head(features)
        return projections

# Example usage
model = SimCLRModel()

# Transform input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_simclr_embedding(image_path):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)
    with torch.no_grad():
        embedding = model(image).squeeze().numpy()
    return embedding

embedding = get_simclr_embedding('path_to_black_hole_image.jpg')