<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/MoCo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy

class MoCo(nn.Module):
    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07):
        super(MoCo, self).__init__()

        # Create the query and key encoders
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = copy.deepcopy(self.encoder_q)

        # Freeze key encoder (updated via momentum)
        for param_k in self.encoder_k.parameters():
            param_k.requires_grad = False

        # MoCo parameters
        self.K = K  # Queue size
        self.m = m  # Momentum for the key encoder update
        self.T = T  # Softmax temperature

        # Create the queue (K x dim)
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)

        # Pointer for queue updates
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """Momentum update of the key encoder"""
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    def forward(self, im_q, im_k):
        """Input: im_q (query image), im_k (key image)"""
        # Compute query features
        q = self.encoder_q(im_q)  # Queries: (N, dim)
        q = F.normalize(q, dim=1)

        # Compute key features
        with torch.no_grad():
            self._momentum_update_key_encoder()  # Update key encoder
            k = self.encoder_k(im_k)  # Keys: (N, dim)
            k = F.normalize(k, dim=1)

        # Compute logits
        # Positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)

        # Negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # Logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # Apply temperature
        logits /= self.T

        # Labels: positive key is at index 0
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)

        # Update the queue
        self._dequeue_and_enqueue(k)

        return logits, labels

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """Dequeue the oldest keys and enqueue the new ones."""
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # Ensure that the batch fits the queue size

        # Replace the keys at ptr (dequeue)
        self.queue[:, ptr:ptr + batch_size] = keys.T

        # Move the pointer
        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

In [6]:
# Load CIFAR-10 for example usage
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet default input size is 224x224
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)

# Initialize the MoCo model with ResNet50 backbone
base_encoder = models.resnet50
# # no GPU override
# # model = MoCo(base_encoder=base_encoder).cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MoCo(base_encoder=base_encoder).to(device)

# Define optimizer
optimizer = optim.SGD(model.encoder_q.parameters(), lr=0.03, momentum=0.9, weight_decay=1e-4)

Files already downloaded and verified


In [None]:
# Training loop
for epoch in range(10):  # Simple 10 epochs
    for i, (images, _) in enumerate(train_loader):
        # no GPU switch to cpu
        # # Get two random augmentations of the same image
        # im_q = images.cuda()
        # im_k = images.cuda()  # In practice, you would have another view for im_k
        im_q = images.to(device)
        im_k = images.to(device)

        # Forward pass through MoCo
        logits, labels = model(im_q, im_k)

        # Cross-entropy loss
        loss = F.cross_entropy(logits, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{10}], Step [{i}/{len(train_loader)}], Loss: {loss.item():.4f}")
