In [1]:
import numpy as np
import util.npose_util as nu
import os
import pathlib
import dgl
from dgl import backend as F
import torch_geometric
from torch.utils.data import random_split, DataLoader, Dataset
from typing import Dict
from torch import Tensor
from dgl import DGLGraph
from torch import nn
# from chemical import cos_ideal_NCAC #from RoseTTAFold2
from torch import einsum
import time
import torch
torch.cuda.is_available()

True

In [2]:
from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.transformer import Sequential, SE3Transformer
from se3_transformer.model.transformer_topk import SE3Transformer_topK
from se3_transformer.model.FAPE_Loss import FAPE_loss, Qs2Rs, normQ
from se3_transformer.model.layers.attentiontopK import AttentionBlockSE3
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling, Latent_Unpool, Unpool_Layer
from se3_transformer.runtime.utils import str2bool, to_cuda
from se3_transformer.model.fiber import Fiber
from se3_transformer.model.transformer import get_populated_edge_features

In [3]:
from se3_diffuse import rigid_utils as ru
from se3_diffuse import utils as du
from se3_diffuse import se3_diffuser
from gudiff_model.Data_Graph import Helix4_Dataset_Score, Make_KNN_MP_Graphs
from gudiff_model.Graph_UNet import GraphUNet
from gudiff_model.Data_Graph import build_npose_from_coords, dump_coord_pdb

In [4]:
def model_score_step(backbone_dict, noised_dict, graph_maker, graph_unet, device='cuda'):
    
    batched_t = noised_dict['batched_t'].to(device)
    
    CA_t  = backbone_dict['CA'].reshape(B, L, 3).to(device)
    NC_t = CA_t + backbone_dict['N_CA'].reshape(B, L, 3).to(device)
    CC_t = CA_t + backbone_dict['C_CA'].reshape(B, L, 3).to(device)
    true =  torch.cat((NC_t,CA_t,CC_t),dim=2).reshape(B,L,3,3)

    CA_n  = noised_dict['CA'].reshape(B, L, 3).to(device)
    NC_n = CA_n + noised_dict['N_CA'].reshape(B, L, 3).to(device)
    CC_n = CA_n + noised_dict['C_CA'].reshape(B, L, 3).to(device)
    noise_xyz =  torch.cat((NC_n,CA_n,CC_n),dim=2).reshape(B,L,3,3)

    x = graph_maker.prep_for_network(noised_dict)
    out = graph_unet(x, batched_t)


    pred_trans_score = out['1'][:,0,:]
    pred_rots_score = out['1'][:,1,:]

    true_trans_score = noised_dict['trans_score'].reshape((-1,3)).to(device)
    tss = noisy['trans_score_scaling'][:,None,None].to(device)

    true_rots_score = noised_dict['rot_score'].reshape((-1,3)).to(device)
    rss = noisy['rot_score_scaling'][:,None,None].to(device)


    loss_trans_mse = torch.square(pred_trans_score-true_trans_score)/tss
    loss_rot_mse = torch.square(pred_rots_score-true_rots_score)/rss

    rot_loss =   torch.sum(loss_rot_mse,dim=(-1, -2))/ (B*L) #
    trans_loss = torch.sum(loss_trans_mse,dim=(-1, -2))/ (B*L) #

    loss_out = rot_loss+trans_loss
    
    return loss_out
    

In [11]:
def make_save_folder(name=''):
    base_folder = time.strftime(f'log/%y%b%d_%I%M%p_{name}/', time.localtime())
    if not os.path.exists(base_folder):
        os.makedirs(base_folder)
    subfolders = ['models']
    for subfolder in subfolders:
        if not os.path.exists(base_folder + subfolder):
            os.makedirs(base_folder + subfolder)
            
    return base_folder
        
def save_chkpt(model_path, model, optimizer, epoch, batch, val_losses, train_losses):
    """Save a training checkpoint
    Args:
        model_path (str): the path to save the model to
        model (nn.Module): the model to save
        optimizer (torch.optim.Optimizer): the optimizer to save
        epoch (int): the current epoch
        batch (int): the current batch in the epoch
        loss_domain (list of int): a list of the shared domain for val and training 
            losses
        val_losses (list of float): a list containing the validation losses
        train_losses (list of float): a list containing the training losses
    """
    os.makedirs(os.path.dirname(model_path), exist_ok=True)
    state_dict = dict()
    state_dict.update({'model':model.state_dict(),
                       'optimizer':optimizer.state_dict(),
                       'epoch':epoch,
                       'batch':batch,
                       'train_losses':train_losses,
                       'val_losses':val_losses
                       })
    torch.save(state_dict, f'{model_path}model_e{epoch}')
    
