In [28]:
import sys
sys.path.insert(0, "../")
from chunkgfn.datamodules.selfies_sequence import SELFIESSequenceModule
from chunkgfn.gfn.tb_gfn import TBGFN
from chunkgfn.gfn.tb_gfn_variable import TBGFN_Variable
from chunkgfn.gfn.tb_gfn_chunk_replacement import TBGFN_Chunk_Replacement

from torch.distributions import Categorical
from scipy.stats import linregress, spearmanr

import torch
import lightning as L
from tqdm.notebook import tqdm
from einops import repeat, rearrange
import pandas as pd
import matplotlib.pyplot as plt
from polyleven import levenshtein

from wandb.proto import wandb_internal_pb2
from wandb.sdk.internal import datastore
import selfies as sf
import rdkit


In [2]:
def get_samples(jobid, seed, gfn_approach, n_samples=2**10):
    dm = SELFIESSequenceModule.load_from_checkpoint(f"/network/scratch/o/oussama.boussif/chunkgfn/logs/selfies/runs/{jobid}/checkpoints/last.ckpt")
    if gfn_approach == "chunk_normal":
        gfn = TBGFN_Variable.load_from_checkpoint(f"/network/scratch/o/oussama.boussif/chunkgfn/logs/selfies/runs/{jobid}/checkpoints/last.ckpt")
    elif gfn_approach == "chunk_replacement":
        gfn = TBGFN_Chunk_Replacement.load_from_checkpoint(f"/network/scratch/o/oussama.boussif/chunkgfn/logs/selfies/runs/{jobid}/checkpoints/last.ckpt")
    elif gfn_approach == "no_chunk":
        gfn = TBGFN.load_from_checkpoint(f"/network/scratch/o/oussama.boussif/chunkgfn/logs/selfies/runs/{jobid}/checkpoints/last.ckpt")
    dm.action_frequency = torch.zeros(len(dm.actions))
    batch_size = n_samples

    L.seed_everything(seed)
    with torch.no_grad():
    
        s0 = dm.s0.to(gfn.device)
        state = repeat(s0, " ... -> b ...", b=batch_size)
        bs = state.shape[0]

        # Start unrolling the trajectories
        actions = []
        trajectories = []
        dones = []
        done = torch.zeros((bs)).to(state).bool()
        trajectory_length = (
            torch.zeros((bs)).to(state).long()
        )  # This tracks the length of trajetcory for each sample in the batch

        while not done.all():
            action_indices = dm.action_indices
            library_embeddings = []
            for action, indices in action_indices.items():
                library_embeddings.append(
                    gfn.action_model(
                        torch.LongTensor(indices).to(gfn.device).unsqueeze(0)
                    )
                )
            library_embeddings = torch.cat(library_embeddings, dim=0)
            action_embedding = gfn.forward_model(
                dm.preprocess_states(state)
            )
            dim = action_embedding.shape[-1]
            p_f_s = torch.einsum("bd, nd -> bn", action_embedding, library_embeddings) / (
                dim**0.5
            )  # Same as in softmax

            uniform_dist_probs = torch.ones_like(p_f_s).to(p_f_s)

            valid_actions_mask = dm.get_forward_mask(state)

            p_f_s = torch.where(
                valid_actions_mask,
                p_f_s,
                torch.tensor(-1e6).to(p_f_s),
            )
            uniform_dist_probs = torch.where(
                valid_actions_mask,
                uniform_dist_probs,
                torch.tensor(0.0).to(uniform_dist_probs),
            )

            cat = Categorical(logits=p_f_s)

            act = cat.sample()

            new_state, done = dm.forward_step(state, act)
            trajectory_length += ~done  # Increment the length of the trajectory for each sample in the batch as long it's not done.

            actions.append(act)
            trajectories.append(state)
            dones.append(done.clone())

            state = new_state.clone()


        logreward = dm.compute_logreward(state).to(
                state.device
        )
        trajectories.append(state)
        dones.append(torch.ones((bs)).to(state).bool())
        trajectories = torch.stack(trajectories, dim=1)
        actions = torch.stack(actions, dim=1)
        dones = torch.stack(dones, dim=1)
    return logreward, state, actions, trajectories, dones

In [None]:
def to_strings(states: torch.Tensor) -> list[str]:
    """Convert the states to raw data.
    Args:
        states (torch.Tensor[batch_size, max_len, dim]): Batch of states.
    Returns:
        raw (list[str]): List of states in their string representation.
    """
    atomic_tokens = ["A", "C", "G", "U"]
    padding_token = -torch.ones(len(atomic_tokens))
    strings = []
    for state in states.cpu():
        # Cut the state before it arrives at [-1,-1,...]
        nonzero = (state == padding_token).nonzero()
        if len(nonzero) > 0:
            state = state[: nonzero[0][0]]

        indices = state.argmax(dim=-1)
        strings.append("".join([atomic_tokens[i] for i in indices]))
    return strings

In [17]:
logreward, state, actions, trajectories, dones = get_samples(4703138, 2024, "chunk_normal", n_samples=2**14)
print(torch.topk(logreward, k=100).values.exp().median())
logreward, state, actions, trajectories, dones = get_samples(4703129, 1998, "chunk_normal", n_samples=2**14)
print(torch.topk(logreward, k=100).values.exp().median())
logreward, state, actions, trajectories, dones = get_samples(4703147, 42, "chunk_normal", n_samples=2**14)
print(torch.topk(logreward, k=100).values.exp().median())

