In [None]:
import numpy as np
import itertools
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
import pickle
import os
try:
    import torch
    from torchvision.utils import make_grid
except ImportError:
    pass

import matplotlib
font = {'family' : "serif",
        'weight' : 'normal',
        'size'   : 24}
TITLE_FONTSIZE = 24
TICK_FONTSIZE = 20
matplotlib.rc('font', **font)
sns.set_palette('colorblind')

In [None]:
def adjacent_values(vals, q1, q3):
    upper_adjacent_value = q3 + (q3 - q1) * 1.5
    upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])

    lower_adjacent_value = q1 - (q3 - q1) * 1.5
    lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
    return lower_adjacent_value, upper_adjacent_value

In [None]:
def get_data_dict(k:int, gen_property_dir:str, sample_result_dir:str, pathway_model:str):
    with open(os.path.join(gen_property_dir, "gen_therapeutic_score_{}_pathway_model.pkl".format(pathway_model)), "rb") as f:
        chem_props_therap = pickle.load(f)
        
    # res_file_uniform = r"/home/alif/JTVAE/sample-results/results-uniform-weight.npz"
    res_file_uniform = os.path.join(sample_result_dir, "results-uniform-weight.npz")
    results_uniform = np.load(res_file_uniform, allow_pickle = True)
    
    res_file_therap_opt = os.path.join(sample_result_dir, r"results-{}-k-{}.npz".format(pathway_model, k))
    results_therap_opt = np.load(res_file_therap_opt, allow_pickle = True)
    
    therapeutic_score_opt_list = []
    therapeutic_score_opt_iter = []
    for smiles in results_uniform['sample_points'][0]:
        if smiles: 
            therapeutic_score_opt_iter.append(chem_props_therap[smiles])
    therapeutic_score_opt_list.append(therapeutic_score_opt_iter)
    for smiles_array in results_therap_opt['sample_points']:
        therapeutic_score_opt_iter = []
        for smiles in smiles_array:
            if smiles: 
                therapeutic_score_opt_iter.append(chem_props_therap[smiles])
        therapeutic_score_opt_list.append(therapeutic_score_opt_iter)
        
    smiles_list=[]
    with open("./data/chem/orig_model/train.txt") as f:
        smiles_list.extend([s.strip() for s in f.readlines()])
        
    with open("./data/chem/orig_model/therapeutic_score_{}.pkl".format(pathway_model), "rb") as f:
        train_chem_props_therap = pickle.load(f)
        
    train_data_therap = np.array([train_chem_props_therap[s] for s in smiles_list])
    
    violin_therap_data = {}

    violin_therap_data[f"k_{k}_train"] = []
    violin_therap_data[f"k_{k}_train"] += list(train_data_therap)

    for idx in np.arange(11):
        violin_therap_data[f"k_{k}_{idx}"] = []
        violin_therap_data[f"k_{k}_{idx}"] += therapeutic_score_opt_list[idx]
        
    return violin_therap_data

In [None]:
gen_property_dir = "gen-mols-property" 
sample_result_dir = "sample-results"

In [None]:
fig, (ax, bx, cx) = plt.subplots(3, 1, figsize=(24, 9), sharex = True)

plt.subplots_adjust(wspace=0, hspace=0)

# --------------------------- #
pathway_model = "viable"
# --------------------------- #

colors = ["tab:red"] * 12
hatches = ["\\\\"] + ["--"] * 11
violin_therap_data = get_data_dict(k=4, pathway_model=pathway_model, gen_property_dir=gen_property_dir, sample_result_dir=sample_result_dir)
data_all = [violin_therap_data[t] for t in violin_therap_data]

inds = 4*np.arange(11+1) + 1.0 #0.8
# inds = 3*np.arange(11+1) + 0.75 #0.8
# inds = np.arange(11+1)
vplot = ax.violinplot(data_all, showmeans=False, showmedians=False, showextrema=False, positions=inds, widths=0.9)

