In [7]:
def pair_comparison(a, b):
    """
    This function compares two values a and b.
    If they are greater, it returns 1.
    If they are less, it return 0
    If they are not equal, it returns 2.
    """
    return (a>b)* 1 if a!=b else 2


In [8]:
import random
def gram_matrix(list_of_score):
    """
    This function computes the Gram matrix for a list of scores.
    
    The Gram matrix is a matrix of pairwise comparisons of scores.
    Each element [i][j] in the matrix represents the result of
    comparing list_of_score[i] with list_of_score[j] using the
    pair_comparison function.
    
    Args:
    - list_of_score: A list of scores
    
    Returns:
    - gram_matrix: The Gram matrix computed from the pairwise comparisons
    """
    # Get the length of the list of scores
    n = len(list_of_score)
    
    # Initialize the Gram matrix with zeros
    gram_matrix = [[0 for _ in range(n)] for _ in range(n)]

    # Iterate through each pair of scores
    for i in range(n):
        for j in range(n):
            # Compute the pairwise comparison using the pair_comparison function
            gram_matrix[i][j] = pair_comparison(list_of_score[i], list_of_score[j])
    
    # Return the computed Gram matrix
    return gram_matrix


In [9]:
import random
from collections import defaultdict
from itertools import cycle, islice

