In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
from tqdm import tqdm

from peptidevae.load_vae import load_vae, vae_decode, vae_forward
from importlib import reload



In [2]:
small_csv = "/Users/aldenrose/DiffusionProject/small_data.csv"
big_csv = "/Users/aldenrose/DiffusionProject/combined_data.csv"
VAE_PKL_LOCATION = "/Users/aldenrose/DiffusionProject/peptidevae/checkpoints/dim128_k1_kl0001_eff256_dff256_pious-sea-2_model_state_epoch_118.pkl"

In [3]:
UNET_DIM = 128

In [4]:
vae, dataobj = load_vae(VAE_PKL_LOCATION, dim=256, max_string_length=50)

In [61]:
from data import PeptideLatentDataset
import data
reload(data)
ds = PeptideLatentDataset(small_csv, vae, dataobj)

Encoding peptides: 100%|██████████| 1/1 [00:02<00:00,  2.04s/it]


In [1]:
latents = torch.tensor(pd.read_csv('latents.csv').to_numpy(dtype='float32'))
labels = torch.tensor(pd.read_csv('labels.csv').to_numpy(dtype='int8').squeeze(-1))
# for debugging

NameError: name 'torch' is not defined

In [5]:
df = pd.read_csv('combined_data.csv')

In [6]:
seq = df['sequence'].to_numpy()
labels = df['extinct'].to_numpy()

In [9]:
with torch.no_grad():
    latents, vae_loss = vae_forward(seq[:1000], dataobj, vae)

In [12]:
class EsmClassificationHead(nn.Module):
    # slightly modified from the original ESM classification head
    def __init__(self, input_dim=256):
        super().__init__()
        self.dense = nn.Linear(input_dim, 2048)
        self.dropout = nn.Dropout(0)
        self.dense2 = nn.Linear(2048, 2048)
        self.dense3 = nn.Linear(2048, 2048)
        self.out_proj = nn.Linear(2048, 2)
    
    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.dense3(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

In [33]:
def load_model(model_path='best_model.pt'):
    # Check for GPU availability
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load saved model data
    checkpoint = torch.load(model_path, map_location=device)
    
    # Initialize model with saved input dimension
    model = EsmClassificationHead(input_dim=256).to(device)
    model.load_state_dict(checkpoint)
    model.eval()
    
    return model, device

def predict_classifer(model, embeddings, device, batch_size=100):
    
    # Convert to torch tensor
    embeddings = torch.FloatTensor(embeddings)
    predictions = []
    all_logits = []
    all_probs = []
    
    with torch.inference_mode():
        
        for i in range(0, len(embeddings), batch_size):
            batch = embeddings[i:i + batch_size].to(device)
            logits = model(batch)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            all_logits.extend(logits.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            predictions.extend(preds.cpu().numpy())
    
    return (np.array(predictions), 
            np.array(all_logits), 
            np.array(all_probs))

In [36]:
classifier_model, device = load_model("/Users/aldenrose/DiffusionProject/classifer/best_model.pt")
embeddings = latents

# load labels for validation
_labels = labels[:1000]

# Get predictions
predictions, logits, probabilities = predict_classifer(classifier_model, embeddings, device)

accuracy = np.mean(predictions == _labels)

print(f"Accuracy: {accuracy:.2f}")

Using device: cpu
Accuracy: 0.97


In [22]:
from guided_diffusion.guided_diffusion_1d import Unet1D, GaussianDiffusion1D
seq_length = 256
model = Unet1D(dim=seq_length, channels=1, self_condition=False)

# input dummy tensors
batch_size = 1
channels = 1    # same as model channels
dummy_input = torch.randn(batch_size, channels, seq_length)

dummy_time = torch.randint(0, 1000, (batch_size,)).float()  # timesteps in [0, num_timesteps)

# check pass works
output = model(dummy_input, dummy_time)
print("UNet output shape:", output.shape)

diffusion_model = GaussianDiffusion1D(
    model=model,
    seq_length=seq_length,
    timesteps=500
)

loss = diffusion_model(dummy_input) 
print("Loss:", loss.item())

# dummy_classifier = Classifier(seq_length=seq_length, num_classes=2)
# dummy_labels = torch.randint(0, 1, (batch_size,))
# grad = classifier_cond_fn(dummy_input, dummy_time, dummy_classifier, dummy_labels, classifier_scale=1)


UNet output shape: torch.Size([1, 1, 256])
Loss: 1.0439468622207642


In [26]:
sampled = diffusion_model.sample(batch_size=4)

sampling loop time step: 100%|██████████| 500/500 [01:18<00:00,  6.38it/s]


In [42]:
sampled = sampled.reshape(-1, 256)

In [45]:
predictions, logits, probabilities = predict_classifer(classifier_model, sampled, device)

In [50]:
print(predictions)

[0 0 0 0]
