In [None]:
import torch
from torch.nn import functional as F
from diffusers import PNDMScheduler, UNet2DModel
# from diffusers import schedulers
# from PIL import Image

import matplotlib.pyplot as plt
from datasets import load_dataset
from torchvision import transforms

In [None]:
def show_images(images_tensor:torch.tensor, cmap = "grey"):
    n_dims = images_tensor.dim()
    
    if n_dims in (2,3):
        x_cat = images_tensor

    elif n_dims == 4:
        x_list = [img for img in images_tensor]
        x_cat = torch.cat(x_list,dim=2)
    
    else:
        raise SyntaxError("The dimensions of images_tensor must be between 2 and 4")

    if n_dims != 2:
        if x_cat.shape[0] == 1:
            plt.imshow(x_cat.movedim(0,-1),cmap);
        else:
            plt.imshow(x_cat.movedim(0,-1));
    else:
        plt.imshow(x_cat,cmap);
def show_images_list(images_list:list[torch.Tensor], cmap = "grey") -> None:

    images_tensor = torch.concat(images_list,dim=0)

    # return images_tensor
    # show_images(images_tensor.unsqueeze(1))
    show_images(images_tensor.unsqueeze(1),cmap)


# Cargar el Dataset de Test

In [None]:
preprocess = transforms.Compose([
    transforms.ToTensor(),              #To Torch Tensor
    transforms.Pad(2),                  # Add a padding of 2 pixels
    transforms.Normalize([0.5], [0.5])  # Normalize to (-1,1)
])

def dataset_preprocess(examples):
    images = [preprocess(example) for example in examples["image"]]
    return {"images": images}
dataset = load_dataset("fashion_mnist")

# train_dataset,test_dataset = torch.utils.data.random_split(dataset["train"].with_transform(dataset_preprocess),(0.8,0.2))

val_dataset = torch.utils.data.random_split(dataset["test"].with_transform(dataset_preprocess),(1,))[0]


# Solamente necesitamos los datos de validaciónS
val_dataloder = torch.utils.data.DataLoader(
    val_dataset,
    shuffle=True,
    batch_size=32       # Tamaño del lote de 32 imágenes. El dataset de Validación son 10.000 imágenes -> 313 lotes
)

# Modelo
Importamos el Modelo del fichero de pesos

In [None]:
base_model = UNet2DModel(
    in_channels=1,  # 1 channels for grey scale
    out_channels=1,
    sample_size=32,  # Specify our input size
    # The number of channels per block affects the model size
    block_out_channels=(32, 64, 128, 256),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D"
        ),

).cuda()

base_model.load_state_dict(torch.load("Base_model_OOD_detection.pth",weights_only=True))
base_model.eval()

In [None]:
scheduler = PNDMScheduler(
    num_train_timesteps=1000, beta_start=0.0015, beta_end=0.0195
)
scheduler.set_timesteps(50)    # Especificamos el nº de pasos de inferencia que usaremos

# Generación de imágenes

In [None]:
def PNDM_generation_loop(input_img:torch.Tensor, input_timestep : int, model: UNet2DModel, scheduler : PNDMScheduler):

    if input_img.dim() != 4:     # Control de Errores
        raise SyntaxError("Error de Dimensiones. El Tensor de entrada dbe tener 4 dimensions, siendo la primera la dimensión de lote")
    
    noisy_x = input_img

    if input_timestep < 0 or input_timestep > 1000: # Control de Errores
        raise SyntaxError("El timestep debe estar entre 0 y 1000")
    
    if input_timestep == 1000:  # Si el Timestep es de 1000, se genera una imagen desde cero
        idx = 0
    else:                       # Si no es de 1000, se comienza desde el punto correspondiente, con los datos de la imagen deseada
        idx = torch.where(scheduler.timesteps == input_timestep)[0][0]  # Buscamos el indice del timestep en la lista del scheduler

    for t in scheduler.timesteps[idx:]:     # Iteramos sobre la lista del scheduler. Cada elemento es uno de los timesteps de la cadena

        with torch.inference_mode():        # Realizamos un paso de la iteración
            noise_pred = model(noisy_x, t,return_dict=False)[0]

        scheduler_output = scheduler.step(noise_pred, t, noisy_x)   # Paso del scheduler

        noisy_x = scheduler_output.prev_sample                      # Realimentamos el bucle
    
    return(scheduler_output.prev_sample)  # Devolvemos el resultado.

