diff --git a/models/equibind.py b/models/equibind.py index 72120ae0..4bb94602 100644 --- a/models/equibind.py +++ b/models/equibind.py @@ -482,8 +482,8 @@ def forward(self, lig_graph, rec_graph, coords_lig, h_feats_lig, original_ligand else: x_evolved_rec = coords_rec - lig_graph.update_all(fn.copy_edge('msg', 'm'), fn.mean('m', 'aggr_msg')) - rec_graph.update_all(fn.copy_edge('msg', 'm'), fn.mean('m', 'aggr_msg')) + lig_graph.update_all(fn.copy_e('msg', 'm'), fn.mean('m', 'aggr_msg')) + rec_graph.update_all(fn.copy_e('msg', 'm'), fn.mean('m', 'aggr_msg')) if self.fine_tune: x_evolved_lig = x_evolved_lig + self.att_mlp_cross_coors_V_lig(h_feats_lig) * ( @@ -515,7 +515,7 @@ def forward(self, lig_graph, rec_graph, coords_lig, h_feats_lig, original_ligand Loss = torch.sum((d_squared - geometry_graph.edata['feat'] ** 2)**2) # this is the loss whose gradient we are calculating here grad_d_squared = 2 * (x_evolved_lig[src] - x_evolved_lig[dst]) geometry_graph.edata['partial_grads'] = 2 * (d_squared - geometry_graph.edata['feat'] ** 2)[:,None] * grad_d_squared - geometry_graph.update_all(fn.copy_edge('partial_grads', 'partial_grads_msg'), + geometry_graph.update_all(fn.copy_e('partial_grads', 'partial_grads_msg'), fn.sum('partial_grads_msg', 'grad_x_evolved')) grad_x_evolved = geometry_graph.ndata['grad_x_evolved'] x_evolved_lig = x_evolved_lig + self.geometry_reg_step_size * grad_x_evolved