# Coding question for Guided Protein Diffusion 

In [1]:
import evodiff
import scipy.spatial.distance as dist

  import pkg_resources


In [2]:
from evodiff.pretrained import D3PM_UNIFORM_38M

checkpoint = D3PM_UNIFORM_38M(return_all=True)
model, collater, tokenizer, scheme, timestep, Q_bar, Q = checkpoint

sohl-dickstein


In [211]:
from evodiff.generate import generate_d3pm

seq_len = 100 

tokenized_sample, generated_sequence = generate_d3pm(model, tokenizer, Q, Q_bar, timestep, seq_len, batch_size=1, device='cpu')

100%|██████████| 499/499 [01:01<00:00,  8.16it/s]

final seq ['MNNRVKGDVLLNSQLLKYRELAEDCQLTAYTTSDQQRHPWFTLLREQVRTLTVGRTLARNLSGIDEALVTTARRDGRTQIVVATSAESKSRWRSLAGSSR']





In [3]:


def d(x, y): #Hamming distance
    assert len(x) == len(y)
    return sum(xi != yi for xi, yi in zip(x, y))

def reward(x):
    return 4 - d(x[:4], 'MSTQ')


reward('MSAAAAAAAG')

2

In [66]:
import torch
from tqdm import tqdm
import numpy as np

def generate_d3pm_FK(model, tokenizer, Q, Q_bar, timesteps, seq_len,frequency=0.1, batch_size=3, device='cuda',llambda=1,reward_fn=reward):
    """
    Samples from the D3PM reverse process using a Feynman-Kac steering
    
    """
    n_samples = 50 #number of samples used to compute the intermediate rewards
    K = tokenizer.K

    sample = torch.randint(0, K, (batch_size, seq_len), dtype=torch.long, device=device)
    Q = Q.to(device)
    Q_bar = Q_bar.to(device)
    
    log_product_of_potentials = torch.zeros(batch_size, device=device)
    max_R_phi = torch.zeros((batch_size), device=device)  
    #sum_R_phi = torch.zeros((batch_size), device=device)
    #diff_R_phi = torch.zeros((batch_size), device=device) 
    #R_phi = torch.zeros((batch_size), device=device) #Intermediate reward at previous timestep

    
    timesteps_iter = list(range(timesteps - 1, 0, -1))
   
    N = int(1/frequency)
    
    with torch.no_grad():
        for t in tqdm(timesteps_iter):
            
            timesteps_tensor = torch.full((batch_size,), t, device=device, dtype=torch.long)
            
            if t > 1 and (t % N == 0):
                # Compute potential weights
                w = torch.exp(llambda * max_R_phi)    
                
                #Resampling according to potentials
                a_t = torch.multinomial(w, num_samples=batch_size, replacement=True)
                #print("Resampling at time", t)
                #print("Weights", w)
                sample = sample[a_t]  
                max_R_phi = max_R_phi[a_t]
                #sum_R_phi = sum_R_phi[a_t]
                #diff_R_phi = diff_R_phi[a_t]
                #R_phi = R_phi[a_t]

                log_product_of_potentials = log_product_of_potentials[a_t] + torch.log(w[a_t])
            
            # sampling x_{t-1})
            prediction = model(sample, timesteps_tensor)
            p = prediction[:, :, :K]  # p_theta_tilde (x_0_tilde | x_t)
            p = torch.nn.functional.softmax(p, dim=-1).to(torch.float64)  # [B, L, K]
            x_tminus1 = sample.clone()

            
            for i, s in enumerate(sample):
                
                x_t_b = tokenizer.one_hot(s)
                if not isinstance(x_t_b, torch.Tensor):
                    x_t_b = torch.tensor(x_t_b, device=device, dtype=torch.float64)
                else:
                    x_t_b = x_t_b.to(device=device, dtype=torch.float64)

                A = torch.mm(x_t_b, torch.t(Q[t]))  # [L x K]
                Q_expand = Q_bar[t-1].unsqueeze(0).expand(A.shape[0], K, K)  # [ L x K x K]
                B_pred = p[i].unsqueeze(2) * Q_expand
                q_t = A.unsqueeze(1) * B_pred  # [ L x K x K ]
                p_theta_marg = torch.bmm(q_t.transpose(1,2), p[i].unsqueeze(2)).squeeze(-1)  # [L x K]
                p_theta_marg = p_theta_marg / (p_theta_marg.sum(dim=1, keepdim=True))
                
                x_tminus1[i] = torch.multinomial(p_theta_marg, num_samples=1).squeeze(1)
                # On final timestep pick next best from standard AA
                if t == 1:
                    x_tminus1[i] = torch.multinomial(p_theta_marg[:, :K-6], num_samples=1).squeeze(1)
            
            sample = x_tminus1
            
            #print("Sample at time",t,[tokenizer.untokenize(s.to('cpu').long()) for s in sample])

            # Compute intermediate rewards
            if t > 1 and (t % N == 0):
                # second forward for p(x0 | x_{t-1}) needed for intermediate reward estimation
                prediction2 = model(sample, timesteps_tensor)
                p2 = torch.nn.functional.softmax(prediction2[:, :, :K], dim=-1).to(torch.float64)

                for i in range(batch_size):
                    
                    x_0_samples = torch.multinomial(p2[i], num_samples=n_samples, replacement=True)  # [L, n_samples]
                    
                    r_vals = []
                    
                    for j in range(x_0_samples.shape[1]):
                        seq_tensor = x_0_samples[:, j].to('cpu').long()
                        seq_str = tokenizer.untokenize(seq_tensor)
                        r_vals.append(float(reward_fn(seq_str)))
                    
                    r_vals = torch.tensor(r_vals, dtype=torch.float64, device=device)
                    
                    r_phi =  torch.log(torch.mean(torch.exp(r_vals )) + 1e-12)

                    #sum_R_phi[i] += r_phi
                    #diff_R_phi[i] = r_phi - R_phi[i]
                    #R_phi[i] = r_phi
                    if r_phi > max_R_phi[i]:
                        max_R_phi[i] = r_phi
            
            elif t == 1:
                # Compute final rewards 
                final_rewards = torch.tensor([
                    reward_fn(tokenizer.untokenize(sample[i].to('cpu').long())) 
                    for i in range(batch_size)
                ], device=device, dtype=torch.float64)
                
                print(final_rewards)
                # Compute corrected final weights
                w = torch.softmax(llambda * final_rewards - log_product_of_potentials,dim=-1) 

                print(w)
                
                
                #Final resampling
                a_t = torch.multinomial(w, num_samples=batch_size, replacement=True)
                sample = sample[a_t].clone()

    untokenized = [tokenizer.untokenize(s.to('cpu').long()) for s in sample]
    print("final seq", untokenized)
    return sample, untokenized


