<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/SimCLR_and_BERT_for_Vision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

# Define the SimCLR class
class SimCLR(nn.Module):
    def __init__(self, base_model, projection_dim):
        super(SimCLR, self).__init__()
        # Store the in_features before replacing the fc layer
        self.encoder_in_features = base_model.fc.in_features
        self.base_model = base_model
        self.base_model.fc = nn.Identity()  # Replace the fully connected layer with Identity
        self.projector = nn.Sequential(
            nn.Linear(self.encoder_in_features, 2048),  # First projection layer
            nn.ReLU(),  # ReLU activation
            nn.Linear(2048, projection_dim)  # Second projection layer
        )

    def forward(self, x):
        h = self.base_model(x)  # Extract features using the base model
        z = self.projector(h)  # Project the features to a lower-dimensional space
        return z

# Example usage with ResNet as the base model
base_model = models.resnet18(pretrained=False)

# Instantiate the SimCLR model
simclr_model = SimCLR(base_model, projection_dim=128)
input_data = torch.randn(32, 3, 224, 224)  # Example input (batch_size=32, channels=3, height=224, width=224)
projections = simclr_model(input_data)

# Print the shape of the projections
print(projections.shape)  # Expected shape: [32, 128]