In [4]:
"""
Author: Valentina Matos (Johns Hopkins - Wirtz/Kiemen Lab)
Date: May 29, 2024
"""

import os
import tensorflow as tf
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from DeepLabV3 import DeepLabV3Plus
import pickle


In [5]:
#Inputs:
pthim = r'\\10.99.68.52\kiemendata\Valentina Matos\coda to python\test model\5x'
pthDL = r'\\10.99.68.52\Kiemendata\Valentina Matos\coda to python\test model\model test tiles'
colHE=False #set to true in the real code
col=False


In [None]:
# Load the model weights and other relevant data
with open(os.path.join(pthDL, 'net.pkl'), 'rb') as f:
    data = pickle.load(f)
    saved_weights = data['net']
    classNames = data['classNames']
    sxy = data['sxy']
    nblack = data['nblack']
    nwhite = data['nwhite']
    cmap = data['cmap']
    nm = data['nm']

In [None]:
input_shape = (sxy,sxy,3)
num_classes = len(classNames)
model = DeepLabV3Plus(shape=(input_shape), num_classes=num_classes)
model.set_weights(saved_weights)

In [None]:
def segment_image_tile(model, tile, input_shape):
    tile = tile / 255.0  # Normalize
    tile = tf.expand_dims(tile, axis=0)  # Add batch dimension
    predictions = model.predict(tile)
    predictions = tf.argmax(predictions, axis=-1)
    predictions = tf.squeeze(predictions, axis=0)
    return predictions

In [None]:
# Output path
outpth = os.path.join(pthim, f'classification_{nm}')
os.makedirs(outpth, exist_ok=True)

In [None]:
imlist = [f for f in os.listdir(pthim) if f.endswith(('tif', 'png'))]
b = 100

In [None]:
for kk, nm in enumerate(imlist):
    print(f'Starting classification of image {kk + 1} of {len(imlist)}: {nm}')
    im_path = os.path.join(pthim, nm)
    im = tf.io.read_file(im_path)
    im = tf.image.decode_image(im, channels=3)
    
    # Pad image
    im = tf.pad(im, [[sxy + b, sxy + b], [sxy + b, sxy + b], [0, 0]], "CONSTANT")
    imclassify = np.zeros(im.shape[:2], dtype=np.uint8)
    
    # Classify tiles
    sz = im.shape
    for s1 in range(0, sz[0] - sxy, sxy - b * 2):
        for s2 in range(0, sz[1] - sxy, sxy - b * 2):
            tileHE = im[s1:s1+sxy, s2:s2+sxy, :]
            tileclassify = segment_image_tile(model, tileHE, input_shape)
            tileclassify = tileclassify[b:-b, b:-b]
            imclassify[s1+b:s1+sxy-b, s2+b:s2+sxy-b] = tileclassify
    
    # Remove padding
    imclassify = imclassify[sxy+b:-sxy-b, sxy+b:-sxy-b]
    
    # Save the classified image
    classified_image_path = os.path.join(outpth, nm)
    tf.keras.preprocessing.image.save_img(classified_image_path, imclassify)
    print(f'Classified image saved to {classified_image_path}')
    
    # Optional: Save colorized and overlayed images
    if col:
        outpthcolor = os.path.join(outpth, 'color')
        os.makedirs(outpthcolor, exist_ok=True)
        cmap = data['cmap']
        am, bm, cm = cmap[:, 0], cmap[:, 1], cmap[:, 2]
        imcolor = np.dstack((am[imclassify], bm[imclassify], cm[imclassify]))
        color_image_path = os.path.join(outpthcolor, nm)
        tf.keras.preprocessing.image.save_img(color_image_path, imcolor)
        print(f'Colorized image saved to {color_image_path}')
    
    if colHE:
        outpth2 = os.path.join(outpth, 'check_classification')
        os.makedirs(outpth2, exist_ok=True)
        # make_check_annotation_image(im, imclassify, data['cmap'], 2, os.path.join(outpth2, nm)) #uncomment in real code
        print(f'Overlay image saved to {os.path.join(outpth2, nm)}')
