In [1]:
import torchvision
from torchvision import transforms
import torch

def create_simclr_transforms(img_size, mean, std):
    """Creates train and test transforms for SimCLR data augmentation.

    Args:
        img_size (int): Desired image size after transformations.
        mean (tuple): Mean values for normalization (e.g., CIFAR-10: (0.4914, 0.4822, 0.4465)).
        std (tuple): Standard deviation values for normalization (e.g., CIFAR-10: (0.2023, 0.1994, 0.2010)).

    Returns:
        tuple: (train_transform, test_transform)
    """

    # Augmentations for both views (consider adding more as needed)
    common_transforms = [
        transforms.RandomResizedCrop(img_size),  # Resize with aspect ratio preservation
        transforms.RandomHorizontalFlip(),      # Random horizontal flip
        transforms.RandomGrayscale(p=0.2),    # Randomly convert to grayscale
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),  # Color jitter
        transforms.RandomRotation(degrees=15),   # Random rotations
        transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),  # Random affine transformations
        transforms.GaussianBlur(kernel_size=int(0.1 * img_size), sigma=(0.01, 0.02)),  # Gaussian blur
    ]
    # Train-specific augmentations (introduce more diversity here)
    train_transforms = common_transforms + [
        transforms.RandomCrop(img_size, padding=4),
        transforms.ToTensor()  # Random cropping with padding
    ]

    # Test transforms (no data augmentation, only normalization)
    test_transforms = [
        transforms.Resize(img_size),  # Resize while maintaining aspect ratio
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]

    train_transform = transforms.Compose(train_transforms)
    test_transform = transforms.Compose(test_transforms)

    return train_transform, test_transform

# Specify image size and normalization parameters for your dataset
img_size = 32  # Example for CIFAR-10
mean = (0.4914, 0.4822, 0.4465)  # Example for CIFAR-10
std = (0.2023, 0.1994, 0.2010)  # Example for CIFAR-10

train_transform, test_transform = create_simclr_transforms(img_size, mean, std)

# Load CIFAR-10 dataset with the created transforms
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
# testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

# Create DataLoaders with appropriate parameters
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [2]:
import torch
from torchvision.models import resnet18

def create_simclr_model(pretrained=True, input_shape=(3, 32, 32)):
    """Creates a SimCLR model with a frozen ResNet-18 base and a specific projection head.

    Args:
        pretrained (bool, optional): Whether to use pretrained weights. Defaults to True.
        input_shape (tuple, optional): Input image shape (channels, height, width). Defaults to (3, 32, 32).

    Returns:
        torch.nn.Module: The SimCLR model.
    """

    model = resnet18(pretrained=pretrained)

    # Freeze all ResNet layers except the final layer
    for name, param in model.named_parameters():
        if not name.startswith("fc."):
            param.requires_grad = False

    # Define the desired projection head with ReLU non-linearities
    projection_head = torch.nn.Sequential(
        torch.nn.Linear(model.fc.in_features, 128),
        torch.nn.ReLU(inplace=True),
        torch.nn.Linear(128, 32),
        torch.nn.ReLU(inplace=True),
    )

    # Replace the final layer with the projection head
    model.fc = projection_head

    return model

In [3]:
import torch.nn.functional as F

def nt_xent_loss(z1, z2, temperature=0.5):
    """Implements the Normalized Cross-Entropy (NCE) loss for SimCLR.

    Args:
        z1 (torch.Tensor): Features of positive views (shape: batch_size, feature_dim).
        z2 (torch.Tensor): Features of negative views (shape: batch_size, feature_dim).
        temperature (float, optional): Temperature parameter for normalization (default: 0.5).

    Returns:
        torch.Tensor: Mean NCE loss.
    """

    # Normalize representations
    z1_norm = F.normalize(z1, dim=1)
    z2_norm = F.normalize(z2, dim=1)

    # Calculate pairwise similarities
    logits = torch.einsum("bi,bj->bij", z1_norm, z2_norm.T) / temperature

    # Generate ground-truth labels for positive pairs
    batch_size = z1.shape[0]
    labels = torch.arange(batch_size).long().to(z1.device)

    # Compute NCE loss and return mean
    loss = F.cross_entropy(logits, labels, reduction='mean')
    return loss

