In [4]:
"""
Author: Valentina Matos (Johns Hopkins - Wirtz/Kiemen Lab)
Date: May 29, 2024
"""

import os
import tensorflow as tf
import numpy as np
from glob import glob
from PIL import Image
import matplotlib.pyplot as plt
from DeepLabV3 import DeepLabV3Plus
import pickle
import time
from scipy.ndimage import binary_fill_holes
from Semanticseg import semantic_seg


In [5]:
#Inputs:
pthim = r'\\10.99.68.52\kiemendata\Valentina Matos\coda to python\test model\5x'
pthDL = r'\\10.99.68.52\Kiemendata\Valentina Matos\coda to python\test model\model test tiles'
color_overlay_HE=True #set to true in the real code
color_mask=False


In [None]:
# Load the model weights and other relevant data
with open(os.path.join(pthDL, 'net.pkl'), 'rb') as f:
    data = pickle.load(f)
    model = data['model']
    classNames = data['classNames']
    sxy = data['sxy']
    nblack = data['nblack']
    nwhite = data['nwhite']
    cmap = data['cmap']
    nm = data['nm']


In [None]:
outpth = os.path.join(pthim, 'classification_'+ nm)
os.makedirs(outpth, exist_ok=True)

In [None]:
b=100
imlist = sorted(glob(os.path.join(pthim, '*.png')))
# If no PNGs found, search for JPG files
if not imlist:
  jpg_files = glob(os.path.join(pthim, "*.jpg"))
  if jpg_files:
    imlist.extend(jpg_files)  # Add full paths of JPGs to list
if not imlist:
  print("No PNG or JPG image files found in", pthim)
print('   ')

In [None]:
classification_st = time.time()
for i, img_path in enumerate(imlist):
    img_name = os.path.basename(img_path)
    print(f'  Starting classification of image {i+1} of {len(imlist)}: {img_name}')
    if os.path.isfile(os.path.join(outpth, img_name[:-4] + ".png")):
        print(f'  Image {img_name} already classified by this model')
        continue
    image = Image.open(os.path.join(pthim, img_name))
    try:
        try:
            TA = Image.open(os.path.join(outpth,'TA',img_name[:-4] + ".png"))
        except:
            TA = Image.open(os.path.join(outpth,'TA',img_name[:-4] + ".tif"))
        TA = binary_fill_holes(TA)
    except:
        TA = np.array(image.convert('L')) < 220
        TA = binary_fill_holes(TA.astype(bool))
        
    imclassify = np.zeros(TA.shape, dtype=np.uint8)
    sz = np.array(image).shape
    
    for s1 in range(0, sxy - b * 2, sz[0]-sxy):
        for s2 in range(0, sxy - b * 2, sz[1]-sxy):
            tileHE= image[s1:s1+sxy, s2:s2+sxy, :]
            tileTA = TA[s1:s1+sxy, s2:s2+sxy]
            
            if np.sum(tileTA) < 100:
                tileclassify = np.zeros(TA.shape)
            else:
                tileclassify = semantic_seg(tileHE, image_size=1024,model=model)
            tileclassify = tileclassify[b:-b, b:-b, :]
            imclassify[s1+b:s1+sxy-b, s2+b:s2+sxy-b] = tileclassify

                
                
        
            
        
        
    