In [1]:
from PIL import Image
from matplotlib import pyplot as plt
from os_paths import path_to_data
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import ToTensor
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pdb
from skimage import color
from lab_utils import *
import time
from models import LVAE, ABVAE
from os_paths import lvae_state_dict_path, abvae_state_dict_path, encodings_dir_path

KeyboardInterrupt: 

In [None]:
class FfhqDatasetAB(Dataset):
    def __init__(self, root_dir, transform=None, colorchannels='ab', cacheing=False, cachelim=30000):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = os.listdir(root_dir)
        #a dictionary that prevents us from doing more compute than is necessary
        # preprocessing a batch takes about 5x as long as actually training it
        self.cacheing = cacheing
        self.cachelim = cachelim
        self.already_seen_images = {}
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image_path = self.image_paths[idx]
        
        # doing this is about 1000x faster than loading an image
        if self.cacheing and image_path in self.already_seen_images:
            return self.already_seen_images[image_path]
        
        img_name = os.path.join(self.root_dir, self.image_paths[idx])
        image = Image.open(img_name).convert('RGB')
        image = np.array(image) #go from PIL image to nparray so that rgb2lab will work
        image = color.rgb2lab(image).astype(np.float32)
        if self.transform:
            image = self.transform(image)
            
        # scale so that values are between 0 and 1
        image = scale_lab(image)
        image = image[1:] ### isolate ab layers

        if self.cacheing == True and len(self.already_seen_images) < self.cachelim:
            self.already_seen_images[image_path] = image
        return image
    
    def preshow_image(self,image):
        """
        input: torch.tensor in scaled CIELAB color space wiothout L channel, dims = (2, H, W)
        output: np array in RGB color space 
        """
        empty_L = torch.zeros_like(image[0:1]) + 0.5
        image = torch.cat((empty_L,image),dim=0)
        image = image.numpy()
        image = descale_lab(image) 
        image = np.moveaxis(image, 0,-1) # Convert from (C, H, W) to (H, W, C) so imshow works
        image = color.lab2rgb(image, channel_axis=-1)
        image = (image * 255).astype(np.uint8)
        return image
    def show_grid(self,nrows, ncols):
        n = nrows * ncols
        images = [self[i] for i in range(n)]
        fig, axes = plt.subplots(nrows, ncols)
        for i, ax in enumerate(axes.flat):
            if i < len(images):
                image = self.preshow_image(images[i])
                ax.imshow(image)
                ax.axis('off')
        plt.show()
        
        
class FfhqDatasetL(Dataset):
    def __init__(self, root_dir, transform=None, colorchannels='ab', cacheing=True, cachelim=50000):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = os.listdir(root_dir)
        #a dictionary that prevents us from doing more compute than is necessary
        # preprocessing a batch takes about 5x as long as actually training it
        self.cacheing = cacheing
        self.cachelim = cachelim
        self.already_seen_images = {}
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image_path = self.image_paths[idx]
        
        # doing this is about 1000x faster than loading an image
        if self.cacheing and image_path in self.already_seen_images:
            return self.already_seen_images[image_path]
        
        img_name = os.path.join(self.root_dir, self.image_paths[idx])
        image = Image.open(img_name).convert('RGB')
        image = np.array(image) #go from PIL image to nparray so that rgb2lab will work
        image = color.rgb2lab(image).astype(np.float32)
        if self.transform:
            image = self.transform(image)
            
        # scale so that values are between 0 and 1
        image = scale_lab(image)
        image = image[0:1] ### isolate L layer

        if self.cacheing == True and len(self.already_seen_images) < self.cachelim:
            self.already_seen_images[image_path] = image
        return image
    
    def preshow_image(self,image):
        """
        input: torch.tensor in scaled CIELAB color space wiothout ab channels, dims = (1, H, W)
        output: np array in RGB color space 
        """
        image = image.numpy()
        image = np.moveaxis(image, 0,-1) # Convert from (C, H, W) to (H, W, C) so imshow works
        image = (image * 255).astype(np.uint8)
        return image
    def show_grid(self,nrows, ncols):
        n = nrows * ncols
        images = [self[i] for i in range(n)]
        fig, axes = plt.subplots(nrows, ncols)
        for i, ax in enumerate(axes.flat):
            if i < len(images):
                image = self.preshow_image(images[i])
                ax.imshow(image, cmap='gray')
                ax.axis('off')
        plt.show()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128),antialias=True),
    #transforms.Normalize((0.5,), (0.5,))
])

path_to_trn = os.path.join(path_to_data,'training')
path_to_tst = os.path.join(path_to_data,'test')
dataset_ab = FfhqDatasetAB(root_dir=path_to_trn, transform=transform, cacheing=False)
tst_dataset_ab = FfhqDatasetAB(root_dir=path_to_tst, transform=transform)

dataset_L = FfhqDatasetL(root_dir=path_to_trn, transform=transform, cacheing=False)
tst_dataset_L = FfhqDatasetL(root_dir=path_to_tst, transform=transform)

In [None]:
dataset_L.show_grid(1,2)

In [None]:
dataset_ab.show_grid(1,2)

In [None]:
def parse_filepath(filepath):
    latent_dim = filepath.split('latent_dim')[-1].split('_')[1]
    return {'latent_dim': int(latent_dim)}

In [None]:
lvar = parse_filepath(lvae_state_dict_path)
abvar = parse_filepath(abvae_state_dict_path)
if abvar['latent_dim'] != lvar['latent_dim']:
    raise ValueError("Latent dims must match")

In [None]:
lvae = LVAE(lvar['latent_dim']).to('cuda')
abvae = ABVAE(abvar['latent_dim']).to('cuda')

lvae.load_state_dict(torch.load(lvae_state_dict_path))
abvae.load_state_dict(torch.load(abvae_state_dict_path))

In [None]:
def encode_dataset(dataset, model,printinterval=2000,device='cuda'):
    t0 = time.time()
    encodings = torch.Tensor([]).to(device)
    i = 0
    for img in dataset:
        t1 = time.time()
        with torch.no_grad():
            img = img.to(device)
            encoding = model.encode(img)
            encodings = torch.cat((encodings,encoding))
        if i % printinterval == 0:
            print(f"loop {i:5}/{len(dataset_ab)}, loop time: {time.time()-t1: 5f} sec")
        i +=1 
    print("Time taken for encoding:", time.time()-t0,"seconds")
    print("Shape:",encodings.shape)
    return encodings.to('cpu')

In [None]:
# encode the ab dataset:
abencodings = encode_dataset(dataset_ab, abvae)

In [None]:
# encode the L dataset:
Lencodings = encode_dataset(dataset_L, lvae)

In [None]:
encodings_dir_path = './encoded_data'
torch.save(abencodings, os.path.join(encodings_dir_path, 'ab_encodings.pt'))
torch.save(Lencodings, os.path.join(encodings_dir_path, 'L_encodings.pt'))