Commit 54479d5
committed
Fix sparse tensor gradients and add backend checks
- Preserve PyTorch sparse tensors through numpy conversion for autograd
- Verify gradient w.r.t. M equals transport plan
- Add sparse backend compatibility checks and teststhrow error when unsupported backend used for sparse"1 parent 1a3dc41 commit 54479d5
File tree
4 files changed
+185
-15
lines changed- ot
- lp
- test
4 files changed
+185
-15
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
178 | 178 | | |
179 | 179 | | |
180 | 180 | | |
181 | | - | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
182 | 191 | | |
183 | 192 | | |
184 | 193 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
13 | | - | |
14 | 13 | | |
15 | 14 | | |
16 | 15 | | |
| |||
295 | 294 | | |
296 | 295 | | |
297 | 296 | | |
298 | | - | |
299 | | - | |
300 | | - | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
301 | 302 | | |
302 | | - | |
303 | | - | |
| 303 | + | |
| 304 | + | |
304 | 305 | | |
305 | 306 | | |
306 | 307 | | |
| |||
579 | 580 | | |
580 | 581 | | |
581 | 582 | | |
582 | | - | |
583 | | - | |
584 | | - | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
585 | 588 | | |
586 | | - | |
587 | | - | |
| 589 | + | |
| 590 | + | |
588 | 591 | | |
589 | | - | |
| 592 | + | |
| 593 | + | |
590 | 594 | | |
591 | 595 | | |
592 | 596 | | |
| |||
599 | 603 | | |
600 | 604 | | |
601 | 605 | | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
602 | 609 | | |
603 | 610 | | |
604 | 611 | | |
| |||
641 | 648 | | |
642 | 649 | | |
643 | 650 | | |
644 | | - | |
| 651 | + | |
| 652 | + | |
| 653 | + | |
645 | 654 | | |
646 | 655 | | |
647 | 656 | | |
| |||
713 | 722 | | |
714 | 723 | | |
715 | 724 | | |
| 725 | + | |
| 726 | + | |
| 727 | + | |
| 728 | + | |
| 729 | + | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
| 733 | + | |
| 734 | + | |
| 735 | + | |
| 736 | + | |
| 737 | + | |
| 738 | + | |
716 | 739 | | |
717 | 740 | | |
718 | 741 | | |
719 | 742 | | |
720 | 743 | | |
721 | 744 | | |
722 | | - | |
| 745 | + | |
723 | 746 | | |
724 | 747 | | |
725 | 748 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
75 | 75 | | |
76 | 76 | | |
77 | 77 | | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
78 | 120 | | |
79 | 121 | | |
80 | 122 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1083 | 1083 | | |
1084 | 1084 | | |
1085 | 1085 | | |
| 1086 | + | |
| 1087 | + | |
| 1088 | + | |
| 1089 | + | |
| 1090 | + | |
| 1091 | + | |
| 1092 | + | |
| 1093 | + | |
| 1094 | + | |
| 1095 | + | |
| 1096 | + | |
| 1097 | + | |
| 1098 | + | |
| 1099 | + | |
| 1100 | + | |
| 1101 | + | |
| 1102 | + | |
| 1103 | + | |
| 1104 | + | |
| 1105 | + | |
| 1106 | + | |
| 1107 | + | |
| 1108 | + | |
| 1109 | + | |
| 1110 | + | |
| 1111 | + | |
| 1112 | + | |
| 1113 | + | |
| 1114 | + | |
| 1115 | + | |
| 1116 | + | |
| 1117 | + | |
| 1118 | + | |
| 1119 | + | |
| 1120 | + | |
| 1121 | + | |
| 1122 | + | |
| 1123 | + | |
| 1124 | + | |
| 1125 | + | |
| 1126 | + | |
| 1127 | + | |
| 1128 | + | |
| 1129 | + | |
| 1130 | + | |
| 1131 | + | |
| 1132 | + | |
| 1133 | + | |
| 1134 | + | |
| 1135 | + | |
| 1136 | + | |
| 1137 | + | |
| 1138 | + | |
| 1139 | + | |
| 1140 | + | |
| 1141 | + | |
| 1142 | + | |
| 1143 | + | |
| 1144 | + | |
| 1145 | + | |
| 1146 | + | |
| 1147 | + | |
| 1148 | + | |
| 1149 | + | |
| 1150 | + | |
| 1151 | + | |
| 1152 | + | |
| 1153 | + | |
| 1154 | + | |
| 1155 | + | |
| 1156 | + | |
| 1157 | + | |
| 1158 | + | |
| 1159 | + | |
| 1160 | + | |
| 1161 | + | |
| 1162 | + | |
| 1163 | + | |
| 1164 | + | |
| 1165 | + | |
| 1166 | + | |
| 1167 | + | |
| 1168 | + | |
| 1169 | + | |
| 1170 | + | |
| 1171 | + | |
| 1172 | + | |
| 1173 | + | |
| 1174 | + | |
| 1175 | + | |
| 1176 | + | |
| 1177 | + | |
| 1178 | + | |
| 1179 | + | |
| 1180 | + | |
| 1181 | + | |
1086 | 1182 | | |
1087 | 1183 | | |
1088 | 1184 | | |
| |||
0 commit comments