# Latent Space Interpolation

In [3]:
import time

import torch
import torchvision.transforms as transforms
import torchvision
import torch.nn as nn
import torch.optim as optim

import numpy as np

import matplotlib.pyplot as plt
import random
from torchsummary import summary

import os

from model import Generator_32, Discriminator, weights_init, compute_acc

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
batch_size = 100
input_size = 110
num_classes = 10
image_size = 32
EPOCH = 100
noise_sd = 1
LR = 0.0002


In [6]:
transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

])

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified


In [76]:
save_folder = "./model/"

In [77]:
model = Generator_32().to(device)
model.load_state_dict(torch.load(os.path.join(save_folder, "netG_epoch_999.pth")))

<All keys matched successfully>

In [78]:
def get_sample_images(genmod, l, num_classes, noise_sd, postfix):
    random.seed(123)
    torch.manual_seed(123)
    k = int(l * num_classes)
    noise = torch.FloatTensor(k, 110, 1, 1).to(device)
    label = np.concatenate([np.repeat([0], l), np.repeat([1], l), np.repeat([2], l), np.repeat([3], l), 
                            np.repeat([4], l), np.repeat([5], l), np.repeat([6], l), np.repeat([7], l), 
                            np.repeat([8], l), np.repeat([9], l)])
    noise_ = np.random.normal(0, noise_sd, (k, 110))
    class_onehot = np.zeros((k, num_classes))
    class_onehot[np.arange(k), label] = 1
    noise_[np.arange(k), :num_classes] = class_onehot[np.arange(k)]
    noise_ = (torch.from_numpy(noise_))
    noise.data.copy_(noise_.view(k, 110, 1, 1))
    fake = genmod(noise)
    return noise, fake
   
#     torchvision.utils.save_image(
#                 fake.data,
#                 './test_image/example_%s.png' % postfix,
#                 nrow=l
#             )

In [79]:
noise, fake_imgs = get_sample_images(genmod=model, l=4, num_classes=num_classes, noise_sd=noise_sd, postfix="")

In [84]:
l = 4
start_idx, end_idx = 1, 3
labels = np.arange(10)
ratios = np.linspace(0, 1, num=10)
for label in labels:
    label_text = classes[label]
    print("Process: ", label_text)
    class_noise = noise[label*l:(l*label)+l]
    start = class_noise[start_idx].squeeze().detach().cpu().numpy()
    end = class_noise[end_idx].squeeze().detach().cpu().numpy()
    vectors = []
    for ratio in ratios:
        v = (1.0 - ratio) * start + ratio * end
        vectors.append(v)
    interpolate = torch.from_numpy(np.array(vectors)).to(device)
    gen_imgs = model(interpolate)
    torchvision.utils.save_image(
            gen_imgs.data,
            './lsi/%s.png' % label_text,
            nrow=4
        )

Process:  plane
Process:  car
Process:  bird
Process:  cat
Process:  deer
Process:  dog
Process:  frog
Process:  horse
Process:  ship
Process:  truck
