### Experimental design: Attention
Use of Attention mechanism as patient diagnosis

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class HelicobacterClassifier(nn.Module):
    def __init__(self):
        super(HelicobacterClassifier, self).__init__()

        # conv1 + pooling (B, C, H, W) -> (B, 32, H/2, W/2)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        # conv2 + pooling (B, 32, H/2, W/2) -> (B, 64, H/4, W/4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        # conv3 + pooling (B, 64, H/4, W/4) -> (B, 128, H/8, W/8)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(128 * 32 * 32, 512)
        self.fc2 = nn.Linear(512, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 32 * 32)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the Attention Module
class AttentionModule(nn.Module):
    def __init__(self, feature_dim, hidden_dim):
        super(AttentionModule, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Output scalar attention score
        )

    def forward(self, patch_features):
        # Compute attention scores
        scores = self.attention(patch_features).squeeze(-1)  # Shape: [N]
        weights = F.softmax(scores, dim=0)  # Normalize scores across patches

        # Weighted sum of patch features
        aggregated_features = torch.sum(weights.unsqueeze(-1) * patch_features, dim=0)
        return aggregated_features, weights


# Define the Patient-Level Classifier
class PatientLevelClassifier(nn.Module):
    def __init__(self, feature_dim, hidden_dim, num_classes):
        super(PatientLevelClassifier, self).__init__()
        self.attention = AttentionModule(feature_dim, hidden_dim)
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, patch_features):
        # Attention-based aggregation
        aggregated_features, weights = self.attention(patch_features)

        # Patient-level classification
        logits = self.fc(aggregated_features)
        return logits, weights

class FeatureExtractor(nn.Module):
    def __init__(self, patch_model_path, output_layer_dim):
        super(FeatureExtractor, self).__init__()
        # Initialize the model architecture
        self.patch_model = HelicobacterClassifier()

        # Load the state dictionary
        state_dict = torch.load(patch_model_path, map_location=torch.device('cpu'))
        self.patch_model.load_state_dict(state_dict)

        # Remove the final classification layer (fc2) to extract features
        self.feature_extractor = nn.Sequential(
            self.patch_model.conv1,
            self.patch_model.pool,
            self.patch_model.conv2,
            self.patch_model.pool,
            self.patch_model.conv3,
            self.patch_model.pool,
            nn.Flatten(),
            self.patch_model.fc1  # Use fc1 as the final feature layer
        )

    def forward(self, x):
        # Extract features
        with torch.no_grad():
            features = self.feature_extractor(x)
        return features


In [None]:
import pandas as pd
path_to_annotated = '/Users/marino/Documents/GitHub/MED-GIA/MED-GIA/HP_WSI-CoordAnnotatedPatches.xlsx'
path_to_Patient_Diagnois = '/Users/marino/Documents/GitHub/MED-GIA/MED-GIA/PatientDiagnosis.csv'
patient_diagnosisDF = pd.read_csv(path_to_Patient_Diagnois)
annotated_patchesDF = pd.read_excel(path_to_annotated)

print(patient_diagnosisDF.head())
print(annotated_patchesDF.head())

patient_diagnosisDF = patient_diagnosisDF[(patient_diagnosisDF['DENSITAT'] == 'ALTA') | (patient_diagnosisDF['DENSITAT'] == 'NEGATIVA')]
annotated_patchesDF = annotated_patchesDF[annotated_patchesDF['Presence'] != 0]


# Group by patient_id and count the number of positive and negative patches
grouped = annotated_patchesDF.groupby(['Pat_ID'])
grouped = annotated_patchesDF.groupby('Pat_ID').agg(
    number_of_positive_patches=('Presence', lambda x: (x == 1).sum()),
    number_of_negative_patches=('Presence', lambda x: (x == -1).sum())
).reset_index()

# Inlcude in gropued pateint_diagnosisDF['DENSITAT'] based on the id
grouped.head()
grouped = grouped.merge(patient_diagnosisDF, left_on='Pat_ID', right_on='CODI', how='inner')
grouped.head()
grouped = grouped.drop(columns=['CODI'])

