In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import guided_diffusion.guided_diffusion_1d
from peptidevae.load_vae import load_vae, vae_decode
import classifier.load_classifier as lc

import peptide_dataset
import importlib
from tqdm import tqdm
import time
import peptidevae.load_vae

device = "cuda" if torch.cuda.is_available() else "cpu"
importlib.reload(peptide_dataset)
importlib.reload(peptidevae.load_vae)
importlib.reload(lc)
importlib.reload(guided_diffusion.guided_diffusion_1d)
import pandas as pd

import torch.nn as nn
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


Using: cuda
Using: cuda


In [6]:
# Get VAE for viewing output: should take ~30 seconds, kinda slow
vae, dataobj = load_vae("peptidevae/checkpoints/dim128_k1_kl0001_eff256_dff256_pious-sea-2_model_state_epoch_118.pkl", dim=256, max_string_length=50)

  state_dict = torch.load(path_to_vae_statedict, map_location=device)


In [3]:
latent_dim = 256
train_batch_size = 16
sample_batch_size = 4

In [3]:

data_module = peptide_dataset.LatentDataModule(batch_size=train_batch_size, train_val_split=0.9)

Loaded full dataset of 10274723 examples


In [9]:

def decode_latent(latent_batch):
    with torch.no_grad():
        return None
        # return vae_decode(latent_batch, vae, dataobj, device=device)

def sample_diffusion(diffusion_model):
    diffusion_model.eval()
    with torch.no_grad():
        latents = diffusion_model.sample(batch_size=sample_batch_size)
        latents = latents.reshape(sample_batch_size, latent_dim)
        return latents, decode_latent(latents)
    
def sample_with_guidance(diffusion_model, classifier, guidance_scale=4.0, target_class=1, batch_size=sample_batch_size):
    def cond_fn(x, t):
        x_in = x.reshape(sample_batch_size, latent_dim)

        x_in = x_in.detach().requires_grad_(True)

        with torch.enable_grad():

            # logits:
            _, logits, _ = lc.predict(classifier, x_in, device)

            # loss = -logits[:, target_class].sum()

            sign = 1 if target_class == 0 else -1
            loss = sign * logits.sum()
            
            grad = torch.autograd.grad(loss, x_in)[0]

        # negative guidance scale, because we minimize lossOk,
            grad = grad.reshape(sample_batch_size, 1, latent_dim)
            return -guidance_scale * grad
    
    diffusion_model.eval()
    with torch.no_grad():
        latents = diffusion_model.sample(batch_size=sample_batch_size, cond_fn=cond_fn, guidance_kwargs={})
        latents = latents.reshape(sample_batch_size, latent_dim)
        return latents, decode_latent(latents)

In [25]:
# Creates dataloader and checks that it works
importlib.reload(peptide_dataset)

data_loader = data_module.full_dataloader
for batch in data_loader:
    latent = batch['latent']
    print(decode_latent(latent))
    break


Using: cuda
['FLPQGTPSPLIPMLLILETISLFIQPMALAVRLTANITAGHLLIHL', 'VMATAFMGYVLPWGQMSFWGATVITNLLSAIPYIGPTLVEWIW', 'QDIRKMGGMMYTLPFTSSCLMIGTLALTGMPFMTGFYSKDHII', 'AFMGYVLPWGQMSFWGATVITNLLSAIPYIGTTLVEW', 'MLTMIPILMKTTNPRSTEAATKYFMTQATASMMLMMALTINLVYS', 'LLVLFIMFQLKVSNHMYPMNPELIKPKLKEQKTPWE', 'QCPKPTLQQISHIAQQLGLEKDVVRVWFCNRRQKGKRSSSDYSQREDF', 'LASATNTWEIQQL', 'IQQAFSHTQAPTLPLLGLILAATGKSAQ', 'MAIAMLSLLSLFFYLRLAYHSTIILPPNSSNH', 'DVIRESTFQGHHTTTVQKGLRYGMVLFIVSEVFFFLGFFW', 'MISHIVTYYSGKKEPFGYMGMVWAMVSIGFLGFIVWA', 'PILIAMAFLMLTERKILGYMQLRKGPNVVGPYGL', 'IPMITNSLT', 'PWASQTSKLPTMLITALL', 'PPLSGFLPKWMIIQEMTKNSLIIMPTMMAI']


