In [1]:
import torch
import torch.nn as nn

from affinity_module.config import get_config

import numpy as np

from dgllife.utils import EarlyStopping

from affinity_module.utils import set_random_seed, load_model
from affinity_module.utils import run_a_train_epoch, run_stat_epoch, run_an_eval_epoch
from affinity_module.utils import Collate
from affinity_module.dataset import VPLGDataset, FoldsOf_VPLGDataset

from affinity_module.protein_graph_loaders import DSSP_loader

from torch.backends import cudnn

Using backend: pytorch


### Prepare parser & environment

In [2]:
cudnn.deterministic = True
cudnn.benchmark = False

args = get_config()
collate = Collate(args)
args['device'] = torch.device("cuda: 0") if torch.cuda.is_available() else torch.device("cpu")
set_random_seed(args['random_seed'])


argv_valFold = args['argv_valFold']
argv_testFold = args['argv_testFold']
cache_dir_prefix = args['cache_dir_prefix']

pdb2graph_translator = DSSP_loader(dssp_files_path = args['dssp_files_path'],
                                   includeAminoacidPhyschemFeatures = False,
                                   cache_dir_prefix = cache_dir_prefix)

best_model_filename = pdb2graph_translator.get_best_model_filename()

_colNames = dict(master_data_table = args['master_data_table'],
                 pdb_id_col_name="PDBs", smiles_col_name="SMILES", target_col_name="logKi",
                 foldId_col_name='Fold')

# Create & cache graphs

In [3]:
dataset = VPLGDataset(
    smiles_to_graph=args['smiles_to_graph'],
    smiles_node_featurizer=args['smiles_node_featurizer'],
    smiles_edge_featurizer=args['smiles_edge_featurizer'],
    **_colNames,
    pdb2graph_translator = pdb2graph_translator,
    load=False)

deleting old cache...
graph_cache_dssp_noPCFAA/smiles_graphs: 10 deleted
graph_cache_dssp_noPCFAA/fasta_graphs: 10 deleted
graph_cache_dssp_noPCFAA/smiles: 10 deleted
graph_cache_dssp_noPCFAA/labels: 10 deleted
Processing dgl graphs from scratch...
...to load: 4514 4514 4514
Processing graph 0/0
Exception in pdbId_to_graph in line: 200, for SMILES; CCCCCC(=O)OC(CO)CO[PH](O)(O)OCC[N+](C)(C)C, PDB_ID: 4TNW
('4TNW', 7592, 'not in resnum2idx')
Exception in pdbId_to_graph in line: 315, for SMILES; CCC(O)CC(=O)NC(CC(N)=O)C(=O)NC(CCC(N)=O)C(=O)NC(CO)CC(C)C, PDB_ID: 3KRD
('3KRD', 5868, 'not in resnum2idx')
Processing graph 500/498
Caching graph 499/499
Processing graph 1000/998
Caching graph 999/999
Processing graph 1500/1498
Caching graph 1499/1499
Processing graph 2000/1998
Caching graph 1999/1999
Processing graph 2500/2498
Caching graph 2499/2499
Exception in pdbId_to_graph in line: 2698, for SMILES; CC(C#Cc1cccc(OC23CC4CC(CC(C4)C2)C3)c1)N([O])C(=N)O, PDB_ID: 3SHJ
('3SHJ', 6192, 'not in res

## Train the model

In [5]:
args['device'] = torch.device("cuda: 0") if torch.cuda.is_available() else torch.device("cpu")

raw_dataset = VPLGDataset(
    smiles_to_graph=args['smiles_to_graph'],
    smiles_node_featurizer=args['smiles_node_featurizer'],
    smiles_edge_featurizer=args['smiles_edge_featurizer'],
    **_colNames,
    pdb2graph_translator = pdb2graph_translator,
    load=True)


args['fasta_node_feat_size'] = raw_dataset.fasta_graphs[0].ndata['h'].shape[1]
args['fasta_edge_feat_size'] = raw_dataset.fasta_graphs[0].edata['e'].shape[1]
print(args['fasta_node_feat_size'], args['fasta_edge_feat_size'])

print('will save best model to:', best_model_filename)
shfl = True

p = dict(batch_size=args['batch_size'], shuffle=shfl, collate_fn=collate.collate_molgraphs)

mx_nodes = 5000 # filter out large graphs
folds5 = [0,1,2,3,4]
folds5.remove(argv_valFold)
if argv_testFold in folds5:
    folds5.remove(argv_testFold)
print('Folds to use: train=%s, val=%s, test=%s' % (str(folds5), str(argv_valFold), str(argv_testFold)) )
#
train_loader = FoldsOf_VPLGDataset(raw_dataset, folds5, max_nodes = mx_nodes).asDataLoader(**p)
val_loader =   FoldsOf_VPLGDataset(raw_dataset, [argv_valFold], max_nodes = mx_nodes ).asDataLoader(**p)
test_loader =  FoldsOf_VPLGDataset(raw_dataset, [argv_testFold], max_nodes = mx_nodes ).asDataLoader(**p)
print('number of batches: train %d, val %d, test %d' % (len(train_loader), len(val_loader), len(test_loader)))
    
torch.cuda.empty_cache()

model = load_model(args)
loss_fn = nn.MSELoss(reduction='none')

optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],
                             weight_decay=args['weight_decay'])

