In [2]:
import torch
from model import UNet
import utils
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from keras.preprocessing import image as Image
from keras.applications.xception import preprocess_input
import uuid

localizationWeights = r'../checkpoints/MBM/epoch_10.pth.tar'
classifyModel = r'../checkpoints/LUSC/my_model.h5'


def detect(srcImgPath):
    model = UNet(2).to(utils.device())
    model.load_state_dict(torch.load(localizationWeights, map_location=utils.device())['state_dict'])
    model.eval()

    image = utils.read_image(srcImgPath)
    transform = T.Compose([T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    x = transform(image).unsqueeze(0)
    points = None
    with torch.no_grad():
        x = x.to(utils.device())
        pred = utils.ensure_array(model(x)).squeeze(0)
        points = utils.extract_points_from_direction_field_map(pred, lambda1=0.7, step=10)
    
    # save image and points
    image_array = np.array(image)
    plt.figure(dpi=500)
    plt.imshow(image_array)
    plt.axis('off')
    points_array = np.array(points)
    plt.plot(points_array[:, 1], points_array[:, 0], marker='o', markerfacecolor='#f9f738', markeredgecolor='none', markersize=2,
             linestyle='none')
    uid = str(uuid.uuid4())
    fileName = ''.join(uid.split('-')) + ".png"
    savePath = 'F:/demo/res/img/' + fileName
    plt.savefig(savePath, bbox_inches='tight')
    plt.close()

    return fileName, image, points

def getPatch(image, point):
    xMax, yMax = image.size
    x, y = point
    l = (x - 30) if ((x - 30) > 0) else 0
    r = (x + 30) if ((x + 30) < xMax) else xMax
    h = (y - 30) if ((y - 30) > 0) else 0
    b = (y + 30) if ((y + 30) < yMax) else yMax
    box = (l, h, r, b)

    patch = image.crop(box)
    return patch

def classify(model ,image):
    x = np.array(image)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    predict = model.predict(x)
    print(predict)

    if round(predict[0,0])==0:
        result = 'ying'
    else:
        result = 'yang'

    return result

    

def analysis(srcImgPath):
    # 获取细胞位置
    fileName, image, points = detect(srcImgPath)
    # print(len(points))

    model = keras.models.load_model(classifyModel)
    yang = 0
    ying = 0
    for point in points:
        patch = getPatch(image, point)
        patch = tf.image.resize(patch, [80, 80])
        patch = tf.cast(patch, tf.float32)
        patch = patch / 255

        result = classify(model, patch)
        if result == 'yang':
            yang += 1
        else:
            ying += 1
    return fileName, yang, ying



srcImgPath = 'F:/demo/res/img/359067147c464632b4c02d5ea777d44f.jpg'
print(analysis(srcImgPath=srcImgPath))


143
[[0.5088523]]
[[0.49920043]]
[[0.5104897]]
[[0.50842905]]
[[0.5151948]]
[[0.5405941]]
[[0.5643462]]
[[0.5446451]]
[[0.511352]]
[[0.51998687]]
[[0.49892744]]
[[0.4854251]]
[[0.5154675]]
[[0.53059345]]
[[0.5155532]]
[[0.46792305]]
[[0.51950634]]
[[0.5159077]]
[[0.5416336]]
[[0.53890324]]
[[0.52203554]]
[[0.5297162]]
[[0.50373024]]
[[0.5179671]]
[[0.5281589]]
[[0.5054146]]
[[0.5110869]]
[[0.51732033]]
[[0.5106951]]
[[0.5170113]]
[[0.5124434]]
[[0.51249826]]
[[0.4988435]]
[[0.5360318]]
[[0.51405597]]
[[0.5118872]]
[[0.51142406]]
[[0.51700574]]
[[0.50053]]
[[0.5144951]]
[[0.50949633]]
[[0.5187515]]
[[0.5117243]]
[[0.50005704]]
[[0.51406884]]
[[0.5129771]]
[[0.51334536]]
[[0.50303227]]
[[0.5384573]]
[[0.5098624]]
[[0.5124483]]
[[0.5145661]]
[[0.5156941]]
[[0.5119009]]
[[0.51977986]]
[[0.5173405]]
[[0.5323397]]
[[0.48755965]]
[[0.50595504]]
[[0.5033086]]
[[0.51626164]]
[[0.5202695]]
[[0.5184417]]
[[0.5280922]]
[[0.523526]]
[[0.5156605]]
[[0.5173372]]
[[0.51723987]]
[[0.52701414]]
[[0.5343