Seed set to 2024


tensor(0.6092, device='cuda:0')


Seed set to 1998


tensor(0.6105, device='cuda:0')


Seed set to 42


tensor(0.6111, device='cuda:0')


In [19]:
dm = SELFIESSequenceModule.load_from_checkpoint(f"/network/scratch/o/oussama.boussif/chunkgfn/logs/selfies/runs/{4703147}/checkpoints/last.ckpt")

In [45]:
selfies_strings = [sf.decoder(dm._string_to_selfie(s.replace("<EOS>", ""))) for s in dm.to_strings(state[torch.topk(logreward, k=100).indices])]

In [46]:
[(s,rdkit.Chem.QED.qed(rdkit.Chem.MolFromSmiles(s))) for s in selfies_strings]

[('BrS[C@H1]/C=N[C@@H1][C@H1]Br', 0.7058209190176709),
 ('Br[C@H1]P(O)SP', 0.6991817040419113),
 ('Br[C@@H1]/CS(Br)[C@H1]', 0.6956880832428124),
 ('BrO[C@@H1]OBr', 0.6829070479362431),
 ('BrOC[N+1]C=CBr', 0.6768116944275369),
 ('BrC[C@@H1](O)[N+1]Br', 0.6766100579045485),
 ('[NH1+1][C@@H1]=C(SBr)SF', 0.6740504678815359),
 ('[NH1+1]/CCS(Br)Br', 0.6716175674409508),
 ('BrN[C@@H1][C@@H1]Br', 0.6685518401149211),
 ('BrP[C@H1]SBr', 0.6643524697412498),
 ('Br[N+1][C@@H1]OBr', 0.6634426989084399),
 ('Br[C@@H1][NH1+1][C@@H1]=C(Br)[NH1]', 0.6610619953065637),
 ('Br[C@@H1]=C[C@@H1][C@H1]Br', 0.6592862236472985),
 ('Br[C@H1]P=NPN', 0.6582621068702441),
 ('N(P)PPBr', 0.6578608758938106),
 ('[NH1]P(Br)NC=CF', 0.6564924778698977),
 ('O[C@H1][C@@H1]SBr', 0.6545925023295471),
 ('BrOP[N+1]Br', 0.6520067433183583),
 ('[C@@H1](N)[C@@H1]SBr', 0.648360595706495),
 ('N[C@H1]S[C@@H1]Br', 0.648360595706495),
 ('N[C@@H1]/CSBr', 0.6467286435545225),
 ('CN[C@H1]SBr', 0.6457259053871923),
 ('O/C/CSBr', 0.64522514

In [53]:
chunks =[(sf.decoder(dm._string_to_selfie(a)), dm._string_to_selfie(a)) for a in dm.actions[31:]]
#chunks =[sf.decoder(dm._string_to_selfie(a)) for a in dm.actions[31:]]
el = chunks[-3][1]


one = sf.decoder(el)


In [51]:
[(s,rdkit.Chem.QED.qed(rdkit.Chem.MolFromSmiles(s))) for s in chunks]

[('SBr', 0.44841819337939126),
 ('BrS', 0.44841819337939126),
 ('[C@@H1]SBr', 0.47430278463097686),
 ('OBr', 0.4600586231697684),
 ('[C@H1]SBr', 0.47430278463097686),
 ('Br[C@@H1]', 0.41944927628282164),
 ('Br[C@H1]', 0.41944927628282164),
 ('BrO', 0.4600586231697684),
 ('BrO', 0.4600586231697684),
 ('Br[N+1]', 0.41325513140013265),
 ('BrN', 0.42962990575410526),
 ('BrS[C@@H1]SBr', 0.7226730225880351),
 ('BrN', 0.42962990575410526),
 ('BrS[C@H1]SBr', 0.7226730225880351),
 ('BrS[C@@H1]SBr', 0.7226730225880351),
 ('BrS[C@@H1]SBr', 0.7226730225880351),
 ('BrS[C@@H1]SBr', 0.7226730225880351),
 ('[C@@H1]=C', 0.36416909733079994),
 ('BrP', 0.4133081920445781)]

In [14]:
#logreward, state, actions, trajectories, dones = get_samples(4703145, 2024, "no_chunk", n_samples=2**14)
#print(torch.topk(logreward, k=100).values.exp().median())
#logreward, state, actions, trajectories, dones = get_samples(4703134, 1998, "no_chunk", n_samples=2**12)
#print(torch.topk(logreward, k=100).values.exp().median())
logreward, state, actions, trajectories, dones = get_samples(4703155, 42, "no_chunk", n_samples=2**14)
print(torch.topk(logreward, k=100).values.exp().median())

Seed set to 42


tensor(0.6473, device='cuda:0')


In [16]:
#logreward, state, actions, trajectories, dones = get_samples(4703142, 2024, "chunk_replacement", n_samples=2**14)
#print(torch.topk(logreward, k=100).values.exp().median())
logreward, state, actions, trajectories, dones = get_samples(4703134, 1998, "chunk_replacement", n_samples=2**14)
print(torch.topk(logreward, k=100).values.exp().median())
logreward, state, actions, trajectories, dones = get_samples(4703152, 42, "chunk_replacement", n_samples=2**14)
print(torch.topk(logreward, k=100).values.exp().median())

Seed set to 1998


tensor(0.6341, device='cuda:0')


Seed set to 42


tensor(0.6349, device='cuda:0')
