In [11]:
import torch
from torch import optim
from torch.optim import Adam
from tqdm import tqdm

from utils.data import read_domain_ids_per_chain_from_txt
from common.res_infor import *
from utils.dataset import *
from diffusion_model.sequence_diffusion_model import *

In [12]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
EPOCH = 20
LEARNING_RATE = 1e-3

In [13]:
train_pdbs, train_pdb_chains = read_domain_ids_per_chain_from_txt('./data/train_domains.txt')
test_pdbs, test_pdb_chains = read_domain_ids_per_chain_from_txt('./data/test_domains.txt')

In [14]:
train_loader = BackboneCoordsDataLoader(train_pdb_chains, "./data/train_backbone_coords_20.npy", "./data/train_data_res_20.npy",seq_length=20, batch_size=128, shuffle=True)
test_loader = BackboneCoordsDataLoader(test_pdb_chains, './data/test_backbone_coords_20.npy', './data/test_data_res_20.npy', seq_length=20, batch_size=128, shuffle=True)

# Print some Sequence data

In [15]:
def tensor_to_string(tensor, label_res_dict):
    sequences = []
    for i in range(tensor.size(0)):
        sequence = "".join([label_res_dict[residue.item()] if residue.item() != 21 else '*' for residue in tensor[i]])
        sequences.append(sequence)
    return sequences

In [16]:
for batch_idx, (pdb_id, res, data) in enumerate(train_loader):
    res = res.squeeze()
    sequences = tensor_to_string(res, label_res_dict)
    print(pdb_id[:5])
    print(sequences[:5])
    break

('4j8sA', '2wa0A', '3ay5A', '3qdsA', '4g56A')
['QQTDLSQVWPEANQHFSKEI', 'SGVDLGTYFQSMDAESLFRE', 'LASPLEQLRHLAEELRLLLP', 'TYSITLRVFQRNPGRGFFSI', 'RVSSGRDVACVTEVADTLGA']


# Visualization of the forward process

In [17]:
diffusion = SequenceDiffusion()

pdb_id, res, _ = next(iter(train_loader))
pdb_1 = pdb_id[0]
res_1 = res[0].reshape(1, -1)
print(pdb_1[:4] + " " + pdb_1[-1])

for t in range(0, 101, 10):
    t = torch.tensor([t])
    x_t, _, _= diffusion.seq_q_sample(res_1, t)
    sequences = tensor_to_string(x_t, label_res_dict)
    print(f"t = {t.item()}:", sequences)
    

4d53 A
t = 0: ['MGHKIDTKEDMKILYSEIAE']
t = 10: ['MGHKIDTKED*KILYSEIAE']
t = 20: ['MGHKI*TKEDMK*LYS*IA*']
t = 30: ['MGH**DTKEDM***YSE*AE']
t = 40: ['MGH*IDTKEDM*ILYS*IAE']
t = 50: ['M*H***TK*D*K***SE*A*']
t = 60: ['*GHKID*K*D*****SE***']
t = 70: ['*GHK**TK******YSE***']
t = 80: ['*****D*********S****']
t = 90: ['********************']
t = 100: ['********************']


# Model training

In [18]:
diffusion = SequenceDiffusion(device=DEVICE)
model = SequenceModel().to(DEVICE)
optimizer = optim.Adam(params=model.parameters(), lr = 1e-3)

In [26]:
for epoch in range(EPOCH):    
    model.train()
    train_loss = 0
    for batch_idx, (pdb, res_label, atom_coords) in enumerate(tqdm(train_loader, leave=False)):
        # Data preparation
        x_0 = res_label.squeeze(-1)
        atom_coords = atom_coords
        n_coords = atom_coords[:, :, 0]
        ca_coords = atom_coords[:, :, 1]
        c_coords = atom_coords[:, :, 2]

        rotaions, translations = rigidFrom3Points(n_coords, ca_coords, c_coords)
        pair_repr = torch.cdist(ca_coords, ca_coords, p=2).to(torch.float32)
        
        # Foward Diffusion
        batch_size = atom_coords.shape[0]
        t, pt = diffusion.sample_timesteps(batch_size = batch_size, device=DEVICE)
        x_t, x_0_ignore, mask = diffusion.seq_q_sample(x_0, t)

        # Backward Diffusion
        x_0_hat_logits = model(x_t.float(), pair_repr, rotaions.float(), translations.float())
        
        # Custom loss function
        cross_entropy_loss = F.cross_entropy(x_0_hat_logits.reshape(batch_size, 21, 20), 
                                             x_0_ignore.reshape(batch_size, 20).type(torch.LongTensor), 
                                             ignore_index = -1, reduction='none').mean(1)
        vb_loss = cross_entropy_loss / t
        vb_loss = vb_loss / pt
        vb_loss = vb_loss / (math.log(2) * x_0.shape[1:].numel())

        loss = vb_loss.mean()

        optimizer.zero_grad()
        loss.backward() # calc gradients
        train_loss += loss.item()
        optimizer.step() # backpropagation
    print('====> Epoch: {} Average loss: {:.10f}'.format(epoch, train_loss / len(train_loader.dataset)))

                                               

KeyboardInterrupt: 