In [1]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from glob import glob
from IPython.display import display
import math

import tensorflow as tf
from tensorflow.io import FixedLenFeature

In [2]:
class CFG:
    dataset_dir="../input/ranzcr-clip-catheter-line-classification/"
    target_cols=['ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal', 'NGT - Abnormal', 'NGT - Borderline',
      'NGT - Incompletely Imaged', 'NGT - Normal', 'CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal', 'Swan Ganz Catheter Present']

In [3]:
def _bytes_feature(value):
    # valueがscalar tensorであった場合, 内部のnumpyを明示的に呼ぶ
    if isinstance(value,type(tf.constant(0))):
        value=value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [4]:
feature_description={
    "CVC - Abnormal":FixedLenFeature([],tf.int64),
    "CVC - Borderline":FixedLenFeature([],tf.int64),
    "CVC - Normal":FixedLenFeature([],tf.int64),
    "ETT - Abnormal":FixedLenFeature([],tf.int64),
    "ETT - Borderline":FixedLenFeature([],tf.int64),
    "ETT - Normal":FixedLenFeature([],tf.int64),
    "NGT - Abnormal":FixedLenFeature([],tf.int64),
    "NGT - Borderline":FixedLenFeature([],tf.int64),
    "NGT - Incompletely Imaged":FixedLenFeature([],tf.int64),
    "NGT - Normal":FixedLenFeature([],tf.int64),
    "Swan Ganz Catheter Present":FixedLenFeature([],tf.int64),
    "StudyInstanceUID":FixedLenFeature([],tf.string),
    "image":FixedLenFeature([],tf.string),
}

def parse_examples(example):
    return tf.io.parse_example(example,feature_description)

In [5]:
# split large files into some tfrecords
def split_list(_list,division_num):
    split_len=math.ceil(len(_list)/division_num)
    splitted_lists=[]
    for i in range(division_num):
        start_idx=i*split_len
        end_idx=(i+1)*split_len
        splitted_lists.append(_list[start_idx:end_idx])
    return splitted_lists

In [6]:
paths=glob("../input/ranzcr_299x299/*")

In [7]:
%%time

train=pd.read_csv(f"{CFG.dataset_dir}train.csv")

def serialize_example(uid,image_bytes):
    uid_bytes=uid.encode()

    feature={
        "StudyInstanceUID":_bytes_feature(uid_bytes),
        "image":_bytes_feature(image_bytes)
    }

    target_series=train[train["StudyInstanceUID"]==uid]
    
    for col_name in CFG.target_cols:
        col_value=target_series[col_name].values[0]
        feature[col_name]=_int64_feature(col_value)
    
    example=tf.train.Example(features=tf.train.Features(feature=feature))

    return example.SerializeToString()

def write2tfrecord():
    recordname=f"../input/ranzcr_299x299.tfrec"
    with tf.io.TFRecordWriter(recordname) as writer:
        for path in tqdm(paths):
            filename=os.path.basename(path)
            uid=os.path.splitext(filename)[0]
            image_bytes=tf.io.read_file(path)

            writer.write(serialize_example(uid,image_bytes))

write2tfrecord()

  0%|          | 0/30083 [00:00<?, ?it/s]

Wall time: 3min 22s
