In [1]:
import torch
from torch import optim
from torch.optim import Adam
from tqdm import tqdm

from utils.data import read_domain_ids_per_chain_from_txt
from common.res_infor import *
from utils.dataset import *
from diffusion_model.sequence_diffusion_model import *

In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPOCH = 20
LEARNING_RATE = 1e-3

In [3]:
train_pdbs, train_pdb_chains = read_domain_ids_per_chain_from_txt('./data/train_domains.txt')
test_pdbs, test_pdb_chains = read_domain_ids_per_chain_from_txt('./data/test_domains.txt')

In [4]:
train_loader = BackboneCoordsDataLoader(train_pdb_chains, "./data/train_backbone_coords_20.npy", "./data/train_data_res_20.npy",seq_length=20, batch_size=128, shuffle=True)
test_loader = BackboneCoordsDataLoader(test_pdb_chains, './data/test_backbone_coords_20.npy', './data/test_data_res_20.npy', seq_length=20, batch_size=128, shuffle=True)

# Print some Sequence data

In [5]:
def tensor_to_string(tensor, label_res_dict):
    sequences = []
    for i in range(tensor.size(0)):
        sequence = "".join([label_res_dict[residue.item()] if residue.item() != 21 else '*' for residue in tensor[i]])
        sequences.append(sequence)
    return sequences

In [6]:
for batch_idx, (pdb_id, res, data) in enumerate(train_loader):
    res = res.squeeze()
    sequences = tensor_to_string(res, label_res_dict)
    print(pdb_id[:5])
    print(sequences[:5])
    break

('1ijqA', '3h5nA', '3b9vB', '5h9uA', '3rnlA')
['IAYLFFTNRHEVRKMTLDRS', 'MDYILGRYVKIARYGSGGLV', 'EVQLVESGGGLVQPGGSLRL', 'MRALRLVTSESVTEGHPDKL', 'VARPNFFIVGAAKCGTSSLD']


# Visualization of the forward process

In [7]:
diffusion = SequenceDiffusion()

pdb_id, res, _ = next(iter(train_loader))
pdb_1 = pdb_id[0]
res_1 = res[0].reshape(1, -1)
print(pdb_1[:4] + " " + pdb_1[-1])

for t in range(0, 101, 10):
    t = torch.tensor([t])
    x_t, _ = diffusion.seq_q_sample(res_1, t)
    sequences = tensor_to_string(x_t, label_res_dict)
    print(f"t = {t.item()}:", sequences)
    

4r0z A
t = 0: ['TQQLKQSVMDLLTYEGSNDM']
t = 10: ['TQQLKQSVMDLLTYEGSNDM']
t = 20: ['TQQ***SVMDL*TYE*SNDM']
t = 30: ['TQ*LKQ*VM*L*TY*GSNDM']
t = 40: ['TQQL*QSV*D*L**EGSND*']
t = 50: ['T***KQ*V*DLLTY*GS*DM']
t = 60: ['TQ****S****L*Y***N*M']
t = 70: ['*Q**K******L**E*****']
t = 80: ['T*Q****V*D*LT*****D*']
t = 90: ['**********L********M']
t = 100: ['********************']


# Model training

In [8]:
diffusion = SequenceDiffusion(device=DEVICE)
model = SequenceModel().to(DEVICE)
optimizer = optim.Adam(params=model.parameters(), lr = 1e-3)

In [18]:
for epoch in range(EPOCH):    
    model.train()
    train_loss = 0
    for batch_idx, (pdb, res_label, atom_coords) in enumerate(tqdm(train_loader, leave=False)):
        # Data preparation
        x_0 = res_label.squeeze(-1)
        atom_coords = atom_coords
        n_coords = atom_coords[:, :, 0]
        ca_coords = atom_coords[:, :, 1]
        c_coords = atom_coords[:, :, 2]

        rotaions, translations = rigidFrom3Points(n_coords, ca_coords, c_coords)
        pair_repr = torch.cdist(ca_coords, ca_coords, p=2).to(torch.float32)
        
        # Foward Diffusion
        batch_size = atom_coords.shape[0]
        t = diffusion.sample_timesteps(batch_size = batch_size).to(DEVICE)
        x_t, x_0_ignore = diffusion.seq_q_sample(x_0, t)
        # Backward Diffusion
        x_0_hat_logits = model(x_t.float(), pair_repr, rotaions.float(), translations.float())
        
        # Custom loss function
        x_0_hat_logits = x_0_hat_logits.view(-1, 22)
        x_0_ignore = x_0_ignore.view(-1)
        print(x_0_hat_logits, x_0_ignore)
        break
        optimizer.zero_grad()
        loss.backward() # calc gradients
        train_loss += loss.item()
        optimizer.step() # backpropagation
    print('====> Epoch: {} Average loss: {:.10f}'.format(epoch, train_loss / len(train_loader.dataset)))

                                       

