In [1]:
import os,sys
sys.path.insert(0, r'~\trainer')
sys.path.insert(0, r'~\core')

import time
from natsort import natsorted
import pandas as pd
import numpy as np
import itertools
import random

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models

from PIL import Image

from sklearn.metrics import roc_curve,auc
from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt

from core.mean_teacher_main import TextFileDataset

import xlwt
import xlrd
import xlutils.copy
import  pandas  as pd
from xlrd import open_workbook
import shutil

In [2]:
def load_model(model_path):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = models.resnet50(pretrained=False)
    fc_inputs = model.fc.in_features
    model.fc = nn.Linear(fc_inputs, 2)
    for param in model.parameters():
        param.requires_grad = True
    model = nn.DataParallel(model).cuda()
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint) 
    model.to(device)
    model.eval()
    return model

def load_model_3(model_path):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = models.resnet50(pretrained=False)
    fc_inputs = model.fc.in_features
    model.fc = nn.Linear(fc_inputs, 3)
    for param in model.parameters():
        param.requires_grad = True
    model = nn.DataParallel(model).cuda()
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint) 
    model.to(device)
    model.eval()
    return model

def predict(model,img_path,channel_stats,binary_threshold):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    image_PIL = Image.open(img_path)
    eval_transformation = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),
        transforms.ToTensor(),transforms.Normalize(**channel_stats)])
    image_tensor = eval_transformation(image_PIL)
    image_tensor.unsqueeze_(0)
    image_tensor = image_tensor.to(device)
    out = model(image_tensor)
    percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
    zxd = percentage.cpu().detach().numpy().tolist()
    lable = 0
    if len(zxd) == 2:
        if float(zxd[1]) > binary_threshold:
            lable = 1
    elif len(zxd) > 2:
        lable = zxd.index(max(zxd))
    return lable

def to_excel(data_list,file_path,col):
    rb = xlrd.open_workbook(file_path)
    wb = xlutils.copy.copy(rb)
    ws = wb.get_sheet(0)
    for i in range(len(data_list)):
        ws.write(i + 1, col, data_list[i])
    wb.save(file_path)

In [3]:
# load the trained models

boundary_model_path = 'best_model_path'
model_boundary = load_model(boundary_model_path)

surface_model_path = 'best_model_path'
model_surface = load_model(surface_model_path)

bleeding_model_path = 'best_model_path'
model_chuxue = load_model(bleeding_model_path)

tone_model_path = 'best_model_path'
model_hongbai = load_model_3(tone_model_path)

elevated_model_path = 'best_model_path'
model_elevated = load_model(elevated_model_path)

depressed_model_path = 'best_model_path'
model_depressed = load_model(depressed_model_path)

In [4]:
imgs_path = r"~\data\imgs_feature_extraction_deep_learning"
excel_path = "file path for the sheet saving the output"

In [3]:
name_list = []
boundary_label_list = []
surface_label_list = []
bleeding_label_list = []
tone_label_list = []
elevated_label_list = []
depressed_label_list = []
ca_noca_label_list = []
for file in os.listdir(imgs_path):
    name_list.append(file)
    img_path = os.path.join(imgs_path, file)
    
    channel_stats_boundary = dict(mean=[0.2061, 0.1973, 0.1918],std=[0.2777, 0.2747, 0.2733])
    boundary_label =  predict(model_boundary, img_path, channel_stats_boundary, 65)
    boundary_label_list.append(boundary_label)  
    
    channel_stats_surface = dict(mean=[0.1801, 0.1735, 0.1492],std=[0.2437, 0.2395, 0.2168])
    surface_label =  predict(model_surface,img_path,channel_stats_surface,95)
    surface_label_list.append(surface_label) 
    
    channel_stats_bleeding = dict(mean=[0.1713, 0.1608, 0.1504],std=[0.255, 0.25, 0.2458])
    bleeding_label =  predict(model_bleeding,img_path,channel_stats_bleeding,40)
    bleeding_label_list.append(bleeding_label) 
    
    channel_stats_tone = dict(mean=[0.2018, 0.1929, 0.1886],std=[0.2741, 0.272, 0.2716])
    tone_label =  predict(model_tone,img_path,channel_stats_tone,50)
    tone_label_list.append(tone_label) 

    channel_stats_elevated = dict(mean=[0.218, 0.2089, 0.2019],std=[0.2767, 0.2747, 0.274])
    elevated_label =  predict(modle_elevated,img_path,channel_stats_elevated,80)
    elevated_label_list.append(elevated_label) 
    
    channel_stats_depressed = dict(mean=[0.1966, 0.1878, 0.1802],std=[0.2712, 0.2683, 0.2652])
    depressed_label =  predict(model_depressed,img_path,channel_stats_depressed,35)
    depressed_label_list.append(depressed_label) 
    
to_excel(name_list,excel_path,0) 
to_excel(boundary_label_list,excel_path,1)
to_excel(surface_label_list,excel_path,2) 
to_excel(bleeding_label_list,excel_path,3) 
to_excel(tone_label_list,excel_path,4) 
to_excel(elevated_label_list,excel_path,5) 
to_excel(depressed_label_list,excel_path,6) 