# Evaluation


In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
sys.path.append('../') #act as if we are one directory higher so imports work 
import torch
from latent_to_timestep_model import LTT_model
from dataset import load_data_from_dir
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from torch import nn
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import plotly.express as px
import torch.nn.functional as F
from dataset import LTTDataset

In [None]:
path = "/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0"
steps = 5
latents, targets, conditions, unconditions, optimal_params = load_data_from_dir(data_folder=path, limit=50, use_optimal_params=True, steps=steps)
#optimal_params



In [None]:
dataset = LTTDataset(dir = "/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/validation",
                     size = 1000,
                     train_flag=False)

In [None]:
second_outputs = []

for i in range(len(dataset)):
    _, second_output, _ = dataset[i]
    second_outputs.append(second_output)
len(second_outputs)

## LD3 Best timesteps


In [2]:
n3_params = torch.tensor([0.6048, 1.0274, 0.6334, 1.8439], device='cuda:0')
n5_params = torch.tensor([0.8088, 1.1801, 0.9390, 0.7322, 0.7591, 2.0050], device='cuda:0')
n7_params = torch.tensor([1.1434, 1.2401, 0.9985, 0.6071, 0.9339, 0.1873, 0.8551, 1.9311], device='cuda:0')
n10_params = torch.tensor([1.6245, 1.3128, 1.5374, 0.6975, 0.8498, 0.9843, 1.3483, 0.6511, 1.1129, 1.2806, 1.6264], device='cuda:0')


n3_params = F.softmax(n3_params, dim=0)
n5_params = F.softmax(n5_params, dim=0)
n7_params = F.softmax(n7_params, dim=0)
n10_params = F.softmax(n10_params, dim=0)

print(f"n3_params:\n{n3_params}")
print(f"n5_params:\n{n5_params}")
print(f"n7_params:\n{n7_params}")
print(f"n10_params:\n{n10_params}")

n3_params:
tensor([0.1427, 0.2178, 0.1468, 0.4927], device='cuda:0')
n5_params:
tensor([0.1140, 0.1652, 0.1298, 0.1056, 0.1084, 0.3770], device='cuda:0')
n7_params:
tensor([0.1300, 0.1432, 0.1124, 0.0760, 0.1054, 0.0500, 0.0974, 0.2857],
       device='cuda:0')
n10_params:
tensor([0.1337, 0.0979, 0.1225, 0.0529, 0.0616, 0.0705, 0.1014, 0.0505, 0.0802,
        0.0948, 0.1340], device='cuda:0')


### LD3 Timesteps Evaluation

In [None]:
import torch
from torch.nn import functional as F
import time
import os

from dataset import load_data_from_dir
from trainer import LD3Trainer, ModelConfig, TrainingConfig, DiscretizeModelWrapper
from utils import (
    get_solvers,
    parse_arguments,
    adjust_hyper,
    set_seed_everything,
    move_tensor_to_device
)
from models import prepare_stuff

In [None]:
def evaluate_params(params: torch.tensor) -> float:    
    start_time = time.time()
    args = parse_arguments([
        "--all_config", "configs/cifar10.yml",
        "--data_dir", "train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0",
        "--num_train", "0",
        "--num_valid", "50",
        "--steps", str(len(params)-1),
        "--training_rounds_v1", "1",
        "--seed", "0",
    ])

    set_seed_everything(args.seed)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    wrapped_model, _, decoding_fn, noise_schedule, latent_resolution, latent_channel, _, _ = prepare_stuff(args)
    adjust_hyper(args, latent_resolution, latent_channel)
    solver, steps, solver_extra_params = get_solvers(
        args.solver_name,
        NFEs=args.steps,
        order=args.order,
        noise_schedule=noise_schedule,
        unipc_variant=args.unipc_variant,
    )
    latents, targets, _, _, _ = load_data_from_dir( #this is what we take from trainig, targets are original images and latens latent goal
        data_folder=args.data_dir, limit=args.num_train + args.num_valid, use_optimal_params=False
    )

    training_config = TrainingConfig(
        train_data=latents,
        valid_data=latents,
        train_batch_size=args.main_train_batch_size,
        valid_batch_size=args.main_valid_batch_size,
        lr_time_1=args.lr_time_1,
        shift_lr=args.shift_lr,
        shift_lr_decay=args.shift_lr_decay,
        min_lr_time_1=args.min_lr_time_1,
        win_rate=args.win_rate,
        patient=args.patient,
        lr_time_decay=args.lr_time_decay,
        momentum_time_1=args.momentum_time_1,
        weight_decay_time_1=args.weight_decay_time_1,
        loss_type=args.loss_type,
        visualize=args.visualize,
        no_v1=args.no_v1,
        prior_timesteps=args.gits_ts,
        match_prior=args.match_prior,
    )
    model_config = ModelConfig(
        net=wrapped_model,
        decoding_fn=decoding_fn,
        noise_schedule=noise_schedule,
        solver=solver,
        solver_name=args.solver_name,
        order=args.order,
        steps=steps,
        prior_bound=args.prior_bound,
        resolution=latent_resolution,
        channels=latent_channel,
        time_mode=args.time_mode,
        solver_extra_params=solver_extra_params,
        device=device,
    )
    trainer = LD3Trainer(model_config, training_config)
    dis_model = DiscretizeModelWrapper( #Changed through LTT
            lambda_max=trainer.lambda_max,
            lambda_min=trainer.lambda_min,
            noise_schedule=trainer.noise_schedule,
            time_mode = trainer.time_mode,
        )
    loss_list = torch.zeros(len(targets))
    for i, (img, latent) in enumerate(zip(targets, latents)):
    
        img, latent = move_tensor_to_device(img, latent, device = device)
        
        timestep = dis_model.convert(params.unsqueeze(0))

        x_next = trainer.noise_schedule.prior_transformation(latent)
        x_next = trainer.solver.sample_simple(
            model_fn=trainer.net,
            x=x_next,
            timesteps=timestep[0],
            order=trainer.order,
            NFEs=trainer.steps,
            **trainer.solver_extra_params,
            )
        x_next = trainer.decoding_fn(x_next)
        trainer.loss_vector = trainer.loss_fn(img.float(), x_next.float()).squeeze()
        loss = trainer.loss_vector.mean() 
        loss_list[i] = loss
    print("Time taken: ", time.time() - start_time)
    return loss_list.mean().item()




