Skip to content

Commit

Permalink
make sure we're checking image param reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
Mayukhdeb committed Feb 25, 2024
1 parent c60d462 commit cc44368
Showing 1 changed file with 33 additions and 16 deletions.
49 changes: 33 additions & 16 deletions torch_dreams/maco/image_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,65 @@ def __init__(
assert height == magnitude_spectrum.height
assert width == magnitude_spectrum.width

param = init_image_param(
image_param = init_image_param(
height=height,
width=width,
sd=standard_deviation,
device=device
)

"""
torch.angle(image_param) would contain only 0 or pi. nothing else. this is because image_param contains only real values
"""
# 2. Decouple the tensor into phase and amplitude
amplitude, phase = torch.abs(param), torch.angle(param)

amplitude, phase = torch.abs(image_param), torch.angle(image_param)
# 3. Set the phase spectrum to be trainable
self.param = phase.requires_grad_()

# 4. Hardcode the amplitude values to be magnitude_spectrum.data
self.amplitude_spectrum = magnitude_spectrum.data
# self.amplitude_spectrum = magnitude_spectrum.data.requires_grad_(False)
self.amplitude_spectrum = amplitude

self.height = height
self.width = width
self.device = device
self.batch_size = batch_size
# Assuming an optimizer is set up later with self.param as the parameter to optimize

def get_image_parameter(self):

"""
Compute image param from the self.param and self.amplitude_spectrum.
Here we recombine the amplitude and phase into a complex tensor,
then perform an inverse FFT to get the spatial domain representation.
we should check whether we can recover image_param given phase and amplitude
"""
# Convert amplitude and phase back to a complex tensor
complex_spectrum = self.amplitude_spectrum * torch.exp(1j * self.param)
self.check_whether_we_get_image_param(
amplitude=amplitude,
phase=phase,
image_param=image_param
)

def reconstruct_image_param(self, amplitude, phase):
reconstructed_image_param = amplitude * torch.exp(1j * phase)
return reconstructed_image_param

def check_whether_we_get_image_param(self, amplitude, phase, image_param):
reconstructed_image_param = self.reconstruct_image_param(
amplitude=amplitude,
phase=phase
)

assert torch.allclose(
image_param,
reconstructed_image_param.real
), f"Could not reconstruct image param. Very sad."

# Inverse FFT to go from frequency to spatial domain
img_spatial = torch.fft.ifft2(complex_spectrum).real # Taking the real part if necessary
def get_image_param_from_phase_and_amplitude_spectrum(self):
pass

return img_spatial


def postprocess(self, device):
img = fft_to_rgb(
height=self.height,
width=self.width,
image_parameter=self.get_image_parameter(),
image_parameter=self.get_image_param_from_phase_and_amplitude_spectrum(),
device=device,
)
img = lucid_colorspace_to_rgb(t=img, device=device)
Expand Down

0 comments on commit cc44368

Please sign in to comment.