In [1]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [4]:
import torch
import numpy as np
def view_Jaco_ddpm(images_list, heatmaps_list, threshold = 0.9, view_image = True):
    if view_image:
        for idx, (image, heatmap) in enumerate(zip(images_list, heatmaps_list)):
            fig, axes = plt.subplots(1, 2, figsize = (5,10))
            axes[0].imshow(heatmap.cpu().squeeze().numpy(), cmap='viridis')
            axes[1].imshow(image.cpu().detach().squeeze().numpy(), cmap = 'gray', interpolation = 'none')
            axes[0].axis('off')
            axes[1].axis('off')
            plt.title(f'{999-idx} time step Jacobian')
            plt.show()

    patch_size_list = []
    for idx, heatmap in enumerate(heatmaps_list):
        patch_size = 0

        heatmap_normalized = heatmap / torch.max(heatmap.abs())

        ts = 0.01  # 1% of maximum gradient
        significant_mask = heatmap_normalized > ts

        heatmap_normalized = heatmap_normalized * significant_mask

        for size in range(13):
            in_patches = heatmap_normalized[0,14-size-1:14+size+2,14-size-1:14+size+2]
            patch_influence = in_patches.abs().sum() / heatmap_normalized.abs().sum()
            if patch_influence.item() > threshold:
                patch_size = in_patches.shape[-1]
                patch_size_list.append(patch_size)
                break

    return patch_size_list

def view_Jaco_ddim(images_list, heatmaps_list, threshold = 0.9, view_image=True):

    if view_image:
        for idx, (image, heatmap) in enumerate(zip(images_list, heatmaps_list)):
            fig, axes = plt.subplots(1, 2, figsize = (5,10))
            axes[0].imshow(heatmap.cpu().squeeze().numpy(), cmap='viridis')
            axes[1].imshow(image.cpu().detach().squeeze().numpy(), cmap = 'gray', interpolation = 'none')
            axes[0].axis('off')
            axes[1].axis('off')
            plt.title(f'{49-idx} time step Jacobian')
            plt.show()

    patch_size_list = []
    for idx, heatmap in enumerate(heatmaps_list):
        patch_size = 0

        heatmap_normalized = heatmap / torch.max(heatmap)

        ts = 0.01  # 1% of maximum gradient
        significant_mask = heatmap_normalized > ts

        heatmap_normalized = heatmap_normalized * significant_mask

        for size in range(13):
            in_patches = heatmap_normalized[0,14-size-1:14+size+2,14-size-1:14+size+2]
            patch_influence = in_patches.abs().sum() / heatmap_normalized.abs().sum()
            if patch_influence.item() > threshold:
                patch_size = in_patches.shape[-1]
                patch_size_list.append(patch_size)
                break

    return patch_size_list

def view_Jaco_ddim1(images_list, heatmaps_list, view_image=True):

    if view_image:
        for idx, (image, heatmap) in enumerate(zip(images_list, heatmaps_list)):
            fig, axes = plt.subplots(1, 2, figsize = (5,10))
            axes[0].imshow(heatmap.cpu().squeeze().numpy(), cmap='viridis')
            axes[1].imshow(image.cpu().detach().squeeze().numpy(), cmap = 'gray', interpolation = 'none')
            axes[0].axis('off')
            axes[1].axis('off')
            plt.title(f'{49-idx} time step Jacobian')
            plt.show()

    patch_size_list = []
    for idx, heatmap in enumerate(heatmaps_list):
        patch_size = 0

        heatmap_np = heatmap.cpu().detach().numpy()
        heatmap_normalized = heatmap_np / heatmap_np.max()

        threshold = 0.01
        significant_pixels = (heatmap_normalized > threshold).sum()

        h, w = heatmap_np.shape[1:]
        y_indices, x_indices = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')

        total_weight = heatmap_normalized.sum()

        if total_weight > 0:  # Avoid division by zero
            # Calculate weighted center of mass
            center_y = np.sum(y_indices * heatmap_normalized) / total_weight
            center_x = np.sum(x_indices * heatmap_normalized) / total_weight


            # Calculate weighted standard deviation (measure of spread)
            std_y = np.sqrt(np.sum(((y_indices - center_y)**2) * heatmap_normalized) / total_weight)
            std_x = np.sqrt(np.sum(((x_indices - center_x)**2) * heatmap_normalized) / total_weight)

            # Effective radius (in pixels)
            effective_radius = np.mean([std_y, std_x]) * 2  # 2σ covers ~95% of influence
            patch_size_list.append(effective_radius.item())
        else:
            effective_radius = 0

    return patch_size_list


def view_image(images):

    all_images = torch.cat(list(images), dim=0)
    fig, axes = plt.subplots(25, 40, figsize=(14, 14))

    for i, ax in enumerate(axes.flat):
        ax.imshow(all_images[i].cpu().squeeze().numpy(), cmap='gray')  # 차원 축소 후 출력
        ax.axis('off')  # 축 숨김

    plt.show()