for name, params in zip(["n3", "n5", "n7", "n10"], [n3_params, n5_params, n7_params, n10_params]):
    print(f"Loss for {name}: {evaluate_params(params)}")

## LTT Model

In [None]:
model_path, steps = "/netpool/homes/connor/DiffusionModels/LD3_connor/logs/logs_cifar10/N10-val50-train450-rv12-seed0/final_ltt_model.pt", 10
model_path, steps = "/netpool/homes/connor/DiffusionModels/LD3_connor/logs/logs_cifar10/N10-val50-train50-rv12-seed0/ltt_model.pt", 10

model_path, steps = "/netpool/homes/connor/DiffusionModels/LD3_connor/logs/logs_cifar10/N10-val50-train450-rv12-seed0-fixed_scaling/final_ltt_model.pt", 10
model_path, steps  = "/netpool/homes/connor/DiffusionModels/LD3_connor/logs/logs_cifar10/N7-val50-train450-rv12-seed0-fixed_scaling/final_ltt_model.pt", 7
model_path, steps = "/netpool/homes/connor/DiffusionModels/LD3_connor/logs/logs_cifar10/N5-val50-train450-rv12-seed0-fixed_scaling/final_ltt_model.pt", 5
# model_path, steps = "/netpool/homes/connor/DiffusionModels/LD3_connor/logs/logs_cifar10/N3-val50-train450-rv12-seed0-fixed_scaling/final_ltt_model.pt", 3
model_path, steps = "/netpool/homes/connor/DiffusionModels/LD3_connor/logs/logs_cifar10/LTT_batch3_moreData_N5-val50-train450-r5/final_ltt_model.pt", 5
without_dropout_model_path, steps = "/netpool/homes/connor/DiffusionModels/LD3_connor/logs/logs_cifar10/LTT_after_ltt_change_batch1_without_dropout_N5-val50-train450-r10/final_ltt_model.pt", 5
with_dropout_model_path, steps = "/netpool/homes/connor/DiffusionModels/LD3_connor/logs/logs_cifar10/LTT_after_ltt_change_batch1_with_dropout_N5-val50-train450-r10/final_ltt_model.pt", 5
trained_on_optimal_without_dropout, steps = "/netpool/homes/connor/DiffusionModels/LD3_connor/runs/RandomModels/model_lr0.0001_batch5_without_dropout.pth", 5
after_ltt_change, steps = "/netpool/homes/connor/DiffusionModels/LD3_connor/logs/logs_cifar10/LTT_After_LTT_DatasetAdjustement_batch3_N5-val50-train450-r2/final_ltt_model.pt", 5

ltt_model = LTT_model(steps=steps)
state_dict = torch.load(trained_on_optimal_without_dropout, weights_only=True)
ltt_model.load_state_dict(state_dict)  # Load the model state

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(ltt_model))

In [None]:
params_list = ltt_model.forward(torch.stack(second_outputs)) 

#visualize as violin plot over each of the 10 timesteps
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

params_list = params_list.detach().numpy()
params_list = params_list.reshape(-1, steps+1)
params_list = pd.DataFrame(params_list, columns=[f"{i}" for i in range(steps+1)])
sns.violinplot(data=params_list)
plt.show()


In [None]:
params_list


### U-net encoding

In [None]:
# Create a dictionary to store the output
hook_storage = {}

# Define the hook function using a closure
def get_hook(storage):
    def hook_fn(module, input, output):
        storage["unet_output"] = output  # Store output in the dictionary
    return hook_fn
# Register the hook on the UNet
hook_handle = ltt_model.unet.register_forward_hook(get_hook(hook_storage))

# Run the forward pass
output = ltt_model.forward(latents[0].unsqueeze(0))

# Retrieve the stored UNet output
unet_output = hook_storage["unet_output"]
print("Stored UNet Output:", unet_output.shape)




In [None]:
def mse(tensor1, tensor2):
    return torch.mean((tensor1 - tensor2) ** 2)

In [None]:
encodings = []
num_matrices = 20
for latent in latents[:num_matrices]:
    output = ltt_model.forward(latent.unsqueeze(0))
    encodings.append(hook_storage["unet_output"])


mse_matrix = np.zeros((num_matrices, num_matrices))

for i in range(num_matrices):
    for j in range(num_matrices):
        mse_matrix[i, j] = mse(encodings[i], encodings[j])

# Plot the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(mse_matrix, annot=False, cmap="viridis", linewidths=0.5)
plt.title("Pairwise MSE Heatmap")
plt.xlabel("Matrix Index")
plt.ylabel("Matrix Index")
plt.show()



In [None]:
encodings[1]

In [None]:
encodings[2]

## Optimal Timesteps Per Image

In [None]:
#load from torch.save(loss_matrix, os.path.join(args.data_dir, f"loss_matrix.pt"))
optimal_dir = os.path.join(path, "OldOptimSteps")
# loss_matrix = torch.load(os.path.join(data_dir, f"loss_matrix.pt"))

In [None]:
loss_matrix

In [None]:
#plot loss matrix

plt.figure(figsize=(10, 8))
sns.heatmap(loss_matrix, annot=False, cmap="viridis", linewidths=0.5)
plt.title("Pairwise MSE Heatmap")
plt.xlabel("Matrix Index")
plt.ylabel("Matrix Index")
plt.show()

