In [None]:
!pip install apache_beam[gcp]

In [None]:
# Install system libraries
!apt-get update
!apt-get install -y libvips-dev
!pip install pyvips
!pip install openslide

In [None]:
!pip install openslide_python.tar.gz

In [None]:
import pandas as pd
import apache_beam as beam
import csv
import openslide
import os
from numpy import typing as npt
import numpy as np
import pyvips
import numpy as np
from typing import List, Tuple
import tensorflow as tf

from google.colab import auth
auth.authenticate_user()

In [None]:
#csv_path = '/kaggle/input/mayo-clinic-strip-ai/train.csv'
image_path = 'gs://mayo-clinic-data/006388_0.tif'
csv_path = 'gs://mayo-clinic-data/train.csv'
train_csv = pd.read_csv(csv_path)
train_csv

Unnamed: 0,image_id,center_id,patient_id,image_num,label
0,006388_0,11,006388,0,CE
1,008e5c_0,11,008e5c,0,CE
2,00c058_0,11,00c058,0,LAA
3,01adc5_0,11,01adc5,0,LAA
4,026c97_0,4,026c97,0,CE
...,...,...,...,...,...
749,fe9645_0,3,fe9645,0,CE
750,fe9bec_0,4,fe9bec,0,LAA
751,ff14e0_0,6,ff14e0,0,CE
752,ffec5c_0,7,ffec5c,0,LAA


In [None]:
def key_patient_info(row_dict):
    return (row_dict['patient_id'], row_dict)

def _bytes_feature(value):
    if not isinstance(value, list):
        value =  [value]
    return tf.train.Feature(bytes_list = tf.train.BytesList(value=value))

def _float_feature(value):
    if not isinstance(value, list):
        value =  [value]
    return tf.train.Feature(float_list = tf.train.FloatList(value=value))

def _int64_feature(value):
    if not isinstance(value, list):
        value =  [value]
    return tf.train.Feature(int64_list = tf.train.Int64List(value=value))

def load_downscaled_image(
    local_file_path: str,
    downsample_ratio: int = 16,
) -> npt.NDArray[np.uint8]:
    full_img = pyvips.Image.new_from_file(local_file_path)
    scaled_down_image: npt.NDArray[np.uint8] = full_img.resize(
        1 / downsample_ratio
    ).numpy()
    return scaled_down_image

def is_non_empty(tile: np.ndarray, background_threshold: float = 0.1) -> bool:
    """
    Determine if a tile is non-empty based on percentage of non-background pixels.
    """
    gray = np.mean(tile, axis=2)
    background_mask = gray > 190  # Adjust threshold as needed
    background_ratio = np.sum(background_mask) / tile.size
    print('Y/N', background_ratio < background_threshold)
    return background_ratio < background_threshold

PatientId = NewType("PatientId", str)
Image = npt.NDArray[np.uint8]

class TileEntry(NamedTuple):
    """Schema for a tile entry."""

    patient_id: PatientId
    image: Image


class Embedding(NamedTuple):
    """Schema for aggregated embeddings."""

    max_embedding: tf.Tensor
    avg_embedding: tf.Tensor


class EmbeddingEntry(NamedTuple):
    """Schema for prediction entry."""

    patient_id: PatientId
    embedding: Embedding


def embed_tiles(
    model: tf.keras.Model,
    tiles_batch: Sequence[Image],
) -> Iterable[Embedding]:
    """
    Run a batch of input images through EfNet
    to generate aggregated embeddings.
    """
    # convert from NumPy to TensorFlow
    input_tensor = tf.ensure_shape(
        tf.convert_to_tensor(tiles_batch),
        [None, 224 * 2, 224 * 2, 3],
    )
    # The input tile is twice as big as needed
    # for EfNet - we scaled it down 2x here.
    resized_input_tensor = tf.image.resize(input_tensor, [224, 224])
    # generate embeddings using EfNet
    results = model(resized_input_tensor, training=False)
    avg_embeddings = tf.ensure_shape(results["avg"], [None, 1280])
    max_embeddings = tf.ensure_shape(results["max"], [None, 1280])
    # wrap the results
    for avg_embedding, max_embedding in zip(avg_embeddings, max_embeddings):
        yield Embedding(avg_embedding=avg_embedding, max_embedding=max_embedding)


