# 常规赛：PALM病理性近视预测 8月第五名方案
## 解压数据集

In [1]:
%cd /home/aistudio/data/data93479
!unzip ./常规赛：PALM病理性近视预测.zip
!mv -f /home/aistudio/data/data93479/常规赛：PALM病理性近视预测/* /home/aistudio/data

  inflating: __MACOSX/常规赛：PALM病理性近视预测/._Train  

## 清除保存的模型及结果文件
- 确认这么做的后果之前，请勿运行该代码。

In [2]:
%cd /home/aistudio/work
!rm -rf savepoint
!rm Classification_Results.csv
!mkdir -p savepoint

/home/aistudio/work
rm: cannot remove 'Classification_Results.csv': No such file or directory


## 安装PPIM
本次需要使用PPIM中的预训练模型Diet-base-224
- 持久化至external-libraries

In [3]:
#持久化ppim
!mkdir /home/aistudio/external-libraries
!pip install ppim -i https://pypi.python.org/pypi -t /home/aistudio/external-libraries

Looking in indexes: https://pypi.python.org/pypi
Collecting ppim
[?25l  Downloading https://files.pythonhosted.org/packages/e7/35/369dc6956de64359703bb49d20721e0ae963c1183bb4c88535470f2efe93/ppim-1.1.0-py3-none-any.whl (66kB)
[K     |████████████████████████████████| 71kB 10.0kB/s ta 0:00:01
[?25hCollecting wget (from ppim)
  Downloading https://files.pythonhosted.org/packages/47/6a/62e288da7bcda82b935ff0c6cfe542970f04e29c756b0e147251b2fb251f/wget-3.2.zip
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25ldone
[?25h  Created wheel for wget: filename=wget-3.2-cp37-none-any.whl size=9675 sha256=3cb290b7c55d66311f01fc15174bd5be4efeddd6902551731eb622603e2ef4fd
  Stored in directory: /home/aistudio/.cache/pip/wheels/d1/e3/b6/e6be72d63f667cef0226c3eedff3e6658ba97d5be7d9df25dd
Successfully built wget
Installing collected packages: wget, ppim
Successfully installed ppim-1.1.0 wget-3.2


将安装的PPIM导入到环境

In [4]:
import sys 
sys.path.append('/home/aistudio/external-libraries')

## 准备工作
导入包，设置随机种子。

In [5]:
import os
import math
import random
import numpy as np
import pandas as pd
from PIL import Image
import paddle

def set_seed(seed):
    """sets random seed"""
    random.seed(seed)
    np.random.seed(seed)
    paddle.seed(seed)

set_seed(0)
paddle.set_device('gpu')

CUDAPlace(0)

读取训练集数据的label

In [6]:
trY=pd.read_excel('../data/Train/Classification.xlsx')
trX_list=list(trY['imgName'])
trY=np.array(trY[['Label']],dtype=np.float32)
print(len(trX_list))
print(trY.shape)

800
(800, 1)


## 生成paddle数据集类
- 对训练集图片的预处理如下：
    - resize至[224,224,3]
    - 随即旋转30度
    - 随机水平翻转
    - 随机竖直翻转
    - 改变通道位置Channel Last->Channel First
    - 输入归一化
- 对测试集图片的预处理如下：
    - resize至[224,224,3]
    - 改变通道位置Channel Last->Channel First
    - 输入归一化

In [8]:
from paddle.io import Dataset
import paddle.vision.transforms as T

class PALMTrData(Dataset):
    def __init__(self,path,pic_list,y):
        super(PALMTrData,self).__init__()
        self.path=path
        self.y=y
        self.pic_list=pic_list
        self.tf=T.Compose([
            T.Resize((224,224),interpolation='bicubic'),
            T.RandomRotation(30,interpolation='bicubic'),
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
        ])
        print(len(self.pic_list))
    
    def __getitem__(self,idx):
        # X-feature
        pic_name=self.pic_list[idx]
        pic=Image.open(self.path+'/'+pic_name)
        pic_size=pic.size
        pic=self.tf(pic)
        # Y-label
        label=self.y[idx].copy()
        return pic,label

    def __len__(self):
        return len(self.pic_list)

class PALMTsData(Dataset):
    def __init__(self,path):
        super(PALMTsData,self).__init__()
        self.path=path
        self.pic_list=os.listdir(path)
        self.tf=T.Compose([
            T.Resize((224,224), interpolation='bicubic'),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        print(len(self.pic_list))
    
    def __getitem__(self,idx):
        pic_name=self.pic_list[idx]
        pic=Image.open(self.path+'/'+pic_name)
        pic=self.tf(pic)
        return pic,0

    def __len__(self):
        return len(self.pic_list)

tr_path='../data/Train/fundus_image'
trData=PALMTrData(path=tr_path,pic_list=trX_list,y=trY)
ts_path='../data/PALM-Testing400-Images'
tsData=PALMTsData(path=ts_path)

800
400


## 构建模型
- 模型使用预训练模型resnet50 + Diet_base_224为主体，二者直接并联，输出为[1000+1000]
- 下游任务分类器结构为2层全连接网络(2000->100->2)，以relu激活，并辅以Dropout=0.2防止过拟合
- 下游任务被设定为二分类问题

In [9]:
from paddle import nn
import paddle.nn.functional as F
from ppim import deit_b_distilled
from paddle.vision.models import resnet50

class PALM(nn.Layer):
    def __init__(self):
        super(PALM,self).__init__()
        self.deit=deit_b_distilled(pretrained=True)
        self.resnet=resnet50(pretrained=True)
        self.do1=nn.Dropout(p=0.2)
        self.lr1=nn.Linear(in_features=2000, out_features=100)
        self.do2=nn.Dropout(p=0.2)
        self.lr2=nn.Linear(in_features=100, out_features=1)

    def forward(self,x):
        x1=self.deit(x)
        x2=self.resnet(x)
        x=paddle.concat([x1,x2],axis=-1)
        x=F.relu(x)
        x=self.do1(x)
        x=self.lr1(x)
        x=F.relu(x)
        x=self.do2(x)
        x=self.lr2(x)
        x=F.sigmoid(x)
        return x

palm=PALM()

100%|██████████| 511043/511043 [00:07<00:00, 68974.38it/s]
100%|██████████| 151272/151272 [00:02<00:00, 63011.04it/s]


## 模型训练参数
- 优化器：adam
- Loss：二分类使用BCELoss
- 学习率：固定5e-5

In [10]:
from paddle.optimizer import Adam
from paddle.nn import BCELoss
from paddle.metric import Accuracy

inputs=paddle.static.InputSpec([-1, 3, 224, 224], dtype='float32', name='input')
label=paddle.static.InputSpec([-1, 1], dtype='float32', name='label')
model=paddle.Model(palm, inputs, label)
model.summary()
model.prepare(optimizer=Adam(learning_rate=5e-5,parameters=model.parameters()),loss=BCELoss(),metrics=Accuracy())

----------------------------------------------------------------------------------------
        Layer (type)             Input Shape          Output Shape         Param #    
          Conv2D-5            [[1, 3, 224, 224]]    [1, 768, 14, 14]       590,592    
        PatchEmbed-1          [[1, 3, 224, 224]]     [1, 196, 768]            0       
         Dropout-1             [[1, 198, 768]]       [1, 198, 768]            0       
        LayerNorm-1            [[1, 198, 768]]       [1, 198, 768]          1,536     
          Linear-1             [[1, 198, 768]]       [1, 198, 2304]       1,771,776   
         Dropout-2           [[1, 12, 198, 198]]   [1, 12, 198, 198]          0       
          Linear-2             [[1, 198, 768]]       [1, 198, 768]         590,592    
         Dropout-3             [[1, 198, 768]]       [1, 198, 768]            0       
        Attention-1            [[1, 198, 768]]       [1, 198, 768]            0       
         Identity-1            [[1, 198, 

## 训练模型
- 训练、预测批大小：16
- 训练回合：50
- 数据异步读取：2个进程
- 训练集：全部训练数据
- 验证集：无

In [11]:
model.fit(train_data=trData,batch_size=16,epochs=50,verbose=1,num_workers=2)

The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
  return (isinstance(seq, collections.Sequence) and
  "When training, we now always track global mean and variance.")
  format(lhs_dtype, rhs_dtype, lhs_dtype))


Epoch 2/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 3/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 4/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 5/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 6/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 7/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 8/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 9/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 10/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 11/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 12/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 13/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 14/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 15/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 16/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 17/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 18/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 19/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 20/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 21/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 22/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 23/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 24/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 25/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 26/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 27/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 28/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 29/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 30/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 31/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 32/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 33/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 34/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 35/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 36/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 37/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 38/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 39/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 40/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 41/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 42/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 43/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 44/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 45/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 46/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 47/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 48/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 49/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Epoch 50/50


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:




## 模型预测
- 对测试集数据预测
- 预测批大小：16
- 预测数据异步读取：2个进程

In [12]:
preds=model.predict(tsData,batch_size=16,num_workers=2)

Predict begin...


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:


Predict samples: 400


## 保存预测结果
- 记得按文件名重新排序

In [13]:
pred=np.squeeze(np.array(preds)).reshape((400,))
pred_list=os.listdir(ts_path)
print(pred.shape)
ans=pd.DataFrame({'FileName':pred_list,'PM Risk':pred})
ans=ans.sort_values(by='FileName',ignore_index=True)
ans.to_csv('./Classification_Results.csv',index=0)

(400,)


## 保存模型

In [14]:
model.save(path='./savepoint/final',training=False)

  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))
  op_type, op_type, EXPRESSION_MAP[method_name]))


## 注意事项与特殊说明
由于GitHub对大文件上传限制，模型训练参数checkpoint无法上传。
- 完整repo详见AI studio项目：https://aistudio.baidu.com/aistudio/projectdetail/2318275