# OBJECTIVE  have a dataframe with the following columns: patient_id, number_of_positive_patches, number_of_negative_patches, diagnosis, prediction

print(patient_diagnosisDF)

densitat = patient_diagnosisDF [patient_diagnosisDF['CODI'] == "B22-03"]
print(list(densitat['DENSITAT']))


import pathlib
grouped


# Create a pathlib object to store the results

path_to_holdout =  'HoldOut'
# iterate over the directories in holdout  and include the name of the directory in a list
holdout_directories = [x for x in pathlib.Path(path_to_holdout).iterdir() if x.is_dir()]

# Extract the name of the directories without the full path
holdout_directories = [x.name[:-2] for x in holdout_directories]

from PIL import Image
import pathlib
import pandas as pd
import torch
import torchvision.transforms as transforms
import numpy as np


def transform_image(image, size):
    return image.resize(size)

class HoldoutDataset():
    def __init__(self, patch_model_path, output_layer_dim, device="cpu"):
        self.device = device
        self.patch_model = FeatureExtractor(patch_model_path, output_layer_dim).to(device)

        # Dataset details
        self.path_to_holdout = 'HoldOut'
        self.patient_directories = [patient for patient in pathlib.Path(self.path_to_holdout).iterdir()]
        self.path_to_Patient_Diagnois = '/Users/marino/Documents/GitHub/MED-GIA/MED-GIA/PatientDiagnosis.csv'
        self.patient_diagnosisDF = pd.read_csv(self.path_to_Patient_Diagnois)
        self.patient_diagnosisDF = self.patient_diagnosisDF[(self.patient_diagnosisDF['DENSITAT'] == 'ALTA') |
                                                            (self.patient_diagnosisDF['DENSITAT'] == 'NEGATIVA')]
        self.patient_diagnosisDF['DENSITAT'] = [1 if x == 'ALTA' else -1 for x in self.patient_diagnosisDF['DENSITAT']]
        self.patient_diagnosisDF = self.patient_diagnosisDF.rename(columns={'DENSITAT': 'DiagnosisGT'})

        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])

        self.dictionary = {}
        for patient in self.patient_directories:
            patient_id = patient.name[:-2]
            if patient_id in self.patient_diagnosisDF['CODI'].values:
                images = [x for x in patient.iterdir() if x.is_file() and x.name.endswith('.png')]
                diagnosis = self.patient_diagnosisDF[self.patient_diagnosisDF['CODI'] == patient_id]['DiagnosisGT'].values[0]
                self.dictionary[patient_id] = (images, diagnosis)

    def __len__(self):
        return len(self.dictionary)

    def __getitem__(self, idx):
        patient_id = list(self.dictionary.keys())[idx]
        images, diagnosis = self.dictionary[patient_id]

        patch_features = []
        for image_path in images:
            image = Image.open(image_path).convert("RGB")
            image = self.transform(image).unsqueeze(0).to(self.device)
            with torch.no_grad():
                feature = self.patch_model(image)
            patch_features.append(feature.cpu().squeeze(0))

        patch_features = torch.stack(patch_features)  # Shape: [N, D]
        return patch_features, patient_id, diagnosis


# Required paths and parameters
patch_model_path = 'classifier.pth'  # Path to your trained patch classification model
output_layer_dim = 512  # The feature dimension output by your patch model
device = "cuda" if torch.cuda.is_available() else "cpu"

# Instantiate the HoldoutDataset
holdout_dataset = HoldoutDataset(patch_model_path=patch_model_path,
                                 output_layer_dim=output_layer_dim,
                                 device=device)
print(len(holdout_dataset))

holdout_directories

     CODI  DENSITAT
0  B22-01     BAIXA
1  B22-02     BAIXA
2  B22-03  NEGATIVA
3  B22-04  NEGATIVA
4  B22-05  NEGATIVA
    Pat_ID  Section_ID  Window_ID      i      j    h    w  Presence