In [None]:
# Create the classifier (pre-trained)
classf, _ = lc.load_model("classifier/best_model.pt")

# Simple test to check classifier is working
test_latents = np.load('classifier/test_peptides_latents.npy') 
labels = pd.read_csv('classifier/test_peptides.csv')['labels'].to_list()
test_latents = torch.Tensor(test_latents).to(device)

predictions, logits, probabilities = lc.predict(classf, test_latents, device)
accuracy = np.mean(predictions.to('cpu').numpy() == labels)
print(f"Accuracy: {accuracy:.2f}")

# predictions, _, _ = lc.predict(classf, test_latents, device)
# print(f"Prediction: {predictions.sum() / len(predictions)}")

def predict_exctinct(latents, classf):
    predictions, _, probabilities = lc.predict(classf, latents, device)
    for i in range(len(predictions)):
        print(f"Predicted class {predictions[i]} with probability of predictions=0 at {probabilities[i][0] * 100.0}%")
    return predictions, probabilities

def get_extinct_prediction_percent(diffusion, guidance_scale, num_samples=64):
    classf = classf.to(device)
    classf.eval()
    diffusion.eval()

    latents, _ = sample_with_guidance(diffusion_model=diffusion, classifier=classf, guidance_scale=guidance_scale, batch_size=64)
    predictions, _, _ = lc.predict(classf, latents, device)
    
    print(f"Prediction: {predictions.sum() / len(predictions)}")



In [None]:
def train_diffusion(diffusion_model, dataloader=data_loader, batch_size=train_batch_size, epochs=10, lr=1e-4, device=device):
    model = diffusion_model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0

        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}")

        for batch_idx, batch in progress_bar:
            latent = batch['latent']
            latent = latent.to(device)

            # IMPORTANT: the dataloader stores objects of shape (b, n), but the
            # UNET / diffusion want (b, 1, n)
            latent = latent.reshape(batch_size, 1, latent_dim)

            # When we sample, we will unshape this
            optimizer.zero_grad()
            loss = model(latent)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch: {batch_idx}: Batch Loss: {loss.item()}")

            if batch_idx % 10000 == 0:
                get_extinct_prediction_percent(diffusion_model, guidance_scale=5.0)



        print(f"Epoch {epoch}, Average Loss: {epoch_loss / len(dataloader):.6f}")


Using device: cuda


  checkpoint = torch.load(model_path, map_location=device)


Accuracy: 0.91


In [None]:
# Train the diffusion model
importlib.reload(guided_diffusion.guided_diffusion_1d)
importlib.reload(lc)
importlib.reload(peptidevae.load_vae)

torch.cuda.empty_cache()
epochs = 1
print(f"Using device: {device}")

unet_dim = latent_dim # for now, matches latent dim, may change
unet_model = guided_diffusion.guided_diffusion_1d.Unet1D(
    dim = unet_dim,
    channels=1,
    dim_mults=(1, 2, 4, 8)
).to(device)

diffusion_model = guided_diffusion.guided_diffusion_1d.GaussianDiffusion1D(
    unet_model,
    seq_length=latent_dim,
    timesteps=1000,
    objective='pred_v'
).to(device)

classifier = classf.to(device)


Using device: cuda


sampling loop time step: 100%|██████████| 1000/1000 [00:13<00:00, 73.17it/s]

Prediction: 0.0





tensor([1.0000, 0.7466, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 1.0000,
        0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
        1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 1.0000, 1.0000, 0.0000,
        0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000,
        1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 1.0000,
        0.0000, 0.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000, 1.0000,
        0.0000, 1.0000, 0.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.0000, 0.0000,
        1.0000, 0.0000, 1.0000, 0.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000,
        0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000,
        1.0000, 0.0033, 1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 1.0000,
        1.0000, 0.4705, 0.0000, 1.0000, 

In [29]:
train_diffusion(diffusion_model, dataloader=data_loader, batch_size=train_batch_size, epochs=1)


Epoch 0:   0%|          | 1/642171 [00:00<107:29:56,  1.66it/s]

Epoch 0, Batch: 0: Batch Loss: 0.8120216131210327


Epoch 0:   0%|          | 72/642171 [00:06<16:37:05, 10.73it/s]


KeyboardInterrupt: 