In [1]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU is not available")

GPU is available


In [3]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Transform: Resize to 224x224 and convert to Tensor
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 for VGG-16
    transforms.ToTensor(),  # Convert to Tensor
])

# Load CIFAR-100 training dataset
cifar100_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)

# Create a DataLoader for CIFAR-100 training data
dataloader = DataLoader(cifar100_train, batch_size=512, shuffle=True, num_workers=4)

# Explanation:
# - batch_size=32: Loads 32 images per batch.
# - shuffle=True: Shuffles the dataset at every epoch to improve training.
# - num_workers=4: Uses 4 subprocesses to load the data in parallel (improves performance).

Files already downloaded and verified


In [4]:
import torch
import torch.nn as nn
from torchvision.models import vgg16

class Encoder(nn.Module):
    def __init__(self, pretrained=True):
        super(Encoder, self).__init__()
        
        # Load VGG-16 model pretrained on ImageNet
        vgg = vgg16(pretrained=pretrained)
        
        # Use only the convolutional layers (we don't need the fully connected layers)
        self.features = vgg.features  # This will contain all the conv layers
        
        # Optional: You could freeze some layers if you want to fine-tune only part of the network.
        # For example, to freeze the first few layers:
        # for param in self.features[:10].parameters():
        #     param.requires_grad = False
    
    def forward(self, x):
        """
        Forward pass to extract features from the image.
        
        Input:
        - x: The input image tensor (B, C, H, W) where:
          B is the batch size
          C is the number of channels (3 for RGB images)
          H and W are the height and width of the image.
        
        Output:
        - features: The extracted feature map (B, 512, H_out, W_out), where:
          512 is the number of output channels from the final VGG-16 layer
          H_out and W_out are the spatial dimensions of the feature map (usually 7x7 for 224x224 input).
        """
        features = self.features(x)  # Pass through the convolutional layers
        return features


In [5]:
import torch.nn as nn

class Predictor(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=4096, output_dim=512):
        super(Predictor, self).__init__()
        
        # Linear layers to predict target features from context features
        self.fc1 = nn.Linear(input_dim, hidden_dim)  # Input to hidden layer
        self.relu = nn.ReLU()                       # Non-linearity
        self.fc2 = nn.Linear(hidden_dim, output_dim) # Hidden to output layer
    
    def forward(self, x):
        # Forward pass through the predictor
        x = self.fc1(x)         # Input -> Hidden layer
        x = self.relu(x)        # Non-linearity
        x = self.fc2(x)         # Hidden layer -> Output
        return x