# Métricas

In [None]:
from torchmetrics.functional import mean_squared_error as MSE
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity as LPIPS

# print(MSE(regeneration,img))
# print(LPIPS(regeneration,img))

# Obtención de datos

In [None]:
N = 50      # Número de Reconstrucciones
reconstructions_timesteps = torch.arange(1000,0,-1000/N)
print(reconstructions_timesteps)

In [None]:
MSE_data = [[] for _ in range(N)]
LPIPS_data = [[] for _ in range(N)]

n_epochs = len(val_dataloder)

for epoch,batch in enumerate(val_dataloder):
    imgs = batch["images"].cuda()
    print(f"Época {epoch}/{n_epochs}")
    print("====================")
    print()

    for idx,t in enumerate(reconstructions_timesteps):      # Tiempo estimado: 4h 20min
        noise = torch.randn_like(imgs)

        if t == 1000:
            restoration = PNDM_generation_loop(noise,t,base_model,scheduler)
            
        else:
            noisy_img = scheduler.add_noise(imgs,noise,t.int())
            restoration = PNDM_generation_loop(noisy_img,t,base_model,scheduler)

        MSE_data[idx].append(MSE(restoration,imgs))
        
        max_rest,min_rest = restoration.max(),restoration.min()
        n_restoration =  2* ( ((restoration-min_rest) / (max_rest-min_rest)) - 0.5)

        max_imgs,min_imgs = imgs.max(),imgs.min()
        n_imgs =  2* ( ((imgs-min_imgs) / (max_imgs-min_imgs)) - 0.5)


        LPIPS_data[idx].append(LPIPS(n_restoration.repeat(1,3,1,1),n_imgs.repeat(1,3,1,1)))

In [None]:
import csv                      # Guardar datos eb fichero para analizar más adelante

file_name = "Base_data.csv"

with open(file_name,"w",newline='') as csv_file:
    writer = csv.writer(csv_file)

    for list in MSE_data:
        writer.writerow(["MSE_data"] + list)

    for list in LPIPS_data:
        writer.writerow(["LPIPS_data"] + list)

In [None]:
# img = next(iter(val_dataloder))["images"].cuda()    # Elegimos un lote de imágenes
# print(len(img)) # Batch_size = 32

In [None]:
# show_images(img[0].cpu())   # Mostramos una de las imágenes

In [None]:
# noise = torch.randn_like(img)   # Generamos el Ruido
# noise.shape

In [None]:
# rand_timestep_idx = torch.randint(0,len(timesteps_list),(1,))               # Elegimos el timestep aleatorio de la lista del Scheduler
# rand_timestep = timesteps_list[rand_timestep_idx]

# rand_timestep_tensor = torch.ones((val_dataloder.batch_size,),dtype=int)*rand_timestep
# rand_timestep_tensor[0]

In [None]:
# noisy_img = scheduler.add_noise(img,noise,rand_timestep_tensor)         # Añadimos Ruido a las imágenes
# show_images(noisy_img[0].cpu())     # Mostramos una imagen con ruido

In [None]:
# regeneration = PNDM_generation_loop(noisy_img,rand_timestep,base_model,scheduler)       # Realizamos la Restauración

In [None]:
# show_images(regeneration[10:20].cpu())  # Mostramos las imágenes restauradas
# show_images(regeneration.cpu())  # Mostramos las imágenes restauradas


In [None]:
# show_images(img[10:20].cpu())   # Mostramos las imágenes originales
# show_images(img.cpu())   # Mostramos las imágenes originales


In [None]:
# MSE_list = []
# for idx in range(val_dataloder.batch_size):
#     MSE_list.append(MSE(regeneration[idx],img[idx]).item())

In [None]:
# import statistics
# statistics.mean(MSE_list)
# MSE_list

Pasos a seguir para obtener las distribuciones base:
1. Obtengo una imagen
2. Se le añaden 50 cantidades distintas de ruido.
3. Para cada cantidad de ruido, se realiza la reconstrucción de la imagen
4. Se calculan las métricas: MSE y LPIPS

5. Tras la obtención de los datos, calculamos su distribución: Media y desviación estándar