In [1]:
class DataGeneratorClassification(tf.keras.utils.Sequence):
    def __init__(self, data_fn, input_vars, output_vars, percentile_path, data_name,
                 norm_fn=None, input_transform=None, output_transform=None,
                 batch_size=1024, shuffle=True, xarray=False, var_cut_off=None, normalize_flag=True, bin_size=100):
        # Just copy over the attributes
        self.data_fn, self.norm_fn = data_fn, norm_fn
        self.input_vars, self.output_vars = input_vars, output_vars
        self.batch_size, self.shuffle = batch_size, shuffle
        self.bin_size = bin_size
        self.percentile_bins = load_pickle(percentile_path)['Percentile'][data_name]
        self.enc = OneHotEncoder(sparse=False)
        classes = np.arange(self.bin_size+2)
        self.enc.fit(classes.reshape(-1,1))
        # Open datasets
        self.data_ds = xr.open_mfdataset(data_fn)
        if norm_fn is not None: self.norm_ds = xr.open_dataset(norm_fn)
     # Compute number of samples and batches
        self.n_samples = self.data_ds.X.shape[0]
        self.n_batches = int(np.floor(self.n_samples) / self.batch_size)

        self.n_inputs, self.n_outputs = 64, 64
        
                # Initialize input and output normalizers/transformers
        if input_transform is None:
            self.input_transform = Normalizer()
        elif type(input_transform) is tuple:
            ## normalize flag added by Ankitesh
            self.input_transform = InputNormalizer(
                self.norm_ds,normalize_flag, input_vars, input_transform[0], input_transform[1], var_cut_off)
        else:
            self.input_transform = input_transform  # Assume an initialized normalizer is passed
            
            
        if output_transform is None:
            self.output_transform = Normalizer()
        elif type(output_transform) is dict:
            self.output_transform = DictNormalizer(self.norm_ds, output_vars, output_transform)
        else:
            self.output_transform = output_transform  # Assume an initialized normalizer is passed

        # Now close the xarray file and load it as an h5 file instead
        # This significantly speeds up the reading of the data...
#         if not xarray:
#             self.data_ds.close()
#             self.data_ds = h5py.File(data_fn, 'r')
    
    def __len__(self):
        return self.n_batches
     
    def __getitem__(self, index):
        # Compute start and end indices for batch
        start_idx = index * self.batch_size
        end_idx = start_idx + self.batch_size

        # Grab batch from data
        batch_X = self.data_ds['X'][start_idx:end_idx].values
        Y = self.data_ds['Y'][start_idx:end_idx].values
        # Normalize
        X = self.input_transform.transform(batch_X)
        return X, Y

    def on_epoch_end(self):
        self.indices = np.arange(self.n_batches)
        if self.shuffle: np.random.shuffle(self.indices)

NameError: name 'tf' is not defined