# Latent Space Interpolation

- In this notebook file, you can generate interpolation images.

- For doing this, trained model is required.

- Please make sure the dataset path and the model path.

In [None]:
import numpy as np
import torch
import os
import torchvision
import tqdm
from models import *
from dataloader import *
import random

In [None]:
data_path = './'
model_path = './t3VAE_best.pt'

In [None]:
def make_reproducibility(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [None]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device(f'cuda' if USE_CUDA else "cpu")
make_reproducibility(2023)
transform = transforms.Compose(
        [
        transforms.CenterCrop(148),
        transforms.Resize(64),
        transforms.ToTensor(),
        ]
    )
testset = CustomCelebA(
root=data_path,
split='test',
transform=transform,
download=False,
)
testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False)

print(f'the number of batches : {len(testloader)}')    
img_shape = torch.tensor([64,3,64,64]) # img : [B, C, H, W]

model = torch.load(model_path)
print(f"load the best model from {model_path}")
test_z = []
with torch.no_grad():
    tqdm_testloader = tqdm.tqdm(testloader)
    for batch_idx, (x, label) in enumerate(tqdm_testloader):
        x = x.to(DEVICE)
        recon_x, z, mu, logvar = model.forward(x)
        test_z.append(z.detach().cpu())

    test_z = torch.cat(test_z, dim=0)

In [38]:
os.makedirs('./reconstructions',exist_ok=True)
result = torch.zeros((8*8,3,64,64))
iter = 100
with torch.no_grad():
    for k in range(iter):
        a,b,c,d = np.random.randint(1,len(testset),4)
        interpol1 = test_z[a].detach().cpu()
        interpol2 = test_z[b].detach().cpu()
        interpol3 = test_z[c].detach().cpu()
        interpol4 = test_z[d].detach().cpu()

        # make 8x8 interpolation images
        for i in range(8):
            for j in range(8):
                int_z1 = interpol1 * (i/7) + interpol2 * (1 - i/7)
                int_z2 = interpol3 * (i/7) + interpol4 * (1 - i/7)
                int_z = (j/7) * int_z1 + (1-j/7) * int_z2
                recon_x, *_ = model.decoder(int_z.to(DEVICE))
                result[i*8+j] = recon_x
                del int_z # memory free

        filename = f'./interpolations/INTERPOL_TEST_{k}.png'
        torchvision.utils.save_image(result, filename,normalize=True, nrow=8)    
