Skip to content

Commit

Permalink
a few improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Vermeille committed Aug 14, 2023
1 parent 7686a96 commit b08ccd5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
2 changes: 1 addition & 1 deletion torchelie/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .conv import *
from .debug import Debug, Dummy
from .noise import Noise
from .vq import VQ, MultiVQ
from .vq import VQ, MultiVQ, MultiVQ2
from .imagenetinputnorm import ImageNetInputNorm
from .withsavedactivations import WithSavedActivations
from .maskedconv import MaskedConv2d, TopLeftConv2d
Expand Down
19 changes: 16 additions & 3 deletions torchelie/nn/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self,
return_indices: bool = True,
max_age: int = 1000):
super(VQ, self).__init__()
self.latent_dim = latent_dim
self.embedding = nn.Embedding(num_tokens, latent_dim)
nn.init.normal_(self.embedding.weight, 0, 1.1)
self.dim = dim
Expand Down Expand Up @@ -89,13 +90,13 @@ def forward(
nb_codes = self.embedding.weight.shape[0]

codebook = self.embedding.weight
if (self.init_mode == 'first' and self.initialized.item() == 0 and
self.training):
if (self.init_mode == 'first' and self.initialized.item() == 0
and self.training):
n_proto = self.embedding.weight.shape[0]

ch_first = x.transpose(dim, -1).contiguous().view(-1, x.shape[dim])
n_samples = ch_first.shape[0]
idx = torch.randint(0, n_samples, (n_proto,))[:nb_codes]
idx = torch.randint(0, n_samples, (n_proto, ))[:nb_codes]
self.embedding.weight.data.copy_(ch_first[idx])
self.initialized[:] = 1

Expand Down Expand Up @@ -171,3 +172,15 @@ def forward(
return q, torch.cat([q[1] for q in quantized], dim=self.dim)
else:
return torch.cat(quantized, dim=self.dim)


class MultiVQ2(VQ):

def forward(self, x: torch.Tensor) -> torch.Tensor:
d = self.latent_dim
dims = x.shape
batched_dims = list(dims)
batched_dims[self.dim] = d
batched_dims[self.dim - 1] = -1
out = super(MultiVQ2, self).forward(x.view(*batched_dims))
return out.view(*dims).contiguous()
24 changes: 21 additions & 3 deletions torchelie/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,10 @@ def load_recursive_state_dict(x: Any, obj: Any) -> None:
load_recursive_state_dict(xx[k], oo[k])


def load_state_dict_forgiving(dst, state_dict: dict, silent: bool = False):
def load_state_dict_forgiving(dst,
state_dict: dict,
silent: bool = False,
fit_dst_size: bool = False):
"""
Loads a state dict, but don't crash if shapes don't match.
"""
Expand All @@ -384,7 +387,22 @@ def load_state_dict_forgiving(dst, state_dict: dict, silent: bool = False):
failed = set()
for name, val in state_dict.items():
try:
dst_dict[name].copy_(val)
if not fit_dst_size:
dst_dict[name].copy_(val)
else:
if dst_dict[name].ndim != val.ndim:
failed.add(name)
if not silent:
print('error in', name, ", can't load shape",
val.shape, "into shape", dst_dict[name].shape)
slices = [
slice(0, min(a, b))
for a, b in zip(val.shape, dst_dict[name].shape)
]
dst_dict[name][slices].copy_(val[slices])
if dst_dict[name].shape != val.shape and not silent:
print('shrunk', name, ': checkpoint has ', val.shape,
'-> model has', dst_dict[name].shape)
except Exception as e:
failed.add(name)
if silent:
Expand Down Expand Up @@ -719,4 +737,4 @@ def __setstate__(self, state):
# '__getstate__': __getstate__,
# '__setstate__': __setstate__
}
return type(cls.__name__, (cls,), d)
return type(cls.__name__, (cls, ), d)

0 comments on commit b08ccd5

Please sign in to comment.