# Compare models

In [12]:
import os
import random
import torch
import numpy as np

from src.data.data_converter import tokens_to_weights, weights_to_flattened_weights, flattened_weights_to_weights
from src.data.inr_dataset import INRDataset
from src.data.utils import get_files_from_selectors
from src.data.inr import INR

from src.core.config import TransformerExperimentConfig, DataConfig, DataSelector, DatasetType
from src.core.config_diffusion import DiffusionExperimentConfig

from src.models.diffusion.pl_diffusion import HyperDiffusion
from src.models.autoencoder.pl_transformer import Autoencoder

from src.evaluation import model_utils, visualization_utils, metrics

print(torch.cuda.is_available())


%load_ext autoreload
%autoreload 2

In [32]:
vae_path = "logs/best_overfit_so_far_099_split.ckpt"
standard_hyperdiffusion="diffusion_logs/lightning_checkpoints/2025-01-26 13-20-53.879333-standard hyperdiffusion, all 2 digits, big ass parameters, pls work 20250126_132035-962oirc5/last.ckpt"
stable_hyperdiffusion_path = "logs/last (1).ckpt"

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [15]:
mlp = INR(up_scale=16)
mlp.eval()

In [16]:
files = get_files_from_selectors("mnist-inrs", [DataSelector(dataset_type=DatasetType.MNIST, class_label=2)])
dataset_flattened = INRDataset(files, "cpu", is_flattened=True)
dataset_tokenized = INRDataset(files, "cpu", is_flattened=False)
ref_cp = dataset_flattened.get_state_dict(0)

### 1. Create Hyperdiffusion model

In [17]:
config: DiffusionExperimentConfig = DiffusionExperimentConfig.sanity()

config.transformer_config.n_embd = 1024
config.transformer_config.n_head = 8
config.transformer_config.n_layer = 8
#config.data = DataConfig.small()
#config.data.selector = 
#data_path = os.path.join(os.getcwd(), config.data.data_path)

data_shape = dataset_flattened[0].unsqueeze(0).shape
# Initialize model
hyperdiffusion = HyperDiffusion(
    config, data_shape
)
hyperdiffusion.eval()

#checkpoint = "good_checkpoints/ferdy/hyperdiffusion/standard_hyperdiffusion.ckpt"
#checkpoint = "good_checkpoints/ferdy/hyperdiffusion/hyperdiffusion_20250126_132035.ckpt"
state_dict = torch.load(standard_hyperdiffusion, map_location=torch.device('cpu'))["state_dict"]
hyperdiffusion.load_state_dict(state_dict)
print("Dataset size:", len(dataset_flattened))

### 2. Create VAE

In [27]:
config_ae: TransformerExperimentConfig = TransformerExperimentConfig.default()
config_ae.model.num_heads = 8
config_ae.model.num_layers = 8
config_ae.model.d_model = 512  # 256 -> 4
config_ae.model.latent_dim = 8
config_ae.model.layer_norm = False
config_ae.model.use_mask = True

autoencoder_checkpoint = "good_checkpoints/best_overfit_so_far_099_split.ckpt"
autoencoder_checkpoint = vae_path
#autoencoder_checkpoint = "good_checkpoints/overfit.ckpt"
# Initialize model
vae = Autoencoder(config_ae)
vae.eval()
state_dict = torch.load(autoencoder_checkpoint, map_location=torch.device('cpu'))
vae.load_state_dict(state_dict["state_dict"])
sum(p.numel() for p in vae.parameters())

### 3. Create Stable Hyperdiffusion Model

In [33]:
config_shyp: DiffusionExperimentConfig = DiffusionExperimentConfig.sanity()
config_shyp.transformer_config.n_embd = 512
config_shyp.transformer_config.n_head = 8
config_shyp.transformer_config.n_layer = 8

data_shape = (0, config_ae.model.n_tokens * config_ae.model.latent_dim)
_,_,positions = dataset_tokenized[0]

# Initialize model
stable_hyperdiffusion = HyperDiffusion(
    config_shyp, data_shape, vae, positions
)

diffusion_checkpoint = "diffusion_logs/lightning_checkpoints/stable_hyperdiffusion_whole_dataset_2025-01-24 16-17-16.651982-hyperdiffusion_num2 20250124_161713-z8ljsssp/last.ckpt"
diffusion_checkpoint = stable_hyperdiffusion_path
state_dict = torch.load(diffusion_checkpoint, map_location=torch.device('cpu'))
stable_hyperdiffusion.load_state_dict(state_dict["state_dict"])
stable_hyperdiffusion.eval()
print("")

