In [1]:
# Supress pytorch pickle load warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Logging
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle

# Library imports
import gdiffusion as gd
import util
import util.chem as chem
import util.visualization as vis
import util.stats as gdstats


import gdiffusion.bayesopt as bayesopt
from gdiffusion.classifier.extinct_predictor import EsmClassificationHead

device = util.util.get_device()
print(f"device: {device}")

# peptide diffusion
DIFFUSION_PATH = "saved_models/peptide_model_v1-20.pt"
PEPTIDE_VAE_PATH = "saved_models/peptide_vae/peptide-vae.ckpt"
PEPTIDE_VAE_VOCAB_PATH = "saved_models/peptide_vae/vocab.json"
EXTINCT_PREDICTOR_PATH = "saved_models/extinct_model8417"

  from .autonotebook import tqdm as notebook_tqdm


device: cuda


In [2]:
classifier = torch.load(EXTINCT_PREDICTOR_PATH)
classifier.eval().to(device)

diffusion = gd.create_peptide_diffusion_model(DIFFUSION_PATH, device=device)
peptide_vae = gd.load_vae_peptides(path_to_vae_statedict=PEPTIDE_VAE_PATH, vocab_path=PEPTIDE_VAE_VOCAB_PATH)


Model created successfully
- Total parameters: 225,056,257
- Trainable parameters: 225,056,257
- Model size: 858.5 MB
- Device: cuda:0
- Model Name: LatentDiffusionModel

Model created successfully
- Total parameters: 225,056,257
- Trainable parameters: 225,056,257
- Model size: 858.5 MB
- Device: cuda:0
- Model Name: LatentDiffusionModel
loading model from saved_models/peptide_vae/peptide-vae.ckpt
Enc params: 2,675,904
Dec params: 360,349


In [62]:
decode = lambda z: gd.latent_to_peptides(z, vae=peptide_vae)
encode = lambda peptide_str: gd.peptides_to_latent(peptide_str, vae=peptide_vae)

def sample_random(batch_size):
    return torch.randn(size=(batch_size, 256), device=device)

def classify(z):
    return torch.softmax(classifier(z), dim=1)

def sample(batch_size, cond_fn=None):
    return diffusion.sample(batch_size=batch_size, cond_fn=cond_fn).reshape(batch_size, 256)

def eval_probs(z):
    probs = classify(z)
    print(f"Diffusion Probs: {probs}")
    argmax = torch.argmax(probs, dim=1)
    print(f"Percent Extinct: {argmax.sum() / len(argmax)}")

In [None]:
z = diffusion.sample(batch_size=16).reshape(-1, 256)
eval_probs(z)

DDPM Sampling loop time step: 100%|██████████| 1000/1000 [00:40<00:00, 24.44it/s]


In [46]:
# Random sampling always results in not extinct (class 0) peptide
z_rand = sample_random(batch_size=16)
classify(z_rand)

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.],
        [1., 0.]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [63]:
def log_prob_fn_extinct(z):
    batch_size, latent_dim = z.shape
    logits = classifier(z)
    log_prob_sum = F.log_softmax(input=logits, dim=1).sum(dim=0)
    log_prob_sum[0] *= -1
    log_prob = log_prob_sum.sum(dim=0)
    return log_prob

cond_fn_extinct = gd.get_cond_fn(
    log_prob_fn=log_prob_fn_extinct, 
    guidance_strength=1.0, 
    clip_grad=True, 
    clip_grad_max=1.0,
    latent_dim=256
)

z_guided = diffusion.sample(batch_size=16, cond_fn=cond_fn_extinct)
z_guided = z_guided.reshape(-1, 256)

DDPM Sampling loop time step: 100%|██████████| 1000/1000 [00:41<00:00, 24.14it/s]


