In [1]:
import biotite.sequence as biotite_seq
import biotite.sequence.align as align
substitution_matrix = align.SubstitutionMatrix.std_protein_matrix()
import glob
import pandas as pd
import wandb
import itertools
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from matplotlib.cm import coolwarm
from matplotlib.colors import Normalize

In [2]:
api = wandb.Api()
eps = 1e-3

In [6]:

def find_pkl_file(directory):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith('.pkl'):
                return os.path.join(root, file)
    return None

def get_diversity(seqs):
    sample_states1 = torch.tensor(seqs)
    sample_states2 = sample_states1.clone()
    dist_matrix = torch.cdist(sample_states1, sample_states2, p=2)
    dist_upper_triangle = torch.triu(dist_matrix, diagonal=1)
    dist_vector = dist_upper_triangle[dist_upper_triangle != 0]
    return dist_vector

# def get_novelty(dataset_seqs, sampled_seqs):
#     sampled_seqs = [biotite_seq.ProteinSequence(seq) for seq in sampled_seqs]
#     dataset_seqs = [biotite_seq.ProteinSequence(seq) for seq in dataset_seqs]
#     min_dists = []
#     for sample in sampled_seqs:
#         dists = []
#         sample_repeated = itertools.repeat(sample, len(dataset_seqs))
#         for s_0, x_0 in zip(sample_repeated, dataset_seqs):
#              alignment = align.align_optimal(s_0, x_0, substitution_matrix, local=False, max_number=1)[0]
#              dists.append(align.get_sequence_identity(alignment))
#         min_dists.append(min(dists))
#     min_dists = torch.FloatTensor(min_dists)
#     return torch.mean(min_dists)

In [7]:
# HARTMANN
oracle_maximize = True
k = 10
AL_BATCH_SIZE = 10

sf_logdir1 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-05_01-08-33"
sf_run_path1 = "alexhg/Hartmann budget data set/nioe5ca4"
# 2 and 3 were of different seeds but started at the same time so logged to the same folder
sf_logdir2 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-04_06-27-13"
sf_run_path2 = "alexhg/Hartmann budget data set/u3jlu0gs"
sf_logdir3 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-04_06-27-13"
sf_run_path3 = "nikita0209/AMP-DKL/zgoc6a5q"
sf_logdirs = [sf_logdir1, sf_logdir2, sf_logdir3]
sf_run_paths = [sf_run_path1, sf_run_path2, sf_run_path3]

# mf_train_dataset = "/home/mila/n/nikita.saxena/activelearning/storage/amp/mf/data_train.csv"
# mf_test_dataset = "/home/mila/n/nikita.saxena/activelearning/storage/amp/mf/data_test.csv"
mf_logdir1 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-05_10-51-53"
mf_run_path1 = "alexhg/Hartmann budget data set/bytcboio"
mf_logdir2 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-04_06-27-11"
mf_run_path2 = "alexhg/Hartmann budget data set/n6i5yw72"
mf_logdir3 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-04_06-27-12"
mf_run_path3 = "alexhg/Hartmann budget data set/wkimrg0l"
mf_logdirs = [mf_logdir1, mf_logdir2, mf_logdir3]
mf_run_paths = [mf_run_path1, mf_run_path2, mf_run_path3]

In [None]:
# BRANIN
oracle_maximize = False
k = 50
AL_BATCH_SIZE = 30

sf_logdir1 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-03_06-35-40"
sf_run_path1 = "alexhg/Branin budget data set/p6nu2nzx"
sf_logdir2 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-03_06-32-45"
sf_run_path2 = "alexhg/Branin budget data set/yvk5tox2"
sf_logdir3 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-03_18-49-33"
sf_run_path3 = "alexhg/Branin budget data set/27japxej"
sf_logdirs = [sf_logdir1, sf_logdir2, sf_logdir3]
sf_run_paths = [sf_run_path1, sf_run_path2, sf_run_path3]

mf_logdir1 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-04_22-24-48"
mf_run_path1 = "alexhg/Branin budget data set/cijnbh4w"
mf_logdir2 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-03_06-41-16"
mf_run_path2 = "alexhg/Branin budget data set/lwzlgxvj"
mf_logdir3 = "/network/scratch/h/hernanga/logs/gflownet/2023-05-03_06-41-50"
mf_run_path3 = "alexhg/Branin budget data set/g557nx3y"
mf_logdirs = [mf_logdir1, mf_logdir2, mf_logdir3]
mf_run_paths = [mf_run_path1, mf_run_path2, mf_run_path3]

