In [2]:
# Supress pytorch pickle load warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Logging
from tqdm import tqdm
import matplotlib.pyplot as plt

# Library imports
import gdiffusion as gd
import util
import util.chem as chem
import util.visualization as vis
import util.stats as gdstats

import gdiffusion.bayesopt as bayesopt
from gdiffusion.classifier.extinct_predictor import EsmClassificationHead

device = util.util.get_device()
print(f"device: {device}")

# peptide diffusion
DIFFUSION_PATH = "saved_models/peptide_model_v1-20.pt"
PEPTIDE_VAE_PATH = "saved_models/peptide_vae/peptide-vae.ckpt"
PEPTIDE_VAE_VOCAB_PATH = "saved_models/peptide_vae/vocab.json"
EXTINCT_PREDICTOR_PATH = "saved_models/extinct_model8417"

  from .autonotebook import tqdm as notebook_tqdm


device: cuda


In [8]:
# Load models and helper functions
_diffusion_model = gd.create_peptide_diffusion_model(model_path=DIFFUSION_PATH, device=device)
diffusion = gd.create_ddim_sampler(diffusion_model=_diffusion_model, sampling_timesteps=50)

classifier = torch.load(EXTINCT_PREDICTOR_PATH).eval().to(device)

peptide_vae = gd.load_vae_peptides(path_to_vae_statedict=PEPTIDE_VAE_PATH, vocab_path=PEPTIDE_VAE_VOCAB_PATH)
decode = lambda z: gd.latent_to_peptides(z, vae=peptide_vae)
encode = lambda peptide_str: gd.peptides_to_latent(peptide_str, vae=peptide_vae)

def classify(z):
    return torch.softmax(classifier(z), dim=1)

def log_prob_fn_extinct(z: torch.Tensor) -> torch.Tensor:
    logits = classifier(z)
    log_probs = F.log_softmax(logits, dim=1)
    return log_probs[:, 1].sum()

def eval_probs(z):
    probs = classify(z)
    print(f"Diffusion Probs: {probs}")
    argmax = torch.argmax(probs, dim=1)
    print(f"Percent Extinct: {argmax.sum() / len(argmax)}")

def sample_random(batch_size):
    return torch.randn(size=(batch_size, 256), device=device)

# This is to include the automatic reshape option
def sample(batch_size: int, guidance_scale:float = 1.0, cond_fn=None):
    return diffusion.sample(batch_size=batch_size, guidance_scale=guidance_scale, cond_fn=cond_fn).reshape(batch_size, 256)

cond_fn_extinct = gd.get_cond_fn(
    log_prob_fn=log_prob_fn_extinct,
    clip_grad=False,
    latent_dim=256,
)


Model created successfully
- Total parameters: 225,056,257
- Trainable parameters: 225,056,257
- Model size: 858.5 MB
- Device: cuda:0
- Model Name: LatentDiffusionModel
- Device: cuda:0
- Model Name: LatentDiffusionModel
loading model from saved_models/peptide_vae/peptide-vae.ckpt
Enc params: 2,675,904
Dec params: 360,349


In [10]:
z_ddim = sample(batch_size=16, guidance_scale=8.0, cond_fn=cond_fn_extinct)

DDIM Sampling Loop Time Step: 100%|██████████| 50/50 [00:01<00:00, 25.31it/s]