tensor([[-0.8426, -1.2501, -0.1453,  ..., -1.3036,  0.3328, -0.1608],
        [-0.7808, -0.6854, -0.2793,  ..., -1.2044,  0.2098, -0.0984],
        [-0.7725, -1.1255, -0.2437,  ..., -1.1823,  0.3181, -0.0795],
        ...,
        [-0.8161, -0.9125, -0.2271,  ..., -1.1594,  0.2122,  0.0869],
        [-0.7935, -1.1852, -0.3795,  ..., -1.1442,  0.2824,  0.0167],
        [-0.8578, -1.1229, -0.3447,  ..., -1.2290,  0.2219, -0.1559]],
       grad_fn=<ViewBackward0>) tensor([17, 14, -1,  ..., -1, 11, -1], dtype=torch.int32)
====> Epoch: 0 Average loss: 0.0000000000


                                       

tensor([[-0.7927, -0.9821, -0.1681,  ..., -1.2467,  0.2788, -0.1434],
        [-0.8010, -1.0634, -0.1977,  ..., -1.2232,  0.2986, -0.0546],
        [-0.7965, -1.0186, -0.2678,  ..., -1.1871,  0.3053, -0.0578],
        ...,
        [-0.8285, -0.9154, -0.1768,  ..., -1.1600,  0.2058,  0.0225],
        [-0.7846, -1.0551, -0.3075,  ..., -1.1409,  0.2567,  0.0753],
        [-0.7706, -1.1822, -0.3538,  ..., -1.1806,  0.2271, -0.1335]],
       grad_fn=<ViewBackward0>) tensor([-1, -1, -1,  ...,  8,  5, 10], dtype=torch.int32)
====> Epoch: 1 Average loss: 0.0000000000


                                       

tensor([[-0.9399, -1.2991, -0.2951,  ..., -1.3090,  0.3878, -0.1385],
        [-0.9585, -0.9602, -0.4659,  ..., -1.2033,  0.2324, -0.2075],
        [-1.0463, -0.8382, -0.3386,  ..., -1.3623,  0.1954, -0.1835],
        ...,
        [-0.9214, -1.0627, -0.3474,  ..., -1.1868,  0.2342, -0.0220],
        [-0.9209, -1.1047, -0.3722,  ..., -1.2303,  0.2062, -0.1304],
        [-0.9463, -1.0019, -0.3547,  ..., -1.2592,  0.1653, -0.1251]],
       grad_fn=<ViewBackward0>) tensor([18, -1, -1,  ..., -1, -1, -1], dtype=torch.int32)
====> Epoch: 2 Average loss: 0.0000000000


                                       

tensor([[-0.7357, -0.9790, -0.1840,  ..., -1.1720,  0.2079, -0.0051],
        [-0.7938, -0.8120, -0.2069,  ..., -1.1592,  0.1693,  0.0689],
        [-0.8595, -1.2168, -0.2039,  ..., -1.2573,  0.3046, -0.0482],
        ...,
        [-0.8168, -1.0229, -0.3071,  ..., -1.0717,  0.1152,  0.0169],
        [-0.8176, -1.0105, -0.3739,  ..., -1.0942,  0.1204, -0.0133],
        [-0.8804, -0.8916, -0.2152,  ..., -1.1543,  0.1480, -0.0019]],
       grad_fn=<ViewBackward0>) tensor([ 2,  1, -1,  ..., 11,  9,  3], dtype=torch.int32)
====> Epoch: 3 Average loss: 0.0000000000


                                       

tensor([[-0.8106, -1.1445, -0.2239,  ..., -1.1855,  0.2693, -0.0059],
        [-0.8082, -1.0301, -0.2927,  ..., -1.1430,  0.2682,  0.1272],
        [-0.8914, -1.2552, -0.3166,  ..., -1.1963,  0.3103, -0.0183],
        ...,
        [-0.8822, -1.0149, -0.3821,  ..., -1.2257,  0.1885, -0.1073],
        [-0.8607, -1.0112, -0.3500,  ..., -1.2051,  0.1429, -0.0689],
        [-0.8933, -0.8748, -0.3287,  ..., -1.2403,  0.1370, -0.0695]],
       grad_fn=<ViewBackward0>) tensor([12,  6, 10,  ..., 10, 11,  3], dtype=torch.int32)