In [9]:
def get_performance(logdir, run_path, oracle_maximize, is_mf=False, eps=1e-3):

    pkl_file = find_pkl_file(logdir)
    culm_pkl = pd.read_pickle(pkl_file)
    culm_samples = culm_pkl['cumulative_sampled_samples']
    culm_energies = culm_pkl['cumulative_sampled_energies']


    metric_diversity = []
    metric_energy = []
    metric_cost = []
    # mean_energy_from_wandb = run.history(keys=["mean_energy_top{}".format(k)])
    # mean_energy_from_wandb = mean_energy_from_wandb["mean_energy_top{}".format(k)].values
    run = api.run(run_path)
    post_al_cum_cost = run.history(keys=["post_al_cum_cost"])
    post_al_cum_cost = np.unique(post_al_cum_cost['post_al_cum_cost'])

    steps = np.arange(start = AL_BATCH_SIZE, stop = len(culm_samples), step = AL_BATCH_SIZE, dtype=int)
    for idx, upper_bound in enumerate(steps):
        culm_samples_curr_iter = culm_samples[0:upper_bound]
        culm_sampled_energies_curr_iter = culm_energies[0:upper_bound]

        idx_topk = torch.argsort(culm_sampled_energies_curr_iter, descending=oracle_maximize)[:k].tolist()
        samples_topk = [culm_samples_curr_iter[i] for i in idx_topk]
        energies_topk = [culm_sampled_energies_curr_iter[i] for i in idx_topk]
        mean_energy_topk = torch.mean(torch.FloatTensor(energies_topk))
        # diff = abs(mean_energy_topk-mean_energy_from_wandb[idx])
        # if diff>eps:
            # print("ERROR: energy from wandb does not match for the {}th iteration".format(idx))
        metric_energy.append(mean_energy_topk.numpy())
        mean_diversity_topk = get_diversity(samples_topk)
        metric_diversity.append(mean_diversity_topk.numpy())
        metric_cost.append(post_al_cum_cost[idx])

    # PLOT METRICS
    reward = np.array(metric_energy)
    diversity = np.array(metric_diversity)
    cost = np.array(metric_cost)

    return reward, diversity, cost

In [11]:
# for i in range(3):
#     setattr(sys.modules[__name__], "sf{}".format(i+1), 32+i)

In [None]:
# iterate over 1, 2, 3 and create three variables sf1, sf2, sf3
import sys
max_iter = 0
for i in range(3):
    setattr(sys.modules[__name__], "sf_tuple{}".format(i+1), get_performance(sf_logdirs[i], sf_run_paths[i], oracle_maximize=False))
    setattr(sys.modules[__name__], "mf_tuple{}".format(i+1), get_performance(sf_logdirs[i], sf_run_paths[i], oracle_maximize=False))
    max_iter = max(max_iter, len(getattr(sys.modules[__name__], "sf_tuple{}".format(i+1))[0]))

for i in range(3):
    for j in range(3):
        sf_tuple = getattr(sys.modules[__name__], "sf_tuple{}".format(i+1))
        mf_tuple = getattr(sys.modules[__name__], "mf_tuple{}".format(j+1))
        sf_reward = sf_tuple[0]
        sf_diversity = sf_tuple[1]
        sf_cost = sf_tuple[2]
        mf_reward = mf_tuple[0]
        mf_diversity = mf_tuple[1]
        mf_cost = mf_tuple[2]
        sf_reward = np.append(sf_reward, np.repeat(sf_reward[-1], max_iter-len(sf_reward)))
        sf_diversity = np.append(sf_diversity, np.repeat(sf_diversity[-1], max_iter-len(sf_diversity)))
        sf_cost = np.append(sf_cost, np.repeat(sf_cost[-1], max_iter-len(sf_cost)))
        mf_reward = np.append(mf_reward, np.repeat(mf_reward[-1], max_iter-len(mf_reward)))
        mf_diversity = np.append(mf_diversity, np.repeat(mf_diversity[-1], max_iter-len(mf_diversity)))
        mf_cost = np.append(mf_cost, np.repeat(mf_cost[-1], max_iter-len(mf_cost)))
        setattr(sys.modules[__name__], "sf_tuple{}".format(i+1), (sf_reward, sf_diversity, sf_cost))
        setattr(sys.modules[__name__], "mf_tuple{}".format(j+1), (mf_reward, mf_diversity, mf_cost))


