In [3]:
# Supress pytorch pickle load warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Logging
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle

# Library imports
import gdiffusion as gd
import util
import util.chem as chem
import util.visualization as vis
import util.stats as gdstats


import gdiffusion.bayesopt as bayesopt
from gdiffusion.classifier.extinct_predictor import EsmClassificationHead

device = util.util.get_device()
print(f"device: {device}")

# peptide diffusion
DIFFUSION_PATH = "saved_models/peptide_model_v1-20.pt"
PEPTIDE_VAE_PATH = "saved_models/peptide_vae/peptide-vae.ckpt"
PEPTIDE_VAE_VOCAB_PATH = "saved_models/peptide_vae/vocab.json"
EXTINCT_PREDICTOR_PATH = "saved_models/extinct_model8417"

device: cuda


In [4]:
classifier = torch.load(EXTINCT_PREDICTOR_PATH)
classifier.eval().to(device)

diffusion = gd.create_peptide_diffusion_model(DIFFUSION_PATH, device=device)
peptide_vae = gd.load_vae_peptides(path_to_vae_statedict=PEPTIDE_VAE_PATH, vocab_path=PEPTIDE_VAE_VOCAB_PATH)


Model created successfully
- Total parameters: 225,056,257
- Trainable parameters: 225,056,257
- Model size: 858.5 MB
- Device: cuda:0
- Model Name: LatentDiffusionModel
- Device: cuda:0
- Model Name: LatentDiffusionModel
loading model from saved_models/peptide_vae/peptide-vae.ckpt
Enc params: 2,675,904
Dec params: 360,349


In [5]:
decode = lambda z: gd.latent_to_peptides(z, vae=peptide_vae)
encode = lambda peptide_str: gd.peptides_to_latent(peptide_str, vae=peptide_vae)

def sample_random(batch_size):
    return torch.randn(size=(batch_size, 256), device=device)

def classify(z):
    return torch.softmax(classifier(z), dim=1)

def sample(batch_size, cond_fn=None):
    return diffusion.sample(batch_size=batch_size, cond_fn=cond_fn).reshape(batch_size, 256)

def eval_probs(z):
    probs = classify(z)
    print(f"Diffusion Probs: {probs}")
    argmax = torch.argmax(probs, dim=1)
    print(f"Percent Extinct: {argmax.sum() / len(argmax)}")

In [6]:
def log_prob_fn_extinct(z):
    batch_size, latent_dim = z.shape
    logits = classifier(z)
    log_prob_sum = F.log_softmax(input=logits, dim=1).sum(dim=0)
    log_prob_sum[0] *= -1
    log_prob = log_prob_sum.sum(dim=0)
    return log_prob

cond_fn_extinct = gd.get_cond_fn(
    log_prob_fn=log_prob_fn_extinct, 
    guidance_strength=1.0, 
    clip_grad=True, 
    clip_grad_max=1.0,
    latent_dim=256
)

z_guided = diffusion.sample(batch_size=16, cond_fn=cond_fn_extinct)
z_guided = z_guided.reshape(-1, 256)

DDPM Sampling loop time step: 100%|██████████| 1000/1000 [00:41<00:00, 24.26it/s]


In [7]:
eval_probs(z_guided)

Diffusion Probs: tensor([[2.8466e-02, 9.7153e-01],
        [1.8587e-10, 1.0000e+00],
        [1.3243e-01, 8.6757e-01],
        [2.5237e-08, 1.0000e+00],
        [6.3647e-07, 1.0000e+00],
        [1.7829e-25, 1.0000e+00],
        [2.4681e-05, 9.9998e-01],
        [6.6248e-01, 3.3752e-01],
        [1.5947e-02, 9.8405e-01],
        [3.0473e-03, 9.9695e-01],
        [2.7940e-02, 9.7206e-01],
        [2.2397e-02, 9.7760e-01],
        [3.2246e-01, 6.7754e-01],
        [4.7122e-07, 1.0000e+00],
        [6.3131e-05, 9.9994e-01],
        [8.7907e-03, 9.9121e-01]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Percent Extinct: 0.9375


In [8]:
unet = diffusion.model

In [16]:
from gdiffusion.diffusion.beta_scheduler import BetaScheduleSigmoid
from gdiffusion.diffusion.util import *
from functools import partial
from collections import namedtuple
ModelPrediction =  namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])