In [None]:
#find lowest loss matrix
min_loss = torch.min(loss_matrix, axis=1)
min_loss

In [None]:
plt.figure(figsize=(10, 8))
for i in range(2):
    min_values = [torch.min(loss_matrix[i, :j+1,]).item() for j in range(300)]
    plt.plot(range(300), min_values, label=f'Line {i+1}')
plt.title("Lowest Value of Second Dimension Up to That Point")
plt.xlabel("Steps")
plt.ylabel("Lowest Value")
plt.legend()
plt.show()

In [None]:
gradient_matrix = torch.load(os.path.join(data_dir, f"loss_grad_matrix.pt"))
gradient_matrix.shape

In [None]:
plt.figure(figsize=(10, 8))
sns.heatmap(gradient_matrix[2], annot=False, cmap="viridis", linewidths=0.5)
plt.title("Pairwise MSE Heatmap")
plt.xlabel("Matrix Index")
plt.ylabel("Matrix Index")
plt.show()


In [None]:
plt.figure(figsize=(10, 8))
for i in range(2):
    plt.plot(range(300), abs(gradient_matrix[1, :, i]), label=f'Line {i+1}')
plt.yscale('log')
plt.title("Gradient Matrix Lines")
plt.xlabel("Steps")
plt.ylabel("Gradient Value (log scale)")
plt.legend()
plt.show()


In [None]:
test_tensor = torch.tensor([1,2,3,4,5,6], dtype=torch.float32)
test_tensor2 = torch.tensor([1,2,3,4,5,6], dtype=torch.float32)*2
test_tensor3 = torch.tensor([1,2,3,4,5,6], dtype=torch.float32)+1

m = nn.Softmax()
print(m(test_tensor))
print(m(test_tensor2))
print(m(test_tensor3))

### Compare optimal params for different images

In [None]:
pt_files = [f for f in os.listdir(path) if "optimal_params" in f]
for i, file_name in enumerate(sorted(pt_files)[:]): #load all training files previously created
    file_path = os.path.join(path, file_name)
    params, loss = torch.load(file_path, weights_only=True)
    data = params.detach().numpy()
    losses = loss.detach().numpy()

    # Standardize the data
    scaler = StandardScaler()
    data_scaled = scaler.fit_transform(data)

    # Apply PCA to reduce to 2 dimensions
    pca = PCA(n_components=2)
    principal_components = pca.fit_transform(data_scaled)

    # Plot the results
    plt.figure(figsize=(10, 6))
    scatter = plt.scatter(principal_components[:, 0], principal_components[:, 1],
                        c=losses, cmap='viridis', edgecolor='k', s=100)
    plt.colorbar(scatter, label='Loss')
    plt.xlabel('Principal Component 1')
    plt.ylabel('Principal Component 2')
    plt.title('PCA Visualization Colored by Loss')
    plt.grid(True)
    plt.show()
    print("params\n", data)
    print("losses\n", losses)
    if i > 5:
        break

In [None]:
pt_files = [f for f in os.listdir(path) if "optimal_params" in f]
pt_files = sorted(pt_files)[:6]  # Load first 6 training files

shapes = ['circle', 'square', 'diamond', 'triangle-up', 'triangle-down', 'cross']  # Different marker shapes

all_data = []
best_losses = []



for i, file_name in enumerate(pt_files):
    file_path = os.path.join(path, file_name)
    # Load parameters and loss
    params, loss = torch.load(file_path, weights_only=True)  
    data = params.detach().numpy()
    losses = loss.detach().numpy()
    all_data.append(data)
    best_losses.append(losses)





all_data = np.stack(all_data).reshape(-1, steps+1)
best_losses = np.stack(best_losses).reshape(-1)

pca = PCA(n_components=2)
principal_components = pca.fit_transform(all_data)
    
# Convert to DataFrame
df = pd.DataFrame(principal_components, columns=["PC1", "PC2"])
df["Loss"] = best_losses
df["Shape"] = [x for x in shapes for _ in range(10)]
df["params"] = list(all_data)
df["params"] = df["params"].apply(lambda x: [round(v, 4) for v in x])

# # Create interactive 3D scatter plot using Plotly
fig = px.scatter(df, x="PC1", y="PC2", color="Loss", symbol="Shape",
                    hover_data={"Loss": ":.4f", "PC1": False, "PC2": False, "params": True},
                    color_continuous_scale="viridis",
                    title="2D PCA Visualization Colored by Loss with Different Shapes")

# Move color bar to the left
fig.update_layout(coloraxis_colorbar=dict(x=-0.2))  

fig.update_traces(marker=dict(size=5, line=dict(width=1, color="black")))

# Show interactive 3D plot
fig.show() 

In [None]:
pt_files = [f for f in os.listdir(path) if "optimal_params" in f]
pt_files = sorted(pt_files)[:6]  # Load first 6 training files

shapes = ['circle', 'square', 'diamond', 'triangle-up', 'triangle-down', 'cross']  # Different marker shapes

all_data = []
best_losses = []



for i, file_name in enumerate(pt_files):
    file_path = os.path.join(path, file_name)
    # Load parameters and loss
    params, loss = torch.load(file_path, weights_only=True)  
    data = params.detach().numpy()
    losses = loss.detach().numpy()
    all_data.append(data)
    best_losses.append(losses)


all_data = np.stack(all_data).reshape(-1, steps+1)
best_losses = np.stack(best_losses).reshape(-1)

pca = PCA(n_components=3)
principal_components = pca.fit_transform(all_data)
    
# Convert to DataFrame
df = pd.DataFrame(principal_components, columns=["PC1", "PC2", "PC3"])
df["Loss"] = best_losses
df["Shape"] = [x for x in shapes for _ in range(10)]
df["params"] = list(all_data)
df["params"] = df["params"].apply(lambda x: [round(v, 4) for v in x])

