In [51]:
# 测试染色体分类的预处理以及分类代码是否可用
# 下为预处理代码
# -*- coding: utf-8 -*-
from __future__ import print_function

import sys
import os
import pickle
import argparse
import itertools

import cv2

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# from datasets.simple import *
# from resnet import *
from transforms import *
from plot import *
from resnet import *

In [54]:
# 测试预处理
def read_image_list_cv2():
    image_root = '/home/voyager/project/chromosome-classifier/input'
    image_list = os.listdir(image_root)
    #print(image_list)
    image_list = [image for image in image_list if 'JPG' in image]
    images = [cv2.imread(os.path.join(image_root, image_file)) for image_file in image_list]
    return images

def read_image_list_pil():
    image_root = '/home/voyager/project/chromosome-classifier/input'
    image_list = os.listdir(image_root)
    #print(image_list)
    image_list = [image for image in image_list if 'JPG' in image]
    images = [Image.open(os.path.join(image_root, image_file)) for image_file in image_list]
    print(type(images[0]))
    return images


def show_image_list(images, idx=None):
    """images: list[ndarray] or list[PIL.Image]"""
    for image in images:
        plt.imshow(image)
        plt.show()
        
def get_transform(original_size=1024, resize=448):
    return transforms.Compose([
        transforms.Grayscale(),
        AutoLevel(0.7, 0.0001),
        transforms.CenterCrop(size=original_size),
        transforms.Resize(resize),
        transforms.ToTensor()
    ])

def get_pretrained_model(model_path=None, pretrained=False, num_classes=2):
    print('initing model:{}'.format('resnet101'))
    model = resnet101(pretrained, num_classes=num_classes)
    if model_path is None or os.path.exists(model_path) is False:
        model_path='/home/voyager/project/chromosome-classifier/checkpoint/epoch_0079_loss_0.198197.pth'
        print('invalid model path, using:{}'.format(model_path))
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['net'])
    #print(model)
    return model

def test_preprocess():
    images = read_image_list_pil()
    trans = get_transform()
    image = trans(images[0])
    image = image.reshape(448, 448)
    print(type(image), image.shape)
    plt.imshow(image, cmap='gray')
    plt.show()
    
def test_model():
    images = read_image_list_pil()
    trans = get_transform() # (1, 448, 448)
    image = trans(images[0])
    print(image.shape)
    image = image.float().unsqueeze(dim=0)
    print(image.shape)
    model = get_pretrained_model()
    output = model(image)
    print(output)
    print(output.shape) # (1, 50) 
    
def main():
    test_model()

main()

<class 'PIL.JpegImagePlugin.JpegImageFile'>
torch.Size([1, 448, 448])
torch.Size([1, 1, 448, 448])
initing model:resnet101
invalid model path, using:/home/voyager/project/chromosome-classifier/checkpoint/epoch_0079_loss_0.198197.pth
tensor([[ -2080.2769,  -2079.8442,  -2113.2104,  -2115.7053,  -2211.8979,
          -3046.1548,  -3050.4468,  -3091.0552,  -3092.9292,  -3096.3279,
          -3275.7056,  -3285.5588,  -3330.2534,  -3333.0608,  -3331.1553,
          -3270.6162,  -3289.6643,  -3328.6631,  -3351.9819,  -3636.6816,
          -3297.5435,  -3383.8367,  -4200.8862,  -4660.6973,  -5525.1001,
         -55919.3594, -55873.8398, -55849.0000, -55924.9297, -56212.0195,
         -53047.5664, -53005.5078, -52911.2734, -52984.0898, -53256.3711,
         -50608.1523, -50229.5234, -49848.7031, -49859.3086, -50580.4492,
         -50727.0195, -50375.4531, -50338.3906, -52120.3828, -54030.5391,
         -50945.1680, -52144.1172, -60907.6055, -73091.3672, -79029.1250]],
       grad_fn=<ViewBackw

In [None]:
# 使用数据测试效果
from PIL import Image
from ClassifyCore import ClassifyCore

model_path = '/home/voyager/jpz/chromosome/exp/res34_bs64_epoch200_lr0.0001/checkpoint/epoch_0020_perc_95.789599.pth'
model_type = 'resnet34'
preprocess = 'autolevel'

core = ClassifyCore(model_path=model_path, model_type=model_type, preprocess=preprocess)

image_root = '/home/voyager/project/chromosome-classifier/input'
image_list = os.listdir(image_root)
image_list = [image for image in image_list if 'JPG' in image]

images = read_image_list_pil()
for path in image_list:
    output = core.classify(path)