In [4]:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
model = resnet18(pretrained=True)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)  # AdamW with weight decay
scheduler = CosineAnnealingLR(optimizer, T_max=len(trainloader), eta_min=0.0) 



In [None]:
import torch
from torch.utils.data import DataLoader
device = "cpu"
model = model.to(device)

def train_simclr(model, trainloader, optimizer, scheduler, device, nt_xent_loss, epochs=100):
    """Trains a SimCLR model with the given parameters.

    Args:
        model: The SimCLR model.
        trainloader: The training data loader.
        optimizer: The optimizer (e.g., AdamW).
        scheduler: The learning rate scheduler (e.g., CosineAnnealingLR).
        device: The device to use (e.g., 'cuda' or 'cpu').
        nt_xent_loss: The NCE loss function.
        epochs: Number of training epochs (default: 100).

    Returns:
        None
    """
    device = "cpu"
    transform = create_simclr_transforms(img_size, mean, std) 
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        for i, (images, targets) in enumerate(trainloader):
            # Move data to device
            images = images.to(device)
        # for i, (images, targets) in enumerate(trainloader):
        #     print(i, images.shape, targets.shape)
        #     break  # Only print one batch


            # Data augmentation for contrastive learning
        with torch.no_grad():
            augmented_images = transform(images)

            # Concatenate original and augmented images
            inputs = torch.cat([images, augmented_images], dim=0)

            # Forward pass and embedding extraction
            embeddings = model(inputs)

            # Calculate NT-Xent loss
            pos_embeddings = embeddings[:images.size(0)]  # Positive views
            neg_embeddings = embeddings[images.size(0):]  # Negative views
            loss = nt_xent_loss(pos_embeddings, neg_embeddings)

            # Backpropagation and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update learning rate (if using scheduler)
            if scheduler is not None:
                scheduler.step()

            # Optional progress logging
            if i % 100 == 0:
                print(f"Batch {i}/{len(trainloader)} - Loss: {loss.item()}")

        # Additional epoch-level logging or evaluation (optional)

if __name__ == "__main__":
    # ... (your model, data loader, optimizer, scheduler, device, and data_augmentation setup)
    train_simclr(model, trainloader, optimizer, scheduler, device, nt_xent_loss)

In [None]:
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torchvision.datasets as datasets


# Sample images from your dataset
num_images = 5

# Data preprocessing and augmentation for visualization
transform = transforms.Compose([
    transforms.Resize(256),  # Adjust size as needed
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Adjust for your dataset
])

images = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
sample_idx = torch.randint(len(images), size=(num_images,))
sample_images = images[sample_idx]

# Define augmentations to visualize
augmentations = [
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(224),  # Adjust size as needed
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    transforms.RandomRotation(degrees=15),
    transforms.GaussianBlur(kernel_size=3, sigma=0.5),
]

# Apply augmentations and create plot
fig, axes = plt.subplots(nrows=num_images, ncols=len(augmentations) + 1, figsize=(12, 12))

for i, image in enumerate(sample_images):
    ax = axes[i, 0]
    ax.imshow(image.permute(1, 2, 0))  # Convert from CxHxW to HxWxC
    ax.set_title(f"Original Image {i + 1}")
    ax.axis('off')

    for j, aug in enumerate(augmentations):
        augmented_image = aug(image.clone())
        ax = axes[i, j + 1]
        ax.imshow(augmented_image.permute(1, 2, 0))
        ax.set_title(f"Augmentation {j + 1}")
        ax.axis('off')

plt.tight_layout()
plt.show()