# Signal Processing Techniques for Contrastive Learning in CV

## What is Contrastive Learning?

Contrastive learning is a self-supervised learning technique that aims to learn useful representations from unlabeled data. It has gained popularity in computer vision tasks, such as image classification and object detection, by leveraging the inherent structure and relationships within the data.

In this notebook, we will use the MNIST dataset as an example to demonstrate the training and evaluation of the contrastive networks.

## Setup

Before we begin, let's import the necessary libraries and prepare the dataset.

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

# Set random seeds for reproducibility
torch.manual_seed(42)

# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# Model Architecture
For our contrastive network, we will use a simple convolutional neural network (CNN) architecture. The CNN will consist of several convolutional layers followed by fully connected layers.

In [2]:
import torch.nn as nn

# Define the contrastive network model
class ContrastiveNet(nn.Module):
    def __init__(self):
        super(ContrastiveNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, 2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, 2)
        x = x.view(-1, 32 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Create an instance of the contrastive network
model = ContrastiveNet()

# Contrastive Loss
The contrastive loss function aims to maximize the similarity between similar pairs of data samples and minimize the similarity between dissimilar pairs. It encourages the network to learn meaningful representations that capture the underlying structure of the data.
Loss = $-\frac{1}{N} \sum_{i=1}^{N} [y \cdot \log (\hat{y}) + (1 - y) \cdot \log (1 - \hat{y})]$, where $N$ is the number of samples in the batch, $y$ is the true label (either 0 or 1), and $\hat{y}$ is the predicted probability output by the model. This formula represents the binary cross-entropy loss commonly used in binary classification tasks. It penalizes the model for incorrect predictions by comparing the predicted probability ($\hat{y}$) to the true label ($y$). The loss is averaged over all samples in the batch.

In [3]:
# Define the contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = nn.functional.pairwise_distance(anchor, positive)
        distance_negative = nn.functional.pairwise_distance(anchor, negative)
        loss = torch.mean(torch.relu(distance_positive - distance_negative + self.margin))
        return loss

# Create an instance of the contrastive loss function
criterion = ContrastiveLoss()

# Training the Contrastive Network
Now, let's train the contrastive network using the contrastive loss and evaluate its performance on the test set.

In [4]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim

# Set the optimizer and learning rate
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion.to(device)

for epoch in range(num_epochs):
    running_loss = 0.0
    
    for images, _ in train_loader:
        images = images.to(device)
        
        # Perform forward pass
        embeddings = model(images)
        
        # Generate positive and negative samples
        positive_samples = torch.flip(embeddings, dims=[0])
        negative_samples = torch.cat((embeddings[1:], embeddings[0].unsqueeze(0)), dim=0)
        
        # Compute the contrastive loss
        loss = criterion(embeddings, positive_samples, negative_samples)
        
        # Perform backward pass and update the parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # Print the average loss for each epoch
    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

Epoch [1/10], Loss: 1.0012
Epoch [2/10], Loss: 1.0003
Epoch [3/10], Loss: 0.9999
Epoch [4/10], Loss: 0.9994
Epoch [5/10], Loss: 1.0015
Epoch [6/10], Loss: 1.0008
Epoch [7/10], Loss: 1.0005
Epoch [8/10], Loss: 1.0006
Epoch [9/10], Loss: 1.0003
Epoch [10/10], Loss: 0.9992


## Evolution from Contrastive to Triplet Loss
The evolution from contrastive loss to triplet loss was driven by the desire to address some limitations of the contrastive loss in learning good feature representations. While contrastive loss worked well for pair-wise comparisons, it had difficulties in situations where the number of negative samples was much larger than the number of positive samples, leading to slow convergence and inefficient training. To address these challenges, the triplet loss was introduced.

The triplet loss is a type of metric learning loss that aims to learn feature embeddings such that similar samples are closer together, while dissimilar samples are pushed further apart in the embedding space. It is particularly useful in tasks like face recognition, where the goal is to distinguish between different individuals.

The key idea behind triplet loss is to form triplets of samples, consisting of an anchor sample, a positive sample (from the same class as the anchor), and a negative sample (from a different class). The loss is then computed based on the distances between these samples in the embedding space.

Loss = $\max (\text{dist}(\text{anchor}, \text{positive}) - \text{dist}(\text{anchor}, \text{negative}) + \text{margin}, 0)$

In [14]:
class TripletLoss(nn.Module):
    def __init__(self, margin = 1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = torch.norm(anchor - positive, p=2, dim=1)
        distance_negative = torch.norm(anchor - negative, p=2, dim=1)
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

The following implements a sample custom dataset for Triplet networks.

In [15]:
from torch.utils.data import DataLoader, Dataset

class MNISTTripletDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, index):
        anchor, anchor_label = self.dataset[index]
        
        # Find positive sample with the same label as anchor
        positive_index = torch.randint(high=len(self.dataset), size=(1,)).item()
        while self.dataset[positive_index][1] != anchor_label:
            positive_index = torch.randint(high=len(self.dataset), size=(1,)).item()
        positive, _ = self.dataset[positive_index]

        # Find negative sample with different label from anchor
        negative_index = torch.randint(high=len(self.dataset), size=(1,)).item()
        while self.dataset[negative_index][1] == anchor_label:
            negative_index = torch.randint(high=len(self.dataset), size=(1,)).item()
        negative, _ = self.dataset[negative_index]

        return anchor, positive, negative

    def __len__(self):
        return len(self.dataset)

Define the triplet network:

In [17]:
class TripletNetwork(nn.Module):
    def __init__(self):
        super(TripletNetwork, self).__init__()
        self.embedding = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        return self.embedding(x)

Start training!

In [None]:
from torchvision import datasets, transforms

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

# Create the triplet dataset
train_triplet_dataset = MNISTTripletDataset(train_dataset)
test_triplet_dataset = MNISTTripletDataset(test_dataset)

# Create data loaders
batch_size = 64
train_loader = DataLoader(train_triplet_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_triplet_dataset, batch_size=batch_size, shuffle=False)

# Initialize the network and the loss function
model = TripletNetwork()
criterion = TripletLoss()

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion.to(device)

# Set optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for anchor, positive, negative in train_loader:
        anchor = anchor.view(anchor.size(0), -1).to(device)
        positive = positive.view(positive.size(0), -1).to(device)
        negative = negative.view(negative.size(0), -1).to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        anchor_embedding = model(anchor)
        positive_embedding = model(positive)
        negative_embedding = model(negative)

        # Compute the loss
        loss = criterion(anchor_embedding, positive_embedding, negative_embedding)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# Evaluation loop
model.eval()
test_loss = 0.0
with torch.no_grad():
    for anchor, positive, negative in test_loader:
        anchor = anchor.view(anchor.size(0), -1).to(device)
        positive = positive.view(positive.size(0), -1).to(device)
        negative = negative.view(negative.size(0), -1).to(device)

        anchor_embedding = model(anchor)
        positive_embedding = model(positive)
        negative_embedding = model(negative)

        loss = criterion(anchor_embedding, positive_embedding, negative_embedding)
        test_loss += loss.item()

test_loss /= len(test_loader)
print(f"Test Loss: {test_loss:.4f}")

Epoch [1/10], Loss: 0.1295
Epoch [2/10], Loss: 0.0504
Epoch [3/10], Loss: 0.0364
Epoch [4/10], Loss: 0.0273
Epoch [5/10], Loss: 0.0276
Epoch [6/10], Loss: 0.0233
Epoch [7/10], Loss: 0.0211
Epoch [8/10], Loss: 0.0194
Epoch [9/10], Loss: 0.0174


## Applications of Contrastive Learning
Contrastive has been successful in various computer vision tasks, including the following:

1. Limited labeled data: Contrastive learning can be effective when there is a scarcity of labeled data. By leveraging large amounts of unlabeled data, contrastive learning can learn useful representations that generalize well to downstream tasks.

2. Similarity-based tasks: Contrastive learning is particularly suitable for tasks that involve measuring similarity or dissimilarity between examples. It excels in tasks like image retrieval, clustering, and nearest neighbor search, where the goal is to find similar instances.

3. Unsupervised feature learning: When there is no explicit supervision available, contrastive learning can be used to learn meaningful representations from raw data. It enables the model to capture underlying structures, patterns, and semantics in the data without requiring explicit labels.

4. Data augmentation: Contrastive learning can be combined with data augmentation techniques to increase the diversity of positive and negative examples. By augmenting the data with various transformations, the model is exposed to a wider range of variations, making the learned representations more robust.

5. Transfer learning: Contrastive learning can serve as a pretraining step for transfer learning. By pretraining on a large unlabeled dataset using contrastive learning, the model can learn general features that can be fine-tuned on specific downstream tasks with limited labeled data.

In this notebook, I want to dive into some signal processing techniques for similarity-based tasks specifically.

## Feature Extraction

Speaking of feature extraction, several smoothing techniques comes to mind. Fourier transform, rolling window - these are all ways to filter out the noises and keep the peaks in signals that we want.
1. Fourier Transform: Convert the input signals from the time domain to the frequency domain using Fourier transform. This helps in analyzing the frequency components present in the signal.
2. Wavelet Transform: Decompose the signal into different frequency components using wavelet transform. This provides a multi-resolution representation of the signal.
3. Spectrogram: Compute the spectrogram of the signal by applying the short-time Fourier transform (STFT). This represents the signal's frequency content over time.

In [None]:
import numpy as np
from scipy.fft import fft
from scipy.signal import spectrogram
import pywt

# Example signal
signal = np.array([1, 2, 3, 4, 5])

# Fourier Transform
fourier_transform = fft(signal)
print("Fourier Transform:", fourier_transform)

# Wavelet Transform (using Haar wavelet)
wavelet_transform = pywt.dwt(signal, "haar")
print("Wavelet Transform:", wavelet_transform)

# Spectrogram
frequencies, times, spectrogram_values = spectrogram(signal)
print("Spectrogram:")
print("Frequencies:", frequencies)
print("Times:", times)
print("Spectrogram Values:")
print(spectrogram_values)

## Maximizing Information Entropy of the Peak Signals Using Contrastive Learning

Contrastive learning aims to maximize the information entropy of the peak signals by leveraging the similarity between positive samples and maximizing the dissimilarity between negative samples.

In contrastive learning, peak signals are identified based on their distinct features or high response values. These peaks represent important or informative patterns in the data. Contrastive learning encourages the positive samples (peaks) to be similar to each other. By maximizing the similarity between positive samples, the network learns to capture the common features or patterns present in the peaks. For instance, in AFM Force-Z curve classification, a single-rupture event is described by one such peak, while other data lack such characteristic.

In fact, in AFM image recognition tasks, after baseline correction and flattening, to ensure the contrastive networks understood the characteristic of the "peak," we choose to feed in a second channel other than the original function - the derivative of the si.