<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/BYOL_(Bootstrap_Your_Own_Latent)_and_MoCo_(Momentum_Contrast).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

# Define the BYOL class
class BYOL(nn.Module):
    def __init__(self, encoder, projector_dim):
        super(BYOL, self).__init__()
        # Store the in_features before replacing the fc layer
        self.encoder_in_features = encoder.fc.in_features
        self.encoder = encoder
        self.encoder.fc = nn.Identity()  # Replace the fully connected layer with Identity
        self.projector = nn.Sequential(
            nn.Linear(self.encoder_in_features, projector_dim),  # First projection layer
            nn.ReLU(),  # ReLU activation
            nn.Linear(projector_dim, projector_dim)  # Second projection layer
        )

    def forward(self, x):
        representations = self.encoder(x)  # Extract features using the encoder
        projections = self.projector(representations)  # Project the features to a lower-dimensional space
        return projections

# Example usage with ResNet encoder
encoder = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

# Instantiate the BYOL model
byol_model = BYOL(encoder=encoder, projector_dim=128)
input_data = torch.randn(32, 3, 224, 224)  # Example input (batch_size=32, channels=3, height=224, width=224)
projections = byol_model(input_data)

# Print the shape of the projections
print(projections.shape)  # Expected shape: [32, 128]