In [67]:
seq_len = 100

tokenized_sample, generated_sequence = generate_d3pm_FK(model, tokenizer, Q, Q_bar, timestep, seq_len, batch_size=30, device='cpu',frequency=0.1,llambda=3.0)

100%|██████████| 499/499 [15:00<00:00,  1.80s/it]

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 1., 1., 2., 1., 2., 1.,
        1., 2., 1., 2., 2., 1., 1., 1., 2., 1., 1., 1.], dtype=torch.float64)
tensor([0.0165, 0.0165, 0.0165, 0.0165, 0.0165, 0.0159, 0.0165, 0.0165, 0.0165,
        0.0784, 0.0165, 0.0814, 0.0159, 0.0165, 0.0784, 0.0165, 0.0814, 0.0165,
        0.0165, 0.0814, 0.0165, 0.0814, 0.0784, 0.0165, 0.0165, 0.0165, 0.0784,
        0.0165, 0.0165, 0.0165], dtype=torch.float64)
final seq ['MTDRLSVTVVLGAAMAATLLNAAAFSATRAWTHAVDVPRAITNAVTGTAPVSLYTSQLHEETTAWPPQTAMLSLLAESHAGRSASVVSTQRRVDMIDEVG', 'MSVVLLAAVVLGAQTAAEVGQSAPGPETAAWILANDVPRAITLNVTDAGTVHLWHVGLHHLEWANVHVTGLELEYGTEAKDGSAQVMYDYRRVRPLWEYG', 'MTDRLLGTVVLGAAMAAVPDGAAHFSATRDWTHAVDVLRAITIAVTGTPPVVLYTSQLHEETFAWPPFTALLSLLTESSTRHSASVVRDQRRVGRDWEVG', 'MTWRLLLTVALGAAMAAPADAAALFSATRGVTVAVDVLPAITGAVTGTVPVSRYTSELPELAFAWPPQTAMLSLLHESGKLRSASVVETQRRGMFIREVG', 'MSVVLLAAVVLGAQTAAEVGQSAPGPATAAWILAADVPRAITLNVTDAGTVHLWHVGLHHLEWANVHVTGLELEYGTEAKDGSAQVMYDYRRVRPLWEYG', 'MSVVLLAAVVLGAQTA




In [69]:
rewards = [reward(seq) for seq in generated_sequence]
print(max(rewards))

2
