In [None]:
'''Main function for letter, college and spam datasets.
'''

In [1]:
# Necessary packages
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [2]:
import argparse
import numpy as np

In [3]:
from data_loader import data_loader
from gain import gain
from utils import rmse_loss

Using TensorFlow backend.


In [6]:
def main (args):
  '''Main function for letter, college and spam datasets.
  
  Args:
    - data_name: letter or spam
    - miss_rate: probability of missing components
    - batch:size: batch size
    - hint_rate: hint rate
    - alpha: hyperparameter
    - iterations: iterations
    
  Returns:
    - imputed_data_x: imputed data
    - rmse: Root Mean Squared Error
  '''
  
  data_name = args['data_name']
  miss_rate = args['miss_rate']
  
  gain_parameters = {'batch_size': args['batch_size'],
                     'hint_rate': args['hint_rate'],
                     'alpha': args['alpha'],
                     'iterations': args['iterations']}
  
  # Load data and introduce missingness
  ori_data_x, miss_data_x, data_m = data_loader(data_name, miss_rate)
  
  # Impute missing data
  imputed_data_x = gain(miss_data_x, gain_parameters)
  
  # Report the RMSE performance
  rmse = rmse_loss (ori_data_x, imputed_data_x, data_m)
  
  print()
  print('RMSE Performance: ' + str(np.round(rmse, 4)))
  
  return imputed_data_x, rmse

In [31]:
args = {}
args['data_name'] = 'spam'
args['miss_rate'] = 0.2
args['batch_size'] = 128
args['hint_rate'] = 0.9
args['alpha'] = 100
args['iterations'] = 10000

In [32]:
imputed_data, rmse = main(args)

100%|██████████| 10000/10000 [00:59<00:00, 169.47it/s]



RMSE Performance: 0.0542


In [None]:
#write imputed data to csv for analysis and visualisation
imputed_data.to_csv(index=False)