====> Epoch: 4 Average loss: 0.0000000000


                                       

tensor([[-0.7747, -0.9979, -0.2132,  ..., -1.2097,  0.2851, -0.1296],
        [-0.8523, -1.1039, -0.3490,  ..., -1.2041,  0.2847, -0.1575],
        [-0.9025, -0.6765, -0.3084,  ..., -1.2408,  0.1401, -0.1007],
        ...,
        [-0.7535, -1.0148, -0.2816,  ..., -1.2183,  0.2227, -0.2315],
        [-0.7603, -1.1797, -0.3074,  ..., -1.2154,  0.2597, -0.3251],
        [-0.7822, -0.8942, -0.1887,  ..., -1.1820,  0.1821, -0.1862]],
       grad_fn=<ViewBackward0>) tensor([18, 14, -1,  ...,  5, -1, 14], dtype=torch.int32)
====> Epoch: 5 Average loss: 0.0000000000


                                       

tensor([[-0.7367, -0.9324, -0.3440,  ..., -1.1232,  0.2432, -0.0859],
        [-0.7886, -0.8026, -0.2569,  ..., -1.1829,  0.2756, -0.0808],
        [-0.7840, -0.8041, -0.3468,  ..., -1.1092,  0.2725, -0.0697],
        ...,
        [-0.8364, -1.2181, -0.3181,  ..., -1.2005,  0.2930,  0.0038],
        [-0.8759, -1.1683, -0.3258,  ..., -1.1770,  0.2773, -0.0678],
        [-0.8810, -0.9654, -0.2204,  ..., -1.2077,  0.2217,  0.0401]],
       grad_fn=<ViewBackward0>) tensor([ 3, -1, -1,  ..., 18, -1,  4], dtype=torch.int32)
====> Epoch: 6 Average loss: 0.0000000000


                                       

tensor([[-0.8028, -1.1483, -0.1346,  ..., -1.2180,  0.2755, -0.0981],
        [-0.7795, -0.9385, -0.3399,  ..., -1.1320,  0.2158, -0.0128],
        [-0.8440, -1.1172, -0.1983,  ..., -1.2489,  0.2702, -0.0954],
        ...,
        [-0.8346, -0.9721, -0.3806,  ..., -1.0798,  0.1734,  0.0429],
        [-0.7525, -1.1291, -0.4116,  ..., -1.0943,  0.1416, -0.0918],
        [-0.8083, -0.9767, -0.3018,  ..., -1.0926,  0.1002, -0.0951]],
       grad_fn=<ViewBackward0>) tensor([20, 10,  3,  ...,  1, -1, -1], dtype=torch.int32)
====> Epoch: 7 Average loss: 0.0000000000


                                       

tensor([[-0.9002, -1.1488, -0.3514,  ..., -1.2562,  0.2597, -0.2083],
        [-0.9926, -0.9201, -0.2182,  ..., -1.3897,  0.2118, -0.2141],
        [-0.9319, -1.1402, -0.2154,  ..., -1.2922,  0.3399, -0.1501],
        ...,
        [-0.8787, -1.2271, -0.3376,  ..., -1.2218,  0.2823, -0.0340],
        [-0.8759, -1.0242, -0.2647,  ..., -1.2595,  0.2131, -0.0882],
        [-0.8351, -0.8929, -0.2384,  ..., -1.1880,  0.1747,  0.0521]],
       grad_fn=<ViewBackward0>) tensor([-1, -1, -1,  ..., 17,  6,  4], dtype=torch.int32)
====> Epoch: 8 Average loss: 0.0000000000


                                       

tensor([[-0.7968, -0.7680, -0.2112,  ..., -1.1874,  0.2349, -0.0839],
        [-0.8238, -1.0964, -0.2365,  ..., -1.2533,  0.3532, -0.1219],
        [-0.8137, -0.8641, -0.2752,  ..., -1.1874,  0.2669, -0.0444],
        ...,
        [-0.9275, -0.9963, -0.3898,  ..., -1.2235,  0.2070, -0.0933],
        [-0.8720, -1.0554, -0.3378,  ..., -1.2895,  0.1835, -0.2099],
        [-0.9430, -1.0584, -0.3259,  ..., -1.3122,  0.1868, -0.2605]],
       grad_fn=<ViewBackward0>) tensor([-1, -1, -1,  ..., -1, 10, -1], dtype=torch.int32)
