In [1]:
from torch import optim
from torch.optim import Adam
from tqdm import tqdm

from utils.data import read_domain_ids_per_chain_from_txt
from utils.dataset import *
from diffusion_model.embed import *
from diffusion_model.structure_diffusion_model import *

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPOCH = 100
LEARNING_RATE = 1e-3

In [3]:
train_pdbs, train_pdb_chains = read_domain_ids_per_chain_from_txt('./data/train_domains.txt')
test_pdbs, test_pdb_chains = read_domain_ids_per_chain_from_txt('./data/test_domains.txt')

In [4]:
train_loader = BackboneCoordsDataLoader(train_pdb_chains, "./data/train_backbone_coords_20.npy", "./data/train_data_res_20.npy",seq_length=20, batch_size=128, shuffle=True)
test_loader = BackboneCoordsDataLoader(test_pdb_chains, './data/test_backbone_coords_20.npy', './data/test_data_res_20.npy', seq_length=20, batch_size=128, shuffle=True)

In [5]:
diffusion = ProteinDiffusion(device=DEVICE)
model = StructureModel().to(DEVICE)
optimizer = optim.Adam(params=model.parameters(), lr = 1e-3)

In [7]:
model.train()
train_loss = 0
for batch_idx, (pdb, res_label, atom_coords) in enumerate(tqdm(train_loader, leave=False)):
    # Data preparation
    atom_coords = atom_coords.to(DEVICE)
    n_coords = atom_coords[:, :, 0]
    ca_coords = atom_coords[:, :, 1]
    c_coords = atom_coords[:, :, 2]
    R, t = rigidFrom3Points(n_coords, ca_coords, c_coords)
    q_0 = roma.rotmat_to_unitquat(R)
    single_repr = get_single_representation(pdb, res_label).to(DEVICE)
    print(single_repr.shape)
    break
    pair_repr = torch.cdist(ca_coords, ca_coords, p=2).to(torch.float32)
    

    # Foward Diffusion
    batch_size = atom_coords.shape[0]
    t = diffusion.sample_timesteps(batch_size = batch_size).to(DEVICE)
    x_t = diffusion.coord_q_sample(ca_coords, t).to(torch.float32)
    q_t = diffusion.quaternion_q_sample(q_0, t)

    pred_coords = model(single_repr, pair_repr, q_t, x_t)



                                       

tensor([[ 0.1289,  0.0791,  0.0149,  ..., -0.1022,  0.1242,  0.0654],
        [ 0.1134,  0.0371,  0.0111,  ...,  0.0153,  0.0565,  0.0277],
        [ 0.1313,  0.0752, -0.0365,  ..., -0.1950, -0.0347, -0.0643],
        ...,
        [ 0.1262,  0.1281, -0.0099,  ..., -0.0825,  0.1049,  0.0579],
        [ 0.0482,  0.2691,  0.1756,  ..., -0.1200, -0.1839,  0.0359],
        [ 0.1464,  0.0757, -0.0119,  ..., -0.1398,  0.0510, -0.0100]],
       device='cuda:0')