for patch, color, hatch in zip(vplot['bodies'], colors, hatches): 
    patch.set_color(color)
    patch.set_hatch(hatch)
    patch.set_edgecolor('black')
    patch.set_alpha(1.0) # 1.0
    patch.set_linewidth(0.5) # 0.5
    patch.set_zorder(2)

# Calculate quartiles for uneven array sizes
calc_data_percentile = lambda p : np.array([np.percentile(d, p) for d in data_all] )

quartile1, medians, quartile3 = calc_data_percentile(25), calc_data_percentile(50), calc_data_percentile(75)
whiskers = np.array([
    adjacent_values(sorted_array, q1, q3)
    for sorted_array, q1, q3 in zip(data_all, quartile1, quartile3)])
whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]

# ax.set_ylim([2000, 16000])
ax.scatter(inds, medians, marker='o', color='white', s=0.95, zorder=3)
ax.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=1.5)
ax.grid(axis='y', zorder=0.0)
ax.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

# --------------------------- #

colors = ["tab:green"] * 12
hatches = ["\\\\"] + ["--"] * 11
violin_therap_data = get_data_dict(k=5, pathway_model=pathway_model, gen_property_dir=gen_property_dir, sample_result_dir=sample_result_dir)
data_all = [violin_therap_data[t] for t in violin_therap_data]

inds = 4*np.arange(11+1) #+ 0.5 #0.8
# inds = 3*np.arange(11+1) #- 0.3 #0.8
# inds = np.arange(11+1)
vplot = ax.violinplot(data_all, showmeans=False, showmedians=False, showextrema=False, positions=inds, widths=0.9)

for patch, color, hatch in zip(vplot['bodies'], colors, hatches): 
    patch.set_color(color)
    patch.set_hatch(hatch)
    patch.set_edgecolor('black')
    patch.set_alpha(1.0) # 1.0
    patch.set_linewidth(0.5) # 0.5
    patch.set_zorder(2)

# Calculate quartiles for uneven array sizes
calc_data_percentile = lambda p : np.array([np.percentile(d, p) for d in data_all] )

quartile1, medians, quartile3 = calc_data_percentile(25), calc_data_percentile(50), calc_data_percentile(75)
whiskers = np.array([
    adjacent_values(sorted_array, q1, q3)
    for sorted_array, q1, q3 in zip(data_all, quartile1, quartile3)])
whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]

# ax.set_ylim([2000, 16000])
ax.scatter(inds, medians, marker='o', color='white', s=0.95, zorder=3)
ax.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=1.5)
ax.grid(axis='y', zorder=0.0)
ax.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

# --------------------------- #

colors = ["tab:blue"] * 12
hatches = ["\\\\"] + ["--"] * 11
violin_therap_data = get_data_dict(k=6, pathway_model=pathway_model, gen_property_dir=gen_property_dir, sample_result_dir=sample_result_dir)
data_all = [violin_therap_data[t] for t in violin_therap_data]

inds = 4*np.arange(11+1) - 1.0 #0.8
# inds = 3*np.arange(11+1) - 0.75 #0.8
# inds = np.arange(11+1)
vplot = ax.violinplot(data_all, showmeans=False, showmedians=False, showextrema=False, positions=inds, widths=0.9)

for patch, color, hatch in zip(vplot['bodies'], colors, hatches): 
    patch.set_color(color)
    patch.set_hatch(hatch)
    patch.set_edgecolor('black')
    patch.set_alpha(1.0) # 1.0
    patch.set_linewidth(0.5) # 0.5
    patch.set_zorder(2)

# Calculate quartiles for uneven array sizes
calc_data_percentile = lambda p : np.array([np.percentile(d, p) for d in data_all] )

quartile1, medians, quartile3 = calc_data_percentile(25), calc_data_percentile(50), calc_data_percentile(75)
whiskers = np.array([
    adjacent_values(sorted_array, q1, q3)
    for sorted_array, q1, q3 in zip(data_all, quartile1, quartile3)])
whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]

ax.set_ylim([2000, 13500])
ax.scatter(inds, medians, marker='o', color='white', s=0.95, zorder=3)
ax.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=1.5)
ax.grid(axis='y', zorder=0.0)
ax.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

# --------------------------- #
pathway_model = 'modified'
# --------------------------- #

colors = ["tab:red"] * 12
hatches = ["\\\\"] + ["--"] * 11
violin_therap_data = get_data_dict(k=4, pathway_model=pathway_model, gen_property_dir=gen_property_dir, sample_result_dir=sample_result_dir)
data_all = [violin_therap_data[t] for t in violin_therap_data]

inds = 4*np.arange(11+1) + 1.0 #0.8
# inds = 3*np.arange(11+1) + 0.75 #0.8
# inds = np.arange(11+1)
vplot = bx.violinplot(data_all, showmeans=False, showmedians=False, showextrema=False, positions=inds, widths=0.9)

for patch, color, hatch in zip(vplot['bodies'], colors, hatches): 
    patch.set_color(color)
    patch.set_hatch(hatch)
    patch.set_edgecolor('black')
    patch.set_alpha(1.0) # 1.0
    patch.set_linewidth(0.5) # 0.5
    patch.set_zorder(2)

# Calculate quartiles for uneven array sizes
calc_data_percentile = lambda p : np.array([np.percentile(d, p) for d in data_all] )

quartile1, medians, quartile3 = calc_data_percentile(25), calc_data_percentile(50), calc_data_percentile(75)
whiskers = np.array([
    adjacent_values(sorted_array, q1, q3)
    for sorted_array, q1, q3 in zip(data_all, quartile1, quartile3)])
whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]

# bx.set_ylim([2000, 12350])
bx.scatter(inds, medians, marker='o', color='white', s=0.95, zorder=3)
bx.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=1.5)
bx.grid(axis='y', zorder=0.0)
bx.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

# --------------------------- #

colors = ["tab:green"] * 12
hatches = ["\\\\"] + ["--"] * 11
violin_therap_data = get_data_dict(k=5, pathway_model=pathway_model, gen_property_dir=gen_property_dir, sample_result_dir=sample_result_dir)
data_all = [violin_therap_data[t] for t in violin_therap_data]

inds = 4*np.arange(11+1) #+ 0.5 #0.8
# inds = 3*np.arange(11+1) #- 0.3 #0.8
# inds = np.arange(11+1)
vplot = bx.violinplot(data_all, showmeans=False, showmedians=False, showextrema=False, positions=inds, widths=0.9)

for patch, color, hatch in zip(vplot['bodies'], colors, hatches): 
    patch.set_color(color)
    patch.set_hatch(hatch)
    patch.set_edgecolor('black')
    patch.set_alpha(1.0) # 1.0
    patch.set_linewidth(0.5) # 0.5
    patch.set_zorder(2)

# Calculate quartiles for uneven array sizes
calc_data_percentile = lambda p : np.array([np.percentile(d, p) for d in data_all] )

quartile1, medians, quartile3 = calc_data_percentile(25), calc_data_percentile(50), calc_data_percentile(75)
whiskers = np.array([
    adjacent_values(sorted_array, q1, q3)
    for sorted_array, q1, q3 in zip(data_all, quartile1, quartile3)])
whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]

# bx.set_ylim([2000, 12350])
bx.scatter(inds, medians, marker='o', color='white', s=0.95, zorder=3)
bx.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=1.5)
bx.grid(axis='y', zorder=0.0)
bx.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

# --------------------------- #

colors = ["tab:blue"] * 12
hatches = ["\\\\"] + ["--"] * 11
violin_therap_data = get_data_dict(k=6, pathway_model=pathway_model, gen_property_dir=gen_property_dir, sample_result_dir=sample_result_dir)
data_all = [violin_therap_data[t] for t in violin_therap_data]

