diff --git a/deeprank/learn/NeuralNet.py b/deeprank/learn/NeuralNet.py index 64440fca..6f080635 100644 --- a/deeprank/learn/NeuralNet.py +++ b/deeprank/learn/NeuralNet.py @@ -993,7 +993,7 @@ def _plot_boxplot_class(self,figname): for pts,t in zip(out,tar): r = F.softmax(torch.FloatTensor(pts), dim=0).data.numpy() data[t].append(r[1]) - confusion[t][r[1]>0.5] += 1 + confusion[t][bool(r[1]>0.5)] += 1 #print(" {:5s}: {:s}".format(l,str(confusion))) diff --git a/deeprank/learn/rankingMetrics.py b/deeprank/learn/rankingMetrics.py index a60bf891..d783a76a 100644 --- a/deeprank/learn/rankingMetrics.py +++ b/deeprank/learn/rankingMetrics.py @@ -16,8 +16,8 @@ def hitrate(rs): Example: - >>> r = [0,1,1] - >>> hit_rate(r,nr) + >>> rs = [0,1,1] + >>> hitrate(r) Attributes: @@ -27,14 +27,34 @@ def hitrate(rs): Returns: hirate (array): [recall@1,recall@2,...] """ - nr = np.max((1,np.sum(rs))) + nr = np.max((1, np.sum(rs))) return np.cumsum(rs) / nr +def success(rs): + """Success for positions ≤ k. + + Example: + >>> rs = [0, 0, 1, 0, 1, 0] + >>> success(rs) + [0, 0, 1, 1, 1, 1] + + Args: + rs (array): binary relevance array + + Returns: + success (array): [success@≤1, success@≤2,...] + """ + success = np.cumsum(rs) > 0 + + return success.astype(np.int) + + def avprec(rs): - return [average_precision(rs[:i]) for i in range(1,len(rs))] + return [average_precision(rs[:i]) for i in range(1, len(rs))] + -def recall(rs,nr): +def recall(rs, nr): """recall rate First element is rank 1, Relevance is binray @@ -56,6 +76,7 @@ def recall(rs,nr): return np.sum(rs)/nr + def mean_reciprocal_rank(rs): """Score is reciprocal of the rank of the first relevant item @@ -272,4 +293,4 @@ def ndcg_at_k(r, k, method=0): dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) if not dcg_max: return 0. - return dcg_at_k(r, k, method) / dcg_max \ No newline at end of file + return dcg_at_k(r, k, method) / dcg_max diff --git a/deeprank/utils/cal_hitrate_successrate.py b/deeprank/utils/cal_hitrate_successrate.py new file mode 100644 index 00000000..74bd9683 --- /dev/null +++ b/deeprank/utils/cal_hitrate_successrate.py @@ -0,0 +1,169 @@ +import numpy as np +import pandas as pd +from deeprank.learn import rankingMetrics + + +def evaluate(data): + ''' + Calculate success rate and hit rate. + + + data: a data frame. + + label caseID modelID target DR HS + Test 1AVX 1AVX_ranair-it0_5286 0 0.503823 6.980802 + Test 1AVX 1AVX_ti5-itw_354w 1 0.502845 -95.158100 + Test 1AVX 1AVX_ranair-it0_6223 0 0.511688 -11.961460 + + + out_df: a data frame. + success: binary variable, indicating whether this case is a success when evaluating its top N models. + + out_df : + label caseID success_DR hitRate_DR success_HS hitRate_HS + train 1ZHI 1 0.1 0 0.01 + train 1ZHI 1 0.2 1 0.3 + + where success =[0, 0, 1, 1, 1,...]: starting from rank 3 this case is a success + + ''' + + out_df = pd.DataFrame() + labels = data.label.unique() # ['train', 'test', 'valid'] + + for l in labels: + # l = 'train', 'test' or 'valid' + + out_df_tmp = pd.DataFrame() + + df = data.loc[data.label == l].copy() + methods = df.columns + methods = methods[4:] + df_grped = df.groupby('caseID') + + for M in methods: + # df_sorted = df_one_case.apply(pd.DataFrame.sort_values, by= M, ascending=True) + + success = [] + hitrate = [] + caseIDs = [] + for caseID, df_one_case in df_grped: + df_sorted = df_one_case.sort_values(by=M, ascending=True) + hitrate.extend(rankingMetrics.hitrate( + df_sorted['target'].astype(np.int))) + success.extend(rankingMetrics.success( + df_sorted['target'].astype(np.int))) + caseIDs.extend([caseID] * len(df_one_case)) + + # hitrate = df_sorted['target'].apply(rankingMetrics.hitrate) # df_sorted['target']: class IDs for each model + # success = hitrate.apply(rankingMetrics.success) # success =[0, 0, 1, 1, 1,...]: starting from rank 3 this case is a success + + out_df_tmp['label'] = [l] * len(df) # train, valid or test + out_df_tmp['caseID'] = caseIDs + out_df_tmp[f'success_{M}'] = success + out_df_tmp[f'hitRate_{M}'] = hitrate + + out_df = pd.concat([out_df, out_df_tmp]) + + return out_df + + +def ave_evaluate(data): + ''' + Calculate the average of each column over all cases. + + INPUT: + data = + label caseID success_HS hitRate_HS success_DR hitRate_DR + + train 1AVX 0.0 0.0 0.0 0.0 + train 1AVX 1.0 1.0 1.0 1.0 + + train 2ACB 0.0 0.0 0.0 0.0 + train 2ACB 1.0 1.0 1.0 1.0 + + test 7CEI 0.0 0.0 0.0 0.0 + test 7CEI 1.0 1.0 1.0 1.0 + + test 5ACD 0.0 0.0 0.0 0.0 + test 5ACD 1.0 1.0 1.0 1.0 + + OUTPUT: + new_data = + label caseID success_HS hitRate_HS success_DR hitRate_DR + + train 1AVX 0.0 0.0 0.0 0.0 + train 1AVX 1.0 1.0 1.0 1.0 + + train 2ACB 0.0 0.0 0.0 0.0 + train 2ACB 1.0 1.0 1.0 1.0 + + test 7CEI 0.0 0.0 0.0 0.0 + test 7CEI 1.0 1.0 1.0 1.0 + + test 5ACD 0.0 0.0 0.0 0.0 + test 5ACD 1.0 1.0 1.0 1.0 + + ''' + + new_data = pd.DataFrame() + for l, perf_per_case in data.groupby('label'): + # l = 'train', 'test' or 'valid' + + # count the model number for each case + grouped = perf_per_case.groupby('caseID') + num_models = grouped.apply(len) + num_cases = len(grouped) + + # -- + top_N = min(num_models) + perf_ave = pd.DataFrame() + perf_ave['label'] = [l] * top_N + + for col in perf_per_case.columns[2:]: + # perf_per_case.columns = ['label', 'caseID', 'success_HS', 'hitRate_HS', 'success_DR', 'hitRate_DR'] + perf_ave[col] = np.zeros(top_N) + + for _, perf_case in grouped: + perf_ave[col] = perf_ave[col][0:top_N] + \ + np.array(perf_case[col][0:top_N]) + + perf_ave[col] = perf_ave[col]/num_cases + + new_data = pd.concat([new_data, perf_ave]) + + return new_data + + +def add_rank(df): + ''' + INPUT (a data frame): + label success_DR hitRate_DR success_HS hitRate_HS + Test 0.0 0.000000 0.0 0.000000 + Test 0.0 0.000000 1.0 0.012821 + + Train 0.0 0.000000 1.0 0.012821 + Train 0.0 0.000000 1.0 0.025641 + + OUTPUT: + label success_DR hitRate_DR success_HS hitRate_HS rank + Test 0.0 0.000000 0.0 0.000000 0.000949 + Test 0.0 0.000000 1.0 0.012821 0.001898 + + Train 0.0 0.000000 1.0 0.012821 0.002846 + Train 0.0 0.000000 1.0 0.025641 0.003795 + + ''' + + # -- add the 'rank' column to df + rank = [] + for _, df_per_label in df.groupby('label'): + num_mol = len(df_per_label) + rank_raw = np.array(range(num_mol)) + 1 + rank.extend(rank_raw/num_mol) + df['rank'] = rank + + df['label'] = pd.Categorical(df['label'], categories=[ + 'Train', 'Valid', 'Test']) + + return df diff --git a/deeprank/utils/get_h5subset.py b/deeprank/utils/get_h5subset.py new file mode 100755 index 00000000..be6b4a51 --- /dev/null +++ b/deeprank/utils/get_h5subset.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +""" +Extract first N groups of a hdf5 to a new hdf5 file. + +Usage: python {0} +Example: python {0} ./001_1GPW.hdf5 ./001_1GPW_sub10.hdf5 10 +""" +import sys +import h5py + +USAGE = __doc__.format(__file__) + + +def check_input(args): + if len(args) != 3: + sys.stderr.write(USAGE) + sys.exit(1) + + +def get_h5subset(fin, fout, n): + """Extract first number of groups and write to a new hdf5 file. + + Args: + fin (hdf5): input hdf5 file. + fout (hdf5): output hdf5 file. + n (int): first n groups to write. + """ + n = int(n) + h5 = h5py.File(fin, "r") + h5out = h5py.File(fout, "w") + print(f"First {n} groups in {fin}:") + for i in list(h5)[0:n]: + print(i) + h5.copy(h5[i], h5out) + + print() + print(f"Groups in {fout}:") + print(list(h5out)) + h5.close() + h5out.close() + print() + print(f"{fout} generated.") + + +if __name__ == "__main__": + check_input(sys.argv[1:]) + fin, fout, n = sys.argv[1:] + get_h5subset(fin, fout, n) diff --git a/deeprank/utils/plot_hitrate.py b/deeprank/utils/plot_hitrate.py deleted file mode 100644 index ccc4a3f5..00000000 --- a/deeprank/utils/plot_hitrate.py +++ /dev/null @@ -1,207 +0,0 @@ -import os -from deeprank.learn import rankingMetrics -import matplotlib.pyplot as plt -import matplotlib.ticker as mtick -import numpy as np -import h5py -import sys -import torch -import torch.nn as nn -import torch.nn.functional as F -import pandas as pd -from ggplot import * - -def plot_boxplot_todo(hdf5,epoch=None,figname=None,inverse = False): - - ''' - Plot a boxplot of predictions VS targets useful ' - to visualize the performance of the training algorithm - This is only usefull in classification tasks - - Args: - figname (str): filename - - ''' - - print('\n --> Box Plot : ', figname, '\n') - - color_plot = {'train':'red','valid':'blue','test':'green'} - labels = ['train','valid','test'] - - #-- read data - h5 = h5py.File(hdf5,'r') - if epoch is None: - keys = list(h5.keys()) - last_epoch_key = list(filter(lambda x: 'epoch_' in x,keys))[-1] - else: - last_epoch_key = 'epoch_%04d' %epoch - if last_epoch_key not in h5: - print('Incorrect epcoh name\n Possible options are: ' + ' '.join(list(h5.keys()))) - h5.close() - return - h5data = h5[last_epoch_key] - - print(f"Generate boxplot for {last_epoch_key} epoch ...") - - - n_panels = len(labels) - data = pd.DataFrame() - - for l in labels: - - if l in h5data: - - tar = h5data[l]['targets'] - raw_out = h5data[l]['outputs'] - - num_hits = list(tar.value).count(1) - total_num=len(tar) - print(f"According to 'targets' -> num of hits for {l}: {num_hits} out of {len(tar.value)}") - m = nn.Softmax(dim = 0) - final_out = np.array(m(torch.FloatTensor(raw_out))) - data_df = pd.DataFrame(list(zip([l]*total_num,raw_out,tar,final_out[:,1])), columns = ['label', 'raw_out', 'target', 'prediction']) - data = pd.concat([data, data_df] ) - print(data) - p= ggplot(aes(x = "target", y = "prediction"), data=data) +geom_boxplot() + facet_grid(None, "label") - p.save(figname) - - - - -def plot_boxplot(hdf5,epoch=None,figname=None,inverse = False): - - ''' - Plot a boxplot of predictions VS targets useful ' - to visualize the performance of the training algorithm - This is only usefull in classification tasks - - Args: - figname (str): filename - - ''' - - print('\n --> Box Plot : ', figname, '\n') - - color_plot = {'train':'red','valid':'blue','test':'green'} - labels = ['train','valid','test'] - - #-- read data - h5 = h5py.File(hdf5,'r') - if epoch is None: - keys = list(h5.keys()) - last_epoch_key = list(filter(lambda x: 'epoch_' in x,keys))[-1] - else: - last_epoch_key = 'epoch_%04d' %epoch - if last_epoch_key not in h5: - print('Incorrect epcoh name\n Possible options are: ' + ' '.join(list(h5.keys()))) - h5.close() - return - h5data = h5[last_epoch_key] - - print(f"Generate boxplot for {last_epoch_key} epoch ...") - - - nwin = len(h5data) - - fig, ax = plt.subplots(1, nwin, sharey=True, squeeze=False) - - iwin = 0 - for l in labels: - - if l in h5data: - - tar = h5data[l]['targets'].value - out = h5data[l]['outputs'].value - - num_hits = list(tar).count(1) - print(f"According to 'targets' -> num of hits for {l}: {num_hits} out of {len(tar)}") - - data = [[], []] - for pts,t in zip(out,tar): - r = F.softmax(torch.FloatTensor(pts), dim=0).data.numpy() - #print(f"prediction: {pts}; target: {t}; r: {r}") - data[t].append(r[1]) - - ax[0, iwin].boxplot(data) - ax[0, iwin].set_xlabel(l) - ax[0, iwin].set_xticklabels(['0', '1']) - iwin += 1 - - fig.savefig(figname, bbox_inches='tight') - plt.close() - - - - -def plot_hit_rate(hdf5,epoch=None,figname=None,inverse = False): - '''Plot the hit rate of the different training/valid/test sets - - The hit rate is defined as: - the percentage of positive decoys that are included among the top m decoys. - a positive decoy is a native-like one with a i-rmsd <= 4A - - Args: - figname (str): filename for the plot - irmsd_thr (float, optional): threshold for 'good' models - - ''' - - - print('\n --> Hit Rate :', figname, '\n') - - color_plot = {'train':'red','valid':'blue','test':'green'} - labels = ['train','valid','test'] - - #-- read data - h5 = h5py.File(hdf5,'r') - if epoch is None: - keys = list(h5.keys()) - last_epoch_key = list(filter(lambda x: 'epoch_' in x,keys))[-1] - else: - last_epoch_key = 'epoch_%04d' %epoch - if last_epoch_key not in h5: - print('Incorrect epcoh name\n Possible options are: ' + ' '.join(list(h5.keys()))) - h5.close() - return - data = h5[last_epoch_key] - - print(f"Generate hit rate plot for {last_epoch_key} epoch ...") - - # plot - fig,ax = plt.subplots() - for l in labels: - # l = train, valid or test - if l in data: - if 'hit' in data[l]: - - #-- count num_hit - hit_labels = data[l]['hit'].value # hit labels for each model: [0 1 0 0 1...] - num_hits = list(hit_labels).count(1) - print(f"According to 'hit' -> num of hits for {l}: {num_hits} out of {len(hit_labels)}") - - #-- calculate and plot hit rate - hitrate = rankingMetrics.hitrate(data[l]['hit']) - m = len(hitrate) - x = np.linspace(0,100,m) - plt.plot(x,hitrate,c = color_plot[l],label=l+' M=%d' %m) - legend = ax.legend(loc='upper left') - ax.set_xlabel('Top M (%)') - ax.set_ylabel('Hit Rate') - - fmt = '%.0f%%' - xticks = mtick.FormatStrFormatter(fmt) - ax.xaxis.set_major_formatter(xticks) - - fig.savefig(figname) - plt.close() - -if __name__ == '__main__': - - if len(sys.argv) !=4: - print(f"Usage: {sys.argv[0]} epoch_data.hdf5 epoch fig_name") - sys.exit() - hdf5 = sys.argv[1] #'epoch_data.hdf5' - epoch = int(sys.argv[2]) # 9 - figname = sys.argv[3] - plot_hit_rate(hdf5,epoch=epoch,figname=figname + '.hitrate.png',inverse = False) - plot_boxplot(hdf5,epoch=None,figname=figname + '.boxplot.png',inverse = False) diff --git a/deeprank/utils/plot_hitrate_boxplot.py b/deeprank/utils/plot_hitrate_boxplot.py deleted file mode 100755 index 9fef1580..00000000 --- a/deeprank/utils/plot_hitrate_boxplot.py +++ /dev/null @@ -1,403 +0,0 @@ -# 1. plot prediction scores for class 0 and 1 using two-panel box plots -# 2. plot hit rate plot -import os -from deeprank.learn import rankingMetrics -import matplotlib.pyplot as plt -import matplotlib.ticker as mtick -import numpy as np -import h5py -import sys -import torch -import torch.nn as nn -import torch.nn.functional as F -import pandas as pd -from ggplot import * -import glob -import re -import pdb -from tqdm import tqdm - -def plot_boxplot_todo(hdf5,epoch=None,figname=None,inverse = False): - - ''' - Plot a boxplot of predictions VS targets useful ' - to visualize the performance of the training algorithm - This is only usefull in classification tasks - - Args: - figname (str): filename - - ''' - - print('\n --> Box Plot : ', figname, '\n') - - color_plot = {'train':'red','valid':'blue','test':'green'} - labels = ['train','valid','test'] - - #-- read data - h5 = h5py.File(hdf5,'r') - if epoch is None: - keys = list(h5.keys()) - last_epoch_key = list(filter(lambda x: 'epoch_' in x,keys))[-1] - else: - last_epoch_key = 'epoch_%04d' %epoch - if last_epoch_key not in h5: - print('Incorrect epcoh name\n Possible options are: ' + ' '.join(list(h5.keys()))) - h5.close() - return - h5data = h5[last_epoch_key] - - print(f"Generate boxplot for {last_epoch_key} epoch ...") - - - n_panels = len(labels) - data = pd.DataFrame() - - for l in labels: - - if l in h5data: - - tar = h5data[l]['targets'] - raw_out = h5data[l]['outputs'] - - num_hits = list(tar.value).count(1) - total_num=len(tar) - print(f"According to 'targets' -> num of hits for {l}: {num_hits} out of {len(tar.value)}") - m = nn.Softmax(dim = 0) - final_out = np.array(m(torch.FloatTensor(raw_out))) - data_df = pd.DataFrame(list(zip([l]*total_num,raw_out,tar,final_out[:,1])), columns = ['label', 'raw_out', 'target', 'prediction']) - data = pd.concat([data, data_df] ) - print(data) - p= ggplot(aes(x = "target", y = "prediction"), data=data) +geom_boxplot() + facet_grid(None, "label") - p.save(figname) - -def sort_modelIDs_by_deeprank(modelIDs, deeprank_score): - out = F.softmax(torch.FloatTensor(deeprank_score), dim=1).data.numpy()[:,1] -# modelIDs_sorted = [y for x, y in sorted(zip(out,modelIDs))] -# modelIDs_sorted = modelIDs_sorted[::-1] #reverse the list - - xue = pd.DataFrame(list(zip(modelIDs, out)), columns = ['modelID', 'final_S']) - xue_sorted = xue.sort_values(by='final_S', ascending=False) - modelIDs_sorted = list(xue_sorted['modelID']) - return modelIDs_sorted - - - - - -def plot_boxplot(hdf5,epoch=None,figname=None,inverse = False): - - ''' - Plot a boxplot of predictions VS targets useful ' - to visualize the performance of the training algorithm - This is only usefull in classification tasks - - Args: - figname (str): filename - - ''' - - print('\n --> Box Plot : ', figname, '\n') - - color_plot = {'train':'red','valid':'blue','test':'green'} - labels = ['train','valid','test'] - - #-- read data - h5 = h5py.File(hdf5,'r') - if epoch is None: - keys = list(h5.keys()) - last_epoch_key = list(filter(lambda x: 'epoch_' in x,keys))[-1] - else: - last_epoch_key = 'epoch_%04d' %epoch - if last_epoch_key not in h5: - print('Incorrect epcoh name\n Possible options are: ' + ' '.join(list(h5.keys()))) - h5.close() - return - h5data = h5[last_epoch_key] - - print(f"Generate boxplot for {last_epoch_key} epoch ...") - - - nwin = len(h5data) - - fig, ax = plt.subplots(1, nwin, sharey=True, squeeze=False) - - iwin = 0 - for l in labels: - - if l in h5data: - - tar = h5data[l]['targets'].value - out = h5data[l]['outputs'].value - - num_hits = list(tar).count(1) - print(f"According to 'targets' -> num of hits for {l}: {num_hits} out of {len(tar)}") - - data = [[], []] - for pts,t in zip(out,tar): - r = F.softmax(torch.FloatTensor(pts), dim=0).data.numpy() - #print(f"prediction: {pts}; target: {t}; r: {r}") - data[t].append(r[1]) - - ax[0, iwin].boxplot(data) - ax[0, iwin].set_xlabel(l) - ax[0, iwin].set_xticklabels(['0', '1']) - iwin += 1 - - fig.savefig(figname, bbox_inches='tight') - plt.close() - - -def plot_hit_rate_withHS(hdf5,HS_DIR=None, epoch=None,figname=None,inverse = False): - '''Plot the hit rate of the different training/valid/test sets with HS (haddock scores) - - The hit rate is defined as: - the percentage of positive decoys that are included among the top m decoys. - a positive decoy is a native-like one with a i-rmsd <= 4A - - Args: - HS_DIR (str): the directory where HS files are stored - figname (str): filename for the plot - irmsd_thr (float, optional): threshold for 'good' models - - ''' - - - print('\n --> Hit Rate :', figname, '\n') - - color_plot = {'train':'red','valid':'blue','test':'green', 'HS-train': 'red', 'HS-valid':'blue','HS-test': 'green'} - line_styles = {'train':'-', 'valid':'-', 'test':'-', 'HS-train':'--', 'HS-valid':'--', 'HS-test':'--'} - labels = ['train','valid','test', 'HS'] - - #-- read haddock data - stats = read_haddockScoreFL(HS_DIR) - haddockS = stats['haddock-score']# haddockS[modelID] = score - - #-- read data - h5 = h5py.File(hdf5,'r') - if epoch is None: - keys = list(h5.keys()) - last_epoch_key = list(filter(lambda x: 'epoch_' in x,keys))[-1] - else: - last_epoch_key = 'epoch_%04d' %epoch - if last_epoch_key not in h5: - print('Incorrect epcoh name\n Possible options are: ' + ' '.join(list(h5.keys()))) - h5.close() - return - data = h5[last_epoch_key] - - print(f"Generate hit rate plot for {last_epoch_key} epoch ...") - - # plot - fig,ax = plt.subplots() - for l in labels: - # l = train, valid or test - if l in data: - if 'hit' in data[l]: - - #-- calculate and plot hit rate)for haddock -# pdb.set_trace() - hit_labels_deeprank = data[l]['hit'].value # np.ndarray, hit labels for each model: [0 1 0 0 1...] - modelIDs_deeprank = sort_modelIDs_by_deeprank(list(data[l]['mol'][:,1]), data[l]['outputs']) # np.ndarry, models IDs ranked by deeprank - -# xue = pd.DataFrame(list(zip(modelIDs_deeprank, hit_labels_deeprank)), columns=['modelIDs_DR', 'hit_labels_DR']) -# xue.to_csv('xue.tsv', sep="\t") - [hit_labels_HS, modelIDs_woHS] = get_hit_labels_HS(haddockS, hit_labels_deeprank, modelIDs_deeprank) - hitrate_HS = rankingMetrics.hitrate(hit_labels_HS) - - m = len(hitrate_HS) - x = np.linspace(0,100,m) - legend='HS-' + l - print(f"legend:{legend}") - ax.plot(x,hitrate_HS,color = color_plot[legend], linestyle = line_styles[legend], label=f'{legend}'+' M=%d' %m) -# pdb.set_trace() - - #-- remove refe pdb from hit rate calcuatioin as we do not have HS for refe - print(f"Models w/o haddock scores: {modelIDs_woHS}.") - print(f"Now remove them from calculating hit rate!") - indices_to_remove = [ modelIDs_deeprank.index(x) for x in modelIDs_woHS] - hit_labels_deeprank = np.delete(hit_labels_deeprank, indices_to_remove) - - #-- calculate and plot hit rate for deeprank -# hitrate_deeprank = rankingMetrics.hitrate(data[l]['hit']) - hitrate_deeprank = rankingMetrics.hitrate(hit_labels_deeprank) - m = len(hitrate_deeprank) - x = np.linspace(0,100,m) - ax.plot(x,hitrate_deeprank,c = color_plot[l],linestyle = line_styles[l], label=l+' M=%d' %m) - - #-- count num_hit - num_hits = list(hit_labels_deeprank).count(1) - print(f"According to 'hit' -> num of hits for {l}: {num_hits} out of {len(hit_labels_deeprank)}") - - - #-- write to csv file - #pdb.set_trace() - - - - - legend = ax.legend(loc='upper left') - ax.set_xlabel('Top M (%)') - ax.set_ylabel('Hit Rate') - - fmt = '%.0f%%' - xticks = mtick.FormatStrFormatter(fmt) - ax.xaxis.set_major_formatter(xticks) - - fig.savefig(figname) - plt.close() - - -def get_hit_labels_HS(haddockS, hit_labels_deeprank, modelIDs_deeprank): - # reorder hit_labels_deeprank based on haddock scores - - HS=[] - modelIDs_woHS = [] - modelIDs_HS = [] - - for modelID in modelIDs_deeprank: - if modelID in haddockS: - HS.append(haddockS[modelID]) - modelIDs_HS.append(modelID) - else: - modelIDs_woHS.append(modelID) - - - #-- remove refe pdb from hit rate calcuatioin as we do not have HS for refe - print(f"Models w/o haddock scores: {modelIDs_woHS}.") - print(f"Now remove them from calculating hit rate!") - indices_to_remove = [ modelIDs_deeprank.index(x) for x in modelIDs_woHS] - hit_labels_deeprank = np.delete(hit_labels_deeprank, indices_to_remove) - - data = pd.DataFrame(list(zip(modelIDs_HS, HS, hit_labels_deeprank)), columns = ['modelID', 'HS', 'hit_labels']) - data_sorted = data.sort_values(by='HS', ascending = True) - hit_labels_HS = data_sorted['hit_labels'] - - return hit_labels_HS, modelIDs_woHS - - - -def read_haddockScoreFL(HS_DIR): - ''' - input: str. /home/lixue/DBs/BM5-haddock24/stats - output: dict. stats['haddock-score'][modelID] = score - - stat file format: - - #struc haddock-score i-RMSD Einter Enb Evdw+0.1Eelec Evdw Eelec Eair Ecdih Ecoup Esani Evean Edani #NOEviol #Dihedviol #Jviol #Saniviol #veanviol #Daniviol bsa dH Edesolv - 1A2K_cm-itw_31w.pdb -124.227921 18.868 -322.996 -323.353 -98.3385 -73.3369 -250.016 0.356572 0 0 0 0 0 0 0 0 0 0 0 2094.49 -30.0497 -0.887821 - 1A2K_cm-itw_187w.pdb -123.982600 18.968 -383.472 -384.327 -76.94 -42.7859 -341.541 0.855228 0 0 0 0 0 0 0 0 0 0 0 1671.79 -71.7494 -12.8885 - - ''' - - stat_FLs = glob.glob(f"{HS_DIR}/*.stats") - stats = {} - - stat_FLs = tqdm(stat_FLs, desc='read stat files for haddock scores', disable = False) - for statFL in stat_FLs: - - f = open(statFL,'r') - for line in f: - - line = line.rstrip() - line = line.strip() - - if re.search('^#',line): - headers = re.split('\s+',line) - headers = headers[1:] - continue - values = re.split('\s+',line) - modelID = re.sub('.pdb','', values.pop(0)) - - if len(headers) != len(values): - sys.exit(f'header field number {len(headers)} is different from the value field number {len(values)}. Check the format of {statFL}') - - for idx, h in enumerate(headers): - if h not in stats: - stats[h]={} - - stats[h][modelID] = float(values[idx]) - - if not stats or not headers or not values: - sys.exit(f"headers or values or stats not defined. Check the format of {statFL}") - f.close() - - return stats - - -def plot_hit_rate(hdf5,epoch=None,figname=None,inverse = False): - '''Plot the hit rate of the different training/valid/test sets - - The hit rate is defined as: - the percentage of positive decoys that are included among the top m decoys. - a positive decoy is a native-like one with a i-rmsd <= 4A - - Args: - figname (str): filename for the plot - irmsd_thr (float, optional): threshold for 'good' models - - ''' - - - print('\n --> Hit Rate :', figname, '\n') - - color_plot = {'train':'red','valid':'blue','test':'green'} - labels = ['train','valid','test'] - - #-- read data - h5 = h5py.File(hdf5,'r') - if epoch is None: - keys = list(h5.keys()) - last_epoch_key = list(filter(lambda x: 'epoch_' in x,keys))[-1] - else: - last_epoch_key = 'epoch_%04d' %epoch - if last_epoch_key not in h5: - print('Incorrect epcoh name\n Possible options are: ' + ' '.join(list(h5.keys()))) - h5.close() - return - data = h5[last_epoch_key] - - print(f"Generate hit rate plot for {last_epoch_key} epoch ...") - - # plot - fig,ax = plt.subplots() - for l in labels: - # l = train, valid or test - if l in data: - if 'hit' in data[l]: - - #-- count num_hit - hit_labels = data[l]['hit'].value # hit labels for each model: [0 1 0 0 1...] - num_hits = list(hit_labels).count(1) - print(f"According to 'hit' -> num of hits for {l}: {num_hits} out of {len(hit_labels)}") - - #-- calculate and plot hit rate - hitrate = rankingMetrics.hitrate(data[l]['hit']) - m = len(hitrate) - x = np.linspace(0,100,m) - plt.plot(x,hitrate,c = color_plot[l],label=l+' M=%d' %m) - legend = ax.legend(loc='upper left') - ax.set_xlabel('Top M (%)') - ax.set_ylabel('Hit Rate') - - fmt = '%.0f%%' - xticks = mtick.FormatStrFormatter(fmt) - ax.xaxis.set_major_formatter(xticks) - - fig.savefig(figname) - plt.close() - -def main(): - if len(sys.argv) !=4: - print(f"Usage: {python sys.argv[0]} epoch_data.hdf5 epoch fig_name") - sys.exit() - hdf5 = sys.argv[1] #'epoch_data.hdf5' - epoch = int(sys.argv[2]) # 9 - figname = sys.argv[3] - plot_hit_rate_withHS(hdf5,HS_DIR='/home/lixue/DBs/BM5-haddock24/stats', epoch=epoch,figname=figname + '.hitrate_wHS.png',inverse = False) - plot_hit_rate(hdf5,epoch=epoch,figname=figname + '.hitrate.png',inverse = False) - plot_boxplot(hdf5,epoch=epoch,figname=figname + '.boxplot.png',inverse = False) -if __name__ == '__main__': - main() - - diff --git a/deeprank/utils/plot_utils.py b/deeprank/utils/plot_utils.py new file mode 100755 index 00000000..4228ab69 --- /dev/null +++ b/deeprank/utils/plot_utils.py @@ -0,0 +1,580 @@ +# 1. plot prediction scores for class 0 and 1 using two-panel box plots +# 2. hit rate plot +# 3. success rate plot +import numpy as np +import h5py +import sys +import torch +import torch.nn.functional as F +import pandas as pd +import re +from itertools import zip_longest + +from cal_hitrate_successrate import evaluate +from cal_hitrate_successrate import ave_evaluate +from cal_hitrate_successrate import add_rank + +import warnings +from rpy2.rinterface import RRuntimeWarning +warnings.filterwarnings("ignore", category=RRuntimeWarning) + +from rpy2.robjects.lib.ggplot2 import * +from rpy2.robjects import pandas2ri +import rpy2.robjects as ro + + +def zip_equal(*iterables): + sentinel = object() + for combo in zip_longest(*iterables, fillvalue=sentinel): + if sentinel in combo: + # if an element is None + raise ValueError(f'Iterables have different lengths: {combo}') + yield combo + + +def plot_boxplot(df,figname=None,inverse = False): + + ''' + Plot a boxplot of predictions vs. targets. Useful + to visualize the performance of the training algorithm. + This is only useful in classification tasks. + + INPUT (pd.DataFrame): + + label modelID target DR sourceFL + Test 1AVX_ranair-it0_5286 0 0.503823 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + Test 1AVX_ti5-itw_354w 1 0.502845 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + Test 1AVX_ranair-it0_6223 0 0.511688 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + ''' + + print('\n --> Box Plot : ', figname, '\n') + + data = df + + font_size = 20 + #line = "#1F3552" + + text_style = element_text(size = font_size, family = "Tahoma", face = "bold") + + colormap_raw =[['0','ivory3'], + ['1','steelblue']] + + colormap = ro.StrVector([elt[1] for elt in colormap_raw]) + colormap.names = ro.StrVector([elt[0] for elt in colormap_raw]) + + p= ggplot(data) + \ + aes_string(x='target', y='DR' , fill='target' ) + \ + geom_boxplot( width = 0.2, alpha = 0.7) + \ + facet_grid(ro.Formula('.~label')) +\ + scale_fill_manual(values = colormap ) + \ + theme_bw() +\ + theme(**{'plot.title' : text_style, + 'text': text_style, + 'axis.title': text_style, + 'axis.text.x': element_text(size = font_size), + 'legend.position': 'right'} ) +\ + scale_x_discrete(name = "Target") + + # p.plot() + ggplot2.ggsave(figname, dpi = 100) + return p + + +def read_epoch_data(DR_h5FL, epoch): + ''' + # read epoch data into a data frame + + OUTPUT (pd.DataFrame): + + label modelID target DR sourceFL + 0 Test 1AVX_ranair-it0_5286 0 0.503823 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + 1 Test 1AVX_ti5-itw_354w 1 0.502845 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + ''' + + #-- 1. read deeprank output data for the specific epoch + h5 = h5py.File(DR_h5FL,'r') + if epoch is None: + print (f"epoch is not provided. Use the last epoch data.") + keys = list(h5.keys()) + last_epoch_key = list(filter(lambda x: 'epoch_' in x,keys))[-1] + else: + last_epoch_key = 'epoch_%04d' %epoch + if last_epoch_key not in h5: + print('Incorrect epcoh name\n Possible options are: ' + ' '.join(list(h5.keys()))) + h5.close() + return + data = h5[last_epoch_key] + + + #-- 2. convert into pd.DataFrame + labels = list(data) # labels = ['train', 'test', 'valid'] + + # write a dataframe of DR and label + to_plot = pd.DataFrame() + for l in labels: + # l = train, valid or test + source_hdf5FLs = data[l]['mol'][:,0] + modelIDs = list(data[l]['mol'][:,1]) + DR_rawOut = data[l]['outputs'] + DR = F.softmax(torch.FloatTensor(DR_rawOut), dim = 1) + DR = np.array(DR[:,0]) # the probability of a model being negative + + targets = data[l]['targets'][()] + targets = targets.astype(np.str) + + to_plot_tmp = pd.DataFrame(list(zip_equal(source_hdf5FLs, modelIDs, targets, DR)), columns = ['sourceFL', 'modelID', 'target', 'DR']) + to_plot_tmp['label'] = l.capitalize() + to_plot = to_plot.append(to_plot_tmp) + + to_plot['target'] = pd.Categorical(to_plot['target'], categories=['0', '1']) + to_plot['label'] = pd.Categorical(to_plot['label'], categories=['Train', 'Valid', 'Test']) + + cols = ['label', 'modelID', 'target', 'DR', 'sourceFL'] + to_plot = to_plot[cols] + + + return to_plot + +def merge_HS_DR(DR_df, haddockS): + + ''' + INPUT 1 (DR_df: a data frame): + + label modelID target DR sourceFL + 0 Test 1AVX_ranair-it0_5286 0 0.503823 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + 1 Test 1AVX_ti5-itw_354w 1 0.502845 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + 2 Test 1AVX_ranair-it0_6223 0 0.511688 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + + INPUT 2: haddockS[modelID] = score + + OUTPUT (a data frame): + + label caseID modelID target score_method1 score_method2 + Test 1ZHI 1ZHI_294w 0 9.758 -19.3448 + Test 1ZHI 1ZHI_89w 1 17.535 -11.2127 + Train 1ACB 1ACB_9w 1 14.535 -19.2127 + ''' + + + #-- merge HS with DR predictions, model IDs and class IDs + modelIDs = DR_df['modelID'] + HS, idx_keep = get_HS(modelIDs, haddockS) + + data = DR_df.iloc[idx_keep,:].copy() + data['HS'] = HS + data['caseID'] = [re.split('_', x)[0] for x in data['modelID']] + + + #-- reorder columns + col_ori = data.columns + col = ['label', 'caseID', 'modelID', 'target', 'sourceFL'] + col.extend( [x for x in col_ori if x not in col]) + data = data[col] + + return data + + +def read_haddockScoreFL(HS_h5FL): + + print(f"Reading haddock score files: {HS_h5FL} ...") + data = pd.read_hdf(HS_h5FL) + + stats = {} + stats['haddock-score'] = {} +# stats['i-RMSD'] = {} + + modelIDs = [ re.sub('.pdb','',x) for x in data['modelID'] ] # remove .pdb from model ID + stats['haddock-score'] = dict(zip_equal(modelIDs, data['haddock-score'])) +# stats['i-RMSD'] = dict(zip(modelIDs, data['i-RMSD'])) # some i-RMSDs are wrong!!! Reported an issue. + + return stats + +def plot_DR_iRMSD(df, figname=None): + ''' + Plot a scatter plot of DeepRank score vs. iRMSD for train, valid and test + + INPUT (a data frame): + + label caseID modelID target sourceFL DR irmsd HS + Test 1AVX 1AVX_ranair-it0_5286 0 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.503823 25.189108 6.980802 + Test 1AVX 1AVX_ti5-itw_354w 1 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.502845 3.668682 -95.158100 + + ''' + print('\n --> Scatter plot of DR vs. iRMSD:', figname, '\n') + + # plot + + font_size = 16 + text_style = element_text(size = font_size, family = "Tahoma", face = "bold") + p = ggplot(df) + aes_string(y = 'irmsd', x = 'DR') +\ + facet_grid(ro.Formula('.~label')) + \ + geom_point(alpha = 0.5) + \ + theme_bw() +\ + theme(**{'plot.title' : text_style, + 'text': text_style, + 'axis.title': text_style, + 'axis.text.x': element_text(size = font_size + 2), + 'axis.text.y': element_text(size = font_size + 2)} ) + \ + scale_y_continuous(name = "i-RMSD") + + #p.plot() + ggplot2.ggsave(figname, height = 7 , width = 7 * 1.5, dpi = 100) + return p + + + +def plot_HS_iRMSD(df, figname=None): + ''' + Plot a scatter plot of HS vs. iRMSD for train, valid and test + + INPUT (a data frame): + + label caseID modelID target sourceFL DR irmsd HS + Test 1AVX 1AVX_ranair-it0_5286 0 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.503823 25.189108 6.980802 + Test 1AVX 1AVX_ti5-itw_354w 1 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.502845 3.668682 -95.158100 + + ''' + print('\n --> Scatter plot of HS vs. iRMSD:', figname, '\n') + + # plot + font_size = 16 + text_style = element_text(size = font_size, family = "Tahoma", face = "bold") + p= ggplot(df) + aes_string(y = 'irmsd', x = 'HS') +\ + facet_grid(ro.Formula('.~label')) + \ + geom_point(alpha = 0.5) + \ + theme_bw() +\ + theme(**{'plot.title' : text_style, + 'text': text_style, + 'axis.title': text_style, + 'axis.text.x': element_text(size = font_size + 2), + 'axis.text.y': element_text(size = font_size + 2)} ) + \ + scale_y_continuous(name = "i-RMSD") + + #p.plot() + ggplot2.ggsave(figname, height = 7 , width = 7 * 1.5, dpi=100) + return p + + +def plot_successRate_hitRate (df, figname=None,inverse = False): + '''Plot the hit rate and success_rate of the different training/valid/test sets with HS (haddock scores) + + The hit rate is defined as: + the percentage of positive decoys that are included among the top m decoys. + a positive decoy is a native-like one with a i-rmsd <= 4A + + Args: + DR_h5FL (str): the hdf5 file generated by DeepRank. + HS_h5FL (str): the hdf5 file that saves data from haddock *.stats files + figname (str): filename for the plot + + Steps: + 0. Input data: + + label caseID modelID target DR HS + 0 Test 1AVX 1AVX_ranair-it0_5286 0 0.503823 6.980802 + 1 Test 1AVX 1AVX_ti5-itw_354w 1 0.502845 -95.158100 + 2 Test 1AVX 1AVX_ranair-it0_6223 0 0.511688 -11.961460 + + + 1. For each case, calculate hit rate and success. Success is a binary, indicating whether this case is success when evaluating its top N models. + + caseID success_DR hitRate_DR success_HS hitRate_HS + 1ZHI 1 0.1 0 0.01 + 1ZHI 1 0.2 1 0.3 + ... + + 1ACB 0 0 1 0.3 + 1ACB 1 0.2 1 0.4 + ... + 2. Calculate success rate and hit rate over all cases. + + + ''' + + #-- 1. calculate success rate and hit rate + performance_per_case = evaluate(df) + performance_ave = ave_evaluate(performance_per_case) + performance_ave = add_rank(performance_ave) + + #-- 2. plot + plot_evaluation(performance_ave, figname) + + +def plot_evaluation(df, figname): + ''' + INPUT: + label success_DR hitRate_DR success_HS hitRate_HS rank + Test 0.0 0.000000 0.0 0.000000 0.000949 + Test 0.0 0.000000 1.0 0.012821 0.001898 + + Train 0.0 0.000000 1.0 0.012821 0.002846 + Train 0.0 0.000000 1.0 0.025641 0.003795 + + ''' + + #---------- hit rate plot ------- + figname1 = figname + '.hitRate.png' + print(f'\n --> Hit Rate plot:', figname1, '\n') + hit_rate_plot(df) + ggplot2.ggsave(figname1, height = 7 , width = 7 * 1.2, dpi = 100) + + + #---------- success rate plot ------- + figname2 = figname + '.successRate.png' + print(f'\n --> Success Rate plot:', figname2, '\n') + + success_rate_plot(df) + ggplot2.ggsave(figname2, height = 7 , width = 7 * 1.2, dpi=100) + + + +def hit_rate_plot(df): + ''' + INPUT: + label success_DR hitRate_DR success_HS hitRate_HS rank + Test 0.0 0.000000 0.0 0.000000 0.000949 + Test 0.0 0.000000 1.0 0.012821 0.001898 + + Train 0.0 0.000000 1.0 0.012821 0.002846 + Train 0.0 0.000000 1.0 0.025641 0.003795 + + ''' + + #-- melt df + df_melt = pd.melt(df, id_vars=['label', 'rank']) + idx1 = df_melt.variable.str.contains('^hitRate') + df_tmp = df_melt.loc[idx1,:].copy() + df_tmp.columns = ['Sets', 'rank', 'Methods', 'hit_rate'] + + tmp = list(df_tmp['Methods']) + df_tmp.loc[:,'Methods']= [re.sub('hitRate_','',x) for x in tmp] # success_DR -> DR + + font_size = 20 + breaks = pd.to_numeric(np.arange(0,1.01,0.25)) + xlabels = list(map(lambda x: str('%d' % (x*100)) + ' % ', np.arange(0,1.01,0.25)) ) + text_style = element_text(size = font_size, family = "Tahoma", face = "bold") + + p = ggplot(df_tmp) + \ + aes_string(x='rank', y = 'hit_rate', color='Sets', linetype= 'Methods') + \ + geom_line(size=1) + \ + labs(**{'x': 'Top models (%)', 'y': 'Hit Rate'}) + \ + theme_bw() + \ + theme(**{'legend.position': 'right', + 'plot.title': text_style, + 'text': text_style, + 'axis.text.x': element_text(size = font_size), + 'axis.text.y': element_text(size = font_size)}) +\ + scale_x_continuous(**{'breaks':breaks, 'labels': xlabels}) + + return p + +def success_rate_plot(df): + ''' + # INPUT: a pandas data frame + label success_HS hitRate_HS success_DR hitRate_DR + 0 valid 1.0 1.0 0.0 0.0 + 1 valid 0.0 1.0 0.0 0.0 + ''' + + #-- add the 'rank' column to df + rank = [] + for _, df_per_label in df.groupby('label'): + num_mol = len(df_per_label) + rank_raw = np.array(range(num_mol )) + 1 + rank.extend(rank_raw/num_mol ) + df['rank'] = rank + + #-- melt df + df_melt = pd.melt(df, id_vars=['label', 'rank']) + idx1 = df_melt.variable.str.contains('^success_') + df_tmp = df_melt.loc[idx1,:].copy() + df_tmp.columns = ['Sets', 'rank', 'Methods', 'success_rate'] + + tmp = list(df_tmp['Methods']) + df_tmp.loc[:,'Methods']= [re.sub('success_','',x) for x in tmp] # success_DR -> DR + + font_size = 20 + breaks = pd.to_numeric(np.arange(0,1.01,0.25)) + xlabels = list(map(lambda x: str('%d' % (x*100)) + ' % ', np.arange(0,1.01,0.25)) ) + text_style = element_text(size = font_size, family = "Tahoma", face = "bold") + + p = ggplot(df_tmp) + \ + aes_string(x='rank', y = 'success_rate', color='Sets', linetype= 'Methods') + \ + geom_line(size=1) + \ + labs(**{'x': 'Top models (%)', 'y': 'Success Rate'}) + \ + theme_bw() + \ + theme(**{'legend.position': 'right', + 'plot.title': text_style, + 'text': text_style, + 'axis.text.x': element_text(size = font_size), + 'axis.text.y': element_text(size = font_size)}) +\ + scale_x_continuous(**{'breaks':breaks, 'labels': xlabels}) + +# p.plot() + return p + +def get_irmsd( source_hdf5, modelIDs): + + irmsd = [] + for h5FL, modelID in zip_equal(source_hdf5, modelIDs): + # h5FL = '/home/lixue/DBs/BM5-haddock24/hdf5/000_1AY7.hdf5' + f = h5py.File(h5FL, 'r') + irmsd.append(f[modelID]['targets/IRMSD'][()]) + f.close() + return irmsd + + + +def get_HS(modelIDs,haddockS): + HS=[] + idx_keep = [] + + for idx, modelID in enumerate(modelIDs): + if modelID in haddockS: + HS.append(haddockS[modelID]) + idx_keep.append(idx) + return HS, idx_keep + +def add_irmsd(df): + + ''' + INPUT (a data frame): + df: + label caseID modelID sourceFL target score_method1 score_method2 + train 1ZHI 1ZHI_294w ..../hdf5/000_1ZHI.hdf5 0 9.758 -19.3448 + test 1ACB 1ACB_89w ..../hdf5/000_1ACB.hdf5 1 17.535 -11.2127 + + OUTPUT (a data frame): + df: + label caseID modelID irmsd target score_method1 score_method2 + train 1ZHI 1ZHI_294w 12.1 0 9.758 -19.3448 + train 1ZHI 1ZHI_89w 1.3 1 17.535 -11.2127 + ... + test 1ACB 1ACB_89w 2.4 1 17.535 -11.2127 + ''' + + modelIDs = df['modelID'] + source_hdf5FLs = df['sourceFL'] + irmsd = np.array(get_irmsd(source_hdf5FLs, modelIDs)) + df['irmsd'] = irmsd + return df + + + +def prepare_df(deeprank_h5FL, HS_h5FL, epoch): + + ''' + OUTPUT: a data frame: + + label caseID modelID target sourceFL DR irmsd HS + Test 1AVX 1AVX_ranair-it0_5286 0 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.503823 25.189108 6.980802 + Test 1AVX 1AVX_ti5-itw_354w 1 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 0.502845 3.668682 -95.158100 + ''' + #-- read deeprank_h5FL epoch data into pd.DataFrame (DR_df) + DR_df = read_epoch_data(deeprank_h5FL, epoch) + + ''' + DR_df (a data frame): + + label modelID target DR sourceFL + 0 Test 1AVX_ranair-it0_5286 0 0.503823 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + 1 Test 1AVX_ti5-itw_354w 1 0.502845 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + 2 Test 1AVX_ranair-it0_6223 0 0.511688 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 + ''' + + #-- add iRMSD column to DR_df + DR_df = add_irmsd(DR_df) + + #-- report the number of hits for train/valid/test + hit_statistics(DR_df) + + #-- add HS to DR_df (note: bound complexes do not have HS) + stats = read_haddockScoreFL(HS_h5FL) + haddockS = stats['haddock-score']# haddockS[modelID] = score + DR_HS_df = merge_HS_DR(DR_df, haddockS) + + ''' + DR_HS_df (a data frame): + + data: + label caseID modelID sourceFL target score_method1 score_method2 + train 1ZHI 1ZHI_294w /home/lixue/DBs/BM5-haddock24/hdf5/000_1ZHI.hdf5 0 9.758 -19.3448 + test 1ACB 1ACB_89w /home/lixue/DBs/BM5-haddock24/hdf5/000_1ACB.hdf5 1 17.535 -11.2127 + ''' + + return DR_HS_df + +def hit_statistics(df): + + ''' + Report the number of hits for Train, valid and test. + + INPUT (a data frame): + + label modelID target DR sourceFL irmsd + Test 1AVX_ranair-it0_5286 0 0.503823 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 25.189108 + Test 1AVX_ti5-itw_354w 1 0.502845 /home/lixue/DBs/BM5-haddock24/hdf5/000_1AVX.hdf5 3.668682 + ''' + + labels = ['Train', 'Valid', 'Test'] + grouped = df.groupby('label') + + #-- 1. count num_hit based on i-rmsd + num_hits = grouped['irmsd'].apply(lambda x: len(x[x<=4])) + num_models = grouped.apply(len) + + for label in labels: + print(f"According to 'i-RMSD' -> num of hits for {label}: {num_hits[label]} out of {num_models[label]} models") + + print("") + #-- 2. count num_hit based on the 'target' column + num_hits = grouped['target'].apply(lambda x: len(x[x=='1'])) + num_models = grouped.apply(len) + + for label in labels: + print(f"According to 'targets' -> num of hits for {label}: {num_hits[label]} out of {num_models[label]} models") + + print("") + #-- 3. report num_cases_wo_hit + df_tmp = df.copy() + df_tmp['caseID'] = df['modelID'].apply(get_caseID) + grouped = df_tmp.groupby(['label', 'caseID']) + num_hits = grouped['target'].apply(lambda x: len(x[x==1])) + grp = num_hits.groupby('label') + num_cases_total = grp.apply(lambda x: len(x)) + num_cases_wo_hit = grp.apply(lambda x: len(x==0)) + + for label in labels: + print(f"According to 'targets' -> {num_cases_wo_hit[label]} out of {num_cases_total[label]} cases do not have any hits for {label}") + print("") + +def get_caseID(modelID): + # modelID = 1AVX_ranair-it0_5286 + # caseID = 1AVX + + tmp = re.split('_', modelID) + caseID = tmp[0] + return caseID + +def main(HS_h5FL= '/home/lixue/DBs/BM5-haddock24/stats/stats.h5'): + if len(sys.argv) !=4: + print(f"Usage: python {sys.argv[0]} epoch_data.hdf5 epoch fig_name") + sys.exit() + deeprank_h5FL = sys.argv[1] #the output h5 file from deeprank: 'epoch_data.hdf5' + epoch = int(sys.argv[2]) # 9 + figname = sys.argv[3] + + pandas2ri.activate() + + df = prepare_df(deeprank_h5FL, HS_h5FL, epoch) + + #-- plot + plot_HS_iRMSD(df, figname=figname + '.epo' + str(epoch) +'.irsmd_HS.png') + plot_DR_iRMSD(df, figname=figname + '.epo' + str(epoch) + '.irsmd_DR.png') + plot_boxplot(df, figname=figname + '.epo' + str(epoch) + '.boxplot.png',inverse = False) + plot_successRate_hitRate(df[['label', 'caseID', 'modelID', 'target', 'DR','HS']].copy(), figname=figname + '.epo' + str(epoch) ,inverse = False) + +if __name__ == '__main__': + main() + + diff --git a/deeprank/utils/run_slurmFLs.py b/deeprank/utils/run_slurmFLs.py new file mode 100755 index 00000000..e77647b8 --- /dev/null +++ b/deeprank/utils/run_slurmFLs.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python +# Li Xue +# 20-Feb-2019 10:50 + +''' +Split multiple jobs into batches and submit to cartesius. + +INPUT: a file that contains all the jobs, for example, + + python /projects/0/deeprank/change_BIN_CLASS.py /projects/000_1ACB.hdf5 & + python /projects/0/deeprank/change_BIN_CLASS.py /projects/000_1AK4.hdf5 & + ... + +''' +import re +import os +import glob +import subprocess +from shlex import quote +from shlex import split +import time + +logDIR='/projects/0/deeprank/BM5/scripts/slurm/change_BINCLASS/hdf5_withGridFeature' +slurmDIR=logDIR +num_cores = 24 # run 24 cores for each slurm job +batch_size = num_cores # number of jobs per slurm file + + +def write_slurmscript(all_job_FL, batch_size, slurmDIR='tmp', logDIRi='tmp'): + + all_job_FL = quote(all_job_FL) + slurmDIR = quote(slurmDIR) + + #- split all_jobs.sh into mutliple files + command = f'cp {all_job_FL} {slurmDIR}' + command = split(command) + subprocess.check_call(command) + + command = f'split -a 3 -d -l {batch_size} --additional-suffix=.slurm {slurmDIR}/{all_job_FL} {slurmDIR}/batch' + print(command) + command = split(command) + subprocess.check_call(command) +# subprocess.check_call(['split' , '-a', '3' ,'-d', f'-l {batch_size}', '--additional-suffix=.slurm' ,f"{slurmDIR}/{all_job_FL}", f"{slurmDIR}/batch"]) + + #-- add slurm header and tail to each file + + batchID = 0 + for slurmFL in glob.glob(f'{slurmDIR}/batch*'): + logFL = slurmFL + '.out' + write_slurm_header(slurmFL, batchID, batch_size, logFL) + write_slurm_tail(slurmFL) + print(slurmFL + ' generated ') + +def submit_slurmscript(slurm_dir, batch_size = 100): + # submit slurm scripts in batches + # each batch waits for the previous batch to finish first. + slu_FLs = glob.glob(slurm_dir + "/*.slurm") + + jobIDs=[] + newjobIDs=[] + num = 0 + for slu_FL in slu_FLs: + + outFL=os.path.splitext(slu_FL)[0] + '.out' + + if os.path.isfile(outFL): + print(f"{outFL} exists. Skip submitting slurm file.") + continue + + num = num + 1 + + if num <= batch_size: +# command = ['sbatch', slu_FLs[i] ] +# print (" ".join(command)) + slu_FL = quote(slu_FL) + command = f'sbatch {slu_FL}' + print(command) + command = split(command) + jobID = subprocess.check_output(command) + jobID = re.findall(r'\d+', str(jobID)) + jobID = jobID[0] + print (num) + print (jobID) + newjobIDs.append(jobID) # these IDs will used for dependency=afterany + + if num >batch_size: +# command=['sbatch', '--dependency=afterany:'+ ":".join(jobIDs), slu_FLs[i] ] +# print (" ".join(command)) + + command = 'sbatch --dependency=afterany:' + ':'.join(jobIDs) + f'{slu_FLs[i]}' + print(command) + command = split(command) + jobID = subprocess.check_output(command) + jobID = re.findall(r'\d+', str(jobID)) + jobID = jobID[0] + print (num) + print (jobID) + newjobIDs.append(jobID) + + if num%batch_size ==0: + print (newjobIDs) + jobIDs=newjobIDs + newjobIDs=[] + print ("------------- new batch --------- \n") + + time.sleep(1) + +# def submit_slrumscript(slurmDIR): +# # submit one by one. +# # Each job waits for the previous job to finish first. +# slu_FLs = glob.glob(slurmDIR+"*.h5.slurm") +# +# jobID_prev='' +# for slurmFL in slu_FLs: +# if jobID_prev == '': +# # submit the first slurm file +# command=['sbatch', slurmFL] +# print (command) +# jobID = subprocess.check_output(command) +# jobID = parse_jobID(jobID) +# jobID_prev = jobID +# else: +# command=['sbatch', f"--dependency=afterany:{jobID_prev}", slurmFL ] +# print (command) +# jobID = subprocess.check_output(command) +# print(jobID) +# jobID = parse_jobID(jobID) +# jobID_prev = jobID +# time.sleep(5) +# + +def parse_jobID(jobID): + # input: b'Submitted batch job 5442433\n ' + # output: 5442433 + jobID = re.findall(r'\d+', str(jobID)) + jobID = jobID[0] + return (jobID) + +def write_slurm_header(slurmFL, batchID, batch_size, logFL): + + #- 1. prepare the header string + header='' + + header = header + "#!/usr/bin/bash\n" + header = header + "#SBATCH -p normal\n" + + jobName = 'batch' + str(batchID) + ".h5" + header = header + "#SBATCH -J " + jobName + "\n" + header = header + "#SBATCH -N 1\n" + header = header + f"#SBATCH --ntasks-per-node={num_cores}\n" + header = header + "#SBATCH -t 04:00:00\n" + + header = header + "#SBATCH -o " + logFL + "\n" + header = header + "#SBATCH -e " + logFL + "\n" + + common_part = """ + start=`date +%s` + + """ + header = header + common_part + + + #- 2. add the header to slurmFL + f = open(slurmFL,'r') + content = f.readlines() + f.close() + + content.insert(0, header) + + f = open(slurmFL,'w') + f.write(''.join(content)) + f.close() + + print(f"slurm header added to {slurmFL}") + + +def write_slurm_tail(slurmFL): + + tail = """ + wait + end=`date +%s` + + runtime=$((end-start)) + echo + echo "total runtime: $runtime sec" + """ + + f = open(slurmFL,'a+') + f.write(tail) + f.close() + + +if not os.path.isdir(slurmDIR): + os.makedirs(slurmDIR) + + +write_slurmscript('all_jobs.sh', batch_size, slurmDIR, logDIR) +#submit_slurmscript(slurmDIR, 200) + diff --git a/setup.py b/setup.py index f7ad9263..7b33875b 100644 --- a/setup.py +++ b/setup.py @@ -15,10 +15,11 @@ 'h5py', 'tqdm', 'pandas', - 'matplotlib', 'torchsummary' ], + 'matplotlib', + 'torchsummary'], - extras_require= { + extras_require={ 'test': ['nose', 'coverage', 'pytest', - 'pytest-cov','codacy-coverage','coveralls'], + 'pytest-cov', 'codacy-coverage', 'coveralls'], } ) diff --git a/test/hitrate_successrate/scores_raw.tsv b/test/hitrate_successrate/scores_raw.tsv new file mode 100644 index 00000000..3ba07336 --- /dev/null +++ b/test/hitrate_successrate/scores_raw.tsv @@ -0,0 +1,19 @@ +label caseID modelID target DR HS +Train 1ACB 1ACB_cm-it0_1000 1 0.5 -90 +Train 1ACB 1ACB_cm-it0_1001 0 0.4 -20 +Train 1ACB 1ACB_cm-it0_1002 0 0.0 -30 +Train 1ACB 1ACB_cm-it0_1003 0 0.1 -40 +Train 1ACB 1ACB_cm-it0_1004 0 0.2 -50 + +Test 2Z0E 2Z0E_cm-it0_1000 0 0.4 -90 +Test 2Z0E 2Z0E_cm-it0_1001 1 0.5 -60 +Test 2Z0E 2Z0E_cm-it0_1002 1 0.0 -10 +Test 2Z0E 2Z0E_cm-it0_1003 0 0.1 -20 +Test 2Z0E 2Z0E_cm-it0_1004 0 0.2 -30 + +Valid 1UUL 1UUL_cm-it0_1000 1 0.4 -90 +Valid 1UUL 1UUL_cm-it0_1001 1 0.5 -60 +Valid 1UUL 1UUL_cm-it0_1002 1 0.0 -10 +Valid 1UUL 1UUL_cm-it0_1003 0 0.1 -20 +Valid 1UUL 1UUL_cm-it0_1004 0 0.2 -30 + diff --git a/test/hitrate_successrate/success_hitrate_ANS.tsv b/test/hitrate_successrate/success_hitrate_ANS.tsv new file mode 100644 index 00000000..cde16d7e --- /dev/null +++ b/test/hitrate_successrate/success_hitrate_ANS.tsv @@ -0,0 +1,16 @@ +label caseID success_DR hitRate_DR success_HS hitRate_HS +Train 1ACB 0 0.0 1 1.0 +Train 1ACB 0 0.0 1 1.0 +Train 1ACB 0 0.0 1 1.0 +Train 1ACB 0 0.0 1 1.0 +Train 1ACB 1 1.0 1 1.0 +Test 2Z0E 1 0.5 0 0.0 +Test 2Z0E 1 0.5 1 0.5 +Test 2Z0E 1 0.5 1 0.5 +Test 2Z0E 1 0.5 1 0.5 +Test 2Z0E 1 1.0 1 1.0 +Valid 1UUL 1 0.3333333333333333 1 0.3333333333333333 +Valid 1UUL 1 0.3333333333333333 1 0.6666666666666666 +Valid 1UUL 1 0.3333333333333333 1 0.6666666666666666 +Valid 1UUL 1 0.6666666666666666 1 0.6666666666666666 +Valid 1UUL 1 1.0 1 1.0 diff --git a/test/test_hitrate_successrate.py b/test/test_hitrate_successrate.py new file mode 100644 index 00000000..9ce70bf8 --- /dev/null +++ b/test/test_hitrate_successrate.py @@ -0,0 +1,62 @@ +import unittest +from deeprank.utils.cal_hitrate_successrate import evaluate +import pandas as pd + + +""" +Some requirement of the naming of the files: + 1. case ID canNOT have underscore '_', e.g., '1ACB_CD' + 2. decoy file name should have this format: 2w83-AB_20.pdb (caseID_xxx.pdb) + 3. pssm file name should have this format: 2w83-AB.A.pssm (caseID.chainID.pssm or caseID.chainID.pdb.pssm) +""" + + +class TestGenerateData(unittest.TestCase): + """Test the calculation of hit rate and success rate.""" + + rawScoreFL = 'hitrate_successrate/scores_raw.tsv' + groundTruth_FL = 'hitrate_successrate/success_hitrate_ANS.tsv' + + def test_1_hitrate_success_averaged_over_cases(self): + + def compare_hitrate_success_one_case(expected_df, real_df, caseID): + + expected_df = expected_df.reset_index() + real_df = real_df.reset_index() + + columns = ['success_DR', 'hitRate_DR', 'success_HS', 'hitRate_HS'] + + for col in columns: + expected = expected_df[col] + real = real_df[col] + error_msg = f"{col} for {label} {caseID} is not correct!" + assert (expected == real).all(), error_msg + + # calculate hitrate and success + rawScore = pd.read_csv(self.rawScoreFL, sep='\t') + hitrate_success_df = evaluate(rawScore) + + # compare with the grount truth + groundTruth_df = pd.read_csv(self.groundTruth_FL, sep='\t') + + labels = groundTruth_df['label'].unique() + caseIDs = {} + truth_grp = groundTruth_df.groupby('label') + + for label, df in truth_grp: + caseIDs[label] = df['caseID'].unique() + + for label in labels: + for caseID in caseIDs[label]: + idx = (groundTruth_df['label'] == label) & ( + groundTruth_df['caseID'] == caseID) + expected_df = groundTruth_df[idx] + + idx = (hitrate_success_df['label'] == label) & ( + hitrate_success_df['caseID'] == caseID) + real_df = hitrate_success_df[idx] + compare_hitrate_success_one_case(expected_df, real_df, caseID) + + +if __name__ == "__main__": + unittest.main()