# Protein Structure Generation with Diffusion Models (DDPM)

This notebook implements a **Denoising Diffusion Probabilistic Model (DDPM)** to generate protein contact maps.

**Key Components**:
1.  **Robust Data Pipeline**: Uses the same PDB fetching and contact map generation as our VAE-GAN work.
2.  **U-Net Architecture**: A deep residual U-Net with time embeddings to predict noise at each timestep.
3.  **Diffusion Process**: Forward process (adding noise) and Reverse process (learning to denoiose).

**Hardware**: Optimized for **25GB+ VRAM** (e.g., 2x T4 GPUs in Colab).

In [None]:
# --- 1. Setup & Dependencies ---
!pip install biopython

import os
import time
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from scipy.spatial.distance import pdist, squareform

try:
    from Bio.PDB import PDBList, PDBParser
except ImportError:
    print("Biopython not correctly installed. Please rerun the cell.")

In [None]:
# --- 2. Data Preparation (Identical to VAE-GAN) ---

def fetch_pdb_ids(max_results=1000):
    """
    Fetches a list of PDB IDs for protein structures from RCSB PDB API.
    Filters: Protein only, Resolution < 3.0A, Length 50-500 residues.
    """
    import requests
    import json
    
    print(f"Fetching list of up to {max_results} PDB IDs from RCSB...")
    
    query = {
        "query": {
            "type": "group",
            "logical_operator": "and",
            "nodes": [
                {
                    "type": "terminal",
                    "service": "text",
                    "parameters": {
                        "attribute": "rcsb_entry_info.selected_polymer_entity_types",
                        "operator": "exact_match",
                        "value": "Protein (only)"
                    }
                },
                {
                    "type": "terminal",
                    "service": "text",
                    "parameters": {
                        "attribute": "rcsb_entry_info.resolution_combined",
                        "operator": "less",
                        "value": 3.0
                    }
                },
                {
                    "type": "terminal",
                    "service": "text",
                    "parameters": {
                        "attribute": "entity_poly.rcsb_sample_sequence_length",
                        "operator": "range",
                        "value": {"from": 60, "to": 200} 
                    }
                }
            ]
        },
        "request_options": {
            "return_all_hits": True
        },
        "return_type": "entry"
    }
    
    url = "https://search.rcsb.org/rcsbsearch/v2/query"
    try:
        response = requests.post(url, json=query)
        if response.status_code == 200:
            data = response.json()
            # Extract identifiers from the result set
            result_set = data.get("result_set", [])
            ids = []
            for item in result_set:
                if isinstance(item, dict) and 'identifier' in item:
                    ids.append(item['identifier'])
                elif isinstance(item, str):
                    ids.append(item)
            
            print(f"Found {len(ids)} potential structures.")
            return ids[:max_results]
        else:
            print(f"Failed to query RCSB: {response.status_code}")
            return []
    except Exception as e:
        print(f"Error checking RCSB: {e}")
        return ['1AIE', '1B7G', '1D0D', '6VSB'] # Fallback

def download_pdb_data(pdb_ids, download_dir="pdb_data"):
    """Downloads PDB files."""
    os.makedirs(download_dir, exist_ok=True)
    pdbl = PDBList()
    
    print(f"Downloading {len(pdb_ids)} proteins to {download_dir}...")
    for i, pdb_id in enumerate(pdb_ids):
        final_path = os.path.join(download_dir, f"{pdb_id}.pdb")
        ent_path_upper = os.path.join(download_dir, f"{pdb_id}.ent")
        ent_path_lower = os.path.join(download_dir, f"pdb{pdb_id.lower()}.ent")
        
        if os.path.exists(final_path) or os.path.exists(ent_path_upper) or os.path.exists(ent_path_lower):
            continue
            
        try:
            pdbl.retrieve_pdb_file(pdb_id, pdir=download_dir, file_format="pdb")
            if os.path.exists(ent_path_lower):
                os.rename(ent_path_lower, final_path)
        except Exception:
            continue
            
    print("Download complete.")

