In [None]:
!pip install biopython

In [None]:
# --- 1. Setup & Dependencies ---
!pip install biopython

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.spectral_norm as spectral_norm
from torch.utils.data import Dataset, DataLoader
from scipy.spatial.distance import pdist, squareform
from skimage.metrics import structural_similarity as ssim

try:
    from Bio.PDB import PDBList, PDBParser
except ImportError:
    print("Biopython not correctly installed. Please rerun the cell.")

In [None]:
# --- 2. Data Preparation ---

def fetch_pdb_ids(max_results=1000):
    """
    Fetches a list of PDB IDs for protein structures from RCSB PDB API.
    Filters: Protein only, Resolution < 3.0A, Length 50-500 residues.
    """
    import requests
    import json
    
    print(f"Fetching list of up to {max_results} PDB IDs from RCSB...")
    
    query = {
        "query": {
            "type": "group",
            "logical_operator": "and",
            "nodes": [
                {
                    "type": "terminal",
                    "service": "text",
                    "parameters": {
                        "attribute": "rcsb_entry_info.selected_polymer_entity_types",
                        "operator": "exact_match",
                        "value": "Protein (only)"
                    }
                },
                {
                    "type": "terminal",
                    "service": "text",
                    "parameters": {
                        "attribute": "rcsb_entry_info.resolution_combined",
                        "operator": "less",
                        "value": 3.0
                    }
                },
                {
                    "type": "terminal",
                    "service": "text",
                    "parameters": {
                        "attribute": "entity_poly.rcsb_sample_sequence_length",
                        "operator": "range",
                        "value": {"from": 60, "to": 200} 
                    }
                }
            ]
        },
        "request_options": {
            "return_all_hits": True
        },
        "return_type": "entry"
    }
    
    url = "https://search.rcsb.org/rcsbsearch/v2/query"
    try:
        response = requests.post(url, json=query)
        if response.status_code == 200:
            data = response.json()
            result_set = data.get("result_set", [])
            
            # FIX: Extract 'identifier' if the API returns dictionaries
            ids = []
            for item in result_set:
                if isinstance(item, dict):
                    ids.append(item.get('identifier'))
                else:
                    ids.append(item)
            
            # Filter out any None values just in case
            ids = [x for x in ids if x]
            
            print(f"Found {len(ids)} potential structures.")
            return ids[:max_results]
        else:
            print(f"Failed to query RCSB: {response.status_code}")
            return []
    except Exception as e:
        print(f"Error checking RCSB: {e}")
        return ['1AIE', '1B7G', '1D0D', '6VSB'] # Fallback

def download_pdb_data(pdb_ids, download_dir="pdb_data"):
    """Downloads PDB files."""
    os.makedirs(download_dir, exist_ok=True)
    pdbl = PDBList()
    
    print(f"Downloading {len(pdb_ids)} proteins to {download_dir}...")
    for i, pdb_id in enumerate(pdb_ids):
        # Check if already exists (as .pdb or .ent)
        final_path = os.path.join(download_dir, f"{pdb_id}.pdb")
        ent_path_upper = os.path.join(download_dir, f"{pdb_id}.ent")
        ent_path_lower = os.path.join(download_dir, f"pdb{pdb_id.lower()}.ent")
        
        if os.path.exists(final_path) or os.path.exists(ent_path_upper) or os.path.exists(ent_path_lower):
            continue
            
        try:
            pdbl.retrieve_pdb_file(pdb_id, pdir=download_dir, file_format="pdb")
            # Rename ent to pdb if possible, but keep .ent is fine if logic handles it
            if os.path.exists(ent_path_lower):
                os.rename(ent_path_lower, final_path)
        except Exception:
            continue
            
    print("Download complete.")
    # Debug: Check file count
    files = os.listdir(download_dir)
    print(f"DEBUG: {len(files)} files found in {download_dir}. First 5: {files[:5]}")

def get_ca_coordinates(pdb_file):
    """Extracts Alpha-Carbon coordinates."""
    parser = PDBParser(QUIET=True)
    try:
        structure = parser.get_structure('protein', pdb_file)
    except Exception:
        return np.array([])
        
    ca_coords = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if 'CA' in residue:
                    ca_coords.append(residue['CA'].get_coord())
        break 
    return np.array(ca_coords)

def get_contact_map(coords, threshold=8.0, size=64):
    """Generates a binary contact map from coordinates."""
    if len(coords) < 10: return torch.zeros((1, size, size))
    
    dist_matrix = squareform(pdist(coords))
    contact_map = (dist_matrix < threshold).astype(float)
    
    result = np.zeros((size, size))
    m, n = contact_map.shape[:2]
    h, w = min(m, size), min(n, size)
    result[:h, :w] = contact_map[:h, :w]
    
    return torch.tensor(result, dtype=torch.float32).unsqueeze(0)

