In [None]:
import os
import sys
from PIL import Image,ImageShow
import numpy as np
import torch
import matplotlib.pyplot as plt
import yaml
from omegaconf import OmegaConf
import seaborn as sb
import torchvision.transforms as transforms
from taming.data.faceshq import NumpyPaths


from taming.models.cond_transformer import Net2NetTransformer

Get the dependencies

In [None]:
%pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 einops transformers

sys.path.append(".")

Prepare configurations of the CelebAHQ model as well as checkpoints.

In [None]:
#Prepare CelebAHQ configurations
config_path = r"C:\Users\DripTooHard\PycharmProjects\taming-transformers\scripts\taming-transformers\configs\2021-04-23T18-11-19-project.yaml"
celebAHQ_config = OmegaConf.load(config_path)
print(yaml.dump(OmegaConf.to_container(celebAHQ_config)))

#Init model with the chosen architecture and configurations
model = Net2NetTransformer(**celebAHQ_config.model.params)

In [None]:
#Load checkpoints
ckpt_path = r"C:\Users\DripTooHard\PycharmProjects\taming-transformers\scripts\taming-transformers\configs\CelebAHQ.ckpt"
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model.load_state_dict(sd)
missing, unexpected = model.load_state_dict(sd, strict=False)

In [None]:
#Put model in evaluation mode
model.eval()
torch.set_grad_enabled(False)

### Step 2: Investigating the norms of the dict

We will now iterate through the indices of the codebook dictionary and investigate what the

In [None]:
norms = []

for params in model.first_stage_model.quantize.embedding.parameters():
    for z in params:
        norms += [np.linalg.norm(z)]

# Creating a histogram
plt.figure(figsize=(10, 6))
sb.histplot(data=norms, bins=30, stat="probability")
plt.xlabel("Norm Lengths")
plt.title("Codebook Norm Lengths")

# Calculating the percentage of norms above the length of 1
norms_above_1 = [norm for norm in norms if norm > 1]
percentage_above_1 = len(norms_above_1) / len(norms) * 100
print(f"Percentage of norms above the length of 1: {percentage_above_1}%")

# Displaying the plot
plt.show()

In [None]:
#TODO: Add some examples of what the embeddings close to 0 and those far away show.
for params in model.first_stage_model.quantize.embedding.parameters():
    print(params.shape)

### Step 3: We will now look at the in practice norm distribution

As the number of embeddings that are initialized is static and does not depend on the actual number of codes needed to meaningfully represent our dataset, we would like to know; Which code lengths are actually used.


In [122]:
import albumentations
from torch.utils.data import Dataset, ConcatDataset
from torchvision import transforms

def get_image_paths(base_path):
    image_paths = []
    for file_name in os.listdir(base_path):
        if file_name.endswith('.png'):
            full_path = os.path.join(base_path, file_name)
            image_paths.append(full_path)
    return image_paths


def display_numpy_array_as_image(numpy_array):
    """
    Args:
    numpy_array (numpy.ndarray): A Numpy array representing an image.
    """


    # Ensure the numpy array is in the right format (uint8)
    if numpy_array.dtype != np.uint8:
        numpy_array = (numpy_array * 255).astype(np.uint8)

    # Convert the Numpy array to a PIL image
    image = Image.fromarray(numpy_array)

    # Display the image
    image.show()


class ImagePaths(Dataset):
    def __init__(self, paths, size=None, random_crop=False, labels=None):
        self.size = size
        self.random_crop = random_crop

        self.labels = dict() if labels is None else labels
        self.labels["file_path_"] = paths
        self._length = len(paths)

        if self.size is not None and self.size > 0:
            self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
            if not self.random_crop:
                self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
            else:
                self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
            self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
        else:
            self.preprocessor = lambda **kwargs: kwargs

    def __len__(self):
        return self._length

    def preprocess_image(self, image_path):
        image = Image.open(image_path)
        if not image.mode == "RGB":
            image = image.convert("RGB")
        image = np.array(image).astype(np.uint8)
        image = self.preprocessor(image=image)["image"]
        image = (image/127.5 - 1.0).astype(np.float32)
        return image

    def __getitem__(self, i):
        example = dict()
        example["image"] = self.preprocess_image(self.labels["file_path_"][i])
        for k in self.labels:
            example[k] = self.labels[k][i]
        return example


class NumpyPaths(ImagePaths):
    def preprocess_image(self, image_path):
        image = self.load_image_nparray(image_path) # 3 x 1024 x 1024
        image = Image.fromarray(image, mode="RGB")
        image = np.array(image).astype(np.uint8)
        image = self.preprocessor(image=image)["image"]
        image = (image/127.5 - 1.0).astype(np.float32)
        return image


    def path_to_tensor(self, path):
        preprocessed_image_np = self.preprocess_image(path)
        image_tensor = torch.from_numpy(preprocessed_image_np).float()
        image_tensor = image_tensor.permute(2, 0, 1)  # Permute to CxHxW
        return image_tensor.unsqueeze(0)  # Add a batch dimension



    def load_image_nparray(self,path):
        image = Image.open(path)
        return np.array(image)

base_image_path = r"C:\Users\DripTooHard\PycharmProjects\taming-transformers\scripts\taming-transformers\data\ffhq_images\01000"

In [124]:
image_size = 1024

FFHQ_image_paths = get_image_paths(base_image_path)
prep = NumpyPaths(FFHQ_image_paths,size = image_size, random_crop=False)
test_image = prep.__getitem__(0)["image"]

#Check that the images look alright
display_numpy_array_as_image(test_image[0].permute((1,2,0)).cpu().detach().numpy())

In [133]:
epsilon = 0.1



def deconstruct_reconstruct(image, epsilon):
    image = image.type(torch.FloatTensor)  # Ensure the image tensor is of type FloatTensor
    image_zq, image_z_indices = model.encode_to_z(image, epsilon)
    image_recon = model.decode_to_img(image_z_indices, image_zq.shape)
    return image_recon

def deconstruct_reconstruct_mix(image1,image2,epsilon):
    image1 = image1.type(torch.FloatTensor)  # Ensure the image tensor is of type FloatTensor
    image2 = image2.type(torch.FloatTensor)  # Ensure the image tensor is of type FloatTensor
    image_zq, image_z_indices = model.encode_to_z_mix(image1,image2, epsilon)
    image_recon = model.decode_to_img(image_z_indices, image_zq.shape)
    return image_recon


step_size = 0.01
max_range = 1

epsilon_values = np.arange(0,max_range,step_size)

folder = fr"C:\Users\DripTooHard\PycharmProjects\taming-transformers\scripts\Noisy Reconstructions\Laplace{max_range}{step_size}"
#os.mkdir(folder)



AttributeError: 'Tensor' object has no attribute 'show'