In [None]:
metrics = ['reward', 'diversity', 'cost']
for metric in metrics:
    setattr(sys.modules[__name__], "cum_sf_{}".format(metric), np.stack([getattr(sys.modules[__name__], "sf_tuple{}".format(i+1))[metrics.index(metric)] for i in range(3)], axis=0))
    setattr(sys.modules[__name__], "sf_{}".format(metric), np.mean(getattr(sys.modules[__name__], "cum_sf_{}".format(metric)), axis=0))
    setattr(sys.modules[__name__], "std_sf_{}".format(metric), np.std(getattr(sys.modules[__name__], "cum_sf_{}".format(metric)), axis=0))
    setattr(sys.modules[__name__], "cum_mf_{}".format(metric), np.stack([getattr(sys.modules[__name__], "mf_tuple{}".format(i+1))[metrics.index(metric)] for i in range(3)], axis=0))
    setattr(sys.modules[__name__], "mf_{}".format(metric), np.mean(getattr(sys.modules[__name__], "cum_mf_{}".format(metric)), axis=0))
    setattr(sys.modules[__name__], "std_mf_{}".format(metric), np.std(getattr(sys.modules[__name__], "cum_mf_{}".format(metric)), axis=0))

# cum_sf_reward = np.stack([sf_tuple1[0], sf_tuple2[0], sf_tuple3[0]], axis=0)
# avg_sf_reward = np.mean(cum_sf_reward, axis=0)
# std_sf_reward = np.std(cum_sf_reward, axis=0)

# cum_sf_diversity = np.stack([sf_tuple1[1], sf_tuple2[1], sf_tuple3[1]], axis=0)
# avg_sf_diversity = np.mean(cum_sf_diversity, axis=0)
# std_sf_diversity = np.std(cum_sf_diversity, axis=0)

# cum_sf_cost = np.stack([sf_tuple1[2], sf_tuple2[2], sf_tuple3[2]], axis=0)
# avg_sf_cost = np.mean(cum_sf_cost, axis=0)
# std_sf_cost = np.std(cum_sf_cost, axis=0)

# cum_mf_reward = np.stack([mf_tuple1[0], mf_tuple2[0], mf_tuple3[0]], axis=0)
# avg_mf_reward = np.mean(cum_mf_reward, axis=0)
# std_mf_reward = np.std(cum_mf_reward, axis=0)

# cum_mf_diversity = np.stack([mf_tuple1[1], mf_tuple2[1], mf_tuple3[1]], axis=0)
# avg_mf_diversity = np.mean(cum_mf_diversity, axis=0)
# std_mf_diversity = np.std(cum_mf_diversity, axis=0)

# cum_mf_cost = np.stack([mf_tuple1[3], mf_tuple2[3], mf_tuple3[3]], axis=0)
# avg_mf_cost = np.mean(cum_mf_cost, axis=0)
# std_mf_cost = np.std(cum_mf_cost, axis=0)

In [None]:
# Create a single subplot
fig, ax = plt.subplots()

# Plot the two line curves
mf_plot = ax.plot(mf_cost, mf_reward, color='blue', label='Multi-Fidelity')
sf_plot = ax.plot(sf_cost, sf_reward, color='red', label='Highest-Fidelity')
ax.fill_between(sf_cost, sf_reward-std_sf_reward, sf_reward+std_sf_reward, alpha=0.2, color='blue')
ax.fill_between(mf_cost, mf_reward-std_mf_reward, mf_reward+std_mf_reward, alpha=0.2, color='red')

# Set the title and axis labels
ax.set_title('Hartmann')
ax.set_xlabel('Cost')
ax.set_ylabel('Top{} Reward'.format(k))
# add grid
ax.grid(True, linestyle='--')
# convert cost labels to exponential
ax.ticklabel_format(axis='x', style='sci', scilimits=(0,0))

# Add a legend
ax.legend(loc='lower right')

# Create a single ScalarMappable object for the colorbar
div_norm = Normalize(vmin=np.min([sf_diversity, mf_diversity]), vmax=np.max([sf_diversity, mf_diversity]))
div_cmap = coolwarm
div_cmap = 'viridis'
sm = plt.cm.ScalarMappable(cmap=div_cmap, norm=div_norm)
# sm = plt.cm.ScalarMappable(cmap='viridis', norm=div_norm)


# Add scatter plots for both line curves and shade the points by diversity
mf_scatter = ax.scatter(mf_cost, mf_reward, c=mf_diversity, cmap=div_cmap, norm=div_norm, marker='^')
sf_scatter = ax.scatter(sf_cost, sf_reward, c=sf_diversity, cmap=div_cmap, norm=div_norm)

# Add a colorbar for the ScalarMappable object
cbar = fig.colorbar(sm)

# Set the label for the colorbar
cbar.ax.set_ylabel('Diversity')


# Display the plot
plt.show()