inds = 4*np.arange(11+1) - 1.0 #0.8
# inds = 3*np.arange(11+1) - 0.75 #0.8
# inds = np.arange(11+1)
vplot = bx.violinplot(data_all, showmeans=False, showmedians=False, showextrema=False, positions=inds, widths=0.9)

for patch, color, hatch in zip(vplot['bodies'], colors, hatches): 
    patch.set_color(color)
    patch.set_hatch(hatch)
    patch.set_edgecolor('black')
    patch.set_alpha(1.0) # 1.0
    patch.set_linewidth(0.5) # 0.5
    patch.set_zorder(2)

# Calculate quartiles for uneven array sizes
calc_data_percentile = lambda p : np.array([np.percentile(d, p) for d in data_all] )

quartile1, medians, quartile3 = calc_data_percentile(25), calc_data_percentile(50), calc_data_percentile(75)
whiskers = np.array([
    adjacent_values(sorted_array, q1, q3)
    for sorted_array, q1, q3 in zip(data_all, quartile1, quartile3)])
whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]

bx.set_ylim([2000, 13500])
bx.scatter(inds, medians, marker='o', color='white', s=0.95, zorder=3)
bx.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=1.5)
bx.grid(axis='y', zorder=0.0)
bx.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

# --------------------------- #
pathway_model = "impractical"
# --------------------------- #

colors = ["tab:red"] * 12
hatches = ["\\\\"] + ["--"] * 11
violin_therap_data = get_data_dict(k=4, pathway_model=pathway_model, gen_property_dir=gen_property_dir, sample_result_dir=sample_result_dir)
data_all = [violin_therap_data[t] for t in violin_therap_data]

inds = 4*np.arange(11+1) + 1.0 #0.8
# inds = 3*np.arange(11+1) + 0.75 #0.8
# inds = np.arange(11+1)
vplot = cx.violinplot(data_all, showmeans=False, showmedians=False, showextrema=False, positions=inds, widths=0.9)

for patch, color, hatch in zip(vplot['bodies'], colors, hatches): 
    patch.set_color(color)
    patch.set_hatch(hatch)
    patch.set_edgecolor('black')
    patch.set_alpha(1.0) # 1.0
    patch.set_linewidth(0.5) # 0.5
    patch.set_zorder(2)

# Calculate quartiles for uneven array sizes
calc_data_percentile = lambda p : np.array([np.percentile(d, p) for d in data_all] )

quartile1, medians, quartile3 = calc_data_percentile(25), calc_data_percentile(50), calc_data_percentile(75)
whiskers = np.array([
    adjacent_values(sorted_array, q1, q3)
    for sorted_array, q1, q3 in zip(data_all, quartile1, quartile3)])
whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]

# bx.set_ylim([2000, 12350])
cx.scatter(inds, medians, marker='o', color='white', s=0.95, zorder=3)
cx.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=1.5)
cx.grid(axis='y', zorder=0.0)
cx.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

# --------------------------- #

colors = ["tab:green"] * 12
hatches = ["\\\\"] + ["--"] * 11
violin_therap_data = get_data_dict(k=5, pathway_model=pathway_model, gen_property_dir=gen_property_dir, sample_result_dir=sample_result_dir)
data_all = [violin_therap_data[t] for t in violin_therap_data]

inds = 4*np.arange(11+1) #+ 0.5 #0.8
# inds = 3*np.arange(11+1) #- 0.3 #0.8
# inds = np.arange(11+1)
vplot = cx.violinplot(data_all, showmeans=False, showmedians=False, showextrema=False, positions=inds, widths=0.9)

for patch, color, hatch in zip(vplot['bodies'], colors, hatches): 
    patch.set_color(color)
    patch.set_hatch(hatch)
    patch.set_edgecolor('black')
    patch.set_alpha(1.0) # 1.0
    patch.set_linewidth(0.5) # 0.5
    patch.set_zorder(2)

