In [None]:
!pip install apache_beam
!pip install openslide
# !pip install librsvg
# !pip install libiconv
!apt-get update && apt-get install -y libvips libvips-dev
!pip install pyvips
# !pip install --upgrade pyvips

In [None]:
import apache_beam as beam
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import RunInference

import tensorflow as tf

import pandas as pd
import numpy as np
from numpy import typing as npt
from PIL import Image

from typing import (
    Iterable,
    NamedTuple,
    NewType,
    Sequence,
    Optional,
    Dict, Any, List
)
import csv
import openslide
import pyvips
import os

In [None]:
csv_path = '/kaggle/input/mayo-clinic-strip-ai/train.csv' #all images
csv_path = '/kaggle/input/train-csv-3/train_csv_2.csv' #1 image - large
csv_path = '/kaggle/input/d/adityashukla1/train-small-2/train_small.csv' #1 image - medium
csv_path = '/kaggle/input/train-csv-4/train_csv_3.csv' #1 image - small

tiff_files_location = '/kaggle/input/mayo-clinic-strip-ai/train'
image_path = '/kaggle/input/mayo-clinic-strip-ai/train/006388_0.tif'
train_csv = pd.read_csv('/kaggle/input/mayo-clinic-strip-ai/train.csv')

output_prefix = 'patient_info'

In [None]:
def key_patient_info(row_dict):
    return (row_dict['patient_id'], row_dict)

class PatientDataCombineFn(beam.CombineFn):
    def create_accumulator(self):
        return {
            'center_ids': set(),
            'labels': set(),
            'image_ids': []
        }

    def add_input(self, acc, element):
        acc['center_ids'].add(element['center_id'])
        if 'label' in element:
            acc['labels'].add(element['label'])
        acc['image_ids'].append(element['image_id'])
        return acc

    def merge_accumulators(self, accs):
        result = self.create_accumulator()
        for acc in accs:
            result['center_ids'].update(acc['center_ids'])
            result['labels'].update(acc['labels'])
            result['image_ids'].extend(acc['image_ids'])
        return result

    def extract_output(self, acc):
        return {
            'center_id': list(acc['center_ids'])[0],
            'label': list(acc['labels'])[0] if acc['labels'] else None,
            'image_ids': acc['image_ids']
        }

def select_image_id(patient_id, record):
    return record['image_id'], record

def load_downscaled_image(local_file_path, downsample_ratio=16):
    print(f'>> Load original image: {local_file_path}')
    full_img = pyvips.Image.new_from_file(local_file_path)
    print(f">> Original image size: h-{full_img.height} w-{full_img.width}")
    scaled_down_image = full_img.resize(1 / downsample_ratio).numpy()
    # print(f'>> scaled down image: {scaled_down_image}')
    return scaled_down_image

def extract_tiles(input_image_path, non_empty_tile_indices, tile_size=448):
    print(f""">> extract tiles: {input_image_path}
>> non-empty tile indices: {non_empty_tile_indices}""")
    size = (tile_size, tile_size)
    with openslide.open_slide(input_image_path) as slide:
        for row, column in non_empty_tile_indices:
            left = column * tile_size
            top = row * tile_size
            position = (left, top)
            tile_img = slide.read_region(position, 0, size).convert("RGB")
            yield tile_img

def produce_tiles(image_id, metadata, tiff_files_location, max_tiles=None):
    image_path = os.path.join(tiff_files_location, f"{image_id}.tif")
    downsample_ratio = 16
    tile_size = 448
    background_threshold = 0.7

    # Load downscaled image
    scaled_image = load_downscaled_image(image_path, downsample_ratio)
    print(f'>> scaled_image shape: {scaled_image.shape}')
    height, width, _ = scaled_image.shape
    rows = height // tile_size
    cols = width // tile_size
    print(f"rows and cols: {rows, cols}")

    non_empty_tiles = []

    for row in range(rows):
        for col in range(cols):
            top = row * tile_size
            left = col * tile_size
            tile = scaled_image[top:top + tile_size, left:left + tile_size]
            # Determine background pixels
            background_pixels = np.all(tile > 190, axis=2)
            background_ratio = np.mean(background_pixels)
            print(f""">> bg ratio: {background_ratio}
>> bg threshold: {background_threshold}"""
            )
            if background_ratio < background_threshold:
                non_empty_tiles.append((row, col))

    if max_tiles:
        non_empty_tiles = non_empty_tiles[:max_tiles]

    for tile_img in extract_tiles(image_path, non_empty_tiles, tile_size):
        print(f">> produce tiles: {image_path}")
        yield {
            'image_id': image_id,
            'patient_id': metadata['patient_id'],
            'center_id': metadata['center_id'],
            'label': metadata.get('label'),
            'tile': tile_img
        }