def split_data_balanced_randomly(data, labels, num_groups, k):
    """
    Split the data and labels randomly into a specified number of groups with balanced label distribution.
    Each group will contain exactly k samples, with items possibly repeated to ensure balanced distribution.

    Args:
    - data: List of data elements
    - labels: List of corresponding labels
    - num_groups: Number of groups to split the data into
    - k: Number of samples in each group

    Returns:
    - image_groups: List of groups containing data elements
    - label_groups: List of groups containing corresponding labels
    """
    # Ensure data and labels have the same length
    assert len(data) == len(labels), "Data and labels must have the same length."
    
    # Group data by labels
    label_to_data = defaultdict(list)
    for item, label in zip(data, labels):
        label_to_data[label].append(item)
    
    # Prepare the result lists
    image_groups = [[] for _ in range(num_groups)]
    label_groups = [[] for _ in range(num_groups)]
    
    # Distribute the data into the groups with repeats allowed
    for label, items in label_to_data.items():
        random.shuffle(items)  # Shuffle items within each label group
        item_cycle = iter(items)
        for group_index in range(num_groups):
            for _ in range(k // len(label_to_data)):
                try:
                    item = next(item_cycle)
                except StopIteration:
                    # If we run out of items for a label, shuffle and start again
                    random.shuffle(items)
                    item_cycle = iter(items)
                    item = next(item_cycle)
                image_groups[group_index].append(item)
                label_groups[group_index].append(label)

    return image_groups, label_groups

def split_data_randomly(data, labels, num_groups, k):
    """
    Split the data and labels randomly into a specified number of groups with each group containing exactly k samples.
    Items can be repeated within and across groups to ensure balanced distribution.

    Args:
    - data: List of data elements
    - labels: List of corresponding labels
    - num_groups: Number of groups to split the data into
    - k: Number of samples in each group

    Returns:
    - image_groups: List of groups containing data elements
    - label_groups: List of groups containing corresponding labels
    """
    # Ensure data and labels have the same length
    assert len(data) == len(labels), "Data and labels must have the same length."
    
    # Shuffle data and labels together
    combined_data = list(zip(data, labels))
    random.shuffle(combined_data)
    
    # Prepare the result lists
    image_groups = [[] for _ in range(num_groups)]
    label_groups = [[] for _ in range(num_groups)]
    
    # Fill each group with k samples, allowing repetition
    for i in range(num_groups):
        for j in range(k):
            item = combined_data[random.randint(0, len(combined_data) - 1)]
            image_groups[i].append(item[0])
            label_groups[i].append(item[1])
    
    return image_groups, label_groups

In [10]:
from torch.utils.data import Dataset
import pandas as pd
import cv2
import os
import torch
from PIL import Image
import pydicom
import random

from torchvision import transforms
from typing import Tuple

mean = [0.6821, 0.4575, 0.2626]
std  = [0.1324, 0.1306, 0.1022]
data_transforms = {
    'training': transforms.Compose([
        transforms.Resize((450,200)),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomApply(torch.nn.ModuleList([transforms.ColorJitter(),]),p=0.3),
        transforms.RandomApply(torch.nn.ModuleList([transforms.GaussianBlur(kernel_size=3),]),p=0.3),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    'valid': transforms.Compose([
        transforms.Resize((450,200)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    'test': transforms.Compose([
        transforms.Resize((450,200)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
}

class SeveritySimilarityDataset(Dataset):
    def __init__(self, annotation_file_path: str, dataset_dir: str, phase: str = "training", num_per_cluster: int = 5, num_group: int = 10000, input_size: Tuple[int] = (224, 224), transforms=None) -> None:
        """
        Dataset class for severity similarity task.

        Args:
        - annotation_df: DataFrame containing annotations
        - dataset_dir: Directory containing image data
        - phase: Phase of the dataset (e.g., "training", "validation", "testing")
        - num_per_cluster: Number of images per cluster
        """
        super(SeveritySimilarityDataset, self).__init__()
        self.dataset_dir = dataset_dir
        annotation_df = pd.read_csv(annotation_file_path)
        # Filter data based on the specified phase
        data = annotation_df[annotation_df["split"] == phase]
        self.transforms = data_transforms[phase] if transforms == None else transforms

        # Concatenate study_id and image_id to get image paths
        image_paths_df = data["study_id"] + "/" + data["image_id"] +".png"
        self.image_paths = image_paths_df.tolist()
        self.num_per_cluster = num_per_cluster
        # Get labels
        labels_df = data["breast_birads"]
        self.labels = [int(s[-1]) for s in labels_df.to_list()]
        
        # Split data into clusters
        if phase != "test":
            self.image_cluster_list, self.label_cluster_list = split_data_balanced_randomly(self.image_paths, self.labels, num_group, self.num_per_cluster)
        else:
            self.image_cluster_list, self.label_cluster_list = split_data_randomly(self.image_paths, self.labels, num_group, self.num_per_cluster)

        self.input_size = input_size

    def __len__(self):
        """
        Returns the number of clusters in the dataset.
        """
        return len(self.label_cluster_list)
    
    def __getitem__(self, index):
        """
        Retrieves a cluster of images and its corresponding label cluster.

        Args:
        - index: Index of the cluster to retrieve

        Returns:
        - images: List of images in the cluster
        - gram_matrix: Gram matrix computed from the label cluster
        """
        image_cluster = self.image_cluster_list[index]
        label_cluster = self.label_cluster_list[index]
        images = []
        for image_path in image_cluster:
            abs_image_path = os.path.join(self.dataset_dir, image_path)
            # Read and preprocess image
            image =  self._read_image(os.path.join(self.dataset_dir,image_path), self.input_size) # Transpose image tensor
            images.append(image)
        # # Compute Gram matrix
        gram_matrix_ = gram_matrix(label_cluster)
        ref_images = self.get_ref()
        return  images, ref_images, torch.tensor(gram_matrix_).to(torch.float), torch.tensor(label_cluster).to(torch.float)
    
    def get_ref(self):
        ref_idxs = [i for i in range(len(self.labels)) if self.labels[i] == 1]
        random_ref_id = random.choice(ref_idxs)
        ref_image_path = self.image_paths[random_ref_id]
        ref_image =  self._read_image(os.path.join(self.dataset_dir,ref_image_path), self.input_size)
        ref_images = []
        for i in range (self.num_per_cluster):
            ref_images.append(ref_image)
        return ref_images
    
    def _read_image(self, filepath, new_size):
        image_pil = Image.open(filepath)
        
        # Kiểm tra chế độ của ảnh
        if image_pil.mode != 'L':
            image_pil = image_pil.convert('L')  # Chuyển đổi sang chế độ 'L' (grayscale) nếu cần thiết
        
        # Tạo ảnh RGB từ ảnh đơn kênh bằng cách sao chép giá trị của kênh đó vào cả ba kênh
        image_pil = Image.merge('RGB', (image_pil, image_pil, image_pil))
        
        # Resize ảnh
        
        resized_image = self.transforms(image_pil)
        resized_image = resized_image.to(torch.float)
        
        return resized_image


In [11]:
from torch import nn
import torch
import torchvision.models as models
import timm

class Setting_3_model(nn.Module):
    def __init__(self, model_name: str, embed_dim: int):
        """
        A custom model for Setting 2, which uses different pre-trained models
        based on the specified `model_name`.

        Args:
        - model_name: Name of the pre-trained model to be used
        - embed_dim: Dimension of the output embeddings
        """
        super(Setting_3_model, self).__init__()

        # Load the specified pre-trained model
        if model_name.startswith('resnet'):
            if model_name == 'resnet50':
                self.model = models.resnet50(pretrained=True)
            elif model_name == 'resnet101':
                self.model = models.resnet101(pretrained=True)
            elif model_name == 'resnet152':
                self.model = models.resnet152(pretrained=True)
            else:
                raise ValueError(f"Unsupported ResNet model: {model_name}")
                
            num_features = self.model.fc.in_features
            self.model.fc = nn.Linear(num_features, embed_dim)
        
        elif model_name.startswith('densenet'):
            if model_name == 'densenet121':
                self.model = models.densenet121(pretrained=True)
            else:
                raise ValueError(f"Unsupported DenseNet model: {model_name}")
                
            num_features = self.model.classifier.in_features
            self.model.classifier = nn.Linear(num_features, embed_dim)
        
        elif model_name.startswith('vit'):
            self.model = timm.create_model(model_name, pretrained=True)

            num_features = self.model.head.in_features
            self.model.head = nn.Linear(num_features, embed_dim)
        
        else:
            raise ValueError(f"Unsupported model: {model_name}")
    
    def forward(self, images, ref_images, device):
        """
        Forward pass of the model.

        Args:
        - images: A list of input images

        Returns:
        - gram_matrix: The Gram matrix computed from the embeddings
        """
        embeddings = []
        ref_embedding = []
        # Iterate over the list of input images
        for image, ref_image in zip(images, ref_images):
            # Pass the image through the pre-trained model
            image = image.to(device)
            image_embedding = self.model(image)
            ref_image = ref_image.to(device)
            ref_image_embedding = self.model(ref_image)
            # Append the embedding to the list
            embeddings.append(image_embedding)
            ref_embedding.append(ref_image_embedding)
        # Stack the embeddings along a new dimension
        embeddings_tensor = torch.stack(embeddings, dim=1)
        ref_embedding_tensor = torch.stack(ref_embedding, dim=1)
        
        return embeddings_tensor.to(device), ref_embedding_tensor.to(device)



In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PreferenceComparisonLoss(nn.Module):
    def __init__(self, margin=2.0):
        """
        Preference Comparison Loss function for computing the loss between predicted
        and ground truth Gram matrices.

        Args:
        - margin: Margin value for the loss calculation
        """
        super(PreferenceComparisonLoss, self).__init__()
        self.margin = margin

    def forward(self, output, label, ref):
        """
        Forward pass of the Preference Comparison Loss function.

        Args:
        - output: Predicted matrix from the model, shape (batch_size, n, embedding_dim)
        - ref: Reference matrix (ground truth), shape (batch_size, n, embedding_dim)
        - label: Ground truth labels, shape (batch_size, n, n)

        Returns:
        - loss: Preference Comparison Loss
        """
        # Calculate cosine similarity between output and ref matrices
        cosine_similarities = F.cosine_similarity(output, ref, dim=2)
        # Initialize loss variable
        loss_contrastive = 0
        count_pairs = 0
        # device = torch.device("cuda" if  torch.cuda.is_available() else "cpu")
        device = torch.device("cpu")

        # Loop over each pair of images in the batch
        for i in range(output.size(0)):
            for j in range(output.size(1)):
                for k in range(output.size(1)):
                    if j != k:
                        # Get cosine similarity distances
                        cosine_distanceA = cosine_similarities[i, j]
                        cosine_distanceB = cosine_similarities[i, k]

                        # Get label for the pair (j, k)
                        label_value = label[i, j, k].item()
                        # print(cosine_distanceA, type(cosine_distanceA), cosine_distanceB, type(cosine_distanceB), label_value, type(label_value))
                        # Calculate loss based on label
                        # if label_value < 2:
                        #     loss_contrastive += torch.pow(torch.clamp(torch.abs(cosine_distanceA - cosine_distanceB) + self.margin, min=0.0), 2)
                        #     count_pairs += 1
                        # else:
                        #     # Same difference
                        #     loss_contrastive +=
                        #  torch.pow(torch.clamp(self.margin - torch.abs(cosine_distanceA - cosine_distanceB), min=0.0), 2)
                        #     count_pairs += 1
                        if label_value < 2:
                            loss_contrastive += torch.nn.BCELoss()(torch.nn.Sigmoid()(cosine_distanceA - cosine_distanceB), torch.tensor(label_value, dtype=torch.float32).to(device))
                        else:
                            loss_contrastive += torch.pow(torch.clamp(self.margin - torch.abs(cosine_distanceA - cosine_distanceB), min=0.0), 2)
                        count_pairs += 1
        # Check if there are valid pairs to compute loss
        if count_pairs > 0:
            # Average loss across the valid pairs
            loss_contrastive /= count_pairs
        else:
            # If there are no valid pairs, return zero loss
            loss_contrastive = torch.tensor(0.0, requires_grad=True, device=output.device)

        return loss_contrastive


In [13]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from torch.optim import lr_scheduler

def train_model(model, train_dataset, val_dataset, checkpoint_folder, num_epochs=10, batch_size=32, learning_rate=0.001):
    """
    Train the model using the provided datasets.

    Args:
    - model: The model to be trained
    - train_dataset: Dataset for training
    - val_dataset: Dataset for validation
    - checkpoint_folder: Folder to store checkpoints
    - num_epochs: Number of epochs for training
    - batch_size: Batch size for training
    - learning_rate: Learning rate for optimization

    Returns:
    - model: Trained model
    - train_losses: List of training losses
    - val_losses: List of validation losses
    """
    # Create the checkpoint folder if it doesn't exist
    if not os.path.exists(checkpoint_folder):
        os.makedirs(checkpoint_folder)
    # device = torch.device("cuda" if  torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    print(f"Device: {device}")
    # Define data loaders for training and validation
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Define loss function and optimizer
    criterion = PreferenceComparisonLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.0)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    # Lists to store training and validation losses
    train_losses = []
    val_losses = []

    # Variables to keep track of the best model and its performance
    best_val_loss = float('inf')
    best_model_state = None

    model = model.to(device)
    print("Training started...")
    for epoch in range(num_epochs):
        print("*"*100)
        print(f"Epoch [{epoch+1}/{num_epochs}]:")
        model.train()
        running_train_loss = 0.0
        for i, (images, ref_images, gram_matrix, _) in enumerate(train_loader):
            optimizer.zero_grad()
            # Forward pass
            gram_matrix = gram_matrix.to(device)
            output, ref_output = model(images, ref_images, device)
            # Compute loss
            loss = criterion(output, gram_matrix,ref_output)
            # Backward pass
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()

            if i % 500 == 0:
                print(f"\t Batch [{i}/{len(train_loader)}], Train Loss: {loss.item():.4f}")
        
        # Compute average training loss for the epoch
        epoch_train_loss = running_train_loss / len(train_loader)
        train_losses.append(epoch_train_loss)

        # Validation loop
        model.eval()
        running_val_loss = 0.0
        with torch.no_grad():
            for i, (images, ref_images, gram_matrix, _) in enumerate(val_loader):
                gram_matrix = gram_matrix.to(device)
                output, ref_output = model(images, ref_images, device)
                loss = criterion(output, gram_matrix,ref_output)
                running_val_loss += loss.item()

                if i % 50 == 0:
                    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Batch [{i}/{len(val_loader)}], Val Loss: {loss.item():.4f}")
        
        # Compute average validation loss for the epoch
        epoch_val_loss = running_val_loss / len(val_loader)
        val_losses.append(epoch_val_loss)

        # Save the model checkpoint for every epoch (last model)
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': epoch_val_loss
        }, os.path.join(checkpoint_folder, 'last.pt'))

        # Save the best model checkpoint based on validation loss
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_model_state = model.state_dict()
            torch.save({
                'epoch': epoch,
                'model_state_dict': best_model_state,
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': best_val_loss
            }, os.path.join(checkpoint_folder, 'best.pt'))

        # Print progress
        print(f"Validation, Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}")
        print("*"*100)
        scheduler.step()
    print("Training completed.")

    return model, train_losses, val_losses


def test_model(model, test_dataset, batch_size=32):
    """
    Evaluate the model on the test dataset.

    Args:
    - model: The trained model to be evaluated
    - test_dataset: Dataset for testing
    - batch_size: Batch size for testing

    Returns:
    - test_loss: Test loss
    """
    # Define data loader for testing
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Define loss function
    criterion = nn.MSELoss()

    # Set model to evaluation mode
    model.eval()

    # Initialize variables for computing test loss
    running_test_loss = 0.0
    num_samples = 0

    print("Testing started...")
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_test_loss += loss.item() * images.size(0)
            num_samples += images.size(0)

    # Compute test loss
    test_loss = running_test_loss / num_samples

    print(f"Test Loss: {test_loss:.4f}")
    print("Testing completed.")

    return test_loss

In [14]:
config = {
    "annotation_data_path": "../csv/split_data.csv",
    "image_folder_path": "/media/jackson/Data/archive/Processed_Images",
    "model_encoder": "resnet50",
    "embedding_dim": 512, 
    "learning_rate": 1e-2,
    "num_epoch": 20,
    "batch_size": 1,
    "num_per_cluster": 5,
    "num_groups": 100,
    "checkpoint_folder": "../weights_setting3/resnet50BasedModel"
}

In [15]:
train_dataset = SeveritySimilarityDataset(annotation_file_path=config["annotation_data_path"],
                                    dataset_dir=config["image_folder_path"],
                                    phase="training",
                                    num_per_cluster=config["num_per_cluster"],
                                    num_group = config["num_groups"],
                                    input_size=(224, 224)
                                    )

valid_dataset = SeveritySimilarityDataset(annotation_file_path=config["annotation_data_path"],
                                    dataset_dir=config["image_folder_path"],
                                    phase="valid",
                                    num_per_cluster=config["num_per_cluster"],
                                    num_group = config["num_groups"]//100,
                                    input_size=(224, 224)
                                    )
model = Setting_3_model(model_name=config["model_encoder"],
                        embed_dim=config["embedding_dim"]
                        )


train_model(model=model, train_dataset=train_dataset,
            val_dataset=valid_dataset, num_epochs=config["num_epoch"],
            batch_size=config["batch_size"], learning_rate=config["learning_rate"],
            checkpoint_folder=config["checkpoint_folder"]
            )



Device: cpu
Training started...
****************************************************************************************************
Epoch [1/20]:
	 Batch [0/100], Train Loss: 0.6932
Epoch [1/20], Validation Batch [0/1], Val Loss: 0.6860
Validation, Train Loss: 0.6932, Val Loss: 0.6860
****************************************************************************************************
****************************************************************************************************
Epoch [2/20]:
	 Batch [0/100], Train Loss: 0.6932
Epoch [2/20], Validation Batch [0/1], Val Loss: 0.7070
Validation, Train Loss: 0.6931, Val Loss: 0.7070
****************************************************************************************************
****************************************************************************************************
Epoch [3/20]:
	 Batch [0/100], Train Loss: 0.6931
Epoch [3/20], Validation Batch [0/1], Val Loss: 0.6950
Validation, Train Loss: 0.6931, Val Loss: 0.6950
*

(Setting_3_model(
   (model): ResNet(
     (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace=True)
     (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
     (layer1): Sequential(
       (0): Bottleneck(
         (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
         (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
         (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
         (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (relu): ReLU(inplace=True)
         (downsample): 