In [36]:
# computing the number of PBs 
import numpy as np
def N(n, k):
    '''
    n : sequence length
    k : max token length
    '''
    if n == 0:
        return  1
    else : 
        return np.sum([N(i,k) for i in range(n - k, n)])


# The n-bonacci formula : 

Let n be the number of letters in a string $S$, and suppose we have : 
- a complete library with all possible subtoken of S
- This implies a GFN structure where all the parents 
. Let $N(S)$ be the number of backward trajectories that lead to $S$. Then : 

$$N(S) = \sum_{i = 0}^{n-1} N(i) \quad \text{Where}  \quad N(0) = 1$$

In [142]:
import numpy as np
import torch 
def compute_logN( terminal_string, alpha, library):
        ''' 
        Returns the weighted number of trajectories, according to Kolya's formula.

        Args :
            terminal_string : str, the terminal string
            alpha : float, temperature parameter
        Returns:
            logN : dict, where:
                - keys are all possible substates that can be used to construct terminal_string,
                - Value is the log-weighted number of trajectories that go through that substate, given terminal_string.
            mask_back_action : tensor of size len(terminal_string) x len(actions), where mask_back_action[i,j] is 1 if action j is a parent of terminal_string[:i]
        '''    
        atomic_traj = list(terminal_string.replace('<EOS>',''))
        logN = {'': 0}
        mask_back_action = torch.ones(len(terminal_string), len(library))  # Parents mask. Is 1 if the action is a parent of the current string
        actions = { action: len(action) for action in library}
        for i in range(1, len(atomic_traj) + 1):
            parents_i = []
            for action, j in actions.items():
                if terminal_string[i - j  :i]  == action:
                    mask_back_action[i - 1, library.index(action)] = 0
                    parents_i.append(terminal_string[:i - j])

            logN_parents_i = torch.Tensor([ logN[s] for s in parents_i])
            logN[ terminal_string[:i] ] = alpha + torch.logsumexp(logN_parents_i, dim = 0).item()
        return logN, mask_back_action



def get_logpb_state(string, terminal_string, alpha, logN, mask_back_action, library):
    '''
    Computes the logpb of each action in the library given the current state ( = string), according to Kolya's formula.

    Args :
        string : str, the current state
        terminal_string : str, the terminal state
        alpha : float, temperature parameter
        N : dict, the weighted number of trajectories given terminal_string
        mask_back_action : tensor of size len(terminal_string) x len(actions), where mask_back_action[i,j] is 1 if action j is a parent of terminal_string[:i]
    
    Returns :
        logpb : tensor of size len(actions), where logpb[j] is the logpb of choosing action j for the current state string.
    '''
    assert terminal_string[:len(string)] == string 
    assert list(logN.keys())[-1] == terminal_string.replace('<EOS>','')
    logpb = - torch.ones(len(library))*float("Inf")
    if string[-5:] == '<EOS>':
        logpb[0] = 1 # Action of removing the EOS
    elif len(string) > 0:
        mask = mask_back_action[len(string) - 1]
        ixs = np.where(mask==0)[0]
        for j in ixs:
            logpb[j] = alpha + logN[string[:- len(library[j])]] - logN[string]
    elif len(string) == 0:
        logpb = torch.zeros(len(library))
        # When no action is available, just fill with uniform because
        # it won't be picked anyway in the backward_step.
        # Doing this avoids having nan when computing probabilities
    return logpb
    

def get_logpb_traj(trajectory, alpha, logN, library):
    ''' 
    Computes the logpb of each action in the library given one trajectory, according to Kolya's formula. 
    Args : 
        traj : list of states, representing a trajectory sampled by a Gflownet
        alpha : float, temperature parameter
    Returns :
        logpbs : tensor of size len(traj)-1 x len(actions), where logpbs[i,j] is the logpb of choosing action j for state i
    '''
    if trajectory[-1][-5:] == '<EOS>':
        traj = trajectory[:-1]
    else:
        traj = trajectory
    if list(logN.keys())[-1] != traj[-1] :
        raise ValueError('The trajectory does not end with the terminal state for which logN was computed.')
    logpbs = - torch.ones(len(traj[:-1]))*float("Inf")
    for i in range(1,len(traj)):
        logpbs[i - 1] = alpha + logN[traj[i-1]] - logN[traj[i]]
    return logpbs.sum().item()

1) How slow is this, in terms of $L$ the library size and $T$ the string size ? 

In [160]:
from torch.distributions import Categorical