def convert_to_tile_entry(tile_dict):
    tile_img: Image.Image = tile_dict["tile"]
    np_image: np.ndarray = np.array(tile_img, dtype=np.uint8)
    return TileEntry(
        patient_id=tile_dict["patient_id"],
        image=np_image
    )


PatientId = NewType("PatientId", str)
Image = npt.NDArray[np.uint8]

class TileEntry(NamedTuple):
    """Schema for a tile entry."""

    patient_id: PatientId
    image: Image


class Embedding(NamedTuple):
    """Schema for aggregated embeddings."""

    max_embedding: tf.Tensor
    avg_embedding: tf.Tensor


class EmbeddingEntry(NamedTuple):
    """Schema for prediction entry."""

    patient_id: PatientId
    embedding: Embedding


def embed_tiles(
    model: tf.keras.Model,
    tiles_batch: Sequence[Image],
) -> Iterable[Embedding]:
    """
    Run a batch of input images through EfNet
    to generate aggregated embeddings.
    """
    # convert from NumPy to TensorFlow
    input_tensor = tf.ensure_shape(
        tf.convert_to_tensor(tiles_batch),
        [None, 224 * 2, 224 * 2, 3],
    )
    # The input tile is twice as big as needed
    # for EfNet - we scaled it down 2x here.
    resized_input_tensor = tf.image.resize(input_tensor, [224, 224])
    # generate embeddings using EfNet
    results = model(resized_input_tensor, training=False)
    avg_embeddings = tf.ensure_shape(results["avg"], [None, 1280])
    max_embeddings = tf.ensure_shape(results["max"], [None, 1280])
    # wrap the results
    for avg_embedding, max_embedding in zip(avg_embeddings, max_embeddings):
        yield Embedding(avg_embedding=avg_embedding, max_embedding=max_embedding)


class TileEmbeddingModelHandler(
    ModelHandler[TileEntry, EmbeddingEntry, tf.keras.Model]
):
    """Wrapper around EfficientNet embedding."""

    def load_model(self) -> tf.keras.Model:
        """Prepare an EfNet aggregation model for tile images."""
        # The model will consume 224x224 RGB images
        image = tf.keras.layers.Input(
            shape=(224, 224, 3), name="image", dtype=tf.float32
        )
        # We use EfNet B0 without top layers for embedding
        backbone = tf.keras.applications.EfficientNetB0(
            include_top=False, weights="imagenet", input_tensor=image
        )
        # To save on compute resources, we won't fine-tune the EfNet backbone
        backbone.trainable = False
        # The backbone output has shape [<batch size>, 7, 7, 1280]
        # We generate two aggregations over the backbone output
        # to obtain a [<batch size>, 1280] shape
        avg_pool = tf.keras.layers.GlobalAveragePooling2D()(backbone.output)
        max_pool = tf.keras.layers.GlobalMaxPooling2D()(backbone.output)
        model = tf.keras.Model(
            image, {"avg": avg_pool, "max": max_pool}, name="EfficientNet"
        )
        model.compile()
        return model

    def run_inference(
        self,
        batch: Sequence[TileEntry],
        model: tf.keras.Model,
        inference_args: Optional[Dict[str, Any]] = None,
    ) -> Iterable[EmbeddingEntry]:
        """Run inference using the loaded model."""
        # Extract just the tile images from the input batch
        input_images = [tile.image for tile in batch]
        # Embed tile images using EfNet
        embeddings = embed_tiles(model, input_images)
        # Wrap the resulting embeddings together with the patient identifier
        for tile, embedding in zip(batch, embeddings):
            yield EmbeddingEntry(patient_id=tile.patient_id, embedding=embedding)



