1. EDA

2. Define utils of dataset

In [18]:
def pair_comparison(a, b):
    """
    This function compares two values a and b.
    If they are equal, it returns 1.
    If they are not equal, it returns 0.
    """
    if a == b:
        # If a is equal to b, return 1
        return 1
    else:
        # If a is not equal to b, return 0
        return 0


In [19]:
def gram_matrix(list_of_score):
    """
    This function computes the Gram matrix for a list of scores.
    
    The Gram matrix is a matrix of pairwise comparisons of scores.
    Each element [i][j] in the matrix represents the result of
    comparing list_of_score[i] with list_of_score[j] using the
    pair_comparison function.
    
    Args:
    - list_of_score: A list of scores
    
    Returns:
    - gram_matrix: The Gram matrix computed from the pairwise comparisons
    """
    # Get the length of the list of scores
    n = len(list_of_score)
    
    # Initialize the Gram matrix with zeros
    gram_matrix = [[0 for _ in range(n)] for _ in range(n)]

    # Iterate through each pair of scores
    for i in range(n):
        for j in range(n):
            # Compute the pairwise comparison using the pair_comparison function
            gram_matrix[i][j] = pair_comparison(list_of_score[i], list_of_score[j])
    
    # Return the computed Gram matrix
    return gram_matrix


In [20]:
# Test
list_of_score = [1,2,3,2,3,4,2,1,4]
gram_matrix(list_of_score)

[[1, 0, 0, 0, 0, 0, 0, 1, 0],
 [0, 1, 0, 1, 0, 0, 1, 0, 0],
 [0, 0, 1, 0, 1, 0, 0, 0, 0],
 [0, 1, 0, 1, 0, 0, 1, 0, 0],
 [0, 0, 1, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 1],
 [0, 1, 0, 1, 0, 0, 1, 0, 0],
 [1, 0, 0, 0, 0, 0, 0, 1, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 1]]

In [21]:
# Define dataset

2. Model

In [22]:
from torch import nn
import torch
import torchvision.models as models
import timm

class Setting_2_model(nn.Module):
    def __init__(self, model_name: str, embed_dim: int):
        """
        A custom model for Setting 2, which uses different pre-trained models
        based on the specified `model_name`.

        Args:
        - model_name: Name of the pre-trained model to be used
        - embed_dim: Dimension of the output embeddings
        """
        super(Setting_2_model, self).__init__()

        # Load the specified pre-trained model
        if model_name.startswith('resnet'):
            if model_name == 'resnet50':
                self.model = models.resnet50(pretrained=True)
            elif model_name == 'resnet101':
                self.model = models.resnet101(pretrained=True)
            elif model_name == 'resnet152':
                self.model = models.resnet152(pretrained=True)
            else:
                raise ValueError(f"Unsupported ResNet model: {model_name}")
                
            num_features = self.model.fc.in_features
            self.model.fc = nn.Linear(num_features, embed_dim)
        
        elif model_name.startswith('densenet'):
            if model_name == 'densenet121':
                self.model = models.densenet121(pretrained=True)
            else:
                raise ValueError(f"Unsupported DenseNet model: {model_name}")
                
            num_features = self.model.classifier.in_features
            self.model.classifier = nn.Linear(num_features, embed_dim)
        
        elif model_name.startswith('vit'):
            self.model = timm.create_model(model_name, pretrained=True)

            num_features = self.model.head.in_features
            self.model.head = nn.Linear(num_features, embed_dim)
        
        else:
            raise ValueError(f"Unsupported model: {model_name}")
    
    def forward(self, images):
        """
        Forward pass of the model.

        Args:
        - images: A list of input images

        Returns:
        - gram_matrix: The Gram matrix computed from the embeddings
        """
        embeddings = []
        for image in images:
            image_embedding = self.model(image)
            embeddings.append(image_embedding)
        embeddings_tensor = torch.cat(embeddings, dim=0)
        
        # Normalize the embeddings
        embeddings_normalized = torch.nn.functional.normalize(embeddings_tensor, p=2, dim=1)

        # Compute the Gram matrix
        gram_matrix = torch.matmul(embeddings_normalized, embeddings_normalized.transpose(0, 1))
        return gram_matrix


3. Loss

In [23]:
from torch import nn

# Hàm tính Constrastive Learning Loss
class ConstrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        """
        Contrastive Loss function for computing the loss between predicted
        and ground truth Gram matrices.

        Args:
        - margin: Margin value for the loss calculation
        """
        super(ConstrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, gram_matrix_predicted, gram_matrix_ground_truth):
        """
        Forward pass of the Contrastive Loss function.

        Args:
        - gram_matrix_predicted: Predicted Gram matrix
        - gram_matrix_ground_truth: Ground truth Gram matrix

        Returns:
        - loss: Contrastive Learning Loss
        """
        # Tính khoảng cách Frobenius giữa hai ma trận Gram
        distance = torch.norm(gram_matrix_predicted - gram_matrix_ground_truth, p='fro')
        
        # Tính Constrastive Learning Loss
        loss = torch.clamp(distance - self.margin, min=0.0)
        return loss


In [24]:
# test
import torch

gt = [1,2,3,2,3,4,2,1,4]
gt = torch.tensor(gram_matrix(gt),dtype=torch.float)

pred = [1,2,3,2,3,4,2,2,4]
pred = torch.tensor(gram_matrix(pred), dtype=torch.float)

criterion = ConstrastiveLoss()
loss = criterion(pred, gt)

print(loss)

tensor(1.8284)