# # Create interactive 3D scatter plot using Plotly
fig = px.scatter_3d(df, x="PC1", y="PC2", z="PC3", color="Loss", symbol="Shape",
                    hover_data={"Loss": ":.4f", "PC1": False, "PC2": False, "PC3": False, "params": True},
                    color_continuous_scale="viridis",
                    title="3D PCA Visualization Colored by Loss with Different Shapes")

# Move color bar to the left
fig.update_layout(coloraxis_colorbar=dict(x=-0.2))  

fig.update_traces(marker=dict(size=5, line=dict(width=1, color="black")))

# Show interactive 3D plot
fig.show() 

In [None]:
plt.figure(figsize=(10, 8))
sns.histplot(best_losses, bins=50, kde=True)
plt.title("Loss Distribution")
plt.xlabel("Loss")
plt.ylabel("Frequency")
plt.show()


#### Only best trial

In [None]:
pt_files = [f for f in os.listdir(optimal_dir) if "optimal_params" in f]
pt_files = sorted(pt_files) # Load first 6 training files

all_data = []
best_losses = []
all_losses = []

for i, file_name in enumerate(pt_files):
    file_path = os.path.join(optimal_dir, file_name)
    # Load parameters and loss
    params, loss = torch.load(file_path, weights_only=True)  
    data = params.detach().numpy()[0]
    losses = loss.detach().numpy()
    all_data.append(data)
    best_losses.append(losses[0])
    all_losses.append(losses)






all_data = np.stack(all_data).reshape(-1, steps+1)
best_losses = np.stack(best_losses).reshape(-1)

pca = PCA(n_components=3)
principal_components = pca.fit_transform(all_data)
    
# Convert to DataFrame
df = pd.DataFrame(principal_components, columns=["PC1", "PC2", "PC3"])
df["Loss"] = best_losses
df["params"] = list(all_data)
df["params"] = df["params"].apply(lambda x: [round(v, 4) for v in x])

# # Create interactive 3D scatter plot using Plotly
fig = px.scatter_3d(df, x="PC1", y="PC2", z="PC3", color="Loss",
                    hover_data={"Loss": ":.4f", "PC1": False, "PC2": False, "PC3": False, "params": True},
                    color_continuous_scale="viridis",
                    title="3D PCA Visualization Colored by Loss with Different Shapes")

# Move color bar to the left
fig.update_layout(coloraxis_colorbar=dict(x=-0.2))  

fig.update_traces(marker=dict(size=5, line=dict(width=1, color="black")))

# Show interactive 3D plot
fig.show()

In [None]:
#plot loss distribution of all_losses

plt.figure(figsize=(10, 8))
sns.histplot(best_losses, bins=50, kde=True)
plt.title("Loss Distribution")
plt.xlabel("Loss")
plt.ylabel("Frequency")
plt.show()



In [None]:
import plotly.graph_objects as go

# Assuming all_losses and best_losses are defined
all_losses = np.stack(all_losses).reshape(-1)

fig = go.Figure()

# Add histogram for all_losses
fig.add_trace(go.Histogram(x=all_losses, nbinsx=75, histnorm='probability density', 
                           name="All Losses", marker_color='blue', opacity=0.5))

# Add histogram for best_losses
fig.add_trace(go.Histogram(x=best_losses, nbinsx=75, histnorm='probability density', 
                           name="Best Losses", marker_color='red', opacity=0.5))

# Update layout
fig.update_layout(
    title="Loss Distribution",
    xaxis_title="Loss",
    yaxis_title="Density",
    barmode='overlay',  # Overlay both histograms
    template="plotly_white"  # Optional: use a clean background
)

fig.show()
print(f"Mean lost of first 50 validation losses: {np.mean(all_losses[:500]):.4f}")
print(f"Mean lost of 50 validation best losses: {np.mean(best_losses[:50]):.4f}")

### New generated ones

In [None]:
torch.load("/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/optimal_params_000003_N10_steps5.pth")

#### With 30 iterations and only 1 trial

In [None]:
opt_t_dir = "/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/train/opt_t"
opt_t_files = [f for f in os.listdir(opt_t_dir) if f.endswith('.pth')]
opt_t_files = sorted(opt_t_files) 


loss_np = np.zeros((len(opt_t_files)))

for i,file_path in enumerate(opt_t_files):
    opt_t_path = os.path.join(opt_t_dir, file_path)
    opt_t = torch.load(opt_t_path, weights_only=True)[1]
    loss_np[i] = opt_t


print(f"Mean loss of optimal train timesteps: {np.mean(loss_np):.4f}")



In [None]:
opt_t_dir = "/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/validation/opt_t"
opt_t_files = [f for f in os.listdir(opt_t_dir) if f.endswith('.pth')]
opt_t_files = sorted(opt_t_files) 


loss_np = np.zeros((len(opt_t_files)))

for i,file_path in enumerate(opt_t_files):
    opt_t_path = os.path.join(opt_t_dir, file_path)
    opt_t = torch.load(opt_t_path, weights_only=True)[1]
    loss_np[i] = opt_t


print(f"Mean loss of optimal validation timesteps: {np.mean(loss_np):.4f}")

### With 50 iterations and 3 trials

In [None]:
opt_t_dir = "/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/train/opt_t_clever_initialisation"
opt_t_files = [f for f in os.listdir(opt_t_dir) if f.endswith('.pth')]
opt_t_files = sorted(opt_t_files) 


loss_np = np.zeros((len(opt_t_files)))

for i,file_path in enumerate(opt_t_files):
    opt_t_path = os.path.join(opt_t_dir, file_path)
    opt_t = torch.load(opt_t_path, weights_only=True)[1][0]
    loss_np[i] = opt_t


