In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        
        # First convolution layer: 7x7 kernel, stride 5, zero padding, 12 channels
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=7, stride=5, padding=0)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Second convolution layer: 3x3 kernel, stride 1, 24 channels
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1)
        
        # Third convolution layer: 3x3 kernel, stride 2, 32 channels
        self.conv3 = nn.Conv2d(in_channels=24, out_channels=32, kernel_size=3, stride=2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Fourth convolution layer: 3x3 kernel, stride 2, 64 channels
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2)
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=3)
        
        # Adaptive pooling or Flatten for embedding output
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Flatten()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.pool2(x)
        
        x = F.relu(self.conv4(x))
        x = self.pool3(x)
        
        x = self.global_pool(x)
        x = self.fc(x)
        return x

In [4]:
def triplet_loss(anchor, positive, negative, margin=1.0):
    """
    Compute the triplet loss:
    Loss = max(||anchor - positive||^2 - ||anchor - negative||^2 + margin, 0)
    """
    positive_dist = F.pairwise_distance(anchor, positive, p=2)
    negative_dist = F.pairwise_distance(anchor, negative, p=2)
    loss = torch.relu(positive_dist - negative_dist + margin)
    return loss.mean()

In [None]:
if __name__ == "__main__":
    model = CustomCNN()
    
    # Dummy input: Anchor, Positive, Negative
    anchor = torch.randn(1, 3, 320, 320)  # Batch size 1, 3 channels, 320x320
    positive = torch.randn(1, 3, 320, 320)
    negative = torch.randn(1, 3, 320, 320)
    
    # Forward pass through the model
    anchor_embedding = model(anchor)
    positive_embedding = model(positive)
    negative_embedding = model(negative)
    
    # Compute triplet loss
    loss = triplet_loss(anchor_embedding, positive_embedding, negative_embedding)
    print(f"Triplet Loss: {loss.item()}")

Triplet Loss: 1.0008314847946167
