In [1]:
## packages
import os
import sys
# get parent dir
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname('USskintumor'))))
import torch
import torchvision
from torchvision import datasets, models, transforms
from PIL import Image

import flask
from flask import Flask, request, render_template

import numpy as np
from scipy import misc
import imageio
import copy

from model.combined_CNN_for_CAM import conv3x3, combined_cnn, _combined_model
from loss_functions.focal_loss import  *

In [2]:
## CUDA
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    raise Exception('cuda is not available')

In [3]:
## model define, transform
def combined_net(**kwargs):
    return _combined_model(transfer_learning=True, num_classes = 3,  **kwargs)

transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((224,224), 3),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

In [4]:
## main page routing
app = Flask(__name__)
@app.route('/')
@app.route('/index')

def index(): # load html templates
    return flask.render_template('index.html')

In [5]:
## prediction protocol
@app.route('/predict', methods=['POST'])

def make_prediction():
    class_names = {'0':'epidermal cyst', '1':'lipoma', '2':'pilomatricoma'}
    global transform
    '''model'''
    model = combined_net().cuda()
    model_path = '../final_saved/0603/focal_loss/1.pth'
    model.load_state_dict(torch.load(model_path))
    

    if request.method == 'POST':
        # file upload protocol
        file = request.files['image']
        if not file : return render_templates('index.html', label = 'No Files')
        filename = file.filename
        img = imageio.imread(file)
        original_img = copy.deepcopy(img)
        img = cv2.resize(img, (224,224))
        origin = copy.deepcopy(img)

        '''save image'''
        img = Image.fromarray(img)
        img = transform(img)
        img = img.view((1,) + img.shape)
        model.eval()
        '''prediction'''

        output, g_att1, g_att2, x8 = model(img.cuda())

        prediction = torch.where(output[0] == output.max())[0].item()
        label = 'Predicted : ' + class_names[str(prediction)] + ' ' + '   /   Probability : ' + str(round(round(torch.softmax(output, -1)[0].max().item(), 3) * 100, 2)) + ' %'
        
        if not os.path.exists('./static/image/' + class_names[str(prediction)]):
            os.makedirs('./static/image/' + class_names[str(prediction)])
            
        im_list = os.listdir('./static/image/' + class_names[str(prediction)])
        file_num = int(len(im_list)/2) + 1
        file_ex = filename[filename.index('.')+1:] # file extension
        cv2.imwrite('./static/image/' + class_names[str(prediction)] + '/' + filename[:filename.index('.')] + '_' + str(file_num) + '.' + file_ex, original_img)
        '''CAM image rendering'''
        fw = model.classifier.weight.cpu().detach().numpy()
        fw = fw.transpose()
        fw_weights = {0:fw[0:3], 1:fw[3:6], 2:fw[6:9]}
        
        weights = {0: model.classifier1.weight.cpu().detach().numpy(),
                   1: model.classifier2.weight.cpu().detach().numpy(),
                   2: model.classifier3.weight.cpu().detach().numpy()}
        
        fnumbers = {0: 128, 1: 256, 2: 512}
        fmaps = {0: g_att1.cpu().detach().numpy(),
                      1: g_att2.cpu().detach().numpy(),
                      2: x8.cpu().detach().numpy()}

        pred = torch.softmax(output, dim=1)
        pred_id = torch.argmax(pred).item()
        # check among 3 classes
        for a in range(3):
            fn,w,fm, fw = fnumbers[a], weights[a], fmaps[a], fw_weights[a].transpose()
            for i in range(fn):
                # prediction : pred_id
                if i == 0: c_cam = w[pred_id][i] * fm[0][i]
                else: c_cam += w[pred_id][i] * fm[0][i]

            c_cam * np.mean(fw[pred_id])
            c_cam = cv2.resize(c_cam, (224,224), interpolation = cv2.INTER_CUBIC)
            if a == 0 : whole_cam = c_cam
            else: whole_cam += c_cam

        if pred_id == 1:
            whole_cam = whole_cam.max() - whole_cam   
        
        cam = whole_cam
        x,y = np.where(cam < cam.max()*0.7)
        cam[x,y] = cam.min()
        # normalization to uint8
        cam_norm = np.zeros(cam.shape, dtype = 'uint8')
        final_cam = cv2.normalize(cam, cam_norm, 0, 255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)

        plt.imshow(origin, cmap='gray')
        plt.imshow(final_cam, cmap ='jet', alpha =0.4)
        plt.axis('off')
        '''save figure in your directory'''
        plt.savefig('./static/image/'+ class_names[str(prediction)] +'/CAM_'+filename[:filename.index('.')]+'_'+str(file_num)+'.'+file_ex,
                    bbox_inches = 'tight', edgecolor='black', pad_inches = 0)

        return render_template('index.html', label=label,
                               image_file='image/' + class_names[str(prediction)] + '/' + file.filename[:file.filename.index('.')] + '_' + str(file_num) + '.' + file_ex,
                               cam_file='image/' + class_names[str(prediction)] + '/CAM_' + file.filename[:file.filename.index('.')] + '_' + str(file_num) + '.' + file_ex)

In [None]:
if __name__ == '__main__':
    #Flask 서비스 스타트
    # app.run(host='192.168.0.20:3389')
    app.run(host = '0.0.0.0')

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on all addresses.
 * Running on http://192.168.0.20:5000/ (Press CTRL+C to quit)
192.168.0.20 - - [10/Nov/2021 16:57:09] "GET / HTTP/1.1" 200 -
192.168.0.20 - - [10/Nov/2021 16:57:09] "GET /static/ HTTP/1.1" 404 -
192.168.0.20 - - [10/Nov/2021 16:57:09] "GET /static/style.css HTTP/1.1" 404 -
  nn.init.constant(self.psi.bias.data, 10.0) # initialize the bias for psi
  init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
192.168.0.20 - - [10/Nov/2021 16:57:15] "POST /predict HTTP/1.1" 200 -
192.168.0.20 - - [10/Nov/2021 16:57:15] "GET /static/style.css HTTP/1.1" 404 -
192.168.0.20 - - [10/Nov/2021 16:57:16] "GET /static/image/lipoma/7713683_20200212163809_002_1.png HTTP/1.1" 200 -
192.168.0.20 - - [10/Nov/2021 16:57:16] "GET /static/image/lipoma/CAM_7713683_20200212163809_002_1.png HTTP/1.1" 200 -