In [6]:
class EMA:
    def __init__(self, model, decay=0.99):
        """
        Initialize EMA with a given model and decay rate.
        
        Args:
        - model: The model to track with EMA (usually the target encoder).
        - decay: The decay rate for updating EMA (default: 0.99).
        """
        self.model = model  # The target encoder
        self.decay = decay  # EMA decay factor (typically close to 1, e.g., 0.99)
        self.shadow = {}    # Stores the moving average of the parameters
        self.backup = {}    # Temporary backup of model parameters during shadow application

    def register(self):
        """Initialize the shadow weights with the original model weights."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        """Update the shadow weights using the current model parameters with EMA."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                # EMA update: shadow[name] = decay * shadow[name] + (1 - decay) * param
                self.shadow[name] = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]

    def apply_shadow(self):
        """Replace the model parameters with the EMA weights (shadow weights)."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                # Backup current parameters
                self.backup[name] = param.data.clone()
                # Replace parameters with EMA weights
                param.data = self.shadow[name]

    def restore(self):
        """Restore the original model parameters after using the shadow weights."""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                # Restore from backup
                param.data = self.backup[name]
        # Clear the backup
        self.backup = {}


In [7]:
import torch
import torch.nn as nn
import numpy as np

# Positional Embedding Layer for the patches
class PositionalEmbedding(nn.Module):
    def __init__(self, grid_size, embed_dim):
        super(PositionalEmbedding, self).__init__()
        self.grid_size = grid_size  # E.g., 7 for a 7x7 grid (49 patches total)
        self.embed_dim = embed_dim  # Embedding dimension (e.g., 512 to match VGG-16 output)
        
        # Learnable positional embeddings
        self.positional_embeddings = nn.Parameter(torch.randn(grid_size * grid_size, embed_dim))

    def forward(self):
        return self.positional_embeddings  # Shape: (grid_size * grid_size, embed_dim)


In [8]:
def generate_blocks_without_positional_embeddings(image, context_encoder, M=4, N=49, context_ratio=0.85, device=torch.device("cpu")):
    """
    Generate context and target blocks without positional embeddings.
    """
    _, H, W = image.shape  # Assume image in [C, H, W] format, with C=3
    grid_size = int(np.sqrt(N))
    patch_h, patch_w = H // grid_size, W // grid_size

    # Step 1: Split the image into patches
    patches = []
    for i in range(grid_size):
        for j in range(grid_size):
            patch = image[:, i*patch_h:(i+1)*patch_h, j*patch_w:(j+1)*patch_w]  # (3, patch_h, patch_w)
            patches.append(patch.to(device))

    # Select context and target patches
    num_context_patches = int(context_ratio * N)
    context_indices = np.random.choice(range(N), size=num_context_patches, replace=False)
    target_indices = [i for i in range(N) if i not in context_indices]

    # Stack patches for encoder compatibility
    context_patches = torch.stack([patches[i] for i in context_indices], dim=0)  # Shape: [num_context_patches, 3, patch_h, patch_w]
    target_patches = torch.stack([patches[i] for i in target_indices[:M]], dim=0)  # Shape: [M, 3, patch_h, patch_w]

    return context_patches, target_patches


In [9]:
import torch.optim as optim

# Initialize encoders
context_encoder = Encoder().to(device)
target_encoder = Encoder().to(device)  # Target encoder uses EMA of context encoder's weights
predictor = Predictor().to(device)

# Initialize EMA
ema = EMA(target_encoder, decay=0.99)
ema.register()  # Register initial weights

# Optimizer for context encoder and predictor
optimizer = optim.Adam(list(context_encoder.parameters()) + list(predictor.parameters()), lr=0.001)



In [10]:
import torch.nn.functional as F
from tqdm import tqdm

# Training loop with global average pooling and progress tracking
for epoch in range(10):
    epoch_loss = 0.0
    print(f"Epoch [{epoch + 1}/10]")

    # Progress bar for tracking each batch
    for batch_idx, (images, _) in enumerate(tqdm(dataloader, desc="Training", leave=False)):
        images = images.to(device)
        
        # Generate context and target blocks without positional embeddings
        context_patches, target_patches = generate_blocks_without_positional_embeddings(
            images[0], context_encoder=context_encoder, M=4, N=49, device=device
        )

        # Forward pass through context encoder
        context_features_all = torch.stack([context_encoder(patch.unsqueeze(0)).squeeze(0) for patch in context_patches])
        context_features_all = F.adaptive_avg_pool2d(context_features_all, (1, 1)).squeeze(-1).squeeze(-1)  # Shape: [batch_size, 512]

        # Select only the first M context features to match the number of target features
        context_features = context_features_all[:len(target_patches)]

        # Apply shadow (EMA weights) to the target encoder for this step
        ema.apply_shadow()  # Use EMA weights for target encoder
        target_features = torch.stack([target_encoder(patch.unsqueeze(0)).squeeze(0) for patch in target_patches])
        target_features = F.adaptive_avg_pool2d(target_features, (1, 1)).squeeze(-1).squeeze(-1)  # Shape: [batch_size, 512]
        ema.restore()       # Restore original weights after prediction

        # Predict target features using predictor
        predicted_target_features = predictor(context_features)

        # Compute loss (MSE between predicted and actual target features)
        loss = nn.MSELoss()(predicted_target_features, target_features)
        epoch_loss += loss.item()

        # Backpropagation and update weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update EMA for target encoder
        ema.update()

        # Print batch loss periodically (every 10 batches)
        if batch_idx % 10 == 0:
            print(f"Epoch [{epoch + 1}/10], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}")

    # Print average loss for the epoch
    avg_epoch_loss = epoch_loss / len(dataloader)
    print(f"Epoch [{epoch + 1}/10] Average Loss: {avg_epoch_loss:.4f}")

    # Optional: Save checkpoint every few epochs
    if (epoch + 1) % 5 == 0:
        torch.save(context_encoder.state_dict(), f"context_encoder_epoch_{epoch + 1}.pth")
        torch.save(target_encoder.state_dict(), f"target_encoder_epoch_{epoch + 1}.pth")
        torch.save(predictor.state_dict(), f"predictor_epoch_{epoch + 1}.pth")


Epoch [1/10]


                                                

KeyboardInterrupt: 