class PDBContactMapDataset(Dataset):
    def __init__(self, pdb_dir, size=64):
        # Support both .pdb and .ent (common download format)
        self.pdb_files = [
            os.path.join(pdb_dir, f) for f in os.listdir(pdb_dir) 
            if f.endswith('.pdb') or f.endswith('.ent')
        ]
        self.size = size
        
    def __len__(self):
        return len(self.pdb_files)
        
    def __getitem__(self, idx):
        try:
            coords = get_ca_coordinates(self.pdb_files[idx])
            return get_contact_map(coords, size=self.size)
        except Exception:
            return torch.zeros((1, self.size, self.size))

In [None]:
# --- 3. High-Capacity Model Architectures ---

class ResNetBlock(nn.Module):
    def __init__(self, channels):
        super(ResNetBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels)
        )
    def forward(self, x):
        return x + self.conv(x)

class ResNetVAE(nn.Module):
    def __init__(self, latent_dim=512):
        super(ResNetVAE, self).__init__()
        
        # Encoder: Deep and Wide
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1), # 32x32
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 128, 4, 2, 1), # 16x16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            ResNetBlock(128),
            
            nn.Conv2d(128, 256, 4, 2, 1), # 8x8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            ResNetBlock(256),
            
            nn.Conv2d(256, 512, 4, 2, 1), # 4x4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            ResNetBlock(512),
            
            nn.Conv2d(512, 1024, 4, 1, 0), # 1x1
            nn.Flatten()
        )
        
        self.fc_mu = nn.Linear(1024, latent_dim)
        self.fc_logvar = nn.Linear(1024, latent_dim)
        
        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 1024)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (1024, 1, 1)),
            nn.ConvTranspose2d(1024, 512, 4, 1, 0), # 4x4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            ResNetBlock(512),
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1), # 8x8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(256, 128, 4, 2, 1), # 16x16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # 32x32
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            
            nn.ConvTranspose2d(64, 1, 4, 2, 1), # 64x64
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        encoded = self.encoder(x)
        mu, logvar = self.fc_mu(encoded), self.fc_logvar(encoded)
        z = self.reparameterize(mu, logvar)
        return self.decoder(self.decoder_input(z)), mu, logvar