def create_library(atomic_tokens, n_chunk, max_size_chunk):
    #TODO : Add condition on n_chunk 
    library = atomic_tokens.copy()
    i = 0
    while i < n_chunk :
        # Choose a length for the chunk
        p = 0.5
        logits = torch.Tensor([np.log(p)*i for i in range(max_size_chunk - 2)])
        size_chunk = Categorical(logits = logits).sample() + 2
        # Create a chunk 
        ixs = torch.randint(0, len(atomic_tokens), (size_chunk,))
        chunk = ''.join([atomic_tokens[ix] for ix in ixs])
        if chunk not in library:
            library.append(chunk)
            i +=1

    return library

def create_terminal_state(atomic_tokens, size):
    terminal_state = ''
    for _ in range(size):
        ix = torch.randint(0, len(atomic_tokens), (1,))
        terminal_state += atomic_tokens[ix]
    return terminal_state

In [173]:
atomic_tokens = ['a', 'b', 'c', 'd']
library = create_library(atomic_tokens = atomic_tokens, n_chunk = 10, max_size_chunk = 6)
terminal_string = create_terminal_state(atomic_tokens, 20)
alpha = -1
logN, mask_back_action = compute_logN(terminal_string, alpha, library)

In [174]:
logN

{'': 0,
 'd': -1.0,
 'dc': -2.0,
 'dca': -1.68673837184906,
 'dcab': -2.68673837184906,
 'dcabd': -3.6867384910583496,
 'dcabdb': -4.68673849105835,
 'dcabdba': -5.68673849105835,
 'dcabdbac': -6.68673849105835,
 'dcabdbacd': -7.68673849105835,
 'dcabdbacda': -7.373476982116699,
 'dcabdbacdab': -8.3734769821167,
 'dcabdbacdaba': -9.3734769821167,
 'dcabdbacdabac': -10.3734769821167,
 'dcabdbacdabacb': -11.3734769821167,
 'dcabdbacdabacbb': -11.06021499633789,
 'dcabdbacdabacbbc': -12.06021499633789,
 'dcabdbacdabacbbcb': -13.06021499633789,
 'dcabdbacdabacbbcbc': -14.06021499633789,
 'dcabdbacdabacbbcbcc': -15.06021499633789,
 'dcabdbacdabacbbcbccd': -16.06021499633789}

In [172]:
import time

st = time.time()
for i in range(64):
    logN, mask_back_action = compute_logN(terminal_string, alpha, library)
et = time.time()
print('Execution time:', et - st, 'seconds')


Execution time: 0.051905155181884766 seconds


2 - How much does this PB help compared to uniform PB ? 

In [None]:
def get_log_uniform_

# Try parallelizing compute_logN 

In [1]:
from datamodules.base_sequence import BaseSequenceModule

In [22]:
import torch
actions_len = torch.Tensor([1, 4, 2, 3 ])

parents_actions = torch.randint(0, 2, (10, 4))

token_len = actions_len.unsqueeze(0) * parents_actions

In [24]:
logn = torch.randn(10, 10)
x = logn[:,:6]

In [27]:
torch.where(token_len == 0,)

tensor([[4., 5., 5., 5.],
        [5., 5., 3., 2.],
        [5., 1., 5., 2.],
        [4., 5., 5., 2.],
        [5., 1., 5., 2.],
        [4., 5., 5., 5.],
        [4., 1., 3., 5.],
        [4., 5., 5., 2.],
        [5., 1., 3., 5.],
        [4., 5., 3., 2.]])

In [25]:
torch.scatter(x, 1, 5-token_len)

tensor([[-1.0318, -0.3077,  0.2367,  0.5855,  0.5622,  0.3680],
        [-0.4552, -0.2963, -0.3099,  0.6454, -0.4992, -0.5301],
        [ 0.7691, -0.3932,  1.3314,  2.4756, -0.2416,  0.5387],
        [ 0.1188,  0.6430, -1.2160,  1.5628, -0.7457, -1.3795],
        [-0.3198, -1.9737, -0.1230, -0.5882,  0.6833, -0.3028],
        [-0.7885,  1.3844,  2.0651, -1.4581, -0.7226, -0.2698],
        [-1.9130,  1.3702, -0.6828, -0.0737, -1.2886,  0.0132],
        [ 1.1104,  0.4858, -1.1197,  0.3565,  0.4411, -1.5587],
        [ 1.3762,  0.2785,  0.2340,  1.0169, -0.3358,  0.1590],
        [ 1.1453,  0.9993, -0.5899, -1.4315, -1.6016, -0.3610]])

In [18]:
actions_len[parents_actions[1]]

tensor([2.])