class TileEmbeddingModelHandler(
    ModelHandler[TileEntry, EmbeddingEntry, tf.keras.Model]
):
    """Wrapper around EfficientNet embedding."""

    def load_model(self) -> tf.keras.Model:
        """Prepare an EfNet aggregation model for tile images."""
        # The model will consume 224x224 RGB images
        image = tf.keras.layers.Input(
            shape=(224, 224, 3), name="image", dtype=tf.float32
        )
        # We use EfNet B0 without top layers for embedding
        backbone = tf.keras.applications.EfficientNetB0(
            include_top=False, weights="imagenet", input_tensor=image
        )
        # To save on compute resources, we won't fine-tune the EfNet backbone
        backbone.trainable = False
        # The backbone output has shape [, 7, 7, 1280]
        # We generate two aggregations over the backbone output
        # to obtain a [, 1280] shape
        avg_pool = tf.keras.layers.GlobalAveragePooling2D()(backbone.output)
        max_pool = tf.keras.layers.GlobalMaxPooling2D()(backbone.output)
        model = tf.keras.Model(
            image, {"avg": avg_pool, "max": max_pool}, name="EfficientNet"
        )
        model.compile()
        return model

    def run_inference(
        self,
        batch: Sequence[TileEntry],
        model: tf.keras.Model,
        inference_args: Optional[Dict[str, Any]] = None,
    ) -> Iterable[EmbeddingEntry]:
        """Run inference using the loaded model."""
        # Extract just the tile images from the input batch
        input_images = [tile.image for tile in batch]
        # Embed tile images using EfNet
        embeddings = embed_tiles(model, input_images)
        # Wrap the resulting embeddings together with the patient identifier
        for tile, embedding in zip(batch, embeddings):
            yield EmbeddingEntry(patient_id=tile.patient_id, embedding=embedding)

class PatientDataCombineFn(beam.CombineFn):
    def create_accumulator(self):
        return {
            'center_ids': set(),
            'labels': set(),
            'image_ids': []
        }

    def add_input(self, acc, element):
        acc['center_ids'].add(element['center_id'])
        if 'label' in element:
            acc['labels'].add(element['label'])
        acc['image_ids'].append(element['image_id'])
        return acc

    def merge_accumulators(self, accs):
        result = self.create_accumulator()
        for acc in accs:
            result['center_ids'].update(acc['center_ids'])
            result['labels'].update(acc['labels'])
            result['image_ids'].extend(acc['image_ids'])
        return result

    def extract_output(self, acc):
        return {
            'center_id': list(acc['center_ids'])[0],
            'label': list(acc['labels'])[0] if acc['labels'] else None,
            'image_ids': acc['image_ids']
        }

class GenerateTiles(beam.DoFn):
    def __init__(self, tiff_files_location, tile_size=448, downsample_ratio=16):
        super().__init__()
        self.tiff_files_location = tiff_files_location
        self.downsample_ratio = downsample_ratio
        self.tile_size = tile_size
    def process(self, element):
        print(element)
        key, row = element
        image_id = row['image_id']
        image_path = os.path.join(self.tiff_files_location, f'{image_id}.tif')
        img = load_downscaled_image(image_path, downsample_ratio=self.downsample_ratio)
        height,width,channels = img.shape
        #tiles = [img[x:x+tile_size,y:y+N=tile_size] for x in range(0,height,tile_size) for y in range(0,width,tile_size)]
        #background_pixels = np.all(tile > 190, axis=2)
        #background_ratio = np.mean(background_pixels)
        for x in range(0, height, self.tile_size):
            for y in range(0, width, self.tile_size):
                tile = img[x:x+self.tile_size, y:y+self.tile_size]
                yield key, {'image_path':image_path, 'tile_info': (x,y,tile)}

