<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Example_Masked_Image_Modeling_for_Vision_Foundation_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define the masking function
def mask_image(images, mask_ratio):
    batch_size, channels, height, width = images.shape
    num_pixels = height * width
    mask = torch.rand(batch_size, channels, height, width) < mask_ratio
    mask = mask.to(images.device)
    masked_images = images.clone()
    masked_images[mask] = 0
    return masked_images, mask

# Define the reconstruction loss function
def reconstruction_loss(outputs, original_images, mask):
    loss = nn.MSELoss()
    return loss(outputs * mask, original_images * mask)

class MaskedImageModeling(nn.Module):
    def __init__(self, model, mask_ratio=0.15):
        super().__init__()
        self.model = model
        self.mask_ratio = mask_ratio

    def forward(self, images):
        # Randomly mask a portion of the image
        masked_images, mask = mask_image(images, self.mask_ratio)
        outputs = self.model(masked_images)
        loss = reconstruction_loss(outputs, images, mask)  # Compare reconstructed and original
        return loss

# Example model definition (e.g., a simple CNN)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 3 * 32 * 32)  # Output size to match the image size

    def forward(self, x):
        x = nn.ReLU()(self.conv1(x))
        x = nn.MaxPool2d(2)(x)
        x = nn.ReLU()(self.conv2(x))
        x = nn.MaxPool2d(2)(x)
        x = x.view(x.size(0), -1)
        x = nn.ReLU()(self.fc1(x))
        x = self.fc2(x)
        x = x.view(x.size(0), 3, 32, 32)  # Reshape to image dimensions
        return x

# Example usage
# Create a dummy dataset
dummy_images = torch.randn(10, 3, 32, 32)  # 10 images, 3 channels, 32x32 pixels
image_dataset = TensorDataset(dummy_images)
image_dataloader = DataLoader(image_dataset, batch_size=2, shuffle=True)

# Initialize model
model = SimpleCNN()
masked_image_model = MaskedImageModeling(model)
optimizer = optim.Adam(masked_image_model.parameters(), lr=1e-4)

for images in image_dataloader:
    images = images[0]  # Extract images from the batch
    optimizer.zero_grad()
    loss = masked_image_model(images)
    loss.backward()
    optimizer.step()
    print(f"Loss: {loss.item()}")