In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
import pickle
import os

In [6]:
class TFRecordConverter():
  
  def __init__(self, features, file_name, file_format, csv_sep=None):
    """
    Initialize the object for TFRecordConverter
    Arguments:
    features - list of columns to use from file
    file_name - the path to file to convert
    file_format - format of the file
    csv_sep - if csv file, separator to be used
    """
    self.feat = features
    self.filename = file_name
    self.fileformat = file_format
    if file_format == 'csv':
      self.data = pd.read_csv(file_name,sep=csv_sep,usecols=features)

  def create_feature_lists(self):
    """
    Creates the dictionary of features that can be used by tf.train.Example
    """
    self.features = {}
    self.types = []
    for f in self.feat:
      if 'float' in type(self.data[f][0]):
        self.features[f] = tf.train.Feature(float_list=tf.train.FloatList(value=self.data[f]))
        self.types.append('float')
      elif 'int' in type(self.data[f][0]):
        self.features[f] = tf.train.Feature(int64_list=tf.train.Int64List(value=self.data[f]))
        self.types.append('int')
      elif 'str' in type(self.data[f][0]):
        self.features[f] = tf.train.Feature(bytes_list=tf.train.BytesList(value=self.data[f].str.encode('utf-8')))
        self.types.append('str')
  
  def create_example_and_write(self):
    """
    Creates tf.train.Example object
    """
    self.create_future_lists()
    example = tf.train.Example(features=tf.train.Features(self.features))
    with tf.python_io.TFRecordWriter(self.filename.strip().split(self.fileformat)[0]+'tfrecord') as writer:
      writer.write(example.SerializeToString())
    pickle.dump(self.types,open(self.filename.strip().split(self.fileformat)[0]+'pkl','wb'))