# Calculate quartiles for uneven array sizes
calc_data_percentile = lambda p : np.array([np.percentile(d, p) for d in data_all] )

quartile1, medians, quartile3 = calc_data_percentile(25), calc_data_percentile(50), calc_data_percentile(75)
whiskers = np.array([
    adjacent_values(sorted_array, q1, q3)
    for sorted_array, q1, q3 in zip(data_all, quartile1, quartile3)])
whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]

# bx.set_ylim([2000, 12350])
cx.scatter(inds, medians, marker='o', color='white', s=0.95, zorder=3)
cx.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=1.5)
cx.grid(axis='y', zorder=0.0)
cx.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

# --------------------------- #

colors = ["tab:blue"] * 12
hatches = ["\\\\"] + ["--"] * 11
violin_therap_data = get_data_dict(k=6, pathway_model=pathway_model, gen_property_dir=gen_property_dir, sample_result_dir=sample_result_dir)
data_all = [violin_therap_data[t] for t in violin_therap_data]

inds = 4*np.arange(11+1) - 1.0 #0.8
# inds = 3*np.arange(11+1) - 0.75 #0.8
# inds = np.arange(11+1)
vplot = cx.violinplot(data_all, showmeans=False, showmedians=False, showextrema=False, positions=inds, widths=0.9)

for patch, color, hatch in zip(vplot['bodies'], colors, hatches): 
    patch.set_color(color)
    patch.set_hatch(hatch)
    patch.set_edgecolor('black')
    patch.set_alpha(1.0) # 1.0
    patch.set_linewidth(0.5) # 0.5
    patch.set_zorder(2)

# Calculate quartiles for uneven array sizes
calc_data_percentile = lambda p : np.array([np.percentile(d, p) for d in data_all] )

quartile1, medians, quartile3 = calc_data_percentile(25), calc_data_percentile(50), calc_data_percentile(75)
whiskers = np.array([
    adjacent_values(sorted_array, q1, q3)
    for sorted_array, q1, q3 in zip(data_all, quartile1, quartile3)])
whiskersMin, whiskersMax = whiskers[:, 0], whiskers[:, 1]

cx.set_ylim([12000, 13200])
cx.scatter(inds, medians, marker='o', color='white', s=0.95, zorder=3)
cx.vlines(inds, quartile1, quartile3, color='k', linestyle='-', lw=1.5)
cx.grid(axis='y', zorder=0.0)
cx.tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)

# --------------------------- #
# --------------------------- #

plt.sca(cx)
plt.xlabel(r"Retraining Iteration")

xticks = ['Train']
xticks.extend(np.arange(10+1))

plt.xticks(4*np.arange(11+1), xticks)
# plt.xticks(3*np.arange(11+1), xticks)
# plt.xticks(np.arange(11+1), xticks)

from matplotlib.lines import Line2D

custom_lines = [
#     Line2D([0], [0], color='orange', lw=4),
    Line2D([0], [0], color="tab:red", lw=4),
    Line2D([0], [0], color="tab:green", lw=4),
    Line2D([0], [0], color="tab:blue", lw=4),
]

fig.legend(custom_lines, 
           [
               r'k=4',
               r'k=5',
               r'k=6',
           ],
           ncol=3,
           loc='upper center'
          )

plt.sca(bx)
plt.ylabel('Therapeutic Score')
# ================== #

fig.text(0.875, 0.85, '(a)', fontsize=22)#, bbox=dict(facecolor='red', alpha=0.5))
fig.text(0.875, 0.595, '(b)', fontsize=22)#, bbox=dict(facecolor='red', alpha=0.5))
fig.text(0.875, 0.34, '(c)', fontsize=22)#, bbox=dict(facecolor='red', alpha=0.5))

# plt.sca(bx)
# plt.ylabel('Therapeutic Score')

# plt.tight_layout()
plt.savefig(f"result.pdf", dpi = 500, bbox_inches='tight')