class DeepDiscriminator(nn.Module):
    def __init__(self):
        super(DeepDiscriminator, self).__init__()
        
        def sn_conv(in_c, out_c):
            return spectral_norm(nn.Conv2d(in_c, out_c, 4, 2, 1))

        self.model = nn.Sequential(
            sn_conv(1, 64),
            nn.LeakyReLU(0.2, inplace=True),
            
            sn_conv(64, 128),
            nn.LeakyReLU(0.2, inplace=True),
            
            sn_conv(128, 256),
            nn.LeakyReLU(0.2, inplace=True),
            
            sn_conv(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            
            sn_conv(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Flatten(),
            nn.Linear(1024 * 2 * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# --- 4. Training Functions ---

def calculate_metrics(real_batch, recon_batch):
    real = real_batch.detach().cpu().numpy().squeeze()
    recon = recon_batch.detach().cpu().numpy().squeeze()
    scores = []
    if len(real.shape) == 2:
        return ssim(real, recon, data_range=1.0)
    for i in range(len(real)):
        scores.append(ssim(real[i], recon[i], data_range=1.0))
    return np.mean(scores)

def train_vae_phase(vae, loader, epochs=50, device='cuda'):
    print("\n[Phase 1] Training VAE...")
    optimizer = optim.Adam(vae.parameters(), lr=1e-4)
    vae.to(device)
    vae.train()
    
    for epoch in range(epochs):
        epoch_loss = 0
        total_ssim = 0
        for batch in loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            recon, mu, logvar = vae(batch)
            bce = nn.functional.binary_cross_entropy(recon, batch, reduction='sum')
            kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss = bce + kld
            
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            if np.random.rand() < 0.1:
                total_ssim += calculate_metrics(batch, recon)
            
        print(f"Epoch {epoch+1}/{epochs} - VAE Loss: {epoch_loss/len(loader.dataset):.2f}")
    return vae

def train_gan_phase(generator, discriminator, loader, epochs=50, device='cuda', latent_dim=512):
    print("\n[Phase 2] Training Pure GAN...")
    generator.to(device); discriminator.to(device)
    opt_g = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    opt_d = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    criterion = nn.BCELoss()
    
    for epoch in range(epochs):
        for batch in loader:
            batch = batch.to(device)
            b_size = batch.size(0)
            
            # Train Disc
            opt_d.zero_grad()
            real_labels = torch.ones(b_size, 1).to(device)
            fake_labels = torch.zeros(b_size, 1).to(device)
            
            d_real_loss = criterion(discriminator(batch), real_labels)
            
            z = torch.randn(b_size, latent_dim).to(device)
            # Use VAE decoder part as generic generator input
            # If data parallel, access module
            if isinstance(generator, nn.DataParallel):
                fake_imgs = generator.module.decoder(generator.module.decoder_input(z))
            else:
                fake_imgs = generator.decoder(generator.decoder_input(z))
            
            d_fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            opt_d.step()
            
            # Train Gen
            opt_g.zero_grad()
            g_loss = criterion(discriminator(fake_imgs), real_labels)
            g_loss.backward()
            opt_g.step()
        
        print(f"Epoch {epoch+1}/{epochs} - D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

def train_vaegan_phase(vae, discriminator, loader, epochs=100, device='cuda', gamma=20.0):
    print("\n[Phase 3] Training VAE-GAN...")
    vae.to(device); discriminator.to(device)
    
    opt_vae = optim.Adam(vae.parameters(), lr=1e-4)
    opt_disc = optim.Adam(discriminator.parameters(), lr=1e-5)
    criterion = nn.BCELoss()
    
    for epoch in range(epochs):
        for batch in loader:
            real_imgs = batch.to(device)
            b_size = real_imgs.size(0)
            
            # --- Train Discriminator ---
            opt_disc.zero_grad()
            valid = torch.ones(b_size, 1).to(device)
            fake = torch.zeros(b_size, 1).to(device)
            
            d_real_loss = criterion(discriminator(real_imgs), valid)
            
            recon, _, _ = vae(real_imgs)
            d_recon_loss = criterion(discriminator(recon.detach()), fake)
            
            z = torch.randn(b_size, 512).to(device)
            if isinstance(vae, nn.DataParallel):
                gen_imgs = vae.module.decoder(vae.module.decoder_input(z))
            else:
                gen_imgs = vae.decoder(vae.decoder_input(z))
                
            d_gen_loss = criterion(discriminator(gen_imgs.detach()), fake)
            
            d_loss = d_real_loss + 0.5 * (d_recon_loss + d_gen_loss)
            d_loss.backward()
            opt_disc.step()
            
            # --- Train VAE ---
            opt_vae.zero_grad()
            
            recon, mu, logvar = vae(real_imgs)
            bce = nn.functional.binary_cross_entropy(recon, real_imgs, reduction='sum')
            kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            
            g_adv_loss = criterion(discriminator(recon), valid)
            
            vae_loss = bce + kld + (gamma * g_adv_loss)
            vae_loss.backward()
            opt_vae.step()

        print(f"Epoch {epoch+1}/{epochs} - VAE Loss: {vae_loss.item():.2f} | D Loss: {d_loss.item():.4f}")

In [None]:
# --- 5. Initialization & Setup ---

# Configuration
BATCH_SIZE = 32 # Suitable for 25GB VRAM (2x T4)
LATENT_DIM = 512

# Fetch Data
pdb_ids = fetch_pdb_ids(max_results=500)
if not pdb_ids: pdb_ids = ['1AIE', '1B7G', '1D0D', '6VSB']
download_pdb_data(pdb_ids)

# Create Loader
dataset = PDBContactMapDataset("pdb_data")
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
print(f"Dataset loaded: {len(dataset)} proteins.")

# Initialize Models & Hardware
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = ResNetVAE(latent_dim=LATENT_DIM)
disc = DeepDiscriminator()

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    vae = nn.DataParallel(vae)
    disc = nn.DataParallel(disc)

vae.to(device)
disc.to(device)
print("Models initialized.")

In [None]:
def save_model(vae, discriminator=None, phase_name="checkpoint", save_dir="checkpoints"):
    """
    Saves model weights to disk. 
    Handles DataParallel wrappers automatically.
    """
    import os
    os.makedirs(save_dir, exist_ok=True)
    
    # Save VAE
    if vae:
        # standardizing to single GPU state dict
        vae_state = vae.module.state_dict() if isinstance(vae, nn.DataParallel) else vae.state_dict()
        torch.save(vae_state, os.path.join(save_dir, f"vae_{phase_name}.pth"))
        
    # Save Discriminator
    if discriminator:
        disc_state = discriminator.module.state_dict() if isinstance(discriminator, nn.DataParallel) else discriminator.state_dict()
        torch.save(disc_state, os.path.join(save_dir, f"disc_{phase_name}.pth"))
        
    print(f"Saved models for '{phase_name}' to {save_dir}/")
def load_model(vae, discriminator=None, phase_name="checkpoint", save_dir="checkpoints", device='cuda'):
    """
    Loads model weights from disk.
    Safe to use even if models are wrapped in DataParallel.
    """
    import os
    vae_path = os.path.join(save_dir, f"vae_{phase_name}.pth")
    disc_path = os.path.join(save_dir, f"disc_{phase_name}.pth")
    
    # Load VAE
    if vae and os.path.exists(vae_path):
        # map_location ensures we can load CUDA models on CPU if needed (or vice versa)
        state_dict = torch.load(vae_path, map_location=device)
        if isinstance(vae, nn.DataParallel):
            vae.module.load_state_dict(state_dict)
        else:
            vae.load_state_dict(state_dict)
        print(f"Loaded VAE from {vae_path}")
    elif vae:
        print(f"Warning: VAE checkpoint not found at {vae_path}")
        
    # Load Discriminator
    if discriminator and os.path.exists(disc_path):
        state_dict = torch.load(disc_path, map_location=device)
        if isinstance(discriminator, nn.DataParallel):
            discriminator.module.load_state_dict(state_dict)
        else:
            discriminator.load_state_dict(state_dict)
        print(f"Loaded Discriminator from {disc_path}")
    elif discriminator:
        print(f"Warning: Discriminator checkpoint not found at {disc_path}")
    
    # Ensure on device
    if vae: vae.to(device)
    if discriminator: discriminator.to(device)
print("Checkpoint utilities initialized.")

In [None]:
# --- Visualization Functions ---

def visualize_results(model, loader, device='cuda', num_samples=5, title="Model Results"):
    """
    Visualizes:
    1. Real Contact Maps (from the dataset)
    2. Reconstructed Maps (what the VAE thinks the input is)
    3. Generated Maps (random proteins from latent space)
    """
    model.eval()
    
    # 1. Get Real Images
    data = next(iter(loader))
    real_imgs = data[:num_samples].to(device)
    
    with torch.no_grad():
        # 2. Get Reconstructions (only valid for VAE/VAE-GAN phases)
        try:
            recon_imgs, _, _ = model(real_imgs)
            show_recon = True
        except:
            # If the model is in pure GAN mode or doesn't return tuple
            show_recon = False

        # 3. Generate New Samples (Latent Walk)
        z = torch.randn(num_samples, 512).to(device)
        
        # Handle DataParallel (if using multiple GPUs)
        if isinstance(model, nn.DataParallel):
            inner_gen = model.module
        else:
            inner_gen = model
            
        # Manually decode from latent Z
        gen_imgs = inner_gen.decoder(inner_gen.decoder_input(z))
    
    # Move to CPU for plotting
    real_imgs = real_imgs.cpu().numpy().squeeze()
    gen_imgs = gen_imgs.cpu().numpy().squeeze()
    if show_recon:
        recon_imgs = recon_imgs.cpu().numpy().squeeze()
        
    # Plotting
    rows = 3 if show_recon else 2
    fig, axes = plt.subplots(rows, num_samples, figsize=(num_samples * 3, rows * 3))
    plt.suptitle(title, fontsize=16)
    
    for i in range(num_samples):
        # Plot Real
        ax_real = axes[0, i] if rows > 1 else axes[i]
        ax_real.imshow(real_imgs[i], cmap='viridis', origin='lower')
        if i == 0: ax_real.set_ylabel("Real", fontsize=14, fontweight='bold')
        ax_real.axis('off')
        
        # Plot Reconstruction
        if show_recon:
            ax_recon = axes[1, i]
            ax_recon.imshow(recon_imgs[i], cmap='viridis', origin='lower')
            if i == 0: ax_recon.set_ylabel("Reconstructed", fontsize=14, fontweight='bold')
            ax_recon.axis('off')
            
        # Plot Generated
        row_idx = 2 if show_recon else 1
        ax_gen = axes[row_idx, i] if rows > 1 else axes[i]
        ax_gen.imshow(gen_imgs[i], cmap='viridis', origin='lower')
        if i == 0: ax_gen.set_ylabel("Generated", fontsize=14, fontweight='bold')
        ax_gen.axis('off')

    plt.tight_layout()
    plt.show()
    model.train() # Switch back to train mode

In [None]:
train_vae_phase(vae, train_loader, epochs=50, device=device)

In [None]:
save_model(vae, discriminator=None, phase_name="phase1_vae")
visualize_results(vae, train_loader, device=device, title="Result: Phase 1 (VAE)")

In [None]:
train_gan_phase(vae, disc, train_loader, epochs=50, device=device, latent_dim=LATENT_DIM)
save_model(vae, disc, "phase2_gan")
visualize_results(vae, train_loader, device=device, title="Result: Phase 2 (GAN)")

In [None]:
# --- 8. Execution: Phase 3 (VAE-GAN Training) ---
# Uncomment to run
train_vaegan_phase(vae, disc, train_loader, epochs=100, device=device, gamma=20.0)
save_model(vae, disc, "phase3_vaegan")
visualize_results(vae, train_loader, device=device, title="Result: Phase 3 (VAE-GAN)")

In [None]:
print("hello")