Skip to content

Commit

Permalink
make integral compatible with Pytorch1.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
lck1201 committed Apr 4, 2019
1 parent 8e1e1a4 commit ad3f875
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pytorch_projects/common_pytorch/common_loss/integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def generate_3d_integral_preds_tensor(heatmaps, num_joints, x_dim, y_dim, z_dim)
accu_z = heatmaps.sum(dim=3)
accu_z = accu_z.sum(dim=3)

accu_x = accu_x * torch.cuda.comm.broadcast(torch.arange(x_dim), devices=[accu_x.device.index])[0]
accu_y = accu_y * torch.cuda.comm.broadcast(torch.arange(y_dim), devices=[accu_y.device.index])[0]
accu_z = accu_z * torch.cuda.comm.broadcast(torch.arange(z_dim), devices=[accu_z.device.index])[0]
accu_x = accu_x * torch.cuda.comm.broadcast(torch.arange(x_dim).type(torch.cuda.FloatTensor), devices=[accu_x.device.index])[0]
accu_y = accu_y * torch.cuda.comm.broadcast(torch.arange(y_dim).type(torch.cuda.FloatTensor), devices=[accu_y.device.index])[0]
accu_z = accu_z * torch.cuda.comm.broadcast(torch.arange(z_dim).type(torch.cuda.FloatTensor), devices=[accu_z.device.index])[0]

accu_x = accu_x.sum(dim=2, keepdim=True)
accu_y = accu_y.sum(dim=2, keepdim=True)
Expand Down

0 comments on commit ad3f875

Please sign in to comment.