def view_one_image(images):

    images = images.detach()
    plt.imshow(images.cpu().squeeze().numpy(), cmap='gray', interpolation="nearest")  # 차원 축소 후 출력
    plt.axis('off')

    plt.show()

def analyze_receptive_field_window_size(heatmap_list, threshold):
    """
    Analyzes heatmaps to determine the p×p window size of the receptive field at each timestep.

    Parameters:
    - heatmap_list: List of gradient heatmaps from ddpm_sample_Jaco function
    - times: List of corresponding timesteps

    Returns:
    - Dictionary mapping timesteps to p values (window sizes)
    """
    receptive_field_sizes = {}
    patch_size_list = []

    for t, heatmap in enumerate(heatmap_list):
        # Convert tensor to numpy for analysis
        heatmap_np = heatmap[0].cpu().detach().abs().numpy()

        # Normalize the heatmap
        heatmap_normalized = heatmap_np / heatmap_np.max()

        # Find significant gradient values (above threshold)

        significant_mask = heatmap_normalized > threshold

        # Find the bounding box of significant gradient values
        y_indices, x_indices = np.where(significant_mask)

        y_min, y_max = y_indices.min(), y_indices.max()
        x_min, x_max = x_indices.min(), x_indices.max()

        # Calculate height and width of the bounding box
        height = y_max - y_min + 1
        width = x_max - x_min + 1

        # Take the maximum dimension to ensure we capture the full receptive field
        p = max(height, width)

        # Convert to odd-sized window (3×3, 5×5, 7×7, etc.) for consistency with the paper
        p = max(3, (p // 2) * 2 + 1)

        patch_size_list.append(p)


    return patch_size_list

def normalize_heatmaps(heatmaps):
    for idx, heatmap in enumerate(heatmaps):
        normalized_heatmap = heatmap / torch.max(heatmap)
        heatmaps[idx] = normalized_heatmap

    return heatmaps

def calculate_average_heatmap(trainer, sampler_method = 'ddpm' , epoch = 100, thres_hold = 0.9):
    device = trainer.device

    total_heatmaps = torch.zeros((1000, 1, 28, 28), device = device, dtype = torch.float64)

    for i in range(epoch):
        start_point = torch.randn((1,1,28,28), device = device)
        images, heatmaps, noise_list = trainer.ema.ema_model.sample_for_Jaco(img = start_point, i = 14, j = 14)
        heatmaps = normalize_heatmaps(heatmaps)
        total_heatmaps += torch.stack(heatmaps)

        if i % 10 == 9:
            average_heatmaps = total_heatmaps / (i + 1)

            average_heatmaps_list = list(average_heatmaps)

            if sampler_method == 'ddpm':
                patches = view_Jaco_ddpm(images, average_heatmaps_list, threshold=thres_hold, view_image = False)
            elif sampler_method == 'ddim':
                patches = view_Jaco_ddim(images, average_heatmaps_list, threshold=thres_hold, view_image = False)

            ts = np.linspace(0.0, 1.0, len(patches))

            # 꺾은선 그래프 그리기
            plt.figure(figsize=(10, 6))
            plt.plot(ts, list(reversed(patches)))
            plt.xlabel('X')
            plt.ylabel('Y')
            plt.title(f'thershold = {thres_hold}')

            plt.ylim(0, 28)

            plt.grid(True)
            plt.show()

    return average_heatmaps

In [7]:
start = torch.randn((1,1,28,28), device=device)

In [8]:
from compared_ddpm_v3 import Unet, GaussianDiffusion, Trainer
import torchvision
from torchvision import transforms
from matplotlib import pyplot as plt

dataset_size = 60000
num_res = 1
resolutions = 2
init_dim = 32


unet = Unet(dim = init_dim, dim_mults=(1,2), channels = 1)
folder = f'init{init_dim}resolutions{resolutions}size{dataset_size}res{num_res}'

transform = transforms.Compose([
            transforms.ToTensor(),
        ])
ds = torchvision.datasets.MNIST(root="/home/dataset/mnist", train=True, transform=transform, download=False)


# Subset으로 훈련 데이터 제한
model = GaussianDiffusion(unet, image_size=28, sampling_timesteps=20)
trainer = Trainer(model, folder, ds = ds, train_num_steps = 50000, save_and_sample_every = 5000, train_batch_size=128)
trainer.load(30)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# 사용 예시
total_params = count_parameters(unet)
print(f'모델의 총 파라미터 수: {total_params:,}')
#trainer.load(30)


start = torch.randn((1,1,28,28), device=device)
images, heatmaps, noise_list = trainer.ema.ema_model.sample_for_Jaco(img = start, i = 14, j = 14)


view_one_image(images[-1])
print(images[-1].max(), images[-1].min())

'''
average_heatmaps = calculate_average_heatmap(trainer, sampler_method = 'ddpm')
'''



ModuleNotFoundError: No module named 'ELS_2'

In [113]:
#import sys
if 'ELS_machine' in sys.modules:
    print('기존의 import된 ELS_machine Module이 제거되었습니다.')
    del sys.modules['ELS_machine']
if 'compared_ddpm' in sys.modules:
    print('기존의 import된 compared_ddpm Module이 제거되었습니다.')
    del sys.modules['compared_ddpm']
if 'compared_ddpm_resnet' in sys.modules:
    print('기존의 import된 compared_ddpm_resnet Module이 제거되었습니다.')
    del sys.modules['compared_ddpm_resnet']


기존의 import된 compared_ddpm Module이 제거되었습니다.


In [None]:
file_name =f"{folder}/average_heatmaps.pth"
average_heatmaps = torch.load(file_name, weights_only=True)

average_heatmaps_list = list(average_heatmaps)

thres_hold = 0.9
images=list(torch.randn((1000, 1, 28, 28), device = device))
patches = view_Jaco_ddim(images, average_heatmaps_list, threshold=thres_hold, view_image = False)
#patches = analyze_receptive_field_window_size(average_heatmaps_list, threshold = thres_hold)

ts = np.linspace(0.0, 1.0, len(patches))

# 꺾은선 그래프 그리기
plt.figure(figsize=(10, 6))
plt.plot(ts, list(reversed(patches)))
plt.xlabel('Time (Forward Process)')
plt.ylabel('Receptive Field')
plt.title(f'UNet / thershold = {thres_hold}')

plt.ylim(0, 28)

plt.grid(True)
plt.show()

In [None]:

import torchvision
from torchvision import transforms
from torch.utils.data import Subset
from ELS_2.out_dated.compared_ddpm_version1 import Unet, GaussianDiffusion, Trainer

transform = transforms.Compose([
            transforms.ToTensor(),
        ])
ds = torchvision.datasets.MNIST(root="/home/dataset/mnist", train=True, transform=transform, download=False)
num_samples = 1000
subset_indices = list(range(num_samples))

mnist_train_dataset = Subset(ds, subset_indices)

unet = Unet(dim = 64, dim_mults=(1,2,4), channels = 1)
folder = 'results_1000_256'

model = GaussianDiffusion(unet, image_size=28, sampling_timesteps=1000)
trainer = Trainer(model, folder, ds = mnist_train_dataset, train_num_steps = 100000, save_and_sample_every = 10000)
trainer.load(5)

trainer.ema.ema_model.eval()
start_point = torch.randn((1,1,28,28), device=device)
images, heatmaps, noise_list = trainer.ema.ema_model.sample_for_Jaco(img = start_point, i = 14, j = 14)
view_Jaco_ddpm(images, heatmaps)

unet = Unet(dim = 64, dim_mults=(1,2,4), channels = 1)
folder = 'results_20000_256'

model = GaussianDiffusion(unet, image_size=28, sampling_timesteps=1000)
trainer = Trainer(model, folder, ds = mnist_train_dataset, train_num_steps = 100000, save_and_sample_every = 10000)
trainer.load(5)

trainer.ema.ema_model.eval()
start_point = torch.randn((1,1,28,28), device=device)
images1, heatmaps1, noise_list = trainer.ema.ema_model.sample_for_Jaco(img = start_point, i = 14, j = 14)
view_Jaco_ddpm(images1, heatmaps1)

'''
heatmaps_list = []

for i in range(1000):
    heatmaps_list.append(torch.zeros((1,28,28), device = device, dtype = torch.long))

for i in range(10):
    images, heatmaps, noise_list = trainer.ema.ema_model.sample_for_Jaco(img = start_point, i = 14, j = 14)

    for idx, heatmap in enumerate(heatmaps):
        heatmaps_list[idx] = heatmaps_list[idx] + heatmap


patch_size_list = analyze_receptive_field_window_size(heatmaps_list, list(np.arange(1000)))

'''


In [None]:
print(torch.max(images[-1]), torch.min(images[-1]))
print(torch.max(images1[-1]), torch.min(images1[-1]))
print(images[-1])

In [None]:
from ELS_2.out_dated.compared_ddpm_resnet import *

unet = ResNet()
model = GaussianDiffusion(unet, image_size=28, sampling_timesteps=50)
folder = 'results_resnet'
transform = transforms.Compose([
            transforms.ToTensor(),
        ])
ds = torchvision.datasets.MNIST(root="/home/dataset/mnist", train=True, transform=transform, download=False)
#num_samples = 1000
#subset_indices = list(range(num_samples))  # 항상 0~999번째 샘플 사용

# Subset으로 훈련 데이터 제한
#mnist_train_dataset = Subset(ds, subset_indices)
trainer = Trainer(model, folder, ds = ds, train_num_steps = 150000, save_and_sample_every = 10000, train_batch_size=128)
trainer.train()