0  B22-129           0        659   7477  11978  256  256        -1
1   B22-68           0        131   6597  12009  256  256        -1
2   B22-68           0        141   5100  10737  256  256        -1
3   B22-68           0        290   5015  14908  256  256        -1
4   B22-68           0        298  11626  13928  256  256        -1
        CODI  DENSITAT
2     B22-03  NEGATIVA
3     B22-04  NEGATIVA
4     B22-05  NEGATIVA
5     B22-06  NEGATIVA
6     B22-07  NEGATIVA
..       ...       ...
304  B22-311      ALTA
305  B22-312      ALTA
306  B22-313      ALTA
307  B22-314  NEGATIVA
308  B22-315  NEGATIVA

[237 rows x 2 columns]
['NEGATIVA']


  state_dict = torch.load(patch_model_path, map_location=torch.device('cpu'))


82


['B22-161',
 'B22-88',
 'B22-04',
 'B22-82',
 'B22-286',
 'B22-44',
 'B22-65',
 'B22-261',
 'B22-225',
 'B22-100',
 'B22-282',
 'B22-314',
 'B22-271',
 'B22-75',
 'B22-196',
 'B22-135',
 'B22-310',
 'B22-14',
 'B22-231',
 'B22-31',
 'B22-10',
 'B22-198',
 'B22-96',
 'B22-207',
 'B22-226',
 'B22-247',
 'B22-266',
 'B22-85',
 'B22-128',
 'B22-222',
 'B22-203',
 'B22-81',
 'B22-66',
 'B22-281',
 'B22-62',
 'B22-09',
 'B22-03',
 'B22-49',
 'B22-309',
 'B22-209',
 'B22-262',
 'B22-243',
 'B22-285',
 'B22-268',
 'B22-07',
 'B22-236',
 'B22-78',
 'B22-238',
 'B22-272',
 'B22-295',
 'B22-132',
 'B22-36',
 'B22-213',
 'B22-13',
 'B22-32',
 'B22-72',
 'B22-257',
 'B22-19',
 'B22-136',
 'B22-259',
 'B22-17',
 'B22-159',
 'B22-246',
 'B22-48',
 'B22-206',
 'B22-263',
 'B22-229',
 'B22-69',
 'B22-227',
 'B22-02',
 'B22-146',
 'B22-08',
 'B22-267',
 'B22-06',
 'B22-202',
 'B22-269',
 'B22-169',
 'B22-242',
 'B22-208',
 'B22-237',
 'B22-12',
 'B22-73',
 'B22-212',
 'B22-233',
 'B22-294',
 'B22-252',


### Attention mechanism

In [None]:
# Update the feature dimension to match the output from fc1
output_layer_dim = 512  # Matches fc1 output from FeatureExtractor
hidden_dim = 128        # Hidden dimension in the attention mechanism
num_classes = 2         # Binary classification (e.g., POSITIVE/NEGATIVE)


# Initialize the patient-level classifier
patient_classifier = PatientLevelClassifier(
    feature_dim=output_layer_dim,  # Updated feature dimension
    hidden_dim=hidden_dim,
    num_classes=num_classes
).to(device)


print(f"Dataset loaded with {len(holdout_dataset)} patients.")


Dataset loaded with 82 patients.


### Main

In [None]:
# Set the patient classifier to evaluation mode
patient_classifier.eval()

# Iterate over patients in the dataset
for patch_features, patient_id, diagnosis_gt in holdout_dataset:
    patch_features = patch_features.to(device)
    diagnosis_gt = torch.tensor(diagnosis_gt).to(device)  # Ensure GT labels are on the same device

    # Predict patient-level diagnosis
    with torch.no_grad():
        logits, attention_weights = patient_classifier(patch_features)

    diagnosis_pred = logits.argmax(dim=-1).item()
    print(f"Patient ID: {patient_id}, Ground Truth: {diagnosis_gt}, Prediction: {diagnosis_pred}")

Patient ID: B22-88, Ground Truth: -1, Prediction: 1
Patient ID: B22-04, Ground Truth: -1, Prediction: 1


KeyboardInterrupt: 