diff --git a/merlin/dataloader/tensorflow.py b/merlin/dataloader/tensorflow.py index 6c258a48..30626891 100644 --- a/merlin/dataloader/tensorflow.py +++ b/merlin/dataloader/tensorflow.py @@ -17,7 +17,7 @@ from merlin.core.compat import tensorflow as tf from merlin.dataloader.array import ArrayLoader -from merlin.table import Device, NumpyColumn, TensorColumn, TensorflowColumn, TensorTable +from merlin.table import TensorColumn, TensorflowColumn, TensorTable from merlin.table.conversions import _dispatch_dlpack_fns, convert_col @@ -104,7 +104,6 @@ def convert_batch(self, batch): tf_inputs = {} if inputs is not None: inputs_table = self.create_table(inputs) - column_type = TensorflowColumn if Device.GPU == inputs_table.device else NumpyColumn for col_name, col in inputs_table.items(): tf_inputs[col_name] = self.convert_col(col, column_type)