In [1]:
s = "MTSMTQSLREVIKAMTKARNFERVLGKITLVSAAPGKVICEMKVEEEHTNAIGTLHGGLTATLVDNISTMALLCTERGAPGVSVDMNITYMSPAKLGEDIVITAHVLKQGKTLAFTSVDLTNKATGKLIAQGRHTKHLGN"
len(s)

140

In [2]:
from utils import compute_reward
import torch
from env import CodonDesignEnv
from preprocessor import CodonSequencePreprocessor
from torchgfn.src.gfn.gflownet import TBGFlowNet
from torchgfn.src.gfn.modules import DiscretePolicyEstimator
from torchgfn.src.gfn.utils.modules import MLP
from torchgfn.src.gfn.samplers import Sampler
from utils import load_config
import numpy as np
import matplotlib.pyplot as plt

In [3]:
def is_pareto_efficient_3d(costs):
    """
    Determine Pareto-efficient points for 3 objectives.
    costs: array of shape (N, 3) with objectives:
        - (-CAI)  [we want to maximize CAI]
        -  MFE    [we want to minimize MFE]
        -  GC     [we want to maximize GC]
    Returns: Boolean mask of Pareto-efficient points
    """
    is_efficient = np.ones(costs.shape[0], dtype=bool)

    for i, c in enumerate(costs):

        if is_efficient[i]:

            # Remove dominated points
            is_efficient[is_efficient] = np.any(costs[is_efficient] > c, axis=1) | np.all(costs[is_efficient] == c, axis=1)
            is_efficient[i] = True
            
    return is_efficient

In [None]:
def plot_pairwise_scatter_with_pareto(results, x_key="CAI", y_key="MFE"):
    """
    Scatter plot of y_key vs x_key (e.g., MFE vs CAI) by config,
    with Pareto front computed using all 3 objectives: CAI (maximize), MFE (minimize), GC (minimize).
    """

    markers = ['o', 's', '^', 'D', 'v', 'P', 'X', '*']
    colors = plt.cm.tab10.colors

    all_points_2d = []
    all_points_3d = []

    plt.figure(figsize=(7, 6))

    for i, res in enumerate(results):
        xs = np.array(res["metrics"][x_key])
        ys = np.array(res["metrics"][y_key])
        gcs = np.array(res["metrics"]["GC"])  # GC content
        m = markers[i % len(markers)]
        c = colors[i % len(colors)]

        all_points_2d.extend(zip(xs, ys))
        all_points_3d.extend(zip(xs, ys, gcs))

        plt.scatter(xs, ys,
                    marker=m, color=c,
                    alpha=0.4,
                    label=res["name"],
                    edgecolors='none',
                    s=40)

    # Compute Pareto front (maximize CAI => minimize -CAI)
    cost_array = np.array([
        [-x, y, -gc] for x, y, gc in all_points_3d
    ])
    pareto_mask = is_pareto_efficient_3d(cost_array)
    pareto_points = np.array(all_points_2d)[pareto_mask]

    # Sort Pareto front for plotting (by x_key, e.g. CAI)
    pareto_sorted = pareto_points[np.argsort(pareto_points[:, 0])]

    # Plot Pareto front
    plt.plot(pareto_sorted[:, 0], pareto_sorted[:, 1], color="black", linestyle="--", label="Pareto Front", linewidth=1.5)
    plt.scatter(pareto_sorted[:, 0], pareto_sorted[:, 1], color="black", edgecolor='k', marker='X', s=60, label="Pareto Points")

    # Axes and labels
    plt.gca().invert_yaxis()  # Optional: since lower MFE is better
    plt.xlabel(x_key.upper())
    plt.ylabel(y_key.upper())
    plt.title(f"{y_key.upper()} vs {x_key.upper()} by Config\nwith Pareto Front (incl. GC)")
    plt.legend(loc='best', fontsize='small', framealpha=0.8)
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.savefig(f"{y_key.upper()}_vs_{x_key.upper()}_with_Pareto.png", dpi=300)
    plt.show()

In [1]:
samples = {'d' :1 , 'dd' :2}

In [3]:
list(samples.keys())

['d', 'dd']