From 965411d61f9b0a544f2080c674137f3523a357cd Mon Sep 17 00:00:00 2001 From: ydveshan <76656875+Eshan-Yadav@users.noreply.github.com> Date: Sun, 26 Mar 2023 15:00:56 +0530 Subject: [PATCH] Update equibind.py The copy_edge() function is deprecated and thus is preventing the code to run properly, thus i updated it to the current version copy_e(). --- models/equibind.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/equibind.py b/models/equibind.py index 72120ae0..4bb94602 100644 --- a/models/equibind.py +++ b/models/equibind.py @@ -482,8 +482,8 @@ def forward(self, lig_graph, rec_graph, coords_lig, h_feats_lig, original_ligand else: x_evolved_rec = coords_rec - lig_graph.update_all(fn.copy_edge('msg', 'm'), fn.mean('m', 'aggr_msg')) - rec_graph.update_all(fn.copy_edge('msg', 'm'), fn.mean('m', 'aggr_msg')) + lig_graph.update_all(fn.copy_e('msg', 'm'), fn.mean('m', 'aggr_msg')) + rec_graph.update_all(fn.copy_e('msg', 'm'), fn.mean('m', 'aggr_msg')) if self.fine_tune: x_evolved_lig = x_evolved_lig + self.att_mlp_cross_coors_V_lig(h_feats_lig) * ( @@ -515,7 +515,7 @@ def forward(self, lig_graph, rec_graph, coords_lig, h_feats_lig, original_ligand Loss = torch.sum((d_squared - geometry_graph.edata['feat'] ** 2)**2) # this is the loss whose gradient we are calculating here grad_d_squared = 2 * (x_evolved_lig[src] - x_evolved_lig[dst]) geometry_graph.edata['partial_grads'] = 2 * (d_squared - geometry_graph.edata['feat'] ** 2)[:,None] * grad_d_squared - geometry_graph.update_all(fn.copy_edge('partial_grads', 'partial_grads_msg'), + geometry_graph.update_all(fn.copy_e('partial_grads', 'partial_grads_msg'), fn.sum('partial_grads_msg', 'grad_x_evolved')) grad_x_evolved = geometry_graph.ndata['grad_x_evolved'] x_evolved_lig = x_evolved_lig + self.geometry_reg_step_size * grad_x_evolved