In [1]:
import argparse
import pickle
import os

import numpy as np
import pandas as pd

In [2]:
def create_parser():
    parser = argparse.ArgumentParser(description="predict with DNN")
    parser.add_argument('--fingerprints', type=str, help='CSV with the fingerprints to predict', required=True)
    parser.add_argument('--dnn', type=str, help='pickled dnn', required=True)
    parser.add_argument('--preproc', type=str, help='pickled preproc', required=True)
    parser.add_argument('--save_to', type=str, help='save resulting dataframe to this CSV file', required=True)
    return parser

In [3]:
parser = create_parser()
args = parser.parse_args([#'--fingerprints', 'data/CMM_vectorfingerprints.csv',
                         '--fingerprints', 'filtered_data/SMRT_fingerprints_filtered_updated.csv',
                         '--preproc', os.path.join('data', 'saved_models', 'v0', 'preprocessor.pkl'),
                         '--dnn', os.path.join('data', 'saved_models', 'v0', 'dnn.pkl'),
                         '--save_to', 'res1'])

In [4]:
fingerprints = pd.read_csv(args.fingerprints, dtype={'pid': str})
pid = fingerprints.pid.values
cmm_id = fingerprints.CMM_id.values
fingerprints = fingerprints[[col for col in fingerprints.columns if col not in ['pid', 'CMM_id']]].values
fingerprints = fingerprints.astype(np.float32)

In [5]:
with open(args.preproc, 'rb') as f:
    preprocessor = pickle.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
preprocessor.fgp_cols = np.arange(fingerprints.shape[1])

In [7]:
with open(args.dnn, 'rb') as f:
    dnn = pickle.load(f)

In [8]:
X_preprocessed = preprocessor.transform(fingerprints)

In [9]:
predictions = dnn.predict(X_preprocessed)

In [10]:
results = pd.DataFrame({
        'pid': pid,
        'cmm_id': cmm_id,
        'prediction': predictions
    })

In [11]:
results.head()

Unnamed: 0,pid,cmm_id,prediction
0,5139,0,91.345947
1,3505,1,659.763428
2,2159,2,597.697754
3,1340,3,585.665283
4,3344,4,578.020996


In [12]:
results.to_csv('results/predicted_rt_db.csv')