<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.sandbox.google.com/github/Google-Health/imaging-research/blob/master/path-foundation/linear-classifier-demo.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/Google-Health/imaging-research/tree/master/path-foundation"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>


# Path Foundation Linear Probe Demo
This notebook is a demonstration of generating and using embeddings from the Path Foundation API to train a linear classifier. This API enables users to compute embeddings for histopathology images. The contents include how to build an API request to generate embeddings from stored patches and train a linear model using the embeddings. Note: This notebook is for API demonstration purposes only. As with all machine-learning use-cases it is critical to consider training and evaluation datasets that reflect the expected distribution of the intended use case.

**Additional details**: For this demo, whole slide images (WSIs) available from the dataset below were split into train and evaluation sets. A subset of patches were sampled randomly from across all available slides and embeddings were generated via the Path Foundation model.

**Dataset**: This notebook uses the [CAMELYON16](https://camelyon16.grand-challenge.org/) dataset, which contains WSIs from lymph node specimens with and without metastatic breast cancer. Any work that uses this dataset should consider additional details along with usage and citation requirements listed on [their website](https://camelyon17.grand-challenge.org/Data/).

**Dataset citation**: Babak Ehteshami Bejnordi; Mitko Veta; Paul Johannes van Diest; Bram van Ginneken; Nico Karssemeijer; Geert Litjens; Jeroen A. W. M. van der Laak; and the CAMELYON16 Consortium. Diagnostic Assessment of Deep Learning Algorithms for Detection of Lymph Node Metastases in Women With Breast Cancer. JAMA. 2017;318(22):2199–2210. DOI: 10.1001/jama.2017.14585
# Prerequisites
You must have access to the Pathology Foundation Tool. See the project's [README](https://github.com/Google-Health/imaging-research/blob/master/path-foundation/README.md) for details.




## Imports and constants


In [None]:
!pip install ez-wsi-dicomweb
!pip install hcls-imaging-ml-toolkit-ez-wsi

In [None]:
# Imports

from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
import json
import random
from typing import Any, List
import warnings
import ez_wsi_dicomweb
from ez_wsi_dicomweb import dicom_slide
from ez_wsi_dicomweb import dicom_web_interface
import ez_wsi_dicomweb.dicomweb_credential_factory as dicomweb_credential_factory
import ez_wsi_dicomweb.pixel_spacing as pixel_spacing
from google.cloud import storage
from google.cloud.storage.blob import Blob
import google.cloud.storage
import hcls_imaging_ml_toolkit.dicom_path as dicom_path
import hcls_imaging_ml_toolkit.tags as tags
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import requests
import sklearn.linear_model
import sklearn.metrics
import sklearn.model_selection
import sklearn.pipeline
import sklearn.preprocessing
import PIL.Image
import io

In [None]:
from google.colab import auth
# Authenticate user for access. There will be a popup asking you to sign in with your user account and approve access.
auth.authenticate_user()

In [None]:
# Constants


PROJECT_ID = 'hai-cd3-foundations'  # Project that contains the stored patches
BUCKET_NAME = (  # Bucket that contains the patches
    'hai-cd3-foundations-pathology-vault-entry'
)
DATASET_PROJECT_ID = 'hai-cd3-foundations'  # @param {type: 'string'}
DATASET_LOCATION = 'us-west1'  # @param {type: 'string'}
DATASET_ID = 'pathology'  # @param {type: 'string'}
STORE_ID = 'camelyon'  # @param {type: 'string'}
PATCHES_DIR_NAME = 'patches/'  # @param {type: 'string'}
EMBEDDINGs_DIR_NAME = 'embeddings/'  # @param {type: 'string'}
CANCER_FILE = 'all_cancer_patches.json'  # @param {type: 'string'}
NON_CANCER_FILE = 'all_non_cancer_patches.json'  # @param {type: 'string'}
TRAINING_CANCER_PATCH_COUNT = 250  # @param {type: 'integer'}
TRAINING_NON_CANCER_PATCH_COUNT = 250  # @param {type: 'integer'}
EVAL_CANCER_PATCH_COUNT = 50  # @param {type: 'integer'}
EVAL_NON_CANCER_PATCH_COUNT = 50  # @param {type: 'integer'}

# Generated using above values
DICOM_WEB_STORE_URL = f'projects/{DATASET_PROJECT_ID}/locations/{DATASET_LOCATION}/datasets/{DATASET_ID}/dicomStores/{STORE_ID}'

GCS_PATH = 'gs://hai-cd3-foundations-pathology-vault-entry/test_patches/test_patch' # @param {type: 'string'}


In [None]:
# Constants that should not be modified

PATCH_SIZE = 224
TARGET_PIXEL_SPACING = pixel_spacing.PixelSpacing.FromMagnificationString('20X')
EVAL_RESERVED_SLIDES = (
    EVAL_CANCER_PATCH_COUNT + 15
)  # slides reserved for the eval set. Add some buffer in case patch count is much higher than the reserved slide count.

# API config
PROD_API_PROJECT_ID = 'hai-cd3-foundations'
PROD_API_LOCATION = 'us-central1'
ENDPOINT_ID = '160'
GCS_ENDPOINT_ID = '161'
ENCODER_ENDPOINT_URL = f'https://{PROD_API_LOCATION}-aiplatform.googleapis.com/v1/projects/{PROD_API_PROJECT_ID}/locations/{PROD_API_LOCATION}/endpoints/{ENDPOINT_ID}:predict'
MODEL_SIZE = 'MEDIUM'
MODEL_KIND = 'LOW_PIXEL_SPACING'

 ## Additional setup

In [None]:
# Defines a helper Dataclass for converting Embeddings from JSON


@dataclass
class Embedding:
  # Assumes patch size is always 224 (so width and height of a patch are 224 pixels)
  dicom_study_uid: str
  dicom_series_uid: str
  x_origin: (
      int  # The X and Y origin represent the top left point in a square patch
  )
  y_origin: int
  instance_uids: list[str]
  embedding: list[float]

  def to_json(self) -> str:
    return json.dumps({
        'dicom_study_uid': self.dicom_study_uid.strip(),
        'dicom_series_uid': self.dicom_series_uid.strip(),
        'x_origin': self.x_origin,
        'y_origin': self.y_origin,
        'instance_uids': self.instance_uids,
        'embedding': json.dumps(self.embedding),
    })


@dataclass
class Embeddings:
  embeddings: list[Embedding]

  def to_json(self) -> str:
    return json.dumps([em.to_json() for em in self.embeddings])

  def concat(self, other):
    return Embeddings(embeddings=self.embeddings + other.embeddings)


def embedding_from_json(json_str: str) -> Embedding:
  json_object = json.loads(json_str)
  return Embedding(
      dicom_study_uid=json_object['dicom_study_uid'],
      dicom_series_uid=json_object['dicom_series_uid'],
      x_origin=json_object['x_origin'],
      y_origin=json_object['y_origin'],
      instance_uids=json_object['instance_uids'],
      embedding=json.loads(json_object['embedding']),  # Deserialize embedding
  )


def embeddings_from_json(json_str: str) -> Embeddings:
  json_objects = json.loads(json_str)
  return Embeddings(
      embeddings=[embedding_from_json(obj) for obj in json_objects]
  )


def embeddings_dataclass_from_response(response: str) -> Embeddings:
  response = json.loads(response)
  embeddings = []
  for prediction in response['predictions']:
    if isinstance(prediction, list):
      for result in prediction:
        for inner_result in result['patch_embeddings']:
          embeddings.append(
              Embedding(
                  dicom_study_uid=result['dicom_study_uid'],
                  dicom_series_uid=result['dicom_series_uid'],
                  x_origin=inner_result['patch_coordinate']['x_origin'],
                  y_origin=inner_result['patch_coordinate']['y_origin'],
                  instance_uids=result['instance_uids'],
                  embedding=inner_result['embeddings'],
              )
          )

  return Embeddings(embeddings=embeddings)

In [None]:
# Defines a helper Dataclass for converting Patches from JSON
# (Embeddings are returned in JSON format, so this converts them for easier downstream use)


@dataclass
class Patch:
  slide_id: str
  study_instance_uid: str
  series_instance_uid: str
  x_origin: int
  y_origin: int

  def to_json(self) -> str:
    return json.dumps(asdict(self))


@dataclass
class PatchCollection:
  patches: list[Patch]

  def to_json(self) -> str:
    return json.dumps(asdict(self))


def patch_collection_from_json(json_str: str) -> PatchCollection:
  data = json.loads(json_str)
  patches = [Patch(**patch_data) for patch_data in data['patches']]
  return PatchCollection(patches=patches)

In [None]:
# Helper function to render patches, this function is used to display example patches


dcf = dicomweb_credential_factory.CredentialFactory()
dwi = dicom_web_interface.DicomWebInterface(dcf)

# Use patch location and DICOM information from a returned embedding to retrieve and display the correct patch
def render_patch_from_embedding(embedding: Embedding, plot_name: str = ''):

  series_path = dicom_path.Path(
      project_id=DATASET_PROJECT_ID,
      location=DATASET_LOCATION,
      dataset_id=DATASET_ID,
      store_id=STORE_ID,
      study_uid=embedding.dicom_study_uid,
      series_uid=embedding.dicom_series_uid,
  )
  ds = dicom_slide.DicomSlide(
      dwi=dwi, path=series_path, enable_client_slide_frame_decompression=True
  )
  patch_bytes = ds.get_patch(
      pixel_spacing=TARGET_PIXEL_SPACING,
      x=embedding.x_origin,
      y=embedding.y_origin,
      width=PATCH_SIZE,
      height=PATCH_SIZE,
  ).image_bytes()
  plt.figure(figsize=(2, 2))
  plt.imshow(patch_bytes)
  plt.title(plot_name)
  plt.axis('off')
  plt.show()

In [None]:
# Generate Auth Token

auth_token = !gcloud auth print-access-token
AUTH_TOKEN = auth_token[0]

## Running the API on Google Cloud Storage Images

In [None]:
# @title Initial Helper Functions and Setup
encoder_endpoint_url = f'https://{PROD_API_LOCATION}-aiplatform.googleapis.com/v1/projects/{PROD_API_PROJECT_ID}/locations/{PROD_API_LOCATION}/endpoints/{GCS_ENDPOINT_ID}:predict'

def tile_image(
    image_height: int,
    image_width: int,
    patch_size: int = PATCH_SIZE,
    patch_overlap: int = 0, # how much the patches overal between each other
    run_edge: bool = True
) -> list[tuple[str, str]]:

  if (image_height or image_width) < patch_size:
    return "Image is too small"
  step = patch_size - patch_overlap
  if step <= 0:
    raise ValueError("patch_size must be greater than patch_overlap")
  column = np.arange(0, (image_width//step)*step, step=step, dtype=np.int32)
  row = np.arange(0, (image_height//step)*step, step=step, dtype=np.int32)
  if run_edge:
    if image_width % patch_size != 0:
      column = np.append(column, image_width-patch_size)
    if image_height % patch_size != 0:
      row = np.append(row, image_height-patch_size)
  patch_coordinates = []
  for x in column:
    for y in row:
      patch_coordinates.append({
          'x_origin': int(x),
          'y_origin': int(y),
          'width': patch_size,
          'height': patch_size,
      })

  return patch_coordinates


def generate_tile_payload(
    model_size: str,
    model_kind: str,
    project_name: str,
    gcs_image_url: str,
    auth_token: str,
    patch_coordinates: list[tuple[str, str]],
) -> dict[str, Any]:
  tile_payload = {
      'parameters': {'model_size': model_size, 'model_kind': model_kind},
      'instances': [{
          'project_name': project_name,
          'gcs_image_url': gcs_image_url,
          'bearer_token': auth_token,
          'ez_wsi_state': {},
          'patch_coordinates': patch_coordinates,
      }],
  }
  return tile_payload
### Fetch image for request to get total width and height
def get_gcs_image_dimensions(project_name, gcs_path):
    client = storage.Client(project=project_name)
    blob = Blob.from_string(gcs_path, client=client)

    # Download image content as a string
    image_string = blob.download_as_string()

    # Load image from string using PIL and get dimensions
    with PIL.Image.open(io.BytesIO(image_string)) as img:
        width, height = img.size

    return height, width

# Example usage (same as before)
project_name = DATASET_PROJECT_ID
height, width = get_gcs_image_dimensions(project_name, GCS_PATH + '.png') #you can adjust this to tiff or jpeg, we have example images of both
total_pixels = height * width
print(f"Image height: {height}")
print(f"Image width: {width}")

class gcs_Embedding:
  # Assumes patch size is always 224 (so width and height of a patch are 224 pixels)
  x_origin: (
      int  # The X and Y origin represent the top left point in a square patch
  )
  def __init__(self, x_origin=0, y_origin=0, embedding=[], gcs_image_url=""):
    self.x_origin = x_origin
    self.y_origin = y_origin
    self.embedding = embedding
    self.gcs_image_url = gcs_image_url


  def to_json(self) -> str:
    return json.dumps({

        'x_origin': self.x_origin,
        'y_origin': self.y_origin,
        'embedding': json.dumps(self.embedding),
        'gcs_image_url': self.gcs_image_url,
    })

def gcs_embeddings_dataclass_from_response(response: str) -> Embeddings:
  response = json.loads(response)
  embeddings = []
  for prediction in response['predictions']:
    if isinstance(prediction, list):
      for result in prediction:
        gcs_url = result['gcs_image_url']
        for inner_result in result['patch_embeddings']:
          embeddings.append(
              gcs_Embedding(
                  x_origin=inner_result['patch_coordinate']['x_origin'],
                  y_origin=inner_result['patch_coordinate']['y_origin'],
                  embedding=inner_result['embeddings'],
                  gcs_image_url=gcs_url,
              )
          )

  return Embeddings(embeddings=embeddings)

In [None]:
# @title Make requests
headers = {'Authorization': f'Bearer {AUTH_TOKEN}'}

random_results = {}
sequential_results = {}
NUM_PATCHES = 1
tile_patch_payload = generate_tile_payload(
  model_size=MODEL_SIZE,
  model_kind=MODEL_KIND,
  project_name=DATASET_PROJECT_ID,
  gcs_image_url= GCS_PATH + '.jpeg',
  auth_token=AUTH_TOKEN,
  patch_coordinates=tile_image(
      image_height=height,
      image_width=width,
      run_edge=False
  ),
)
response = requests.post(
  encoder_endpoint_url, headers=headers, json=tile_patch_payload
)


In [None]:
# @title Process Response
gcs_embedding =gcs_embeddings_dataclass_from_response(response.text)

## Download & Organize Patches Into Train and Eval Lists

In [None]:
# @title Downloads the pre-generated patches


client = storage.Client(project=PROJECT_ID)
bucket = client.bucket(BUCKET_NAME)


def download_and_convert_patches(blob_path: str) -> Patch:
  """Downloads a blob and converts JSON to dataclass"""
  json_data = (
      client.bucket(BUCKET_NAME).get_blob(blob_path).download_as_string()
  )
  return patch_collection_from_json(json_data)


# Downloads patch collections
cancer_patch_collection = download_and_convert_patches(
    PATCHES_DIR_NAME + CANCER_FILE
)
non_cancer_patch_collection = download_and_convert_patches(
    PATCHES_DIR_NAME + NON_CANCER_FILE
)

In [None]:
# @title Split into Training and Eval lists
# Split by slide for eval and separate patches into training and eval lists according to patch labels.


# Bucket patches by slide_id
def build_patches_by_slide_id(
    patch_collection: PatchCollection,
) -> dict[str, list[Patch]]:
  patches_by_slide = defaultdict(list)  # Create a defaultdict of lists
  for patch in patch_collection.patches:
    patches_by_slide[patch.slide_id].append(patch)  # Directly append
  return patches_by_slide


def select_random_slide_ids(
    patches_by_slide: dict[str, list[Patch]], num_slides: int
) -> list[str]:
  slide_ids = list(patches_by_slide.keys())  # Get all slide IDs
  random.shuffle(slide_ids)  # Shuffle for randomness
  return slide_ids[:num_slides]  # Select the first num_slides elements


def get_patches_from_slide_ids(
    patches_by_slide: dict[str, list[Patch]],
    selected_slide_ids: list[str],
    include_selected: bool = True,
) -> list[Patch]:
  def filter_patches(slide_id: str) -> bool:
    return (
        slide_id in selected_slide_ids
        if include_selected
        else slide_id not in selected_slide_ids
    )

  return [
      patch
      for slide_id in patches_by_slide
      if filter_patches(slide_id)
      for patch in patches_by_slide[slide_id]
  ]


cancer_slide_to_patches = build_patches_by_slide_id(cancer_patch_collection)
non_cancer_slide_to_patches = build_patches_by_slide_id(
    non_cancer_patch_collection
)

eval_reserved_slides = select_random_slide_ids(
    cancer_slide_to_patches, EVAL_RESERVED_SLIDES
)

training_cancer_patches = get_patches_from_slide_ids(
    cancer_slide_to_patches, eval_reserved_slides, include_selected=False
)
training_non_cancer_patches = get_patches_from_slide_ids(
    non_cancer_slide_to_patches, eval_reserved_slides, include_selected=False
)

eval_cancer_patches = get_patches_from_slide_ids(
    cancer_slide_to_patches, eval_reserved_slides, include_selected=True
)
eval_non_cancer_patches = get_patches_from_slide_ids(
    non_cancer_slide_to_patches, eval_reserved_slides, include_selected=True
)

print(f'Total Training Non cancer patches: {len(training_non_cancer_patches)}')
print(f'Total Training Cancer patches: {len(training_cancer_patches)}')
print(f'Total Eval Non cancer patches: {len(eval_non_cancer_patches)}')
print(f'Total Eval Cancer patches: {len(eval_cancer_patches)}')

##Using the API on Google DICOM store images

In [None]:
# @title Initial Helper Functions and Setup

def generate_embeddings_payload(
    patch_count: int, input_patches: List[Patch]
) -> dict[str, Any]:
  selected_patches = random.sample(input_patches, patch_count)

  # Group patches by series for efficient processing
  patches_by_series = _group_patches_by_series(selected_patches)

  instances = []
  for series_uid, patches in patches_by_series.items():
    instances.append(_create_instance_data(patches))

  return {
      'parameters': {'model_size': MODEL_SIZE, 'model_kind': MODEL_KIND},
      'instances': instances,
  }


def _group_patches_by_series(patches: List[Patch]) -> dict[str, List[Patch]]:
  patches_by_series = defaultdict(list)
  for patch in patches:
    patches_by_series[patch.series_instance_uid].append(patch)
  return patches_by_series


def _create_instance_data(patches: List[Patch]) -> dict[str, Any]:
  first_patch = patches[0]
  series_path = _create_dicom_slide_path(first_patch)
  ds = dicom_slide.DicomSlide(
      dwi=dwi, path=series_path, enable_client_slide_frame_decompression=True
  )

  instance_uids = _get_instance_uids(ds)

  return {
      'dicom_web_store_url': DICOM_WEB_STORE_URL,
      'dicom_study_uid': first_patch.study_instance_uid,
      'dicom_series_uid': first_patch.series_instance_uid,
      'bearer_token': AUTH_TOKEN,
      'ez_wsi_state': {},
      'instance_uids': instance_uids,
      'patch_coordinates': _format_patch_coordinates(patches),
  }


def _create_dicom_slide_path(patch: Patch) -> dicom_path.Path:
  return dicom_path.Path(
      project_id=DATASET_PROJECT_ID,
      location=DATASET_LOCATION,
      dataset_id=DATASET_ID,
      store_id=STORE_ID,
      study_uid=patch.study_instance_uid,
      series_uid=patch.series_instance_uid,
  )


def _get_instance_uids(ds: dicom_slide.DicomSlide) -> List[str]:
  # Get instance UID at TARGET_PIXEL_SPACING
  instance_uids = []
  for instance_id, instance in ds.get_level_by_pixel_spacing(
      TARGET_PIXEL_SPACING
  ).instances.items():
    instance_uids.append(instance.dicom_object.get_value(tags.SOP_INSTANCE_UID))
  return instance_uids


def _format_patch_coordinates(patches: List[Patch]) -> List[dict[str, int]]:
  return [
      {
          'x_origin': patch.x_origin,
          'y_origin': patch.y_origin,
          'width': PATCH_SIZE,
          'height': PATCH_SIZE,
      }
      for patch in patches
  ]

In [None]:
# @title Create payload for the API using the lists of patches defined above.
# Note:  May take approximately 5 minutes


eval_cancer_payload = generate_embeddings_payload(
    patch_count=EVAL_CANCER_PATCH_COUNT, input_patches=eval_cancer_patches
)
eval_non_cancer_payload = generate_embeddings_payload(
    patch_count=EVAL_NON_CANCER_PATCH_COUNT,
    input_patches=eval_non_cancer_patches,
)
training_cancer_payload = generate_embeddings_payload(
    patch_count=TRAINING_CANCER_PATCH_COUNT,
    input_patches=training_cancer_patches,
)
training_non_cancer_payload = generate_embeddings_payload(
    patch_count=TRAINING_NON_CANCER_PATCH_COUNT,
    input_patches=training_non_cancer_patches,
)

In [None]:
# @title Generate Embeddings for the patches in the Training and Eval sets
# Note: May take approximately 5 Minutes

# Because we are fetching patches from random locations across many slides, this
# may take several minutes. In scenarios where multiple patches are retrieved
# sequentially across a whole slide image, performance will be faster for this step.

headers = {'Authorization': f'Bearer {AUTH_TOKEN}'}

eval_cancer_response = requests.post(
    ENCODER_ENDPOINT_URL, headers=headers, json=eval_cancer_payload
).text

eval_non_cancer_response = requests.post(
    ENCODER_ENDPOINT_URL, headers=headers, json=eval_non_cancer_payload
).text

training_cancer_response = requests.post(
    ENCODER_ENDPOINT_URL, headers=headers, json=training_cancer_payload
).text

training_non_cancer_response = requests.post(
    ENCODER_ENDPOINT_URL, headers=headers, json=training_non_cancer_payload
).text

In [None]:
# @title Process the embeddings using the helper dataclass defined above


eval_cancer_embeddings = embeddings_dataclass_from_response(
    eval_cancer_response
)
eval_non_cancer_embeddings = embeddings_dataclass_from_response(
    eval_non_cancer_response
)
training_cancer_embeddings = embeddings_dataclass_from_response(
    training_cancer_response
)
training_non_cancer_embeddings = embeddings_dataclass_from_response(
    training_non_cancer_response
)

## Train and Evaluate Linear Probe

In [None]:
# Pass the embeddings into scikit-learn


def concatenate_embeddings(embeddings_obj: Embeddings) -> np.array:
  """Concatenates embeddings into a NumPy array."""
  return np.array(
      [embedding.embedding for embedding in embeddings_obj.embeddings]
  )


def concatenate_series_ids(embeddings_obj: Embeddings) -> np.array:
  """Concatenates instance UIDs into a NumPy array."""
  # Assume there is one instance uid per series.
  return np.array(
      [embedding.instance_uids[0] for embedding in embeddings_obj.embeddings]
  )


# Create NumPy arrays directly (more efficient)
training_embeddings = np.concatenate([
    concatenate_embeddings(training_cancer_embeddings),
    concatenate_embeddings(training_non_cancer_embeddings),
])
eval_embeddings = np.concatenate([
    concatenate_embeddings(eval_cancer_embeddings),
    concatenate_embeddings(eval_non_cancer_embeddings),
])

training_ids = np.concatenate([
    concatenate_series_ids(training_cancer_embeddings),
    concatenate_series_ids(training_non_cancer_embeddings),
])
eval_ids = np.concatenate([
    concatenate_series_ids(eval_cancer_embeddings),
    concatenate_series_ids(eval_non_cancer_embeddings),
])

# Create labels (already done in the previous question)
training_labels = np.concatenate((
    np.ones(TRAINING_CANCER_PATCH_COUNT),
    np.zeros(TRAINING_NON_CANCER_PATCH_COUNT),
))
eval_labels = np.concatenate(
    (np.ones(EVAL_CANCER_PATCH_COUNT), np.zeros(EVAL_NON_CANCER_PATCH_COUNT))
)

In [None]:
# Train a linear classifier using the embeddings


with warnings.catch_warnings():
  warnings.simplefilter('ignore')
  clf_pipeline = sklearn.pipeline.Pipeline([
      ('scaler', sklearn.preprocessing.StandardScaler()),
      (
          'logreg',
          sklearn.model_selection.GridSearchCV(
              sklearn.linear_model.LogisticRegression(
                  random_state=0,
                  multi_class='ovr',
                  verbose=False,
              ),
              cv=sklearn.model_selection.StratifiedGroupKFold(n_splits=5).split(
                  training_embeddings, y=training_labels, groups=training_ids
              ),
              param_grid={'C': np.logspace(start=-4, stop=4, num=10, base=10)},
              scoring='roc_auc_ovr',
              refit=True,
          ),
      ),
  ]).fit(training_embeddings, training_labels)

  test_predictions = clf_pipeline.predict_proba(eval_embeddings)[:, 1]

In [None]:
# Evaluate the linear classifiers performance using the eval patches

sklearn.metrics.roc_auc_score(eval_labels, test_predictions)

In [None]:
# @title Plot the ROC Curve

display = sklearn.metrics.RocCurveDisplay.from_predictions(
    eval_labels, test_predictions, name="Tumor Classifier"
)
display.ax_.set_title("ROC of Tumor Classifier")

In [None]:
# @title Find Youden's index for threshold selection

thresholds = np.linspace(0, 1, 100)
sensitivities = []
specificities = []
for threshold in thresholds:
  predictions = test_predictions > threshold
  sensitivities.append(sklearn.metrics.recall_score(eval_labels, predictions))
  specificities.append(
      sklearn.metrics.recall_score(eval_labels == 0, predictions == 0)
  )
index = np.argmax(np.array(sensitivities) + np.array(specificities))
best_threshold = thresholds[index]
sens = sensitivities[index]
spec = specificities[index]
print(
    f"Best threshold: {round(best_threshold,2)}. Sensitivity is"
    f" {round(sens*100,2)}% and Specificity is {round(spec*100,2)}% "
)

In [None]:
# @title Show the results in a table

eval_embeddings_obj = eval_cancer_embeddings.concat(eval_non_cancer_embeddings)

df = pd.DataFrame(
    {'ground_truth': eval_labels, 'model_score': test_predictions}
)
df['tumor_prediction'] = df['model_score'] > best_threshold
df['embeddings'] = [embedding for embedding in eval_embeddings_obj.embeddings]

df

In [None]:
# @title Visualize True Positives

df_tp = (
    df[(df['tumor_prediction'] == True) & (df['ground_truth'] == 1)]
    .sort_values('model_score', ascending=False)
    .head(5)
)
for _, row in df_tp.iterrows():
  print(f'model score is {row.model_score}')
  render_patch_from_embedding(row.embeddings, f'True Positive')

In [None]:
# @title Visualize True Negatives

df_tn = (
    df[(df['tumor_prediction'] == False) & (df['ground_truth'] == 0)]
    .sort_values('model_score', ascending=False)
    .head(5)
)
for _, row in df_tn.iterrows():
  print(f'model score is {row.model_score}')
  render_patch_from_embedding(row.embeddings, f'True Negative')

In [None]:
# @title Visualize False Positives

df_fp = df[
    (df['tumor_prediction'] == True) & (df['ground_truth'] == 0)
].sort_values('model_score', ascending=False)
for _, row in df_fp.iterrows():
  print(f'model score is {row.model_score}')
  render_patch_from_embedding(row.embeddings, f'False Positive')

In [None]:
# @title Visualize False Negatives

df_fn = df[
    (df['tumor_prediction'] == False) & (df['ground_truth'] == 1)
].sort_values('model_score', ascending=True)
for _, row in df_fn.iterrows():
  print(f'model score is {row.model_score}')
  render_patch_from_embedding(row.embeddings, f'False Negative')