In [1]:
import lightning
import data
import yaml   
import argparse
import os 
from torch_geometric.loader import DataLoader
import torch
import numpy as np

import matplotlib.pyplot as plt


def draw_3d(crd,sselist,chainlist):
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(projection='3d')
    draw_point=[]
    count_list=[]
    helix_list=[]
    strand_list=[]
    tmp=0
    # helix red
    # strand blue 
    # loop black
    y=[i for i in range(0,len(sselist))]
    ax.scatter(crd[:,0],crd[:,1],crd[:,2],c=y,cmap='gist_rainbow')
    for  i in range(0,len(sselist)):
        ax.text(crd[i,0],crd[i,1],crd[i,2],  '%s' % (str(i)), size=10, zorder=1,  color='k') 
    crd_ = crd.reshape(int(len(sselist)/3),3,3)
    ax.plot(crd[:,0],crd[:,1],crd[:,2],c='r')
    for i in range(0,len(crd_)):
        sse = crd_[i]
        if sselist[3*i] =='H':
            if chainlist[3*i]==0.0:
                ax.plot(sse[:,0],sse[:,1],sse[:,2],c='r')
            else:
                ax.plot(sse[:,0],sse[:,1],sse[:,2],c='m')
        else:
            if chainlist[3*i]==0.0:
                ax.plot(sse[:,0],sse[:,1],sse[:,2],c='b')
            else:
                ax.plot(sse[:,0],sse[:,1],sse[:,2],c='c')
        
    for i in range(0,len(crd_)-1):
        sse1=crd_[i]
        sse2=crd_[i+1]
        ax.plot([sse1[-1,0],sse2[0,0]],[sse1[-1,1],sse2[0,1]],[sse1[-1,2],sse2[0,2]],c='k')





#example for unconditional single chain generation 


In [4]:
## setup argument parser and load model
## modify the ckpt path and data path accordingly
## generate by the secondary structure strings from test dataset

parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', action='store', type=str,default= './ckpt/monomer.ckpt',)
parser.add_argument('--data', action='store', type=str, default='/work/lpdi/users/ymiao/TopoFlow/filtcathgraphlist_simple.pt')
parser.add_argument('--output', action='store', type=str, default='./example/output/uncondition_random_sample.npy')
parser.add_argument('--keep_frames', action='store', type=int, default=2)
parser.add_argument('--device', action='store', type=str,default='cuda')
args = parser.parse_args(args=[])
model = lightning.DDPM.load_from_checkpoint(args.checkpoint,map_location=args.device,data_path=args.data,torch_device=args.device)
device = torch.device(args.device)
model.ddpm.in_node_nf=11
chain0_list=[]
batch_list= []
test_dataset= torch.load('/work/lpdi/users/ymiao/TopoFlow/pinder_cluster_test.pt') # download here https://drive.google.com/drive/folders/1PCb00AA71NEmlkjWM-0XOl_HE1sOZiKc?usp=drive_link
testloader= DataLoader(test_dataset, batch_size= 16,shuffle=True, num_workers=10,pin_memory=True,prefetch_factor=2)
i=0
for data in testloader:
    i+=1
    data=data.to(args.device)
    chain,chain0=model.ddpm.sample_chain_gvp(data.pos,data.x[:,0:3],data.batch,keep_frames=2)
    chain0_list.append(chain0.cpu())
    batch_list.append(data.batch.cpu())
    if i >2:
        break
chain0_total = torch.cat(chain0_list)
offset = 0
adjusted_batch_list = []
for batch in batch_list:
    adjusted_batch_list.append(batch + offset)  # Increment batch indices
    offset += batch.max().item() + 1  # Update offset for the next tensor
batch_total = torch.cat(adjusted_batch_list)
chain0_filt=[]
batch_filt=[]
for num in torch.unique(batch_total.cpu()):
    looplen=[]
    test=chain0_total[batch_total.cpu()==num].cpu()
    chain_B = test[:,0:3]
    for i in range(0,int(len(chain_B)/3)-1):
        looplen.append(torch.linalg.norm(chain_B[3*i+2]-chain_B[3*(i+1)]))
    if max(looplen)<15: #filter out samples with large loops
        chain0_filt.append(test)
        batch_filt.append(batch_total[batch_total.cpu()==num].cpu())
batch_filt = torch.cat(batch_filt)
chain0_filt = torch.cat(chain0_filt)

save_topo =torch.cat([chain0_filt,batch_filt.unsqueeze(1)],dim=1).numpy()
np.save('./example/output/uncondition_random_sample.npy',save_topo)

#then use build_sketch.py to create sketch and downstream scripts
#python ./buildpdb_from_sketch_dimer.py --input /work/lpdi/users/ymiao/code/DiffTopo/example/output/uncondition_random_sample.npy --output /work/lpdi/users/ymiao/code/DiffTopo/example/output/outsketch

#example for specify string single chain generation 

In [None]:
import torch
from torch_geometric.data import Data, Batch

# Create a single protein structure (no padding needed!)
def create_fake_protein(ss_sequence='EEHEHEE'):
    """Create a single fake protein structure"""
    length = len(ss_sequence)*3 # 3 atoms per residue
    ss = []
    for i in ss_sequence:
        if i == 'E':
            ss.extend([[0., 1.]] * 3)  # 3 atoms per residue
        else:
            ss.extend([[1., 0.]] * 3)
    
    h = torch.tensor(ss)  # Secondary structure features
    x = torch.randn(length, 3)     # 3D coordinates
    
    # Create PyG Data object
    data = Data(
        x=x,           # Node positions [N, 3]
        h=h,           # Node features [N, 2]
        num_nodes=length
    )
    return data

# Create multiple proteins with DIFFERENT lengths (no padding!)
fake_proteins = []
chain0_list = []
batch_list = []
data = create_fake_protein('EEHEHEE')
for i in range(10):
    fake_proteins.append(data)
# Batch them together - PyG handles different sizes automatically!
batchdata = Batch.from_data_list(fake_proteins).to(device)
data=batchdata.to(args.device)
chain,chain0=model.ddpm.sample_chain_gvp(data.pos,data.x[:,0:3],data.batch,keep_frames=2)
chain0_list.append(chain0.cpu())
batch_list.append(data.batch.cpu())
save_topo =torch.cat([chain0_list,batch_list.unsqueeze(1)],dim=1).numpy()
np.save('./example/output/EEHEHEE_random_sample.npy',save_topo)
#then use build_sketch.py to create sketch and downstream scripts
#python ./buildpdb_from_sketch_dimer.py --input /work/lpdi/users/ymiao/code/DiffTopo/example/output/EEHEHEE_random_sample.npy --output /work/lpdi/users/ymiao/code/DiffTopo/example/output/outsketch