def load_model(model_path, model_class):
    """Load a saved model"""
    
    device = 'cuda:0'
    model = model_class()
    model.load_state_dict(torch.load(model_path)['model'])
    model.to(device)
    return model

In [5]:
# data_path_str  = 'data/h4_ca_coords.npz'
# test_limit = 1028
# rr = np.load(data_path_str)
# ca_coords = [rr[f] for f in rr.files][0][:test_limit,:,:3]
# ca_coords.shape

# getting N-Ca, Ca-C vectors to add as typeI features
#apa = apart helices for val/train split
#tog = together helices for val/train split
apa_path_str  = 'data_npose/h4_apa_coords.npz'
tog_path_str  = 'data_npose/h4_tog_coords.npz'

#grab the first 3 atoms which are N,CA,C
test_limit = 5048
rr = np.load(apa_path_str)
coords_apa = [rr[f] for f in rr.files][0][:test_limit,:]

rr = np.load(tog_path_str)
coords_tog = [rr[f] for f in rr.files][0][:test_limit,:]

In [6]:
B = 32
L=65
limit = 5048
h4_trainData = Helix4_Dataset_Score(coords_tog[:limit])
train_dL = DataLoader(h4_trainData, batch_size=B, shuffle=True, drop_last=True)
# test_iter = iter(train_dL)
# test_batch = next(test_iter)

In [7]:
se3d = se3_diffuser.SE3Diffuser()

In [8]:
gu = GraphUNet(batch_size = B, num_layers_ca = 2).to('cuda')
opti = torch.optim.Adam(gu.parameters(), lr=0.001, weight_decay=5e-6)
gm = Make_KNN_MP_Graphs() #consider precalculating graphs for training

In [9]:
t_vec = np.ones(B,)*0.05
    

In [12]:
avg_tloss=0
avg_vloss=0
model_path = make_save_folder(name=f'full_diff_score')
save_chkpt(f'log/{model_path}', gu, opti, e, B, avg_vloss, avg_tloss)

In [10]:
num_epochs = 200
e_start= 0
model_path = make_save_folder(name=f'full_diff_score')
save_per=10
avg_vloss=0

for e in range(e_start, e_start+num_epochs):
    
    running_tloss = 0 
    start = time.time()
    for i, bb_dict in enumerate(train_dL):
        noisy = se3d.forward_marginal(bb_dict,t_vec=t_vec)
        train_loss = torch.sum(model_score_step(bb_dict, noisy, gm, gu))

        opti.zero_grad()
        train_loss.backward()
        opti.step()

        running_tloss += train_loss.detach().cpu()
    
    end = time.time()
    avg_tloss = running_tloss/(i+1)
    print(f'Average Train Loss Epoch {e}: {avg_tloss};   Epoch time: {end-start:.0f}')
    
    
    if e %save_per==save_per-1:
        save_chkpt(f'log/{model_path}', gu, opti, e, B, avg_vloss, avg_tloss)
    
    

  assert input.numel() == input.storage().size(), (


Average Train Loss Epoch 0: 779.9620361328125;   Epoch time: 58
Average Train Loss Epoch 1: 689.7782592773438;   Epoch time: 53
Average Train Loss Epoch 2: 668.9260864257812;   Epoch time: 53
Average Train Loss Epoch 3: 651.8217163085938;   Epoch time: 52
Average Train Loss Epoch 4: 615.6889038085938;   Epoch time: 50
Average Train Loss Epoch 5: 568.6504516601562;   Epoch time: 49
Average Train Loss Epoch 6: 543.7719116210938;   Epoch time: 50
Average Train Loss Epoch 7: 522.166015625;   Epoch time: 52
Average Train Loss Epoch 8: 501.2473449707031;   Epoch time: 52
Average Train Loss Epoch 9: 475.0062255859375;   Epoch time: 52
Average Train Loss Epoch 10: 447.1307067871094;   Epoch time: 50
Average Train Loss Epoch 11: 425.4366149902344;   Epoch time: 49
Average Train Loss Epoch 12: 408.2470703125;   Epoch time: 50
Average Train Loss Epoch 13: 398.56719970703125;   Epoch time: 50
Average Train Loss Epoch 14: 388.7009582519531;   Epoch time: 51
Average Train Loss Epoch 15: 379.79180908