====> Epoch: 9 Average loss: 0.0000000000


                                       

tensor([[-0.8657, -1.0053, -0.2318,  ..., -1.2620,  0.2011, -0.1405],
        [-0.8489, -1.0124, -0.2207,  ..., -1.1497,  0.2620, -0.0390],
        [-0.8452, -1.1446, -0.2743,  ..., -1.1502,  0.3071,  0.0266],
        ...,
        [-0.8478, -1.1300, -0.3745,  ..., -1.1818,  0.2291, -0.0534],
        [-0.8704, -1.0252, -0.2952,  ..., -1.2627,  0.1813, -0.1523],
        [-0.8552, -1.1301, -0.2275,  ..., -1.3130,  0.2059, -0.2180]],
       grad_fn=<ViewBackward0>) tensor([-1, -1,  7,  ..., 11, 14, -1], dtype=torch.int32)
====> Epoch: 10 Average loss: 0.0000000000


                                       

tensor([[-0.9255, -1.2293, -0.2254,  ..., -1.3156,  0.2658, -0.1405],
        [-0.9369, -1.1729, -0.2129,  ..., -1.3496,  0.2512, -0.1567],
        [-0.8523, -0.7388, -0.2857,  ..., -1.2155,  0.1372,  0.0302],
        ...,
        [-0.8740, -1.1105, -0.3992,  ..., -1.1504,  0.2977,  0.0255],
        [-0.9643, -1.0825, -0.3312,  ..., -1.2381,  0.2841, -0.0422],
        [-0.9353, -1.1344, -0.3905,  ..., -1.2563,  0.2496, -0.1078]],
       grad_fn=<ViewBackward0>) tensor([10, 13, 18,  ..., -1,  1, -1], dtype=torch.int32)
====> Epoch: 11 Average loss: 0.0000000000


                                       

tensor([[-0.7777, -1.0035, -0.2236,  ..., -1.2014,  0.3326, -0.1113],
        [-0.8709, -1.1229, -0.2347,  ..., -1.2656,  0.3835, -0.1571],
        [-0.8090, -1.0531, -0.2744,  ..., -1.2007,  0.3469, -0.0418],
        ...,
        [-0.7767, -0.9123, -0.2577,  ..., -1.0904,  0.1764,  0.0236],
        [-0.9053, -0.8690, -0.1944,  ..., -1.1634,  0.2069, -0.0244],
        [-0.8282, -1.1990, -0.3234,  ..., -1.1946,  0.2823, -0.0577]],
       grad_fn=<ViewBackward0>) tensor([-1, -1, -1,  ...,  9, -1,  1], dtype=torch.int32)
====> Epoch: 12 Average loss: 0.0000000000


                                       

tensor([[-0.7673, -0.6636, -0.2927,  ..., -1.1405,  0.2453, -0.1082],
        [-0.7850, -0.7933, -0.3011,  ..., -1.1647,  0.2972, -0.1137],
        [-0.7663, -1.0120, -0.3078,  ..., -1.1683,  0.3554, -0.0912],
        ...,
        [-0.8826, -0.9874, -0.2996,  ..., -1.1318,  0.1519, -0.0688],
        [-0.7733, -0.9303, -0.3217,  ..., -1.0765,  0.1105,  0.0596],
        [-0.8679, -1.0529, -0.2004,  ..., -1.2084,  0.2000, -0.1568]],
       grad_fn=<ViewBackward0>) tensor([ 9, 17,  2,  ..., -1, -1,  7], dtype=torch.int32)
====> Epoch: 13 Average loss: 0.0000000000


                                       

tensor([[-0.8765, -0.8929, -0.1852,  ..., -1.3492,  0.1806, -0.2269],
        [-0.8049, -0.7520, -0.1915,  ..., -1.1687,  0.1806, -0.0559],
        [-0.8253, -1.0412, -0.2342,  ..., -1.1947,  0.3025, -0.0796],
        ...,
        [-0.8589, -1.2183, -0.3878,  ..., -1.1587,  0.3143,  0.0081],
        [-0.9149, -1.0179, -0.3037,  ..., -1.2224,  0.2556,  0.0235],
        [-0.9175, -1.1999, -0.3585,  ..., -1.2709,  0.2617, -0.1072]],
       grad_fn=<ViewBackward0>) tensor([18, 10,  2,  ..., 10, 16, -1], dtype=torch.int32)