### 4. Visualize the results

In [10]:
n_samples = 100
mlp = mlp.to(device)
hyperdiffusion = hyperdiffusion.to(device)

In [None]:
path = "evaluation/hyperdiffusion"
os.makedirs(path, exist_ok=True)
for i in range(100):
    samples = hyperdiffusion.generate_samples(n_samples)
    #hyperdiffusion_images = model_utils.generate_diffusion_images(hyperdiffusion, mlp, n_samples, ref_cp=ref_cp)
    #hyperdiffusion_images_numpy = np.array(hyperdiffusion_images)
    print(samples.shape)
    torch.save(samples, os.path.join(path, f"hyperdiffusion_{i}.pt"))

## HyperDiffusion

In [9]:
path = "evaluation/hyperdiffusion/ferdy"
file_paths = [os.path.join(root, file) for root, _, files in os.walk(path) for file in files]
tensors = [torch.load(f, map_location="cpu") for f in file_paths]
hyperdiffusion_vectors = torch.cat(tensors, dim=0)
print(hyperdiffusion_vectors.shape)

In [75]:
hyperdiffusion_images = [model_utils.compute_image(mlp, flattened_weights_to_weights(weights, mlp)) for weights in hyperdiffusion_vectors]

In [None]:
num_images_per_row = 5
num_images = 200
for i in range(min(num_images, len(hyperdiffusion_images))//num_images_per_row):
    images = hyperdiffusion_images[i*num_images_per_row:i*num_images_per_row+num_images_per_row]
    visualization_utils.plot_n_images(images, single_row=True)

In [11]:
dataset_vectors = [dataset_flattened[i] for i in range(len(dataset_flattened))]
dataset_vectors_array = torch.stack(dataset_vectors)
print(dataset_vectors_array.shape)

In [77]:
dataset_images = [model_utils.compute_image(mlp, flattened_weights_to_weights(dataset_flattened[i], mlp)) for i in range(len(dataset_flattened))]
dataset_images_array = torch.Tensor(dataset_images)
dataset_images_array = dataset_images_array.view(dataset_images_array.size(0), -1)
print(dataset_images_array.shape)

In [21]:
hyperdiffusion_images_array =  torch.Tensor(hyperdiffusion_images)
hyperdiffusion_images_array = hyperdiffusion_images_array.view(hyperdiffusion_images_array.size(0), -1)
print(hyperdiffusion_images_array.shape)

In [None]:
# Get cosine similarity for closest images
indices, distances = metrics.find_nearest_neighbor(hyperdiffusion_images_array, dataset_images_array)
print("Cosine similarity for images: ", distances.mean().item())

In [12]:
# Get euclidean distance for closest images
num_images = 100
distances_total = []
for i in range(len(hyperdiffusion_images_array)//num_images):
    images = hyperdiffusion_images_array[i*num_images:i*num_images+num_images]
    indices, distances = metrics.find_nearest_neighbor(images, dataset_images_array, metric="euclidean")
    distances_total.extend(distances)
print("Euclidean distance for images: ", np.mean(distances_total))

In [13]:
# Get manhattan distance for closest images
num_images = 100
distances_total = []
for i in range(len(hyperdiffusion_images_array)//num_images):
    images = hyperdiffusion_images_array[i*num_images:i*num_images+num_images]
    indices, distances = metrics.find_nearest_neighbor(images, dataset_images_array, metric="manhattan")
    distances_total.extend(distances)
print("Manhattan distance for images: ", np.mean(distances_total))

In [14]:
# Get cosine similarity for closest weights
indices, distances = metrics.find_nearest_neighbor(hyperdiffusion_vectors, dataset_vectors_array)
print("Cosine similarity for weights: ", distances.mean().item())

In [15]:
# Get cosine similarity for closest weights
num_images = 100
distances_total = []
for i in range(len(hyperdiffusion_vectors)//num_images):
    images = hyperdiffusion_vectors[i*num_images:i*num_images+num_images]
    indices, distances = metrics.find_nearest_neighbor(images, dataset_vectors_array, metric="euclidean")
    distances_total.extend(distances)
print("Euclidean distance for weights: ", np.mean(distances_total))

In [16]:
# Get manhattan distance for closest weights
num_images = 100
distances_total = []
for i in range(len(hyperdiffusion_vectors)//num_images):
    images = hyperdiffusion_vectors[i*num_images:i*num_images+num_images]
    indices, distances = metrics.find_nearest_neighbor(images, dataset_vectors_array, metric="manhattan")
    distances_total.extend(distances)
print("Manhattan distance for weights: ", np.mean(distances_total))

## Stable HyperDiffusion

In [None]:
#path = "evaluation/stable_hyperdiffusion"
#n_samples = 100
#os.makedirs(path, exist_ok=True)
#for i in range(100):
#    samples_tokenized, samples_reconstructed, positions = stable_hyperdiffusion.generate_samples(n_samples)
#    #hyperdiffusion_images = model_utils.generate_diffusion_images(hyperdiffusion, mlp, n_samples, ref_cp=ref_cp)
#    #hyperdiffusion_images_numpy = np.array(hyperdiffusion_images)
#    print(samples_tokenized.shape, samples_reconstructed.shape, positions.shape)
 #   torch.save((samples_tokenized, samples_reconstructed, positions), os.path.join(path, f"stable_hyperdiffusion_{i}.pt"))

In [17]:
path = "evaluation/stable_hyperdiffusion"
file_paths = [os.path.join(root, file) for root, _, files in os.walk(path) for file in files]
print(file_paths)
latens = []
tokens = []
pos = []
for file in file_paths:
    samples_tokenized, samples_reconstructed, positions = torch.load(file, map_location="cpu")
    latens.append(samples_tokenized)
    tokens.append(samples_reconstructed)
    pos.append(positions)
stable_hyperdiffusion_latent = torch.cat(latens, dim=0)
stable_hyperdiffusion_recon = torch.cat(tokens, dim=0)
stable_hyperdiffusion_pos = torch.cat(pos, dim=0)
print(stable_hyperdiffusion_latent.shape)
print(stable_hyperdiffusion_recon.shape)
print(stable_hyperdiffusion_pos.shape)

In [74]:
stable_hyperdiffusion_images = [model_utils.compute_image(mlp, tokens_to_weights(t, p, ref_cp)) for t, p in zip(stable_hyperdiffusion_recon, stable_hyperdiffusion_pos)]

In [19]:
stable_hyperdiffusion_images_array =  torch.Tensor(stable_hyperdiffusion_images)
stable_hyperdiffusion_images_array = stable_hyperdiffusion_images_array.view(stable_hyperdiffusion_images_array.size(0), -1)
print(stable_hyperdiffusion_images_array.shape)

In [20]:
# Get cosine similarity for closest images
indices, distances = metrics.find_nearest_neighbor(stable_hyperdiffusion_images_array, dataset_images_array)
print("Cosine similarity for images: ", distances.mean().item())

In [21]:
# Get cosine similarity for closest images
num_images = 100
distances_total = []
for i in range(len(stable_hyperdiffusion_images_array)//num_images):
    images = stable_hyperdiffusion_images_array[i*num_images:i*num_images+num_images]
    indices, distances = metrics.find_nearest_neighbor(images, dataset_images_array, metric="euclidean")
    distances_total.extend(distances)
print("Euclidean distance for images: ", np.mean(distances_total))

In [22]:
# Get manhattan distance for closest images
num_images = 100
distances_total = []
for i in range(len(stable_hyperdiffusion_images_array)//num_images):
    images = stable_hyperdiffusion_images_array[i*num_images:i*num_images+num_images]
    indices, distances = metrics.find_nearest_neighbor(images, dataset_images_array, metric="manhattan")
    distances_total.extend(distances)
print("Manhattan distance for images: ", np.mean(distances_total))

In [23]:
stable_hyperdiffusion_vectors = [weights_to_flattened_weights(tokens_to_weights(t, p, ref_cp)) for t, p in zip(stable_hyperdiffusion_recon, stable_hyperdiffusion_pos)]
stable_hyperdiffusion_vectors = torch.stack(stable_hyperdiffusion_vectors)
print(stable_hyperdiffusion_vectors.shape)

In [24]:
# Get cosine similarity for closest weights
indices, distances = metrics.find_nearest_neighbor(stable_hyperdiffusion_vectors, dataset_vectors_array)
print("Cosine similarity for weights: ", distances.mean().item())

In [25]:
# Get cosine similarity for closest weights
num_images = 100
distances_total = []
for i in range(len(stable_hyperdiffusion_vectors)//num_images):
    images = stable_hyperdiffusion_vectors[i*num_images:i*num_images+num_images]
    indices, distances = metrics.find_nearest_neighbor(images, dataset_vectors_array, metric="euclidean")
    distances_total.extend(distances)
print("Euclidean distance for weights: ", np.mean(distances_total))

In [26]:
# Get manhattan distance for closest weights
num_images = 100
distances_total = []
for i in range(len(stable_hyperdiffusion_vectors)//num_images):
    images = stable_hyperdiffusion_vectors[i*num_images:i*num_images+num_images]
    indices, distances = metrics.find_nearest_neighbor(images, dataset_vectors_array, metric="manhattan")
    distances_total.extend(distances)
print("Manhattan distance for weights: ", np.mean(distances_total))

### 5. Calculate metrics

#### 5.1 Image comparison

In [71]:
num_samples = 200
hyperdiffusion_results = hyperdiffusion.generate_samples(num_samples)
stable_hyperdiffusion_latent_outputs, stable_hyperdiffusion_results, stable_hyperdiffusion_positions = stable_hyperdiffusion.generate_samples(num_samples)

In [None]:
print(hyperdiffusion_results.shape)
print(stable_hyperdiffusion_results.shape)
print(stable_hyperdiffusion_latent_outputs.shape)

In [None]:
# Get n random samples from dataset
indices = random.sample(range(len(dataset_flattened)), num_samples)
dataset_samples = [dataset_flattened[i] for i in indices]
dataset_samples = torch.stack(dataset_samples)
print(dataset_samples.shape)

In [None]:
stable_hyperdiffusion_results_flattened = [weights_to_flattened_weights(tokens_to_weights(t, p, ref_cp)) for t, p in zip(stable_hyperdiffusion_results, stable_hyperdiffusion_positions)]
stable_hyperdiffusion_results_flattened = torch.stack(stable_hyperdiffusion_results_flattened)
print(stable_hyperdiffusion_results_flattened.shape)

In [83]:
# Generate images and stack to tensor
hyperdiffusion_images = [torch.from_numpy(model_utils.compute_image(mlp, flattened_weights_to_weights(s, mlp))) for s in hyperdiffusion_results]
hyperdiffusion_images = torch.stack(hyperdiffusion_images)
stable_hyperdiffusion_images = [torch.from_numpy(model_utils.compute_image(mlp, flattened_weights_to_weights(s, mlp))) for s in stable_hyperdiffusion_results_flattened]
stable_hyperdiffusion_images = torch.stack(stable_hyperdiffusion_images)
dataset_images = [torch.from_numpy(model_utils.compute_image(mlp, flattened_weights_to_weights(s, mlp))) for s in dataset_samples]
dataset_images = torch.stack(dataset_images)

In [51]:
dataset_images_array = torch.Tensor(dataset_images)
hyperdiffusion_images_array = torch.Tensor(hyperdiffusion_images)
stable_hyperdiffusion_images_array = torch.Tensor(stable_hyperdiffusion_images)

In [26]:
print(dataset_images_array.shape), print(hyperdiffusion_images_array.shape)

##### Calculate FID scores:

In [31]:
n_samples = 500
fid_scores = []
for i in range(len(hyperdiffusion_images_array)//n_samples):
    hyperdiffusion_images_array_subset = hyperdiffusion_images_array[i*n_samples:(i+1)*n_samples]

    # Take random samples from the dataset for computational efficiency
    indices = np.random.choice(len(dataset_images_array), n_samples, replace=False)
    dataset_images_array_subset = dataset_images_array[indices]
    fid_hyperdiffusion = metrics.calculate_fid(dataset_images_array_subset, hyperdiffusion_images_array_subset)
    fid_scores.append(fid_hyperdiffusion)
fid_hyperdiffusion = np.mean(fid_scores)

fid_scores = []
for i in range(len(stable_hyperdiffusion_images_array)//n_samples):
    stable_hyperdiffusion_images_array_subset = stable_hyperdiffusion_images_array[i*n_samples:(i+1)*n_samples]

    # Take random samples from the dataset for computational efficiency
    indices = np.random.choice(len(dataset_images_array), n_samples, replace=False)
    dataset_images_array_subset = dataset_images_array[indices]
    fid_stable_hyperdiffusion = metrics.calculate_fid(dataset_images_array_subset, stable_hyperdiffusion_images_array_subset)
    fid_scores.append(fid_stable_hyperdiffusion)
fid_stable_hyperdiffusion = np.mean(fid_scores)
print(f"FID Hyperdiffusion: {fid_hyperdiffusion}")
print(f"FID Stable Hyperdiffusion: {fid_stable_hyperdiffusion}")

In [61]:
stable_hyperdiffusion_images_array = stable_hyperdiffusion_images_array.view(stable_hyperdiffusion_images_array.size(0), -1)[:len(dataset_images_array)]
hyperdiffusion_images_array = hyperdiffusion_images_array.view(hyperdiffusion_images_array.size(0), -1)[:len(dataset_images_array)]
dataset_images_array = dataset_images_array.view(dataset_images_array.size(0), -1)

##### Calculate MSE scores:

In [38]:
mse_hyperdiffusion = metrics.calculate_mse(dataset_images_array, hyperdiffusion_images_array)
mse_stable_hyperdiffusion = metrics.calculate_mse(dataset_images_array, stable_hyperdiffusion_images_array)
print(f"MSE Hyperdiffusion: {mse_hyperdiffusion}")
print(f"MSE Stable Hyperdiffusion: {mse_stable_hyperdiffusion}")

##### Calculate the Minimum Matching Distance

In [41]:
mmd_hyperdiffusion = metrics.compute_mmd(dataset_images_array.reshape(dataset_images_array.shape[0], -1), hyperdiffusion_images_array.reshape(hyperdiffusion_images_array.shape[0], -1))
mmd_stable_hyperdiffusion = metrics.compute_mmd(dataset_images_array.reshape(dataset_images_array.shape[0], -1), stable_hyperdiffusion_images_array.reshape(stable_hyperdiffusion_images_array.shape[0], -1))
print(f"Minimum Matching Distance Hyperdiffusion: {mmd_hyperdiffusion}")
print(f"Minimum Matching Distance Stable Hyperdiffusion: {mmd_stable_hyperdiffusion}")

##### Calculate the coverage

In [42]:
coverage_hyperdiffusion = metrics.compute_coverage(dataset_images_array.reshape(dataset_images_array.shape[0], -1), hyperdiffusion_images_array.reshape(hyperdiffusion_images_array.shape[0], -1))
coverage_stable_hyperdiffusion = metrics.compute_coverage(dataset_images_array.reshape(dataset_images_array.shape[0], -1), stable_hyperdiffusion_images_array.reshape(stable_hyperdiffusion_images_array.shape[0], -1))
print(f"Coverage Hyperdiffusion: {coverage_hyperdiffusion}")
print(f"Coverage Stable Hyperdiffusion: {coverage_stable_hyperdiffusion}")

##### Calculate the 1-Nearest-Neighbor Accuracy

In [43]:
nna_hyperdiffusion = metrics.compute_1nna(dataset_images_array.reshape(dataset_images_array.shape[0], -1), hyperdiffusion_images_array.reshape(hyperdiffusion_images_array.shape[0], -1))
nna_stable_hyperdiffusion = metrics.compute_1nna(dataset_images_array.reshape(dataset_images_array.shape[0], -1), stable_hyperdiffusion_images_array.reshape(stable_hyperdiffusion_images_array.shape[0], -1))
print(f"1-Nearest-Neighbor Accuracy Hyperdiffusion: {nna_hyperdiffusion}")
print(f"1-Nearest-Neighbor Accuracy Stable Hyperdiffusion: {nna_stable_hyperdiffusion}")

##### Calculate the novelty

In [44]:
# NOTE: maybe we should use the whole training dataset here
novelty_hyperdiffusion = metrics.detect_novelty(dataset_images_array.reshape(dataset_images_array.shape[0], -1), hyperdiffusion_images_array.reshape(hyperdiffusion_images_array.shape[0], -1))
novelty_stable_hyperdiffusion = metrics.detect_novelty(dataset_images_array.reshape(dataset_images_array.shape[0], -1), stable_hyperdiffusion_images_array.reshape(stable_hyperdiffusion_images_array.shape[0], -1))
print(f"Novelty Hyperdiffusion: {novelty_hyperdiffusion[0].sum()} of {len(novelty_hyperdiffusion[0])}")
print(f"Novelty Stable Hyperdiffusion: {novelty_stable_hyperdiffusion[0].sum()} of {len(novelty_stable_hyperdiffusion[0])}")

In [47]:
novelty_hyperdiffusion = metrics.novelty_svm(dataset_images_array.reshape(dataset_images_array.shape[0], -1), hyperdiffusion_images_array.reshape(hyperdiffusion_images_array.shape[0], -1))
novelty_stable_hyperdiffusion = metrics.novelty_svm(dataset_images_array.reshape(dataset_images_array.shape[0], -1), stable_hyperdiffusion_images_array.reshape(stable_hyperdiffusion_images_array.shape[0], -1))
print(f"Novelty Hyperdiffusion: {(novelty_hyperdiffusion == -1).sum()}")
print(f"Novelty Stable Hyperdiffusion: {(novelty_stable_hyperdiffusion == -1).sum()}")

#### Calculate inception score (IS)

In [56]:
from torchmetrics.image.inception import InceptionScore
import einops

is_metric = InceptionScore(normalize=True)
hyperdiffusion_images = einops.repeat(hyperdiffusion_images_array, "b h w -> b c h w", c=3).to(torch.uint8)
scores = []
for i in range(len(hyperdiffusion_images)//1000):
    scores.append(is_metric(hyperdiffusion_images[i*1000:(i+1)*1000]))

In [57]:
print(f"IS Hyperdiffusion: {np.mean(np.array(scores), axis=0)}")

In [58]:
stable_hyperdiffusion_images = einops.repeat(stable_hyperdiffusion_images_array, "b h w -> b c h w", c=3).to(torch.uint8)
scores = []
for i in range(len(stable_hyperdiffusion_images) // 1000):
    scores.append(is_metric(stable_hyperdiffusion_images[i*1000:(i+1)*1000]))

In [59]:
print(f"IS Stable Hyperdiffusion: {np.mean(np.array(scores), axis=0)}")

#### Peak Signal-to-Noise Ratio (PSNR)

In [62]:
import torch
import torchmetrics

psnr = torchmetrics.PeakSignalNoiseRatio()
score = psnr(hyperdiffusion_images_array, dataset_images_array)
print(f"PSNR hyperdiffusion: {score}")
score = psnr(stable_hyperdiffusion_images_array, dataset_images_array)
print(f"PSNR stable hyperdiffusion: {score}")

#### Structural Similarity Index (SSIM)

In [78]:
dataset_images_array = torch.Tensor(dataset_images)
hyperdiffusion_images_array = torch.Tensor(hyperdiffusion_images)[:len(dataset_images)]
stable_hyperdiffusion_images_array = torch.Tensor(stable_hyperdiffusion_images)[:len(dataset_images)]
print(f"Dataset images shape: {dataset_images_array.shape}")
print(f"Hyperdiffusion images shape: {hyperdiffusion_images_array.shape}")
print(f"Stable hyperdiffusion images shape: {stable_hyperdiffusion_images_array.shape}")	

In [80]:
import torchmetrics

ssim = torchmetrics.StructuralSimilarityIndexMeasure(data_range=1.0)
hyperdiffusion_images = einops.repeat(hyperdiffusion_images_array, "b h w -> b c h w", c=3)
stable_hyperdiffusion_images = einops.repeat(stable_hyperdiffusion_images_array, "b h w -> b c h w", c=3)
dataset_images = einops.repeat(dataset_images_array, "b h w -> b c h w", c=3)
score = ssim(hyperdiffusion_images, dataset_images)
print(f"SSIM score hyperdiffusion: {score}")
score = ssim(stable_hyperdiffusion_images, dataset_images)
print(f"SSIM score stable hyperdiffusion: {score}")

#### 5.2 Weights comparison

In [106]:
#TODO:

### 6. VAE Latent Space

#### Sample from latent space

In [None]:
num_samples = 50
latent_images = model_utils.sample_from_latent_space(
    vae, mlp, ref_cp, dataset_tokenized[0][2], config_ae.model.n_tokens, config_ae.model.latent_dim, n_samples
)
visualization_utils.plot_n_images(latent_images)

##### Interpolate latent space

In [None]:
num_interpolation_steps = 10

num_interpolations = 100

for i in range(num_interpolations):
    idx = random.randint(0, len(dataset_tokenized) - 1)
    idy = random.randint(0, len(dataset_tokenized) - 1)
    latent_vectors = [dataset_tokenized[idx][0], dataset_tokenized[idy][0]]
    latent_positions = [dataset_tokenized[idx][2], dataset_tokenized[idy][2]]
    latent_vectors = torch.stack(latent_vectors)
    latent_positions = torch.stack(latent_positions)

    vae.eval()
    with torch.no_grad():
        latent_vector, _, _ = vae.encoder(latent_vectors, latent_positions)

    latent_images = model_utils.interpolate_latent_space(
        latent_vector[0], latent_vector[1], num_interpolation_steps, dataset_tokenized[0][2], vae, ref_cp, mlp
    )
    print("Indices:", idx, idy)
    visualization_utils.plot_n_images(latent_images, row=True)

In [47]:
images, recons, mse_weights, mse_images = model_utils.get_n_images_and_mses(vae, dataset_tokenized, mlp, len(dataset_tokenized), device=device, random=False)

In [48]:
best_mse_weight_indices = model_utils.get_best_samples(mse_weights, best_n=10)
best_mse_image_indices = model_utils.get_best_samples(mse_images, best_n=10)

In [None]:
for i in best_mse_image_indices:
    for j in best_mse_image_indices:
        latent_vectors = [dataset_tokenized[i][0], dataset_tokenized[j][0]]
        latent_positions = [dataset_tokenized[i][2], dataset_tokenized[j][2]]
        latent_vectors = torch.stack(latent_vectors)
        latent_positions = torch.stack(latent_positions)

        vae.eval()
        with torch.no_grad():
            latent_vector, _, _ = vae.encoder(latent_vectors, latent_positions)

        latent_images = model_utils.interpolate_latent_space(
            latent_vector[0], latent_vector[1], num_interpolation_steps, dataset_tokenized[0][2], vae, ref_cp, mlp
        )
        visualization_utils.plot_n_images(latent_images, row=True)

### 7. Get nearest neighbors

In [123]:
num_samples = 30
k = 2
dataset_images, hyperdiffusion_images, mse_images, mse_weights, distances = model_utils.generate_nearest_neighbors(hyperdiffusion, 
                                                                                                                  inr=mlp,
                                                                                                                  dataset=dataset_flattened,
                                                                                                                  num_samples=30,
                                                                                                                  k=k)

In [None]:
visualization_utils.plot_diffusion_knn(
    hyperdiffusion_images,
    dataset_images,
    mse_images,
    mse_weights,
    k=k,
    num_samples=num_samples
)

In [125]:
dataset_images, stable_hyperdifusion_images, mse_images, mse_weights, distances = model_utils.generate_nearest_neighbors(stable_hyperdiffusion, 
                                                                                                                  inr=mlp,
                                                                                                                  dataset=dataset_tokenized,
                                                                                                                  num_samples=30,
                                                                                                                  k=k)

In [None]:
visualization_utils.plot_diffusion_knn(
    stable_hyperdifusion_images,
    dataset_images,
    mse_images,
    mse_weights,
    k=k,
    num_samples=num_samples
)

## Analse latent space

In [9]:
dataset_tokens = [dataset_tokenized[i][0] for i in range(len(dataset_tokenized))]
dataset_tokens = torch.stack(dataset_tokens)
dataset_positions = [dataset_tokenized[i][2] for i in range(len(dataset_tokenized))]
dataset_positions = torch.stack(dataset_positions)
print("Dataset tokens:", dataset_tokens.shape)
print("Dataset positions:", dataset_positions.shape)

In [10]:
path = "evaluation/stable_hyperdiffusion"
file_paths = [os.path.join(root, file) for root, _, files in os.walk(path) for file in files]
print(file_paths)
latens = []
tokens = []
pos = []
for file in file_paths:
    samples_tokenized, samples_reconstructed, positions = torch.load(file, map_location="cpu")
    latens.append(samples_tokenized)
    tokens.append(samples_reconstructed)
    pos.append(positions)
stable_hyperdiffusion_latent = torch.cat(latens, dim=0)
stable_hyperdiffusion_recon = torch.cat(tokens, dim=0)
stable_hyperdiffusion_pos = torch.cat(pos, dim=0)
print(stable_hyperdiffusion_latent.shape)
print(stable_hyperdiffusion_recon.shape)
print(stable_hyperdiffusion_pos.shape)

In [12]:
vae.eval()
with torch.no_grad():
    dataset_latent,_,_ = vae.encoder(dataset_tokens, dataset_positions)
print(dataset_latent.shape)

In [17]:
dataset_latent_flattened = dataset_latent.flatten(start_dim =1)
stable_hyperdiffusion_latent_flattened = stable_hyperdiffusion_latent.flatten(start_dim =1)
print(dataset_latent_flattened.shape)
print(stable_hyperdiffusion_latent_flattened.shape)

In [18]:
# Get cosine similarity for closest latent vectors
indices, distances = metrics.find_nearest_neighbor(stable_hyperdiffusion_latent_flattened, dataset_latent_flattened)
print("Cosine similarity for latent vectors: ", distances.mean().item())

In [19]:
# Get cosine similarity for closest latent vectors
num_images = 200
distances_total = []
for i in range(len(stable_hyperdiffusion_latent)//num_images):
    latent_vectors = stable_hyperdiffusion_latent[i*num_images:i*num_images+num_images]
    indices, distances = metrics.find_nearest_neighbor(latent_vectors, dataset_latent, metric="euclidean")
    distances_total.extend(distances)
print("Euclidean distance for latent vectors: ", np.mean(distances_total))

In [20]:
# Get manhattan distance for closest latent vectors
num_images = 200
distances_total = []
for i in range(len(stable_hyperdiffusion_latent)//num_images):
    latent_vectors = stable_hyperdiffusion_latent[i*num_images:i*num_images+num_images]
    indices, distances = metrics.find_nearest_neighbor(latent_vectors, dataset_latent, metric="manhattan")
    distances_total.extend(distances)
print("Manhattan distance for latent vectors: ", np.mean(distances_total))

# 8. Training and inference speed

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [34]:
stable_hyperdiffusion = stable_hyperdiffusion.to(device)
hyperdiffusion = hyperdiffusion.to(device)

In [41]:
%timeit stable_hyperdiffusion.generate_samples(1000)

In [42]:
%timeit hyperdiffusion.generate_samples(1000)

In [37]:
from src.data.inr_dataset import DataHandler

In [38]:
config: DiffusionExperimentConfig = DiffusionExperimentConfig.sanity()
config.data = DataConfig.small()
config.data.batch_size = 128
config.data.sample_limit = None
print(config.data)

In [39]:
# Setup dataloaders
data_handler = DataHandler(config, use_autoencoder=False)
data_handler.setup()
hyperdiffusion_loader = data_handler.train_dataloader()

In [54]:
from transformers import get_linear_schedule_with_warmup
import einops

def train(model, data_loader, epochs, device):
    optimizer = torch.optim.AdamW(
        model.parameters()
    )

    total_steps = len(data_loader) * epochs
    warmup_steps = int(total_steps * 0.1)

    # Linear warmup scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )
    model = model.to(device)
    model.train()
    for epoch in range(epochs):
        for i, batch in enumerate(data_loader):
            if model.autoencoder is None:
                
                batch = batch.to(device)
                loss = model._compute_loss(batch)
            else:
                original_tokens, _, original_positions = batch
                original_tokens = original_tokens.to(device)
                original_positions = original_positions.to(device)

                with torch.no_grad():
                    latent_vector, _, _ = model.autoencoder.encoder(
                        original_tokens, original_positions
                    )  # Shape: (batch_size, n_tokens, latent_dim)

                # Flatten all dimensions after batch (b) into a single dimension (-1 means auto-calculate size)
                flattened_latent = latent_vector.flatten(start_dim=1)

                loss = model._compute_loss(flattened_latent)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

In [56]:
%timeit train(hyperdiffusion, hyperdiffusion_loader, epochs=2, device=device)

In [44]:
# Setup dataloaders
data_handler = DataHandler(config, use_autoencoder=True)
data_handler.setup()
stable_hyperdiffusion_loader = data_handler.train_dataloader()

In [55]:
%timeit train(stable_hyperdiffusion, stable_hyperdiffusion_loader, epochs=2, device=device)