def get_ca_coordinates(pdb_file):
    """Extracts Alpha-Carbon coordinates."""
    parser = PDBParser(QUIET=True)
    try:
        structure = parser.get_structure('protein', pdb_file)
    except Exception:
        return np.array([])
        
    ca_coords = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if 'CA' in residue:
                    ca_coords.append(residue['CA'].get_coord())
        break 
    return np.array(ca_coords)

def get_contact_map(coords, threshold=8.0, size=64):
    """Generates a binary contact map from coordinates."""
    if len(coords) < 10: return torch.zeros((1, size, size))
    
    dist_matrix = squareform(pdist(coords))
    contact_map = (dist_matrix < threshold).astype(float)
    
    # IMPORTANT: Diffusion models work best with input in range [-1, 1]
    # Standard contact map is [0, 1]. We will scale it later or during dataset getitem.
    # For simplicity, let's keep it [0, 1] here and Normalize in transform if needed.
    
    result = np.zeros((size, size))
    m, n = contact_map.shape[:2]
    h, w = min(m, size), min(n, size)
    result[:h, :w] = contact_map[:h, :w]
    
    return torch.tensor(result, dtype=torch.float32).unsqueeze(0)

class PDBContactMapDataset(Dataset):
    def __init__(self, pdb_dir, size=64):
        self.pdb_files = [
            os.path.join(pdb_dir, f) for f in os.listdir(pdb_dir) 
            if f.endswith('.pdb') or f.endswith('.ent')
        ]
        self.size = size
        
    def __len__(self):
        return len(self.pdb_files)
        
    def __getitem__(self, idx):
        try:
            coords = get_ca_coordinates(self.pdb_files[idx])
            t = get_contact_map(coords, size=self.size)
            # Scale to [-1, 1] for Diffusion
            return t * 2.0 - 1.0 
        except Exception:
            return torch.zeros((1, self.size, self.size))

In [None]:
# --- 3. Diffusion Model Architecture (U-Net) ---
# Based on standard DDPM implementations

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

    def forward(self, x, t):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time Embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

class SimpleUnet(nn.Module):
    """
    A simplified U-Net architecture for Diffusion.
    """
    def __init__(self):
        super().__init__()
        image_channels = 1
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1 
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], time_emb_dim) for i in range(len(down_channels)-1)])
        
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True) for i in range(len(up_channels)-1)])
        
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embed time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Residual connections
        residuals = []
        for down in self.downs:
            x = down(x, t)
            residuals.append(x)
        for up in self.ups:
            residual = residuals.pop()
            # Concatenate
            x = torch.cat((x, residual), dim=1)
            x = up(x, t)
        return self.output(x)


In [None]:
# --- 4. Models & Diffusion Logic ---

