In [1]:
import numpy as np
import pandas as pd
from IPython.display import display
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

In [2]:
class CFG:
    debug=True

    dataset_dir="../input/ranzcr-clip-catheter-line-classification/"
    batch_size=4 if debug else 64

In [3]:
test_path=f"{CFG.dataset_dir}train.csv" if CFG.debug else CFG.dataset_dir+"sample_submission.csv"
test=pd.read_csv(test_path)

if CFG.debug: test=test[0:4]

In [4]:
AUTO=tf.data.experimental.AUTOTUNE

def preprocessing(path):
    file_bytes=tf.io.read_file(path)
    image=tf.io.decode_jpeg(file_bytes,channels=3) # decode_imageではresizeでエラーが出る
    image=tf.cast(image,tf.float32)
    image=tf.image.resize(image,(299,299))
    image=tf.reshape(image,[299,299,3]) # without this, TPU will not run
    image/=255.0 # normalization
    return image    

def build_dataset(uid):
    path=f"{CFG.dataset_dir}train/"+uid+".jpg" if CFG.debug else f"{CFG.dataset_dir}test/"+uid+".jpg"
    dset=tf.data.Dataset.from_tensor_slices(path)
    dset=dset.map(preprocessing,num_parallel_calls=AUTO) # memory-consuming process should be placed after cache()
    dset=dset.batch(CFG.batch_size).prefetch(AUTO) # Warning: putting preprocessing after batch() causes an error
    return dset

In [5]:
test_dataset=build_dataset(test["StudyInstanceUID"])

In [7]:
xception=keras.models.load_model("./models/xception_ckpt/")
xception.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
xception (Functional)        (None, 2048)              20861480  
_________________________________________________________________
dense_3 (Dense)              (None, 11)                22539     
Total params: 20,884,019
Trainable params: 20,829,491
Non-trainable params: 54,528
_________________________________________________________________


In [8]:
submission=test.copy()
submission.iloc[:,1:12]=xception.predict(test_dataset)
display(submission)
if not CFG.debug: submission.to_csv("submission.csv",index=False)

Unnamed: 0,StudyInstanceUID,ETT - Abnormal,ETT - Borderline,ETT - Normal,NGT - Abnormal,NGT - Borderline,NGT - Incompletely Imaged,NGT - Normal,CVC - Abnormal,CVC - Borderline,CVC - Normal,Swan Ganz Catheter Present,PatientID
0,1.2.826.0.1.3680043.8.498.26697628953273228189...,0.002011,5.2e-05,0.000121,0.050357,0.129483,0.01394829,0.724036,0.006088,0.079409,0.410922,2.9e-05,ec89415d1
1,1.2.826.0.1.3680043.8.498.46302891597398758759...,1.9e-05,0.006022,0.993922,0.030547,0.003023,0.9750408,0.001863,0.034195,0.117789,0.957203,0.041508,bf4c6da3c
2,1.2.826.0.1.3680043.8.498.23819260719748494858...,3e-06,2.2e-05,9e-06,2.8e-05,4.5e-05,6.568887e-07,2.6e-05,0.019722,0.971912,0.01957,3e-06,3fc1c97e5
3,1.2.826.0.1.3680043.8.498.68286643202323212801...,0.000548,0.000164,1.6e-05,0.00042,2.1e-05,1.587168e-05,3e-06,0.575235,0.266533,0.166688,0.000134,c31019814
