Skip to content

Commit

Permalink
update the functions maco_image_parametrization
Browse files Browse the repository at this point in the history
  • Loading branch information
sushmanthreddy committed Mar 25, 2024
1 parent 9cc6feb commit 264a1b6
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions torch_dreams/maco/features_visualizations/preconditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,66 +256,62 @@ def init_maco_buffer(image_shape, std=1.0):
Returns
-------
magnitude
Magnitude of the spectrum
phase
Phase of the spectrum
magnitude
Magnitude of the spectrum
phase
Phase of the spectrum
'''
spectrum_shape = (image_shape[0], image_shape[1]//2+1)
phase = np.random.normal(size=(3, *spectrum_shape), scale=std).astype(np.float32)
magnitude_path = get_file("spectrum_decorrelated.npy", IMAGENET_SPECTRUM_URL, cache_subdir="spectrums")
magnitude = np.load(magnitude_path)
magnitude_p = get_file("imagenet_decorrelated.npy", IMAGENET_SPECTRUM_URL)
magnitude = np.load(magnitude_p)
magnitude = np.moveaxis(magnitude, 0, -1)
magnitude = Resize(spectrum_shape)(torch.tensor(magnitude).permute(2, 0, 1)).numpy()
magnitude = np.moveaxis(magnitude, -1, 0)
magnitude_resized = torch.nn.functional.interpolate(torch.tensor(magnitude).permute(2, 0, 1).unsqueeze(0), size=spectrum_shape, mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).numpy()

magnitude = np.moveaxis(magnitude_resized, -1, 0)

return torch.tensor(magnitude, dtype=torch.float32), torch.tensor(phase, dtype=torch.float32)





def maco_image_parametrization(magnitude, phase ,values_range):
def maco_image_parametrization(magnitude, phase, values_range):
"""
Generate the image from the magnitude and phase using MaCo method.
Parameters
----------
magnitude : torch.Tensor
Magnitude of the spectrum.
phase : torch.Tensor
Phase of the spectrum.
values_range : tuple
Range of values of the inputs that will be provided to the model, e.g (0, 1) or (-1, 1).
Returns
-------
image : torch.Tensor
img : torch.Tensor
Image in the 'pixels' basis.
"""
phase = phase - torch.mean(phase)
phase = phase / (torch.std(phase) + 1e-5)

buffer = torch.complex(torch.cos(phase) * magnitude, torch.sin(phase) * magnitude)
buffer = torch.complex(magnitude * torch.cos(phase), magnitude * torch.sin(phase))

img = torch.fft.irfft2(buffer)
img = img.permute(1, 2, 0)


img = img - torch.mean(img)
img = img / (torch.std(img) + 1e-5)


img = recorrelate_colors(img)
img = recorrelate_colors(img) # Assuming you have a similar function in PyTorch
img = torch.sigmoid(img)


img = img * (values_range[1] - values_range[0]) + values_range[0]
img = img * (values_range[1] - values_range[0]) + values_range[0]

return img






0 comments on commit 264a1b6

Please sign in to comment.