In [1]:
from tensorflow import keras
from tensorflow.keras import models
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.efficientnet import preprocess_input

from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np
from tqdm import tqdm_notebook
from IPython.display import display

## CFG

In [12]:
class CFG:
    debug=True

    dataset_dir="../input/ranzcr-clip-catheter-line-classification/"
    models_dir="./models/" if debug else "../input/efficientnet-lightgbm-models/"

    num_features=100
    target_cols=['ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal', 'NGT - Abnormal', 'NGT - Borderline', 'NGT - Incompletely Imaged',           'NGT - Normal', 'CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present']

In [3]:
test=pd.read_csv(f"{CFG.dataset_dir}sample_submission.csv")

if CFG.debug:
    n_samples=10
    train,test=train_test_split(test,test_size=n_samples)

display(test)

Unnamed: 0,StudyInstanceUID,ETT - Abnormal,ETT - Borderline,ETT - Normal,NGT - Abnormal,NGT - Borderline,NGT - Incompletely Imaged,NGT - Normal,CVC - Abnormal,CVC - Borderline,CVC - Normal,Swan Ganz Catheter Present
1784,1.2.826.0.1.3680043.8.498.44840699418247685242...,0,0,0,0,0,0,0,0,0,0,0
2896,1.2.826.0.1.3680043.8.498.18001429910476560724...,0,0,0,0,0,0,0,0,0,0,0
3257,1.2.826.0.1.3680043.8.498.69576628924126473682...,0,0,0,0,0,0,0,0,0,0,0
369,1.2.826.0.1.3680043.8.498.92431639625019636008...,0,0,0,0,0,0,0,0,0,0,0
3471,1.2.826.0.1.3680043.8.498.38317279627863831147...,0,0,0,0,0,0,0,0,0,0,0
2188,1.2.826.0.1.3680043.8.498.25557651291138934660...,0,0,0,0,0,0,0,0,0,0,0
3081,1.2.826.0.1.3680043.8.498.70177330790614858744...,0,0,0,0,0,0,0,0,0,0,0
295,1.2.826.0.1.3680043.8.498.58158240367325542195...,0,0,0,0,0,0,0,0,0,0,0
2659,1.2.826.0.1.3680043.8.498.10981929683941499728...,0,0,0,0,0,0,0,0,0,0,0
2705,1.2.826.0.1.3680043.8.498.65527966462470053855...,0,0,0,0,0,0,0,0,0,0,0


## EfficientNetB7による特徴量の取得

In [4]:
nn_model=models.load_model(f"{CFG.models_dir}efficientnetB7.h5")



In [7]:
pred_list=[]
for uid in tqdm_notebook(test["StudyInstanceUID"]):
    img_pil=image.load_img(f"{CFG.dataset_dir}test/{uid}.jpg", target_size=(256,256)) # PIL形式で画像を読み込み、リサイズ
    img=image.img_to_array(img_pil)
    img=np.expand_dims(img,axis=0)
    nn_pred=nn_model.predict(preprocess_input(img))
    pred_list.append(nn_pred[0])

nn_pred=np.array(pred_list)

  0%|          | 0/10 [00:00<?, ?it/s]

## Dense層による推論

In [8]:
with open(f"{CFG.models_dir}eff_dense/model_structure","rt") as f:
    model_json_str=f.read()

dense_model=models.model_from_json(model_json_str)
dense_model.load_weights(f"{CFG.models_dir}eff_dense/checkpoint")
dense_model.compile(optimizer="adam",loss="binary_crossentropy",metrics=[keras.metrics.AUC(multi_label=True)])
dense_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dropout (Dropout)            (None, 2560)              0         
_________________________________________________________________
dense (Dense)                (None, 11)                28171     
Total params: 28,171
Trainable params: 28,171
Non-trainable params: 0
_________________________________________________________________


In [9]:
dense_pred=pd.DataFrame(dense_model.predict(nn_pred),columns=CFG.target_cols)
submission=pd.concat([test["StudyInstanceUID"].reset_index(drop=True),dense_pred],axis=1)
display(submission)

submission.to_csv("submission.csv",index=False)

Unnamed: 0,StudyInstanceUID,ETT - Abnormal,ETT - Borderline,ETT - Normal,NGT - Abnormal,NGT - Borderline,NGT - Incompletely Imaged,NGT - Normal,CVC - Abnormal,CVC - Borderline,CVC - Normal,Swan Ganz Catheter Present
0,1.2.826.0.1.3680043.8.498.44840699418247685242...,0.000158,0.000534,0.013527,0.00018,0.001717,0.009325,0.003662,0.058051,0.265117,0.645058,0.00046
1,1.2.826.0.1.3680043.8.498.18001429910476560724...,0.000304,0.001394,0.001446,0.000387,0.000104,0.001173,0.002135,0.061959,0.302805,0.675501,2.8e-05
2,1.2.826.0.1.3680043.8.498.69576628924126473682...,0.02027,0.232128,0.766469,0.003808,0.032536,0.15563,0.835688,0.108055,0.326746,0.729117,0.127217
3,1.2.826.0.1.3680043.8.498.92431639625019636008...,0.000359,0.001464,0.01992,0.000719,0.001915,0.014275,0.009068,0.0851,0.247977,0.615497,0.001619
4,1.2.826.0.1.3680043.8.498.38317279627863831147...,0.000626,0.003693,0.013934,0.00323,0.005809,0.001621,0.038497,0.087462,0.423562,0.422017,0.001078
5,1.2.826.0.1.3680043.8.498.25557651291138934660...,4.3e-05,0.00115,0.002267,0.000117,0.001322,0.002956,0.004927,0.081146,0.214727,0.737497,0.002859
6,1.2.826.0.1.3680043.8.498.70177330790614858744...,0.001676,0.006071,0.131546,0.002174,0.008494,0.046295,0.02461,0.103482,0.166792,0.721191,0.011287
7,1.2.826.0.1.3680043.8.498.58158240367325542195...,2.5e-05,0.00066,0.010396,0.000269,0.001001,0.004335,0.002785,0.0193,0.261929,0.626538,0.000549
8,1.2.826.0.1.3680043.8.498.10981929683941499728...,0.000629,0.003709,0.025297,0.000365,0.001448,0.005777,0.021417,0.105889,0.264922,0.547494,0.001498
9,1.2.826.0.1.3680043.8.498.65527966462470053855...,0.000323,0.006679,0.010553,5.5e-05,0.000878,0.027839,0.004885,0.10738,0.265243,0.615343,0.021705
