Skip to content

Commit

Permalink
fix: default_collate changes in PyTorch 2.0
Browse files Browse the repository at this point in the history
As shown in PyTorch issue 99227,
default_collate behaves differently in PyTorch v2.0.
Thus, this commit manually reimplements the desired behavior.
  • Loading branch information
YodaEmbedding authored and fracape committed Apr 16, 2023
1 parent 50ddf91 commit b10cc7c
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions compressai/latent_codecs/rasterscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from torch.utils.data.dataloader import default_collate

from compressai.ans import BufferedRansEncoder, RansDecoder
from compressai.entropy_models import GaussianConditional
Expand All @@ -47,6 +46,9 @@
"RasterScanLatentCodec",
]

K = TypeVar("K")
V = TypeVar("V")


@register_module("RasterScanLatentCodec")
class RasterScanLatentCodec(LatentCodec):
Expand Down Expand Up @@ -309,3 +311,26 @@ def _pad_2d(x: Tensor, padding: int) -> Tensor:
def _reduce_seq(xs):
assert all(x == xs[0] for x in xs)
return xs[0]


def default_collate(batch: List[Dict[K, V]]) -> Dict[K, List[V]]:
if not isinstance(batch, list) or any(not isinstance(d, dict) for d in batch):
raise NotImplementedError

result = _ld_to_dl(batch)

for k, vs in result.items():
if all(isinstance(v, Tensor) for v in vs):
result[k] = torch.stack(vs)

return result


def _ld_to_dl(ld: List[Dict[K, V]]) -> Dict[K, List[V]]:
dl = {}
for d in ld:
for k, v in d.items():
if k not in dl:
dl[k] = []
dl[k].append(v)
return dl

0 comments on commit b10cc7c

Please sign in to comment.