class ExtractNonEmptyTiles(beam.DoFn):
    def __init__(self, tile_size=448, background_threshold=0.10):
        super().__init__()
        self.tile_size = tile_size
        self.background_threshold = background_threshold
    def process(self, element):
        key, ele_dict = element
        image_path = ele_dict['image_path']
        x,y,tile = ele_dict['tile_info']
        position = (y,x)
        size = (self.tile_size, self.tile_size)
        with openslide.open_slide(image_path) as slide:
            if is_non_empty(tile, background_threshold=self.background_threshold):
                tile_img = slide.read_region(position, 0, size).convert("RGB")
                tile = tf.convert_to_tensor(tile_img)
                encoded_tile = tf.io.encode_png(tile)
                yield key, encoded_tile.numpy()

class ProduceTiles(beam.PTransform):
    def __init__(self, tiff_files_location, downsample_ratio=16, tile_size=448, background_threshold=0.10):
        super().__init__()
        self.tiff_files_location = tiff_files_location
        self.downsample_ratio = downsample_ratio
        self.tile_size = tile_size
        self.background_threshold = background_threshold
    def expand(self, pcoll):
        return (pcoll
                | beam.ParDo(GenerateTiles(self.tiff_files_location))
                | beam.ParDo(ExtractNonEmptyTiles(background_threshold=self.background_threshold)))

class ConvertToTFRecords(beam.DoFn):
    def process(self, inputs):
        image_id = inputs[0]
        center_id = inputs['center_id']
        label = inputs['label']
        image_ids = inputs['image_ids']
        tiles = inputs['tiles']

        all_features = tf.train.Features(feature={
            'image_id': _bytes_feature(image_id),
            'center_id': _int64_feature(center_id),
            'label': _bytes_feature(label),
            'image_ids':  _bytes_feature(image_ids),
            'tiles': _bytes_feature(tiles)
        })
        example_proto = tf.train.Example(features=all_features)
        yield example_proto.SerializeToString()



In [None]:
options = {
    #'runner': 'DirectRunner',
    #'num_workers': 0,
    #'direct_running_mode': 'multiprocessing',
    'runner': 'DataflowRunner',
    'project': 'my-notebook-372506',
    'region': 'us-east1',
    'temp_location': 'gs://mayo-clinic-data/',
}

#with open(csv_path, 'r') as f:
#    header_line = f.readline().strip()
#    header = header_line.split(',')

header = list(train_csv.columns)
pipeline_options = beam.pipeline.PipelineOptions(flags=[], **options)
def run():


    with beam.Pipeline(options=pipeline_options) as p:
        csv_lines = (
            p
            | "ReadCSV" >> beam.io.ReadFromText(csv_path, skip_header_lines=1)
        )
        parsed_row_dicts = (
            csv_lines
            | beam.Map(lambda row: dict(zip(header, row.split(','))))
        )
        keyed_rows = (
            parsed_row_dicts
            | beam.Map(key_patient_info)
        )
        patient_info = (
            keyed_rows
            | beam.CombinePerKey(PatientDataCombineFn())
            #| beam.Map(print)
        )

        tiles = (
            keyed_rows
            | beam.combiners.Sample.FixedSizeGlobally(1)
            | beam.Map(lambda x: x[0])
            | ProduceTiles(tiff_files_location='gs://mayo-clinic-demo',
                          background_threshold=0.3)
        )

        tile_entries = tiles | beam.Map(convert_to_tile_entry)

        #creating embeddings for the input tiles (creating max and avg embeddings using EfficientNet model)
        embedded_tiles = tile_entries | "EmbedTiles" >> RunInference(TileEmbeddingModelHandler())
        final_embeddings = embedded_tiles | "KeyByPatientId" >> beam.Map(lambda e: (e.patient_id, e)) | "CombineEmbeddingsPerPatient" >> beam.CombinePerKey(CombineEmbeddingsFn())


        merged_examples = ((patient_info, final_embeddings)
                  | beam.CoGroupByKey()
                  | beam.Filter(lambda ele: ele[1][1])  #removes if list is empty of tiles is empty
                  | beam.MapTuple(lambda key, value: (key, {**value[0][0], 'tiles':value[1]}))
                 )
        merged_examples | beam.io.tfrecordio.WriteToTFRecord(file_path_prefix='gs://mayo-clinic-demo/tfrecords/output',
                                                            file_name_suffix='.tfrecord',
                                                            coder=beam.coders.BytesCoder())
run()