In [1]:
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model

from sklearn.decomposition import PCA
from sklearn.model_selection import StratifiedKFold
import lightgbm as lgb

import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from IPython.display import display
import pickle

debug=True

dataset_dir="../input/ranzcr-clip-catheter-line-classification/"
models_dir="./models/" if debug else "../input/efficientnet-lightgbm-models/"
target_cols=['ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal', 'NGT - Abnormal', 'NGT - Borderline', 'NGT - Incompletely Imaged', 'NGT - Normal', 'CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present']

In [2]:
nn_model=load_model(models_dir+"efficientnetB7.h5")

________________
block6k_drop (Dropout)          (None, None, None, 3 0           block6k_project_bn[0][0]         
__________________________________________________________________________________________________
block6k_add (Add)               (None, None, None, 3 0           block6k_drop[0][0]               
                                                                 block6j_add[0][0]                
__________________________________________________________________________________________________
block6l_expand_conv (Conv2D)    (None, None, None, 2 884736      block6k_add[0][0]                
__________________________________________________________________________________________________
block6l_expand_bn (BatchNormali (None, None, None, 2 9216        block6l_expand_conv[0][0]        
__________________________________________________________________________________________________
block6l_expand_activation (Acti (None, None, None, 2 0           block6l_expand_bn[0][0]    

In [12]:
test=pd.read_csv(dataset_dir+"sample_submission.csv")

if debug:
    n_samples=100
    n_splits=test.shape[0]//n_samples
    s_kfold=StratifiedKFold(n_splits=n_splits)
    for train_index,target_index in s_kfold.split(test,test.loc[:,target_cols[0]]):
        test=test.iloc[target_index,:]
        break

display(test)

Unnamed: 0,StudyInstanceUID,ETT - Abnormal,ETT - Borderline,ETT - Normal,NGT - Abnormal,NGT - Borderline,NGT - Incompletely Imaged,NGT - Normal,CVC - Abnormal,CVC - Borderline,CVC - Normal,Swan Ganz Catheter Present
0,1.2.826.0.1.3680043.8.498.46923145579096002617...,0,0,0,0,0,0,0,0,0,0,0
1,1.2.826.0.1.3680043.8.498.84006870182611080091...,0,0,0,0,0,0,0,0,0,0,0
2,1.2.826.0.1.3680043.8.498.12219033294413119947...,0,0,0,0,0,0,0,0,0,0,0
3,1.2.826.0.1.3680043.8.498.84994474380235968109...,0,0,0,0,0,0,0,0,0,0,0
4,1.2.826.0.1.3680043.8.498.35798987793805669662...,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...
98,1.2.826.0.1.3680043.8.498.36767479584462721608...,0,0,0,0,0,0,0,0,0,0,0
99,1.2.826.0.1.3680043.8.498.68424334879992213995...,0,0,0,0,0,0,0,0,0,0,0
100,1.2.826.0.1.3680043.8.498.81916789914212432516...,0,0,0,0,0,0,0,0,0,0,0
101,1.2.826.0.1.3680043.8.498.16318550184938466094...,0,0,0,0,0,0,0,0,0,0,0


## 中間出力の取得に関する注意
元画像を一度pngなどに保存してから予測を行うのと、保存せずに直接予測するのでは予測結果が変化する。  
submit時と同じ中間層出力が得られるようにpng等を介さず直接ndarrayを保存しなければならない  
各段階においてtrain, inferenceが同じ入力に対して同じ出力を返すか確認する

In [13]:
pred_list=[]
for uid in tqdm(test["StudyInstanceUID"]):
    img_pil=image.load_img(f"{dataset_dir}test/{uid}.jpg", target_size=(256,256)) # PIL形式で画像を読み込み、リサイズ
    img=image.img_to_array(img_pil)
    img=np.expand_dims(img,axis=0)
    nn_pred=nn_model.predict(preprocess_input(img))
    pred_list.append(nn_pred[0])

nn_pred=np.array(pred_list)

100%|██████████| 103/103 [01:08<00:00,  1.50it/s]Wall time: 1min 8s



In [14]:
df_features=pd.DataFrame(nn_pred)

n_components= 5 if debug else 75

pca=pickle.load(open(f"{models_dir}pca_model_{n_components}.pickle","rb"))

features_pca=pca.transform(df_features)

X=pd.concat([test["StudyInstanceUID"],pd.DataFrame(features_pca)],axis=1)
display(X)

Unnamed: 0,StudyInstanceUID,0,1,2,3,4
0,1.2.826.0.1.3680043.8.498.46923145579096002617...,0.728175,2.847438,0.206963,-2.879821,-0.446215
1,1.2.826.0.1.3680043.8.498.84006870182611080091...,-0.501735,-1.522047,-1.555522,1.397800,-2.336977
2,1.2.826.0.1.3680043.8.498.12219033294413119947...,3.700490,-2.404421,2.754051,1.582664,-0.763454
3,1.2.826.0.1.3680043.8.498.84994474380235968109...,0.863464,4.276123,-0.795283,-0.717602,1.767195
4,1.2.826.0.1.3680043.8.498.35798987793805669662...,1.282943,0.244260,-1.558035,4.970026,-1.207666
...,...,...,...,...,...,...
98,1.2.826.0.1.3680043.8.498.36767479584462721608...,2.371390,-2.579474,-2.505542,0.082699,-0.739868
99,1.2.826.0.1.3680043.8.498.68424334879992213995...,2.705163,-1.174013,-4.312785,1.780306,-1.393938
100,1.2.826.0.1.3680043.8.498.81916789914212432516...,-3.033429,-1.636155,-1.073318,1.331120,-1.597037
101,1.2.826.0.1.3680043.8.498.16318550184938466094...,-0.160146,-2.977732,-0.789130,-0.668567,-1.455618


In [15]:
submission=test.copy()

for i,col_name in enumerate(target_cols):
    model=pickle.load(open(f"{models_dir}pca_{n_components}/lgb_model_{i+1}.pickle","rb"))
    pred=model.predict(features_pca)
    submission.loc[:,col_name]=pred
    
submission.to_csv("submission.csv",index=False)
display(submission)

Unnamed: 0,StudyInstanceUID,ETT - Abnormal,ETT - Borderline,ETT - Normal,NGT - Abnormal,NGT - Borderline,NGT - Incompletely Imaged,NGT - Normal,CVC - Abnormal,CVC - Borderline,CVC - Normal,Swan Ganz Catheter Present
0,1.2.826.0.1.3680043.8.498.46923145579096002617...,0.001709,0.065413,0.331421,0.016904,0.025233,0.187935,0.230046,0.115777,0.309227,0.680503,0.069771
1,1.2.826.0.1.3680043.8.498.84006870182611080091...,0.000356,0.005605,0.127226,0.003753,0.013479,0.030368,0.074232,0.084774,0.266309,0.510254,0.003910
2,1.2.826.0.1.3680043.8.498.12219033294413119947...,0.001401,0.005874,0.120480,0.007454,0.009114,0.036550,0.082347,0.093505,0.224893,0.782794,0.006320
3,1.2.826.0.1.3680043.8.498.84994474380235968109...,0.000614,0.057838,0.358220,0.005298,0.010208,0.105797,0.180199,0.097956,0.333728,0.821112,0.054912
4,1.2.826.0.1.3680043.8.498.35798987793805669662...,0.000154,0.011674,0.087914,0.002892,0.017892,0.030927,0.052697,0.037634,0.299188,0.730797,0.008753
...,...,...,...,...,...,...,...,...,...,...,...,...
98,1.2.826.0.1.3680043.8.498.36767479584462721608...,0.000068,0.002542,0.011914,0.002767,0.004642,0.010962,0.011018,0.085745,0.195656,0.798350,0.001777
99,1.2.826.0.1.3680043.8.498.68424334879992213995...,0.000211,0.013153,0.015458,0.002064,0.005878,0.014850,0.010181,0.044312,0.239251,0.773484,0.001111
100,1.2.826.0.1.3680043.8.498.81916789914212432516...,0.000798,0.007657,0.188954,0.004886,0.013479,0.044310,0.074574,0.079104,0.274924,0.735905,0.005811
101,1.2.826.0.1.3680043.8.498.16318550184938466094...,0.000170,0.004034,0.030597,0.006464,0.005701,0.014071,0.022411,0.101254,0.234324,0.720430,0.003278
