In [None]:
import numpy as np
import matplotlib.pyplot as plt

from utils import generate_combinations
from utils import load_lora_info

In [2]:
import os
from PIL import Image

In [3]:
def calculate_radial_profile(log_amplitude_spectrum):
    rows, cols = log_amplitude_spectrum.shape
    crow, ccol = rows // 2, cols // 2
    
    # Create coordinate grids
    x = np.arange(cols) - ccol
    y = np.arange(rows) - crow
    X, Y = np.meshgrid(x, y)
    R = np.sqrt(X**2 + Y**2)
    
    # Flatten the radius array and log amplitude spectrum
    R_flatten = R.flatten().astype(int)
    log_amplitude_flatten = log_amplitude_spectrum.flatten()
    
    # Create an array to hold the radial profile
    radii = np.arange(0, np.max(R_flatten) + 1)
    radial_profile = np.zeros_like(radii, dtype=np.float32)
    
    # Average the log amplitudes in each radial bin
    for r in radii:
        radial_profile[r] = log_amplitude_flatten[R_flatten == r].mean()
    
    return radii, radial_profile

def log_amplitude_spectrum(image):
    # Compute the 2D FFT and shift the zero frequency component to the center
    f_transform = np.fft.fft2(image)
    f_transform_shifted = np.fft.fftshift(f_transform)
    
    # Calculate magnitude spectrum and take the log
    magnitude_spectrum = np.abs(f_transform_shifted)
    log_amplitude_spectrum = magnitude_spectrum
    # log_amplitude_spectrum = np.log(magnitude_spectrum + 1e-8)  # Avoid log(0)
    
    return log_amplitude_spectrum

In [4]:
lora_info = load_lora_info('anime')
combinations = generate_combinations(lora_info, 1)

In [None]:
path1 = ""

inter1 = 1
low = 72

for combo in combinations:
    cur_loras = [lora['id'] for lora in combo]

    # file name
    file_name1 = 'merge' + '_' + '_'.join([lora['id'] for lora in combo]) + '_' + str(inter1) + '.png'

    image1 = os.path.join(path1, file_name1)

    image1 = Image.open(image1)
    image_array = np.array(image1)
    image = np.mean(image_array, axis=2).astype(np.uint8)

    f_transform = np.fft.fft2(image)
    f_transform_shifted = np.fft.fftshift(f_transform)
    magnitude_spectrum = np.abs(f_transform_shifted)
    log_amplitude_spectrum = np.log(magnitude_spectrum + 1e-8)

    rows, cols = log_amplitude_spectrum.shape
    crow, ccol = rows // 2 , cols // 2
    x = np.arange(cols) - ccol
    y = np.arange(rows) - crow
    X, Y = np.meshgrid(x, y)
    R = np.sqrt(X**2 + Y**2)

    R_flatten = R.flatten().astype(int)
    log_amplitude_flatten = log_amplitude_spectrum.flatten()

    # Average the log amplitudes in each radial bin
    radii = np.arange(0, np.max(R_flatten)+1)
    radial_profile = np.zeros_like(radii, dtype=np.float32)

    for r in radii:
        radial_profile[r] = log_amplitude_flatten[R_flatten == r].mean()

    # Step 6: Normalize the radial profile (relative log amplitude)
    relative_log_amplitude = radial_profile/np.max(radial_profile)

    # Step 7: Plot the frequency map
    print([lora['id'] for lora in combo], np.mean(relative_log_amplitude[low:]))

In [6]:
def high_pass_filter(image, cutoff):
    # Perform FFT
    f = np.fft.fft2(image)
    fshift = np.fft.fftshift(f)
    
    # Get image dimensions
    rows, cols = image.shape
    crow, ccol = rows // 2 , cols // 2  # Center
    
    # Create a mask with high-pass filtering
    mask = np.ones((rows, cols), np.uint8)
    mask[crow - cutoff:crow + cutoff, ccol - cutoff:ccol + cutoff] = 0
    
    # Apply mask and inverse FFT
    fshift = fshift * mask
    f_ishift = np.fft.ifftshift(fshift)
    img_back = np.fft.ifft2(f_ishift)
    
    # Return the real part of the image
    return np.abs(img_back), np.abs(fshift)

# Function to compare the high-frequency components
def compare_high_frequency_components(fshift1, fshift2):
    difference = np.abs(fshift2 - fshift1)
    return np.mean(difference)

In [None]:
from matplotlib.colors import Normalize
inter1 = 1
cutoff = 72  # Adjust this cutoff value to focus on higher frequencies

path1 = ""
path2 = ""

for combo in combinations[:1]:
    cur_loras = [lora['id'] for lora in combo]

    # file name
    file_name1 = 'merge' + '_' + '_'.join([lora['id'] for lora in combo]) + '_' + str(inter1) + '.png'

    image1 = os.path.join(path1, file_name1)
    image2 = os.path.join(path2, file_name1)

    image1 = Image.open(image1)
    image_array = np.array(image1)
    image2= Image.open(image2)
    image_array2 = np.array(image2)

    image1 = np.mean(image_array, axis=2).astype(np.uint8)
    image2 = np.mean(image_array2, axis=2).astype(np.uint8)

    high_freq_image1, fshift1 = high_pass_filter(image1, cutoff)
    high_freq_image2, fshift2 = high_pass_filter(image2, cutoff)

    # Print the result
    # plt.imshow(high_freq_image1, cmap='gray')
    norm = Normalize(vmin=np.min(high_freq_image2-high_freq_image1), vmax=np.max(high_freq_image2-high_freq_image1))
    # plt.style.use('dark_background')
    plt.imshow(high_freq_image2-high_freq_image1,cmap='seismic', norm=norm)
    # print(compare_high_frequency_components(fshift1, fshift2))
    # print(np.mean(np.abs(fshift2-fshift1)))
    a = [lora['id'] for lora in combo]
    # plt.colorbar()
    plt.axis('off')
    # plt.title(f'Image {a}'), plt.xticks([]), plt.yticks([])
    plt.savefig('a',bbox_inches='tight', dpi=200)
    plt.show()