Diffusion Probs: tensor([[0.0000e+00, 1.0000e+00],
        [2.5574e-04, 9.9974e-01],
        [1.0876e-05, 9.9999e-01],
        [8.9827e-02, 9.1017e-01],
        [8.8552e-07, 1.0000e+00],
        [4.0323e-29, 1.0000e+00],
        [3.1427e-01, 6.8573e-01],
        [6.6403e-01, 3.3597e-01],
        [1.9429e-05, 9.9998e-01],
        [3.0211e-04, 9.9970e-01],
        [4.0222e-03, 9.9598e-01],
        [1.0535e-06, 1.0000e+00],
        [1.4940e-01, 8.5060e-01],
        [1.6418e-11, 1.0000e+00],
        [1.1970e-02, 9.8803e-01],
        [5.1481e-08, 1.0000e+00]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Percent Extinct: 0.9375


In [69]:
decode(z_guided)

['PKFPNAHMPGMLLLMYTLAIIMGVVRIPFLKPWIVGD',
 'AENNGGEMMGLPFLTWSQGTSLIDLLFSSIRERPGSSVYSV',
 'HSRPVAVLVALIVVTEENLDERLYFSVKLSDLGLYIIMV',
 'FTTFLGVALLFLQGLSNNLLPSISPSLLYLTISLTLFLTMLM',
 'IVTKYVFYAVCTPEQPTREMAINQSCQLATKTVTVRVFHHVPPLVIV',
 'HYKKMSNYAFVVKLCVARVQRERRARERPRERAPRTDDEEPLVTM',
 'TSVTWMGMFLWLLIHTWAYTYININPTI',
 'KTPPAATSMAWTMWYV',
 'LYMMMWLKNMSSLHMWSIISFLLLSLLTGLTLTL',
 'TSIILLALTSTFTSLGLLQELISTNQTLLYSLTMA',
 'NALGLMLLWSTPVLLVPITWLSGAMINKSLPMLPQTL',
 'RWKAPPELGGELGVLSRRRIMLPTLTPPLIAYGTPNIILVFPENNTFH',
 'WTVNSLSGSYRVWM',
 'GMVMMALKLAVGESNVGASGVFLFVFSGSLLEWSGGLVVVLLSVVPVD',
 'LTPALMILIMYLKTLLVLTTTLLANTTNMLLEALAMAL',
 'AMILTVSITNSFMEGQAVVEVMFGVSYLLGYYHILMWLTDHV']

In [75]:
# sample from diffusion (no guidance)
eval_probs(z)

# sample from diffusion (with guidance)
eval_probs(z_guided)

Diffusion Probs: tensor([[0.9697, 0.0303],
        [0.8986, 0.1014],
        [0.7736, 0.2264],
        [0.4445, 0.5555],
        [0.5576, 0.4424],
        [0.9958, 0.0042],
        [0.8961, 0.1039],
        [0.3110, 0.6890],
        [0.0114, 0.9886],
        [0.6271, 0.3729],
        [0.6524, 0.3476],
        [0.5959, 0.4041],
        [0.9967, 0.0033],
        [0.9662, 0.0338],
        [0.8483, 0.1517],
        [0.9548, 0.0452]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Percent Extinct: 0.1875
Diffusion Probs: tensor([[0.0000e+00, 1.0000e+00],
        [2.5574e-04, 9.9974e-01],
        [1.0876e-05, 9.9999e-01],
        [8.9827e-02, 9.1017e-01],
        [8.8552e-07, 1.0000e+00],
        [4.0323e-29, 1.0000e+00],
        [3.1427e-01, 6.8573e-01],
        [6.6403e-01, 3.3597e-01],
        [1.9429e-05, 9.9998e-01],
        [3.0211e-04, 9.9970e-01],
        [4.0222e-03, 9.9598e-01],
        [1.0535e-06, 1.0000e+00],
        [1.4940e-01, 8.5060e-01],
        [1.6418e-11, 1.0000e+00],
    

In [None]:
probs_z_rand_extinct = torch.softmax(input=classifier(z_rand), dim=-1)[:, 1]
probs_z_extinct = torch.softmax(classifier(z), dim=-1)[:, 1]
probs_z_guided_extinct = torch.softmax(classifier(z_guided), dim=-1)[:, 1]
print(probs_z)

tensor([0.0303, 0.1014, 0.2264, 0.5555, 0.4424, 0.0042, 0.1039, 0.6890, 0.9886,
        0.3729, 0.3476, 0.4041, 0.0033, 0.0338, 0.1517, 0.0452],
       device='cuda:0', grad_fn=<SelectBackward0>)


In [None]:
vis.display_histograms()