====> Epoch: 14 Average loss: 0.0000000000


                                       

tensor([[-0.9630, -1.0069, -0.3352,  ..., -1.3359,  0.1982, -0.0941],
        [-0.9340, -0.9518, -0.2570,  ..., -1.3619,  0.2122, -0.1832],
        [-0.8351, -0.7896, -0.3319,  ..., -1.1719,  0.1724, -0.0115],
        ...,
        [-0.8857, -0.9691, -0.2001,  ..., -1.2390,  0.2702, -0.0388],
        [-0.8684, -1.0859, -0.3517,  ..., -1.2248,  0.3152, -0.0230],
        [-0.8988, -0.9522, -0.2748,  ..., -1.2904,  0.2602, -0.0079]],
       grad_fn=<ViewBackward0>) tensor([13,  9, 14,  ..., -1, -1,  6], dtype=torch.int32)
====> Epoch: 15 Average loss: 0.0000000000


                                       

tensor([[-0.7909, -0.7968, -0.2517,  ..., -1.1488,  0.1437,  0.0180],
        [-0.8140, -1.0436, -0.1843,  ..., -1.1748,  0.2402, -0.0616],
        [-0.8146, -1.0952, -0.2034,  ..., -1.1307,  0.2689,  0.0564],
        ...,
        [-0.8821, -1.1029, -0.3180,  ..., -1.1731,  0.2526,  0.0344],
        [-0.8753, -1.1991, -0.3155,  ..., -1.2175,  0.2183, -0.0842],
        [-0.8863, -0.9907, -0.2364,  ..., -1.2027,  0.1897, -0.0379]],
       grad_fn=<ViewBackward0>) tensor([ 3,  6, 15,  ..., 15,  7,  1], dtype=torch.int32)
====> Epoch: 16 Average loss: 0.0000000000


                                       

tensor([[-0.7829, -1.2338, -0.2064,  ..., -1.1632,  0.2715, -0.0449],
        [-0.7761, -0.9839, -0.3244,  ..., -1.0620,  0.1702,  0.1135],
        [-0.8403, -0.8958, -0.2319,  ..., -1.1351,  0.1679,  0.0528],
        ...,
        [-0.8108, -1.0544, -0.3417,  ..., -1.0501,  0.1359,  0.0096],
        [-0.7915, -1.0668, -0.3983,  ..., -1.0677,  0.1268, -0.0615],
        [-0.7847, -0.9630, -0.3411,  ..., -1.0661,  0.0781, -0.0432]],
       grad_fn=<ViewBackward0>) tensor([10,  0, -1,  ...,  3, -1, -1], dtype=torch.int32)
====> Epoch: 17 Average loss: 0.0000000000


                                       

tensor([[-0.7767, -0.7911, -0.2886,  ..., -1.1718,  0.1809, -0.0792],
        [-0.8407, -1.0944, -0.1810,  ..., -1.2844,  0.2639, -0.0964],
        [-0.8076, -1.0410, -0.3350,  ..., -1.1397,  0.2915, -0.1129],
        ...,
        [-0.8468, -1.1884, -0.2758,  ..., -1.2542,  0.2635, -0.1792],
        [-0.9006, -0.9564, -0.1955,  ..., -1.2184,  0.2267, -0.0758],
        [-0.7952, -1.0457, -0.2906,  ..., -1.1865,  0.2366,  0.0616]],
       grad_fn=<ViewBackward0>) tensor([10,  4, 16,  ...,  5, -1, -1], dtype=torch.int32)
====> Epoch: 18 Average loss: 0.0000000000


                                       

tensor([[-0.7795, -0.7227, -0.2324,  ..., -1.1778,  0.2333, -0.0802],
        [-0.7927, -0.8767, -0.3007,  ..., -1.1748,  0.2865, -0.0799],
        [-0.8274, -0.8604, -0.2640,  ..., -1.2130,  0.2719, -0.0399],
        ...,
        [-0.8877, -1.1754, -0.3359,  ..., -1.2360,  0.2303, -0.1845],
        [-0.8909, -0.8285, -0.2650,  ..., -1.2557,  0.1523, -0.0762],
        [-0.8770, -1.1737, -0.3105,  ..., -1.2730,  0.2153, -0.2829]],
       grad_fn=<ViewBackward0>) tensor([17,  8, -1,  ...,  2,  8, 10], dtype=torch.int32)
====> Epoch: 19 Average loss: 0.0000000000


