<a href="https://colab.research.google.com/github/Devadeut/Neural-Networks-Hands-On-Projects/blob/main/Fairness_Contrastive_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torchvision.models import resnet50
from torchvision.transforms.functional import to_pil_image
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
!pip install pydicom
import pydicom
from PIL import Image



Collecting pydicom
  Downloading pydicom-2.4.4-py3-none-any.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pydicom
Successfully installed pydicom-2.4.4


## DATA

In [None]:
# @title Data Processing


# Define the transformation pipeline
transform_pipeline = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomRotation(degrees=10),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

def dicom_to_jpg(dicom_path):
    # Load the DICOM image
    dicom_image = pydicom.dcmread(dicom_path)
    image_array = dicom_image.pixel_array

    # Normalize to [0, 255]
    image_array = (np.maximum(image_array, 0) / image_array.max()) * 255.0

    # Invert pixels if necessary
    if dicom_image.PhotometricInterpretation == "MONOCHROME1":
        image_array = 255.0 - image_array

    # Perform histogram equalization
    image_eq = cv2.equalizeHist(image_array.astype(np.uint8))

    # Convert to PIL Image
    pil_img = Image.fromarray(image_eq)

    # Save as JPG
    pil_img.save("output.jpg", "JPEG", quality=95)

    return pil_img

def preprocess_and_augment(image_path):
    # Convert DICOM to JPG if it's a DICOM file
    if image_path.endswith('.dcm'):
        image = dicom_to_jpg(image_path)
    else:
        image = Image.open(image_path)

    # Apply transformations
    return transform_pipeline(image)

# Example usage:
# processed_image = preprocess_and_augment('path_to_image.dcm')


In [None]:
# @title Dataloading


from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import random

class TripletDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        # image_paths: List of paths to images
        # labels: Dictionary mapping image paths to their labels
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.labels_set = set(labels.values())
        self.labels_to_indices = {label: np.where(np.array(labels.values()) == label)[0]
                                  for label in self.labels_set}

    def __getitem__(self, index):
        anchor_path = self.image_paths[index]
        anchor_label = self.labels[anchor_path]

        # Get a positive sample (same label, different image)
        positive_index = index
        while positive_index == index:
            positive_index = random.choice(self.labels_to_indices[anchor_label])
        positive_path = self.image_paths[positive_index]

        # Get a negative sample (different label)
        negative_label = random.choice(list(self.labels_set - set([anchor_label])))
        negative_index = random.choice(self.labels_to_indices[negative_label])
        negative_path = self.image_paths[negative_index]

        # Load images and apply transformations
        anchor_img = preprocess_and_augment(anchor_path)
        positive_img = preprocess_and_augment(positive_path)
        negative_img = preprocess_and_augment(negative_path)

        return anchor_img, positive_img, negative_img, anchor_label

    def __len__(self):
        return len(self.image_paths)




## MODEL

In [None]:
# @title Backbone
# Model Backbone
class Backbone(nn.Module):
    def __init__(self):
        super(Backbone, self).__init__()
        # Use a pre-trained model without the top layer
        self.base_model = resnet50(pretrained=True)
        self.base_model.fc = nn.Identity()  # Remove the top layer

    def forward(self, x):
        return self.base_model(x)

In [None]:
# @title Contrastive Head
class ContrastiveHead(nn.Module):
    def __init__(self, feature_dim=1024, embedding_dim=128):
        super(ContrastiveHead, self).__init__()
        self.fc = nn.Linear(feature_dim, embedding_dim)

    def forward(self, x):
        return self.fc(x)

In [None]:
# @title ContrastiveLoss
class ContrastiveLoss(nn.Module):
  def __init__(self):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

  def forward(self, anchor, positives, negatives):
        #compute the similarities
        anchor_dot_positives = torch.matmul(anchor, positives.t()) / self.temperature
        anchor_dot_negatives = torch.matmul(anchor, negatives.t()) / self.temperature

        # Compute the log-sum-exp of negatives for each anchor
        negatives_logsumexp = torch.logsumexp(anchor_dot_negatives, dim=1)

        # Sum over all positives for each anchor, and average over all anchors in the batch
        loss = 0
        for i in range(anchor.size(0)):
            for j in range(positives.size(0)):
                loss -= anchor_dot_positives[i][j] - negatives_logsumexp[i]

        return loss.mean()


In [None]:
# @title Projection Head  and Classification Head
class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, projection_dim):
        super(ProjectionHead, self).__init__()
        self.fc = nn.Linear(embedding_dim, projection_dim)

    def forward(self, x):
        return F.relu(self.fc(x))

#
class ClassificationHead(nn.Module):
    def __init__(self, projection_dim):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(projection_dim, 1)

    def forward(self, x):
        return torch.sigmoid(self.fc(x))

In [None]:
# @title Model Training
# @markdown  computes the loss and updates the model parameters.

# Initialize the resnet50 backbone and the contrastive head
backbone = Backbone()
contrastive_head = ContrastiveHead()

# Assuming `image_paths` is a list of image file paths and `labels` is a dict mapping image paths to labels
triplet_dataset = TripletDataset(image_paths=image_paths, labels=labels, transform=transform_pipeline)
triplet_dataloader = DataLoader(triplet_dataset, batch_size=32, shuffle=True)

# Adding the projection and classification heads to the model
projection_head = ProjectionHead(embedding_dim=128, projection_dim=128)
classification_head = ClassificationHead(projection_dim=128)

# Binary Cross-Entropy Loss for binary classification
loss_function = nn.BCELoss()

# Assuming 'data_loader' is a PyTorch DataLoader that provides batches of images and labels for the downstream task
optimizer = torch.optim.Adam(list(backbone.parameters()) + list(projection_head.parameters()) + list(classification_head.parameters()), lr=0.0001)

num_epochs = 10
# Training loop for the downstream task
for epoch in range(num_epochs):
    for images, labels in triplet_dataloader:
        optimizer.zero_grad()
        embeddings = contrastive_head(backbone(images))
        projections = projection_head(embeddings)
        predictions = classification_head(projections).squeeze(1)
        loss = loss_function(predictions, labels.float())
        loss.backward()
        optimizer.step()

# Function to compute metrics
def compute_metrics(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred > 0.5)
    precision = precision_score(y_true, y_pred > 0.5)
    recall = recall_score(y_true, y_pred > 0.5)
    f1 = f1_score(y_true, y_pred > 0.5)
    auc = roc_auc_score(y_true, y_pred)
    return accuracy, precision, recall, f1, auc

# Example evaluation on validation set
# with torch.no_grad():
#     y_true = []
#     y_pred = []
#     for images, labels in validation_loader:
#         embeddings = contrastive_head(backbone(images))
#         projections = projection_head(embeddings)
#         predictions = classification_head(projections).squeeze(1)
#         y_true.extend(labels.numpy())
#         y_pred.extend(predictions.numpy())
#     metrics = compute_metrics(np.array(y_true), np.array(y_pred))
#     print(f"Accuracy: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, F1: {metrics[3]}, AUC: {metrics[4]}")
