diff --git a/torch_dreams/maco/features_visualizations/preconditioning.py b/torch_dreams/maco/features_visualizations/preconditioning.py index 26e4d72..2935b19 100644 --- a/torch_dreams/maco/features_visualizations/preconditioning.py +++ b/torch_dreams/maco/features_visualizations/preconditioning.py @@ -256,25 +256,26 @@ def init_maco_buffer(image_shape, std=1.0): Returns ------- - magnitude - Magnitude of the spectrum - phase - Phase of the spectrum + magnitude + Magnitude of the spectrum + phase + Phase of the spectrum ''' spectrum_shape = (image_shape[0], image_shape[1]//2+1) phase = np.random.normal(size=(3, *spectrum_shape), scale=std).astype(np.float32) - magnitude_path = get_file("spectrum_decorrelated.npy", IMAGENET_SPECTRUM_URL, cache_subdir="spectrums") - magnitude = np.load(magnitude_path) + magnitude_p = get_file("imagenet_decorrelated.npy", IMAGENET_SPECTRUM_URL) + magnitude = np.load(magnitude_p) magnitude = np.moveaxis(magnitude, 0, -1) - magnitude = Resize(spectrum_shape)(torch.tensor(magnitude).permute(2, 0, 1)).numpy() - magnitude = np.moveaxis(magnitude, -1, 0) + magnitude_resized = torch.nn.functional.interpolate(torch.tensor(magnitude).permute(2, 0, 1).unsqueeze(0), size=spectrum_shape, mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0).numpy() + + magnitude = np.moveaxis(magnitude_resized, -1, 0) + return torch.tensor(magnitude, dtype=torch.float32), torch.tensor(phase, dtype=torch.float32) - -def maco_image_parametrization(magnitude, phase ,values_range): +def maco_image_parametrization(magnitude, phase, values_range): """ Generate the image from the magnitude and phase using MaCo method. @@ -282,40 +283,35 @@ def maco_image_parametrization(magnitude, phase ,values_range): ---------- magnitude : torch.Tensor Magnitude of the spectrum. - phase : torch.Tensor Phase of the spectrum. - values_range : tuple Range of values of the inputs that will be provided to the model, e.g (0, 1) or (-1, 1). Returns ------- - image : torch.Tensor + img : torch.Tensor Image in the 'pixels' basis. """ phase = phase - torch.mean(phase) phase = phase / (torch.std(phase) + 1e-5) - buffer = torch.complex(torch.cos(phase) * magnitude, torch.sin(phase) * magnitude) + buffer = torch.complex(magnitude * torch.cos(phase), magnitude * torch.sin(phase)) + img = torch.fft.irfft2(buffer) img = img.permute(1, 2, 0) - img = img - torch.mean(img) img = img / (torch.std(img) + 1e-5) - - img = recorrelate_colors(img) + img = recorrelate_colors(img) # Assuming you have a similar function in PyTorch img = torch.sigmoid(img) - - img = img * (values_range[1] - values_range[0]) + values_range[0] + img = img * (values_range[1] - values_range[0]) + values_range[0] return img -