print(f"Mean loss of optimal train timesteps: {np.mean(loss_np):.4f}")



In [None]:
opt_t_dir = "/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/validation/opt_t_clever_initialisation"
opt_t_files = [f for f in os.listdir(opt_t_dir) if f.endswith('.pth')]
opt_t_files = sorted(opt_t_files) 


loss_np = np.zeros((len(opt_t_files)))

for i,file_path in enumerate(opt_t_files):
    opt_t_path = os.path.join(opt_t_dir, file_path)
    opt_t = torch.load(opt_t_path, weights_only=True)[1][0]
    loss_np[i] = opt_t


print(f"Mean loss of optimal validation timesteps: {np.mean(loss_np):.4f}")

## Generating latent -> image pairs efficiently in mass

In [None]:
from PIL import Image
import torchvision.transforms as transforms
import torch
import numpy as np

# Load the PNG image
image_path = '/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/train/img/000000.png'
image = Image.open(image_path).convert("RGB")

# Define a transform to convert the image to a tensor
transform_to_tensor = transforms.ToTensor()

# Apply the transform to the image
tensor_from_img = transform_to_tensor(image)

# Verify the tensor shape and type
print(f"Tensor shape: {tensor_from_img.shape}")  # Should print: torch.Size([C, H, W])
print(f"Tensor dtype: {tensor_from_img.dtype}")  # Should print: torch.float32



In [None]:
def png_to_tensor(image):
    # Load image

    # Convert to NumPy array and normalize (0-255 → 0-1)
    image_np = np.array(image, dtype=np.float32) / 255.0

    # Reorder dimensions (H, W, C) → (C, H, W)
    image_tensor = torch.tensor(image_np).permute(2, 0, 1)

    # Reverse normalization ((x * 2) - 1)
    image_tensor = (image_tensor * 2.0) - 1.0  

    return image_tensor

In [None]:
tensor_original = torch.load("/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/test_images/img_000000.pt")
print(f"Tensor shape: {tensor_original.shape}")  # Should print: torch.Size([C, H, W])
print(f"Tensor dtype: {tensor_original.dtype}")  # Should print: torch.float32


to_pil = transforms.ToPILImage()
image_pil = to_pil((tensor_original +1 ) / 2)
# image_pil = to_pil(torch.clip((tensor_original +1 ) / 2, 0,1))
plt.imshow(image_pil)


In [None]:
img_tensor = transform_to_tensor(image_pil)
img_tensor = (img_tensor - 0.5) * 2



In [None]:
torch.sum(img_tensor - tensor_original)

In [None]:
from PIL import Image
import numpy as np

# Convert the tensor to a numpy array and transpose the dimensions to match the image format
tensor_original_np = tensor_original.permute(1, 2, 0).numpy()

# Convert the tensor values from the range [-1, 1] to [0, 255]
tensor_original_np = ((tensor_original_np + 1) * 127.5).astype(np.uint8)

# Create a PIL image from the numpy array
image_pil = Image.fromarray(tensor_original_np)

# Save the image
image_pil.save("/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/test_images/image.png")

In [None]:
torch.sum(png_to_tensor(image) - torch.clip(tensor_original, -1,1))

In [None]:
image = np.transpose(tensor_original, (1, 2, 0))
convert =  lambda x: (x + 1.0) / 2.0
image = convert(image)


# Display the image
plt.imshow(image)
plt.axis("off")  # Hide axes
plt.savefig("/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/test_images/original_image.png")
plt.show()

In [None]:

import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

# Create a random tensor [3, 32, 32] with values in range [-1, 1]
original_tensor = torch.rand(3, 32, 32) * 2 - 1  # Scale to [-1, 1]

