In [None]:
import numpy as np
from PIL import Image
from sklearn.decomposition import PCA
import torch
from torchvision import transforms
from model import Generator
from dataset import MultiResolutionDataset

import matplotlib.pyplot as plt

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
                             inplace=True),
    ]
)

transform_label = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,), inplace=True),
    ]
)

dataset = MultiResolutionDataset('dataset.lmdb', transform, transform_label,
                                 256)

In [None]:
generator = Generator(256, 512, 8, 2, architecture='spade').cuda()
ckpt = torch.load('checkpoint_256_spade_with_noise/250000.pt')
generator.load_state_dict(ckpt['g_ema'])

In [None]:
def generate_with_w(label, w, truncation_latent=None, truncation=1):
    with torch.no_grad():
        res = generator(label.unsqueeze(0).cuda(), [w], input_is_latent=True,
                        truncation_latent=truncation_latent, truncation=truncation)

    res = res[0].cpu()[0].numpy()
    res = np.transpose(res, (1, 2, 0))
    res = (res * 0.5 + 0.5) * 255
    res = np.clip(res, 0, 255).astype(np.uint8)
    return res

## Examples

### Same label different W

In [None]:
torch.manual_seed(3)

z = torch.randn(8, 512, device='cuda')

with torch.no_grad():
    w = generator.style(z)

In [None]:
plt.figure(figsize=(10, 10))
for i in range(16):
    plt.subplot(5, 4, i + 1)
    
    if i % 2 == 0:
        img, label = dataset[5]
        plt.imshow(label[0])
    else:
        res = generate_with_w(label, w[i // 2].unsqueeze(0))
        plt.imshow(res)
    plt.axis('off')
    
plt.suptitle('Same pose, different W', y=1.02)
plt.tight_layout()

### Same W different labels

In [None]:
torch.manual_seed(16)

z = torch.randn(1, 512, device='cuda')

with torch.no_grad():
    w = generator.style(z)

In [None]:
plt.figure(figsize=(10, 10))
for i in range(16):
    plt.subplot(5, 4, i + 1)
    
    if i % 2 == 0:
        img, label = dataset[i + 16]
        plt.imshow(label[0])
    else:
        res = generate_with_w(label, w)
        plt.imshow(res)
    plt.axis('off')
    
plt.suptitle('Same W, different poses', y=1.02)
plt.tight_layout()

## PCA

In [None]:
z = torch.randn(16384, 512, device='cuda')

with torch.no_grad():
    w = generator.style(z)

In [None]:
pca = PCA(512)
pca.fit(w.cpu().numpy())

In [None]:
plt.plot(pca.explained_variance_ratio_[:10])

In [None]:
w.shape

## Truncation

In [None]:
w_avg = w.mean(0)

In [None]:
transform_label_256 = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,), inplace=True),
    ]
)

mean_pose = transform_label_256(Image.open('mean_pose_label.jpg'))

In [None]:
plt.imshow(mean_pose[0])

In [None]:
torch.manual_seed(8)

z = torch.randn(4, 512, device='cuda')

with torch.no_grad():
    w = generator.style(z)

In [None]:
plt.figure(figsize=(7, 10))

truncations = [1, 0.8, 0.6, 0.4, 0.2]
for i in range(5):
    for j in range(4):
        plt.subplot(5, 4, 4 * i + j + 1)
        res = generate_with_w(mean_pose, w[j].unsqueeze(0),
                              truncation_latent=w_avg,
                              truncation=truncations[i])
        plt.imshow(res)
        plt.axis('off')
        
        if j == 0:
            plt.title('truncation:' + str(truncations[i]))
    plt.axis('off')
    
plt.suptitle('Truncation for different W', y=1.02)
plt.tight_layout()

### PCA directions

In [None]:
torch.manual_seed(2)

z = torch.randn(5, 512, device='cuda')

with torch.no_grad():
    w = generator.style(z)

In [None]:
direction = pca.components_[2]

In [None]:
results = []
for i, a in enumerate(np.linspace(-4, 4, 7)):
    w_new = w + torch.tensor(a * direction).cuda()
    row = []
    for j in range(5):
        img, label = dataset[j]
        row.append(generate_with_w(label, w_new[j].unsqueeze(0)))
    results.append(row)

In [None]:
plt.figure(figsize=(7, 10))
for i, row in enumerate(results):
    for j in range(5):
        plt.subplot(7, 5, i * 5 + j + 1)
        plt.imshow(row[j])
        plt.axis('off')
plt.tight_layout()
plt.suptitle('Third PCA component, from -5 to 5 alpha', y=1.005);