stopper = EarlyStopping(mode=args['mode'],
                        patience=args['patience'],
                        filename=best_model_filename)

if args['load_checkpoint']:
    print('Loading checkpoint...')
    stopper.load_checkpoint(model)
model.to(args['device'])

for epoch in range(args['num_epochs']):
    # Train
    run_a_train_epoch(args, epoch, model, train_loader, loss_fn, optimizer)

    # Validation and early stop
    val_score_ext = run_stat_epoch(args, model, val_loader)
    val_score = val_score_ext[ args['metric_name'] ]
    
    test_score = run_an_eval_epoch(args, model, test_loader)
    early_stop = stopper.step(val_score, model)
    print('epoch {:d}/{:d}, validation {} {:.4f}, test {} {:.4f}, best validation {} {:.4f}'.format(
        epoch + 1, args['num_epochs'], args['metric_name'], val_score,
        args['metric_name'], test_score,
        args['metric_name'], stopper.best_score), 
          ', now R2 = %.4f' % val_score_ext['R2'])

    if early_stop:
        break

print('-'*80)
stopper.load_checkpoint(model)

print()

all_metrics = {}
for dsName, data_loader in zip(['train', 'val', 'test'], [train_loader, val_loader, test_loader]):    
    metrics, _, y_true, y_pred = run_stat_epoch(args, model, data_loader, return_pred=True)
    all_metrics[dsName] = (metrics, y_true, y_pred)


print('-'*50)


all_metric_names = set()
for dsName in ['train', 'val', 'test']:
    metrics, _, _ = all_metrics[dsName]
    for k in metrics.keys():
        all_metric_names.add(k)
#

print('%25s' % '', end='')
for dsName in ['train', 'val', 'test']:
    print('%12s' % dsName, end='')
print()
for mName in all_metric_names:
    print('%25s' % mName, end='')
    for dsName in ['train', 'val', 'test']:
        metrics, _, _ = all_metrics[dsName]
        print('%12.5f' % metrics[mName], end='')
    print()
#
print('-'*50)


Loading previously saved dgl graphs...
smiles: 4504
smiles_graphs: 4504
fasta_graphs: 4504
labels: 4504
32 6
will save best model to: graph_cache_best_model_dssp_noPCFAA.pth
Folds to use: train=[0, 1, 2, 3], val=4, test=4
skip: len(fg.ndata['h']) == 5808
skip: len(fg.ndata['h']) == 5616
skip: len(fg.ndata['h']) == 5016
skip: len(fg.ndata['h']) == 8384
skip: len(fg.ndata['h']) == 6528
skip: len(fg.ndata['h']) == 6224
skip: len(fg.ndata['h']) == 6224
skip: len(fg.ndata['h']) == 6818
skip: len(fg.ndata['h']) == 6208
skip: len(fg.ndata['h']) == 5891
skip: len(fg.ndata['h']) == 5889
skip: len(fg.ndata['h']) == 6173
skip: len(fg.ndata['h']) == 6155
skip: len(fg.ndata['h']) == 6162
skip: len(fg.ndata['h']) == 6165
skip: len(fg.ndata['h']) == 6178
skip: len(fg.ndata['h']) == 6189
skip: len(fg.ndata['h']) == 6173
skip: len(fg.ndata['h']) == 7078
skip: len(fg.ndata['h']) == 6148
skip: len(fg.ndata['h']) == 6168
skip: len(fg.ndata['h']) == 6173
skip: len(fg.ndata['h']) == 7078
skip: len(fg.ndata[

	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple)


epoch 1/1000, training mae 1.9223
  [run_stat_epoch] will ignore batch 28 with size 1
  [run_an_eval_epoch] will ignore batch 28 with size 1
epoch 1/1000, validation mae 1.4517, test mae 1.4526, best validation mae 1.4517 , now R2 = 0.1896
epoch 2/1000, training mae 1.4391
  [run_stat_epoch] will ignore batch 28 with size 1
  [run_an_eval_epoch] will ignore batch 28 with size 1
EarlyStopping counter: 1 out of 50
epoch 2/1000, validation mae 1.5507, test mae 1.5514, best validation mae 1.4517 , now R2 = 0.1056
epoch 3/1000, training mae 1.4115
  [run_stat_epoch] will ignore batch 28 with size 1
  [run_an_eval_epoch] will ignore batch 28 with size 1
EarlyStopping counter: 2 out of 50
epoch 3/1000, validation mae 1.4701, test mae 1.4701, best validation mae 1.4517 , now R2 = 0.1500
epoch 4/1000, training mae 1.4062
  [run_stat_epoch] will ignore batch 28 with size 1
  [run_an_eval_epoch] will ignore batch 28 with size 1
epoch 4/1000, validation mae 1.3728, test mae 1.3736, best validation

In [6]:
_baseFname = pdb2graph_translator.get_best_model_filename().replace('.pth', '')

for dsName in ['train', 'val', 'test']:
    _, y_true, y_pred = all_metrics[dsName]
    y_true = np.array(y_true)[:, 0]
    y_pred = np.array(y_pred)

    np.savetxt(_baseFname+'_yy_%s.txt' % dsName, np.vstack((y_true, y_pred)).T, 
               header='y_true, y_pred (%s)' % dsName )