In [48]:
class DiffusionSampler(nn.Module):
    def __init__(self, model, latent_dim, num_timesteps=1000, device=None):
        super().__init__()

        self.model = model
        self.dim = latent_dim
        self.num_timesteps = num_timesteps
        self.device = self._get_device(device)
        
        betas = BetaScheduleSigmoid.get_betas(num_timesteps=num_timesteps)

        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)

        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))

        register_buffer('betas', betas)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        register_buffer('posterior_variance', posterior_variance)

        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

        # loss weight
        snr = alphas_cumprod / (1 - alphas_cumprod)
        loss_weight = snr / (snr + 1)
        register_buffer('loss_weight', loss_weight)

    def _get_device(self, device=None):
        if device is not None:
            return device
        
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        if device != 'cuda':
            print("Warning: device is {device}, not cuda")

        return device
    
    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_start_from_noise(self, x_t, t, noise):
        return (
            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def predict_noise_from_start(self, x_t, t, x0):
        return (
            (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
        )

    def predict_v(self, x_start, t, noise):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
        )

    def predict_start_from_v(self, x_t, t, v):
        return (
            extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
        )

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped
    
    def model_predictions(self, x, t):
        v = self.model(x, t)

        x_start = self.predict_start_from_v(x, t, v)
        pred_noise = self.predict_noise_from_start(x, t, x_start)

        return ModelPrediction(pred_noise, x_start)
    
    def mean_variance(self,x, t):
        preds = self.model_predictions(x, t)
        x_start = preds.pred_x_start
        pred_noise = preds.pred_noise

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
        return model_mean, posterior_variance, posterior_log_variance, x_start

    def sample(self, batch_size=16, return_all_timesteps=False, cond_fn=None, *args):
        raise NotImplementedError("Sample must be derived")


In [62]:
class DDIMSampler(DiffusionSampler):
    def __init__(self, model, latent_dim, num_timesteps=1000, sampling_timesteps=None, device=None, ddim_sampling_eta=0.,):
        super().__init__(model, latent_dim, num_timesteps, device)
        self.sampling_timesteps = sampling_timesteps
        self.ddim_sampling_eta = ddim_sampling_eta

    def condition_score(self, cond_fn, pred_noise, x_start, x, t):
        alpha_bar = extract(self.alphas_cumprod, t, x.shape)
        new_pred_noise = pred_noise - (1 - alpha_bar).sqrt() * cond_fn(x, t)
        
        new_x_start = self.predict_start_from_noise(x, t, new_pred_noise)

        return new_x_start, pred_noise
    
    @torch.no_grad()
    def sample(self, batch_size=16, return_all_timesteps=False, cond_fn=None):
        shape = (batch_size, 1, self.dim)
        eta = self.ddim_sampling_eta

        times = torch.linspace(-1, self.num_timesteps - 1, steps=self.sampling_timesteps + 1)
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

        x = torch.randn(shape, device=self.device)

        for time, time_next in tqdm(time_pairs, desc = 'DDIM Sampling Loop Time Step'):
            time_cond = torch.full((batch_size,), time, device=self.device, dtype=torch.long)
            pred_noise, x_start = self.model_predictions(x, time_cond)

            if cond_fn is not None:
                new_x_start, new_pred_noise = self.condition_score(cond_fn, pred_noise, x_start, x, time_cond)
                x_start = new_x_start
                pred_noise = new_pred_noise

            if time_next < 0:
                x = x_start
                continue

            alpha = self.alphas_cumprod[time]
            alpha_next = self.alphas_cumprod[time_next]

            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()

            noise = torch.randn_like(x)

            x = x_start * alpha_next.sqrt() + \
                  c * pred_noise + \
                  sigma * noise

        return x


In [70]:
ddim = DDIMSampler(model=unet.to('cuda'), latent_dim=256, num_timesteps=1000, sampling_timesteps=500, device='cuda').to(device)
cond_fn_extinct_10 = gd.get_cond_fn(
    log_prob_fn=log_prob_fn_extinct, 
    guidance_strength=2.0, 
    clip_grad=True, 
    clip_grad_max=1.0,
    latent_dim=256
)
z_ddim = ddim.sample(batch_size=16, cond_fn=cond_fn_extinct_10).reshape(-1, 256)


DDIM Sampling Loop Time Step: 100%|██████████| 500/500 [00:20<00:00, 24.14it/s]


In [71]:
z_ddim

tensor([[ 57.8160, -52.5945,  61.0330,  ..., -47.9392, -81.4191, -80.5274],
        [ 69.0803, -67.9495,  63.3505,  ..., -67.3417, -73.5253, -77.0669],
        [ 68.8275, -58.3579,  60.3463,  ..., -62.5844, -78.9527, -71.2014],
        ...,
        [ 74.1794, -69.7479,  60.5194,  ..., -61.6581, -79.1902, -76.8611],
        [ 64.5007, -68.8190,  55.3076,  ..., -54.2688, -78.8820, -82.4434],
        [ 79.0259, -63.6021,  61.4692,  ..., -65.1815, -75.6009, -80.9547]],
       device='cuda:0')

In [72]:
eval_probs(z_ddim)

Diffusion Probs: tensor([[1.2451e-07, 1.0000e+00],
        [1.0000e+00, 9.0503e-11],
        [9.8874e-01, 1.1256e-02],
        [1.0000e+00, 6.7999e-08],
        [1.6868e-05, 9.9998e-01],
        [1.0000e+00, 1.9122e-06],
        [9.9998e-01, 1.9980e-05],
        [1.0000e+00, 5.3889e-08],
        [1.0000e+00, 5.8288e-09],
        [9.9960e-01, 3.9647e-04],
        [9.7899e-01, 2.1005e-02],
        [1.0000e+00, 1.1943e-10],
        [1.0000e+00, 5.0092e-09],
        [9.9999e-01, 1.3650e-05],
        [9.9999e-01, 5.2688e-06],
        [1.0000e+00, 1.2181e-10]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Percent Extinct: 0.125