In [None]:
class CombineEmbeddingsFn(beam.CombineFn):
    def create_accumulator(self):
        return []

    def add_input(self, accumulator, embedding_entry: EmbeddingEntry):
        accumulator.append(embedding_entry.embedding)
        return accumulator

    def merge_accumulators(self, accumulators):
        merged = []
        for acc in accumulators:
            merged.extend(acc)
        return merged

    def extract_output(self, embeddings: list[Embedding]):
        if not embeddings:
            return None

        avg_embeddings = np.stack([e.avg_embedding.numpy() for e in embeddings])
        max_embeddings = np.stack([e.max_embedding.numpy() for e in embeddings])

        combined_avg = tf.convert_to_tensor(np.mean(avg_embeddings, axis=0))
        combined_max = tf.convert_to_tensor(np.max(max_embeddings, axis=0))

        return Embedding(avg_embedding=combined_avg, max_embedding=combined_max)

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))

def to_example(patient_id, data):
    # Parse metadata
    patient_info = data["patient_info"][0]
    label = patient_info["label"]
    center_id = patient_info["center_id"]
    image_ids = patient_info["image_ids"]

    # Parse embeddings (assume one Embedding object in list)
    embedding = data["embeddings"][0]
    avg_emb = embedding.avg_embedding.numpy().tolist()
    max_emb = embedding.max_embedding.numpy().tolist()

    # Build tf.train.Example
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                "patient_id": _bytes_feature(patient_id),
                "label": _bytes_feature(label),
                "center_id": _bytes_feature(center_id),
                "image_ids": _bytes_feature(",".join(image_ids)),
                "avg_embedding": _float_feature(avg_emb),
                "max_embedding": _float_feature(max_emb),
            }
        )
    )
    print(f">> TfExample: {example}")
    return example


In [None]:
def run():
    with open(csv_path, 'r') as f:
        header_line = f.readline().strip()
        header = header_line.split(',')
        
    with beam.Pipeline() as p:
        csv_lines = p | "ReadCSV" >> beam.io.ReadFromText(csv_path, skip_header_lines=1)
        parsed_row_dicts = csv_lines | beam.Map(lambda row: dict(zip(header, row.split(','))))
        keyed_rows = parsed_row_dicts | beam.Map(key_patient_info)
        patient_info = keyed_rows | beam.CombinePerKey(PatientDataCombineFn())
        raw_tiles = keyed_rows | "SelectImageId" >> beam.MapTuple(select_image_id) | "ProduceTiles" >> beam.FlatMapTuple(produce_tiles, tiff_files_location)
        
        #splitting each large image into multiple smalled zoomed in tiles 
        #removing the black background from each image based on threshold
        #loading selected tiles from whole images which are actually useful for downstream tasks
        tile_entries = raw_tiles | beam.Map(convert_to_tile_entry)

        #creating embeddings for the input tiles (creating max and avg embeddings using EfficientNet model)
        embedded_tiles = tile_entries | "EmbedTiles" >> RunInference(TileEmbeddingModelHandler())
        final_embeddings = embedded_tiles | "KeyByPatientId" >> beam.Map(lambda e: (e.patient_id, e)) | "CombineEmbeddingsPerPatient" >> beam.CombinePerKey(CombineEmbeddingsFn())

        #combining everything together for richer data 
        merged_data = (
            {
                'patient_info': patient_info,
                'embeddings': final_embeddings
            }
                | "MergedData" >> beam.CoGroupByKey()
          ) 
        # | beam.Map(print)

        #converting to tf train examples
        examples = merged_data | "ToTfExample" >> beam.MapTuple(to_example)

        #writing the outoput to TFRecord files used for downstream tensorflow tasks
        examples | "WriteTFRecords" >> beam.io.tfrecordio.WriteToTFRecord(
            file_path_prefix=output_prefix,
            file_name_suffix=".tfrecord",
            coder=beam.coders.ProtoCoder(tf.train.Example),
        )

run()