# Hyperparameters
T = 300 # Timesteps (keep low for speed in dry run)
beta_start = 0.0001
beta_end = 0.02
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define Beta Schedule (Linear)
betas = torch.linspace(beta_start, beta_end, T).to(device)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def forward_diffusion_sample(x_0, t, device="cuda"):
    """ 
    Takes an image and a timestep t and returns the noisy version of it at t
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t][:, None, None, None]
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
    
    # Reparameterization trick: mean + std * noise
    return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise, noise

@torch.no_grad()
def sample_plot_image(model, img_size=64, device="cuda"):
    # Sample noise
    img = torch.randn((1, 1, img_size, img_size), device=device)
    model.eval()
    
    for i in range(0, T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img_pred_noise = model(img, t)
        
        # Sampling algo
        alpha = alphas[i]
        alpha_hat = alphas_cumprod[i]
        beta = betas[i]
        if i > 1:
            noise = torch.randn_like(img)
        else:
            noise = torch.zeros_like(img)
            
        img = 1 / torch.sqrt(alpha) * (img - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * img_pred_noise) + torch.sqrt(beta) * noise
        
    model.train()
    # Clip to -1, 1
    img = torch.clamp(img, -1.0, 1.0)
    # Convert to 0, 1 for plot
    img = (img + 1) / 2
    
    plt.imshow(img.cpu().squeeze().numpy(), cmap='binary')
    plt.title("Generated Protein Contact Map")
    plt.show()

def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion_sample(x_0, t, device)
    noise_pred = model(x_noisy, t)
    return F.l1_loss(noise, noise_pred) # L1 loss is often better for diffusion

In [None]:
# --- 5. Training Loop ---

def train_diffusion(model, loader, epochs=100):
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    model.to(device)
    
    for epoch in range(epochs):
        epoch_loss = 0
        for batch in loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # Sample random timesteps
            t = torch.randint(0, T, (batch.shape[0],), device=device).long()
            
            loss = get_loss(model, batch, t)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs} | Loss: {epoch_loss/len(loader):.4f}")
        if (epoch + 1) % 10 == 0:
            sample_plot_image(model)

In [None]:
# --- 6. Execution ---

# 1. Fetch & Download Data
pdb_ids = fetch_pdb_ids(max_results=1000)
if not pdb_ids: 
    pdb_ids = ['1AIE', '1B7G', '1D0D', '6VSB']
download_pdb_data(pdb_ids)

# 2. Loader
BATCH_SIZE = 32
dataset = PDBContactMapDataset("pdb_data")
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
print(f"Loaded {len(dataset)} items.")

# 3. Model
model = SimpleUnet()
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)
model.to(device)

# 4. Train
# Uncomment to start training
# train_diffusion(model, loader, epochs=200)

In [None]:
train_diffusion(model, loader, epochs=200)

In [None]:
# --- 7. Evaluation, Visualization & Storage ---

# 1. Save the Model
print("Saving model to 'ddpm_protein_model.pth'...")
if isinstance(model, nn.DataParallel):
    torch.save(model.module.state_dict(), "ddpm_protein_model.pth")
else:
    torch.save(model.state_dict(), "ddpm_protein_model.pth")
print("Model saved successfully.")

# 2. Evaluation Helper Function
@torch.no_grad()
def generate_and_compare(model, dataset, num_samples=5, device="cuda"):
    model.eval()
    
    # --- A. Get Real Samples ---
    real_samples = []
    idxs = np.random.choice(len(dataset), num_samples, replace=False)
    for i in idxs:
        real_samples.append(dataset[i])
    real_samples = torch.stack(real_samples).to(device)
    
    # --- B. Generate Fake Samples (Reverse Diffusion) ---
    print(f"Generating {num_samples} samples from noise...")
    # Start from pure noise
    img = torch.randn((num_samples, 1, 64, 64), device=device)
    
    # Iteratively denoise
    for i in range(0, T)[::-1]:
        t = torch.full((num_samples,), i, device=device, dtype=torch.long)
        img_pred_noise = model(img, t)
        
        alpha = alphas[i]
        alpha_hat = alphas_cumprod[i]
        beta = betas[i]
        
        if i > 1:
            noise = torch.randn_like(img)
        else:
            noise = torch.zeros_like(img)
            
        img = 1 / torch.sqrt(alpha) * (img - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * img_pred_noise) + torch.sqrt(beta) * noise
        
    # Scale back to [0, 1] for plotting (since model output is [-1, 1])
    fake_samples = (img.clamp(-1, 1) + 1) / 2
    real_samples = (real_samples + 1) / 2 # Assuming dataset output was also [-1, 1]

    # --- C. Plot Comparison ---
    fig, axes = plt.subplots(2, num_samples, figsize=(3*num_samples, 6))
    
    for i in range(num_samples):
        # Plot Real
        axes[0, i].imshow(real_samples[i].cpu().squeeze(), cmap='binary')
        axes[0, i].axis('off')
        if i == 0: axes[0, i].set_title("Real Data (Ground Truth)", fontsize=14, pad=10)
        
        # Plot Generated
        axes[1, i].imshow(fake_samples[i].cpu().squeeze(), cmap='binary')
        axes[1, i].axis('off')
        if i == 0: axes[1, i].set_title("Generated (Diffusion)", fontsize=14, pad=10)
        
    plt.tight_layout()
    plt.show()
    model.train()

# 3. Run Comparison
generate_and_compare(model, dataset, num_samples=5, device=device)