# Convert tensor to image
def tensor_to_image(tensor):
    tensor = (tensor + 1.0) / 2.0  # Scale from [-1, 1] to [0, 1]
    tensor = torch.clamp(tensor, 0, 1)  # Ensure valid range
    image_np = (tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
    return Image.fromarray(image_np, "RGB")

# Convert image back to tensor
def image_to_tensor(image):
    image_np = np.array(image).astype(np.float32) / 255.0  # Convert to [0,1]
    tensor = torch.tensor(image_np).permute(2, 0, 1)  # Convert to [C, H, W]
    tensor = (tensor * 2.0) - 1.0  # Scale back to [-1, 1]
    return tensor

# Convert tensor to image and display it
image = tensor_to_image(original_tensor)
plt.imshow(image)
plt.axis("off")
plt.show()

# Convert image back to tensor
recovered_tensor = image_to_tensor(image)

# Calculate difference
difference = torch.abs(original_tensor - recovered_tensor)
error = torch.mean(difference).item()  # Mean absolute error

print(f"Mean Absolute Error between original and recovered tensor: {error}")


In [None]:
def visualize(tensor: torch.Tensor) -> None:
    convert =  lambda x: (x + 1.0) / 2.0
    samples_raw = convert(tensor.unsqueeze(0))
    samples = np.clip(  #10 because of batch size
                    samples_raw.permute(0, 2, 3, 1).cpu().numpy() * 255.0, 0, 255
                ).astype(np.uint8)
    image_np = samples.reshape((-1, 32, 32, 3))[0]



    plt.imshow(image_np)
    plt.title('Generated Image')
    plt.axis('off')
    plt.show()

# PIL.Image.fromarray(image_np, "RGB").save(image_path)

In [None]:
tensor_original_np = tensor_original.detach().cpu().numpy()
tensor_from_img_np = tensor_from_img.detach().cpu().numpy()
undo_convert = lambda x: (x * 2.0) - 1.0
tensor_from_img_np = undo_convert(tensor_from_img_np)
tensor_original_np = np.clip(tensor_original_np, 0, 1)
tensor_from_img_np = np.clip(tensor_from_img_np, 0, 1)
np.sum(tensor_original_np - tensor_from_img_np)


In [None]:
# Plot the original tensor image and the tensor from the loaded image side by side
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Convert tensors to numpy arrays for plotting
tensor_original_np = tensor_original.permute(1, 2, 0).numpy()
tensor_from_img_np = tensor_from_img.permute(1, 2, 0).numpy()
undo_convert = lambda x: (x * 2.0) - 1.0
tensor_from_img_np = undo_convert(tensor_from_img_np)

# Plot the original tensor image
axes[0].imshow(tensor_original_np)
axes[0].set_title('Original Tensor Image')
axes[0].axis('off')

# Plot the tensor from the loaded image
axes[1].imshow(tensor_from_img_np)
axes[1].set_title('Tensor from Loaded Image')
axes[1].axis('off')

plt.show()

In [None]:
visualize(tensor_original)

In [None]:
from dataset import LTTDataset
path = "/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/validation"
dataset = LTTDataset(path)

for i in range(1):
    img, latent, opt_t = dataset[i]
    # Convert the tensor to a numpy array and transpose the dimensions to match the image format
    img_np = img.permute(1, 2, 0).numpy()
    latent_np = latent.permute(1, 2, 0).numpy()

    # Plot the image and latent side by side
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    axes[0].imshow(img_np)
    axes[0].set_title(f'Image {i}')
    axes[0].axis('off')

    axes[1].imshow(latent_np)
    axes[1].set_title(f'Latent {i}')
    axes[1].axis('off')

    print(opt_t)

    plt.show()

# the second output should be exactly the same proving that the data generator working
for i in range(1):
    img, latent, _ = dataset[i]
    # Convert the tensor to a numpy array and transpose the dimensions to match the image format
    img_np = img.permute(1, 2, 0).numpy()
    latent_np = latent.permute(1, 2, 0).numpy()

    # Plot the image and latent side by side
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    axes[0].imshow(img_np)
    axes[0].set_title(f'Image {i}')
    axes[0].axis('off')

    axes[1].imshow(latent_np)
    axes[1].set_title(f'Latent {i}')
    axes[1].axis('off')

    plt.show()


In [None]:
from dataset import LTTDataset
path = "/netpool/homes/connor/DiffusionModels/LD3_connor/train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0/train/img"
dataset = LTTDataset(path)

for i in range(1):
    img, latent = dataset[i]
    # Convert the tensor to a numpy array and transpose the dimensions to match the image format
    img_np = img.permute(1, 2, 0).numpy()
    latent_np = latent.permute(1, 2, 0).numpy()

    # Plot the image and latent side by side
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    axes[0].imshow(img_np)
    axes[0].set_title(f'Image {i}')
    axes[0].axis('off')

    axes[1].imshow(latent_np)
    axes[1].set_title(f'Latent {i}')
    axes[1].axis('off')

    plt.show()

# the second output should be exactly the same proving that the data generator working correctly
for i in range(1):
    img, latent = dataset[i]
    # Convert the tensor to a numpy array and transpose the dimensions to match the image format
    img_np = img.permute(1, 2, 0).numpy()
    latent_np = latent.permute(1, 2, 0).numpy()

    # Plot the image and latent side by side
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    axes[0].imshow(img_np)
    axes[0].set_title(f'Image {i}')
    axes[0].axis('off')

    axes[1].imshow(latent_np)
    axes[1].set_title(f'Latent {i}')
    axes[1].axis('off')

    plt.show()


## Delta Timestep Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import os
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter  # Add this import
import lpips
from trainer import LD3Trainer, ModelConfig, TrainingConfig, DiscretizeModelWrapper
from utils import get_solvers, move_tensor_to_device, parse_arguments, set_seed_everything

from dataset import load_data_from_dir, LTTDataset
from latent_to_timestep_model import  Delta_LTT_model
from models import prepare_stuff
import torch.optim.lr_scheduler as lr_scheduler
from utils import visual


args = parse_arguments([
    "--all_config", "configs/cifar10.yml",
    "--data_dir", "train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0",
    "--num_train", "1000",
    "--num_valid", "1000",
    "--main_train_batch_size", "200",
    "--main_valid_batch_size", "200",
    "--training_rounds_v1", "1",
    "--log_path", "logs/logs_cifar10",
    "--force_train", "True",
    "--steps", "5",
    "--lr_time_1", "0.00005",
    "--mlp_dropout", "0.0",
    "--log_suffix", "BiggerValidation_GroupNorm_EvalTrue"
])

set_seed_everything(args.seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Dataset
data_dir = 'train_data/train_data_cifar10/uni_pc_NFE20_edm_seed0'
model_dir = "runs_delta_timesteps/models"
steps = 5
lpips_loss_fn = lpips.LPIPS(net='vgg').to(device)


wrapped_model, _, decoding_fn, noise_schedule, latent_resolution, latent_channel, _, _ = prepare_stuff(args)
solver, steps, solver_extra_params = get_solvers(
    args.solver_name,
    NFEs=args.steps,
    order=args.order,
    noise_schedule=noise_schedule,
    unipc_variant=args.unipc_variant,
)

order = args.order  

def custom_collate_fn(batch):
    collated_batch = []
    for samples in zip(*batch):
        if any(item is None for item in samples):
            collated_batch.append(None)
        else:
            collated_batch.append(torch.utils.data._utils.collate.default_collate(samples))
    return collated_batch

valid_dataset = LTTDataset(dir=os.path.join(data_dir, "validation"), size=args.num_valid, train_flag=False, use_optimal_params=False) 
train_dataset = LTTDataset(dir=os.path.join(data_dir, "train"), size=args.num_train, train_flag=True, use_optimal_params=False)

delta_ltt_model = Delta_LTT_model(steps = steps, mlp_dropout=args.mlp_dropout)
delta_ltt_model = delta_ltt_model.to(device)

wrapped_model, _, decoding_fn, noise_schedule, latent_resolution, latent_channel, _, _ = prepare_stuff(args)
solver, steps, solver_extra_params = get_solvers(
    args.solver_name,
    NFEs=args.steps,
    order=args.order,
    noise_schedule=noise_schedule,
    unipc_variant=args.unipc_variant,
)

training_config = TrainingConfig(
    train_data=train_dataset,
    valid_data=valid_dataset,
    train_batch_size=args.main_train_batch_size,
    valid_batch_size=args.main_valid_batch_size,
    lr_time_1=args.lr_time_1,
    shift_lr=args.shift_lr,
    shift_lr_decay=args.shift_lr_decay,
    min_lr_time_1=args.min_lr_time_1,
    win_rate=args.win_rate,
    patient=args.patient,
    lr_time_decay=args.lr_time_decay,
    momentum_time_1=args.momentum_time_1,
    weight_decay_time_1=args.weight_decay_time_1,
    loss_type=args.loss_type,
    visualize=args.visualize,
    no_v1=args.no_v1,
    prior_timesteps=args.gits_ts,
    match_prior=args.match_prior,
)
model_config = ModelConfig(
    net=wrapped_model,
    decoding_fn=decoding_fn,
    noise_schedule=noise_schedule,
    solver=solver,
    solver_name=args.solver_name,
    order=args.order,
    steps=steps,
    prior_bound=args.prior_bound,
    resolution=latent_resolution,
    channels=latent_channel,
    time_mode=args.time_mode,
    solver_extra_params=solver_extra_params,
    device=device,
)
trainer = LD3Trainer(model_config, training_config)


dis_model = DiscretizeModelWrapper( #Changed through LTT
        lambda_max=trainer.lambda_max,
        lambda_min=trainer.lambda_min,
        noise_schedule=trainer.noise_schedule,
        time_mode = trainer.time_mode,
    )


img, latent, _ = valid_dataset[0]
latent = latent.to(device)

In [None]:
# Load the saved state dictionary

group_norm_model = "model_lr5e-05_batch3_nTrain500000_BiggerValidation_GroupNorm_EvalTrue"
rerun_alpha_3 = "model_lr5e-05_batch3_nTrain500000_RerunAlpha"
rerun_alpha_30 = "model_lr5e-05_batch30_nTrain500000_RerunAlpha"
state_dict = torch.load(os.path.join(model_dir, rerun_alpha_30), map_location=device,weights_only=True)

# Load the state dictionary into the delta_ltt_model
delta_ltt_model.load_state_dict(state_dict)

print("Parameters successfully loaded into delta_ltt_model.")

### Training Data


In [None]:
with torch.no_grad():
    delta_ltt_model.eval()
    for i,batch in enumerate(trainer.train_loader):
        img, latent, _ = batch
        latent = latent.to(device)
        img = img.to(device)

        x_next_list = trainer.noise_schedule.prior_transformation(latent) #Multiply with timestep in edm case (x80 in beginning)
        x_next_computed = []
        x_next_list_computed = []
        for x in x_next_list:
            x_next, x_list, _ = trainer.solver.delta_sample_simple(
                model_fn=trainer.net,
                delta_ltt=delta_ltt_model,
                x=x.unsqueeze(0),
                order=trainer.order,
                steps = trainer.steps,
                start_timestep = 80,
                NFEs=trainer.steps,
                condition=None,
                unconditional_condition=None,
                **trainer.solver_extra_params,
            )
            x_next_computed.append(x_next)#This was wrong the whole time?
        
        x_next_computed = torch.cat(x_next_computed, dim=0) 
        loss_vector = trainer.loss_fn(img.float(), x_next_computed.float()).squeeze()
        loss = loss_vector.mean()
        print(f"Validated on iter{i}: {loss.item()}")

### Valdiation Data

In [None]:

all_timesteps = np.zeros((args.num_valid, steps+1))
all_losses = np.zeros((args.num_valid) // args.main_valid_batch_size)

with torch.no_grad():
    delta_ltt_model.eval()
    for i,batch in enumerate(trainer.valid_only_loader):
        img, latent, _ = batch
        latent = latent.to(device)
        img = img.to(device)

        x_next_list = trainer.noise_schedule.prior_transformation(latent) #Multiply with timestep in edm case (x80 in beginning)
        x_next_computed = []
        x_next_list_computed = []
        for j, x in enumerate(x_next_list):
            x_next, x_list, t_list = trainer.solver.delta_sample_simple(
                model_fn=trainer.net,
                delta_ltt=delta_ltt_model,
                x=x.unsqueeze(0),
                order=trainer.order,
                steps = trainer.steps,
                start_timestep = 80,
                NFEs=trainer.steps,
                condition=None,
                unconditional_condition=None,
                **trainer.solver_extra_params,
            )
            x_next_computed.append(x_next)#This was wrong the whole time?
            all_timesteps[i*args.main_valid_batch_size+j] = t_list
        
        x_next_computed = torch.cat(x_next_computed, dim=0) 
        loss_vector = trainer.loss_fn(img.float(), x_next_computed.float()).squeeze()
        loss = loss_vector.mean()
        all_losses[i] = loss.item()
        print(f"Validated on iter{i}: {loss.item()}")

In [None]:
print(f"Average Loss: {np.mean(all_losses)}")

### Timestep Distribution

In [None]:
# Plot histogram for each column in all_timesteps
num_columns = all_timesteps.shape[1]
plt.figure(figsize=(15, 10))

for i in range(num_columns):
    plt.subplot(2, (num_columns + 1) // 2, i + 1)  # Arrange subplots in rows
    plt.hist(all_timesteps[:, i], bins=30, alpha=0.7, color='blue', edgecolor='black')
    plt.title(f"Timestep {i}")
    plt.xlabel("Value")
    plt.ylabel("Frequency")

plt.tight_layout()
plt.show()

### Influence if Prior Timestep and Steps Left

In [None]:
all_ratios = np.zeros((80))
for t in range(0,80):
    delta_timestep_ratio = delta_ltt_model(latent.unsqueeze(0), torch.tensor(t, device=device), torch.tensor(5, device=device))
    all_ratios[t] = delta_timestep_ratio.item()


In [None]:
plt.scatter(range(len(all_ratios)), all_ratios, color='blue', alpha=0.7)
plt.title("All Ratios at differnt timestep with same latent")
plt.xlabel("Index")
plt.ylabel("Ratio")
plt.grid(True, linestyle='--', alpha=0.5)
plt.show()

### Runtime of Diffusion Model vs DLTT 

In [None]:
latent = valid_dataset[0][1].to(device)

In [None]:
import pickle
from noise_schedulers import NoiseScheduleVE
from models.edm_uncond import model_wrapper
with open("pretrained/edm-cifar10-32x32-uncond-vp.pkl", "rb") as f:
    net = pickle.load(f)["ema"].to(device)
noise_schedule = NoiseScheduleVE(schedule='edm')

for param in net.parameters():
    param.requires_grad = False

model_fn = model_wrapper(net, noise_schedule)

x = latent.unsqueeze(0)
t = torch.tensor(80, device = device)
model_fn(x, t.expand((x.shape[0])))

In [None]:
net(x, t.expand((x.shape[0])))

In [None]:
#stack 10000 latens and run them through while measuring time
import time

all_latents = torch.stack([latent for _ in range(1000)])
all_latents = all_latents.to(device)
t = torch.tensor(80, device = device)
start = time.time()
model_fn(all_latents, t.expand((all_latents.shape[0])))
end = time.time()
print(f"Time taken for 1000 latents: {end-start}")


In [None]:
t = torch.tensor(80, device = device)
start = time.time()
for i in range(100):
    delta_ltt_model(all_latents, t, torch.tensor(5, device=x.device))
end = time.time()
print(f"Time taken for 1000 latents: {(end-start) / 100}")

In [None]:
4.404688119888306 / 0.05439952373504639

#### Attempting to get Bottleneck layer

In [None]:
torch.cuda.empty_cache()
from torchsummary import summary

In [None]:
# Create a custom wrapper to handle the additional argument
class ModelSummaryWrapper(nn.Module):
    def __init__(self, model):
        super(ModelSummaryWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        # Split the input tensor into the required inputs
        batch_size = x.shape[0]
        t = torch.tensor(80, device=x.device).expand(batch_size)  # Adjust the default value as needed
        return self.model(x, t)

# Wrap the model
summary_wrapper = ModelSummaryWrapper(net)

# Create a dummy input tensor
dummy_input = torch.randn(1, 3, 32, 32).to(device)  # Adjust dimensions as needed

# Print the summary of the model
summary(summary_wrapper, input_size=(3, 32, 32))

In [None]:
for i, (name, layer) in enumerate(net.named_modules()):
    print(f"{i+1}: Layer Name: {name}, Layer Type: {type(layer)}")

In [None]:
net.model

In [None]:
list(net.named_modules())

In [34]:
model = net.model

def hook_fn(module, input, output):
    global bottleneck_output
    bottleneck_output = output

# Register the hook
hook = model.enc["8x8_block3"].affine.register_forward_hook(hook_fn)

input_image = torch.randn(1, 3, 32, 32)
output_image = model(x, t.expand((x.shape[0])), None)
print(bottleneck_output.shape)

hook.remove()
bottleneck_output
#model.enc

torch.Size([1, 256])


tensor([[-1.4278e+00,  6.5677e+00,  3.9757e+00,  6.1155e+00, -5.6272e-02,
          4.7583e+00, -3.5497e+00,  1.2434e+00,  3.0584e+00, -2.9782e+00,
          1.0482e+00,  3.8307e+00, -2.7746e+00,  4.4532e+00,  2.0363e+00,
          2.6655e+00,  4.5572e-01,  6.7098e+00, -4.3862e+00, -2.5329e+00,
          4.5053e-01, -5.6717e+00, -5.5204e+00,  6.4583e+00, -3.1902e-01,
          7.1224e+00, -8.4573e-01,  3.0359e+00, -2.1529e+00, -2.4064e+00,
          7.7525e-01,  5.5615e+00,  2.2871e+00,  3.2362e+00,  4.9820e+00,
         -1.1608e+00,  2.2242e+00,  3.0207e+00,  1.5397e+00, -1.2069e+01,
         -2.5013e+00,  3.8893e-01,  3.8898e+00,  5.8242e+00,  3.8508e+00,
          4.8703e+00, -7.6412e+00,  3.8769e+00,  2.8533e+00, -7.9984e-01,
          3.5449e+00,  1.0804e+00,  4.4722e+00,  5.4717e+00,  1.0273e+01,
         -5.7373e+00,  7.9053e+00, -3.6632e+00,  1.0591e+00,  4.0612e+00,
         -1.9804e+00,  1.2799e+00,  1.7201e+00,  3.8070e+00,  4.9658e+00,
         -7.9767e+01,  1.3754e+00,  5.

#### Visualize Delta Ltt

In [None]:
from torch.utils.tensorboard import SummaryWriter

# Initialize the SummaryWriter
writer = SummaryWriter(log_dir="runs/delta_ltt_model_visualization")

# Add the model graph to TensorBoard
dummy_input = torch.randn(1, 3, 32, 32).to(device)  # Adjust dimensions as needed
writer.add_graph(delta_ltt_model, (dummy_input, torch.tensor(80, device=device), torch.tensor(5, device=device)))

# Close the writer
writer.close()

print("Model graph has been added to TensorBoard. Run the following command to view it:")
print("tensorboard --logdir=runs/model_visualization")