Testing if the Diffusion Model Works:


In [1]:
# Changes working directory to root to allow imports to work

import os

def change_directory_to_root(root_name: str = "GDProject"):
    """Changes the directory to the root of the project"""
    current_folder = os.getcwd().split('/')[-1]
    if current_folder != root_name:
        os.chdir('..')

    print(f"New Current Directory is {os.getcwd()}")

change_directory_to_root()

New Current Directory is /home/alden/Research/GDProject


In [2]:
# Supress pytorch pickle load warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Logging
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle

# Library imports
import gdiffusion as gd
import util
import util.chem as chem
import util.visualization as vis
import util.stats as gdstats

from gdiffusion.classifier.logp_predictor import LogPPredictor
from gdiffusion.classifier.extinct_predictor import EsmClassificationHead
# Paths for all of the models:
from util.model_paths import *

device = util.get_device(print_device=True)

Device: cuda


In [12]:
# Load Diffusion
peptide_diffusion_model = gd.PeptideDiffusionModel(unet_state_dict_path=PEPTIDE_DIFFUSION_MODEL_PATH)
diffusion = gd.DDIMSampler(diffusion_model=peptide_diffusion_model, sampling_timesteps=100)

[Molecule Diffusion]: UNet Successfully loaded
- Total parameters: 225,056,257
- Trainable parameters: 225,056,257
- Model size: 858.5 MB


In [13]:
# Load VAE and classifier
vae = gd.PeptideVAE()
classifier : EsmClassificationHead = torch.load(EXTINCT_CLASSIFIER_PATH, weights_only=False).to(device)


 Loading Peptide VAE:
------------------------------------------------
Loaded VAE Vocab from saved_models/peptide_vae/vocab.json
Getting State Dict...
Loading model from saved_models/peptide_vae/peptide-vae.ckpt
Enc params: 2,675,904
Dec params: 360,349
------------------------------------------------



In [14]:
# Load models and helper functions


def eval_probs(z):
    probs = classifier.classify(z)
    print(f"Diffusion Probs: {probs}")
    argmax = torch.argmax(probs, dim=1)
    print(f"Percent Extinct: {argmax.sum() / len(argmax)}")

def sample_random(batch_size):
    return torch.randn(size=(batch_size, 256), device=device)

cond_fn_extinct = gd.get_cond_fn(
    log_prob_fn=classifier.log_prob_fn_extinct,
    clip_grad=False,
    latent_dim=256,
)

In [39]:
z_ddim = diffusion.sample(batch_size=16, guidance_scale=16.0, cond_fn=cond_fn_extinct)
eval_probs(z_ddim)


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:04<00:00, 24.61it/s]

Diffusion Probs: tensor([[2.6474e-12, 1.0000e+00],
        [2.1940e-26, 1.0000e+00],
        [1.6263e-40, 1.0000e+00],
        [9.3578e-23, 1.0000e+00],
        [1.9555e-14, 1.0000e+00],
        [1.0000e+00, 0.0000e+00],
        [7.1258e-13, 1.0000e+00],
        [1.0000e+00, 0.0000e+00],
        [2.6347e-41, 1.0000e+00],
        [1.8379e-14, 1.0000e+00],
        [1.4375e-13, 1.0000e+00],
        [2.5544e-15, 1.0000e+00],
        [6.5575e-16, 1.0000e+00],
        [9.1375e-14, 1.0000e+00],
        [9.1540e-17, 1.0000e+00],
        [9.1808e-35, 1.0000e+00]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Percent Extinct: 0.875





In [40]:
z_ddim = diffusion.sample(batch_size=16)
eval_probs(z_ddim)


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:04<00:00, 24.77it/s]

Diffusion Probs: tensor([[0.9961, 0.0039],
        [0.9706, 0.0294],
        [0.1963, 0.8037],
        [0.4406, 0.5594],
        [0.4786, 0.5214],
        [0.9468, 0.0532],
        [0.3373, 0.6627],
        [0.6366, 0.3634],
        [0.7163, 0.2837],
        [0.9550, 0.0450],
        [0.9447, 0.0553],
        [0.8912, 0.1088],
        [0.8357, 0.1643],
        [0.9235, 0.0765],
        [0.9980, 0.0020],
        [0.9448, 0.0552]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
Percent Extinct: 0.25



