In [None]:
from tqdm import tqdm
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import tnrange, tqdm
import os
import numpy as np
from atpbar import atpbar
from mantichora import mantichora
import json

AUTO = tf.data.experimental.AUTOTUNE

def image_aug(image,one_hot_label):
        image = tf.image.transpose(image)
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
#        image = tf.image.random_hue(image,0.1)
        return image,one_hot_label
    
def read_tfrecord(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),  # tf.string = bytestring (not text string)
        "class": tf.io.FixedLenFeature([], tf.int64),   # shape [] means scalar
        
        # additional (not very useful) fields to demonstrate TFRecord writing/reading of different types of data
        "label":         tf.io.FixedLenFeature([], tf.string),  # one bytestring
        "size":          tf.io.FixedLenFeature([2], tf.int64),  # two integers
        "one_hot_class": tf.io.VarLenFeature(tf.float32)        # a certain number of floats
    }
    # decode the TFRecord
    example = tf.io.parse_single_example(example, features)
    
    # FixedLenFeature fields are now ready to use: exmple['size']
    # VarLenFeature fields require additional sparse_to_dense decoding
    
    image = tf.image.decode_png(example['image'], channels=3)
    class_num = example['class']   
    label  = example['label']
    height = example['size'][0]
    width  = example['size'][1]
    one_hot_class = tf.sparse.to_dense(example['one_hot_class'])
    
    image = tf.image.resize(image, [296,296])
    image /= 255.0 
    
    return image,one_hot_class
    
# read from TFRecords. For optimal performance, read from multiple
# TFRecord files at once and set the option experimental_deterministic = False
# to allow order-altering optimizations.

option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False

filenames = tf.io.gfile.glob('/home/agsl0905/PDL1_HER2_data/tpu_data/retrain_from_scratch/tile_750/valid/PDL1_valid_data_examples_4777_6_.tfrec')
dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
dataset = dataset.with_options(option_no_order)
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO).map(image_aug,num_parallel_calls=AUTO)
dataset = dataset.batch(32)

In [None]:
img_n_label_with_aug = [(image.numpy(), one_hot_class.numpy())for i, (image, one_hot_class) in enumerate(dataset.unbatch())]

In [None]:
from bokeh.plotting import figure, output_file, show
from bokeh.models import Range1d, ColumnDataSource, Legend, LegendItem
from bokeh.models.mappers import ColorMapper
from bokeh.io import export_png, output_notebook, output_file
from bokeh.layouts import gridplot
import PIL.Image

def bokeh_plot_img_label(zipped_img_array_label, n_to_plot, notebook_output = True, ncols = 5):
    """
    Input: zipped img_fps, pred_class and pred_multilabel
            n_to_plot: number of img and prediction to plot
            output_notebook : False for output plots in html
    Output:
            Bokeh plot 
    """

    if notebook_output:
        output_notebook()
    else:
        output_file('bokeh.html')
    plots = []
    
    def build_legend_item(label):
        items_list = [LegendItem(label = ['Label: '+str(label)],renderers=[glyphs],index = 0)] 
#        items_list.append(LegendItem(label = ['fp :' + img_fp],renderers=[glyphs],index = 1))
#         for i,label in enumerate(pred_label_top_n):   
#             items_list.append(LegendItem(label =['Cateogory' + str(label)],renderers=[glyphs],index = i+1))
        return items_list
    
#    zipped_img_fps_preds= list(zipped_img_fps_preds)
    if (n_to_plot > len(zipped_img_array_label)): n_to_plot = len(zipped_img_array_label)
    
    for i in range(n_to_plot):
        img, label = zipped_img_array_label[i]
        if img.dtype != np.uint8:
            img*=255
            img = img.astype(np.uint8)
        image = PIL.Image.fromarray(img)

#        size_new = (width_pixel,height_pixel)
        image = image.convert('RGBA')
        a = np.array(image)
        width_pixel,height_pixel = a.shape[0:2]
        # stack each tuple (R,G,B,A) to a UINT32,then reshape to (width,height)
        img = a.view(dtype=np.uint32).reshape(a.shape[:-1])
    
        # reverse y axis manualy
        img = img[::-1]
        
        subp = figure(height = 250,width = 250)
        glyphs = subp.image_rgba(image=[img], x=0, y=height_pixel, dw=width_pixel, dh=height_pixel)#labels_list[i])#_label = str(labels_list[i],top5_label[i]))

        items = build_legend_item(label)

        legend = Legend(items=[item for item in items])
        subp.add_layout(legend,place = 'below')
        plots.append(subp)

    show(gridplot(children = plots, ncols = ncols, merge_tools = True))
    

In [None]:
bokeh_plot_img_label(img_n_label_with_aug , 50, notebook_output = True)
#show(gridplot(children = plots, ncols = 3, merge_tools = True))