In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

from torchvision.datasets import ImageFolder
import torchvision.transforms as T

import matplotlib.pyplot as plt
%matplotlib inline

# https://towardsdatascience.com/beginners-guide-to-loading-image-data-with-pytorch-289c60b7afec
# β-VAE: https://github.com/1Konny/Beta-VAE

In [2]:
from PIL import Image

def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGBA')  # force alpha channel
        background = Image.new('RGBA', img.size, (255, 255, 255))
        alpha_composite = Image.alpha_composite(background, img).convert('RGB')
    return alpha_composite

# 图像增强：https://pytorch.org/vision/main/auto_examples/plot_transforms.html#random-transforms
transform = T.Compose([T.Resize((256, 256)),
                       T.RandomInvert(p=1),
                       T.RandomHorizontalFlip(),
                       T.RandomAffine(degrees=(-30, 30), translate=(0.1,0.1), scale=(0.8, 1.2), interpolation=T.InterpolationMode.BILINEAR),
                       T.RandomInvert(p=1),
                       T.ColorJitter(hue=0.5, saturation=0.1, contrast=0.2),
                       T.ToTensor()])

img = ImageFolder(root='dataset', loader = pil_loader_rgba, transform=transform)

In [3]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [4]:
from model import BetaVAE_H as VAE
from Solver import reconstruction_loss, kl_divergence

model = VAE(nc=3)

xrecon, mu, logvar = model(arr.unsqueeze(0))

In [5]:
from model import BetaVAE_H as VAE
from Solver import reconstruction_loss, kl_divergence

model = VAE(nc=3)
model.eval()

xrecon, mu, logvar = model(arr.unsqueeze(0))

In [6]:
beta = 4.0; lr=1e-4

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(100):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 5 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [7]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).cpu().detach())
plt.show()

In [8]:
beta = 4.0; lr=1e-3

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(100):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 5 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [9]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).cpu().detach())
plt.show()

In [10]:
beta = 1.0; lr=1e-3

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(100):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 5 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [11]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).cpu().detach())
plt.show()

In [12]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).cpu().detach())
plt.show()

In [13]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).cpu().detach())
plt.show()

In [14]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).cpu().detach())
plt.show()

In [15]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [16]:
beta = 1.0; lr=1e-2

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(100):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 5 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [17]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [18]:
beta = 1.0; lr=1e-2

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(1000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 50 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [19]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [20]:
beta = 1.0; lr=1e-2; dim=32

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(1000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 50 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [21]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [22]:
beta = 1.0; lr=1e-4; dim=32

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(1000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 50 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [23]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [24]:
from PIL import Image

def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGBA')  # force alpha channel
        background = Image.new('RGBA', img.size, (255, 255, 255))
        alpha_composite = Image.alpha_composite(background, img).convert('RGB')
    return alpha_composite

# 图像增强：https://pytorch.org/vision/main/auto_examples/plot_transforms.html#random-transforms
transform = T.Compose([T.Resize((256, 256)),
                       T.RandomInvert(p=1),
                       T.RandomHorizontalFlip(),
                       T.RandomAffine(degrees=(-30, 30), translate=(0.1,0.1), scale=(0.8, 1.2), interpolation=T.InterpolationMode.BILINEAR),
                       T.RandomInvert(p=1),
                       T.ColorJitter(hue=0.5, saturation=0.1, contrast=0.2),
                       T.ToTensor()])
transform = T.Compose([T.Resize((256, 256)),
                       T.ToTensor()])

img = ImageFolder(root='dataset', loader = pil_loader_rgba, transform=transform)

In [25]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [26]:
beta = 1.0; lr=1e-4; dim=32

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(1000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 50 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [27]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [28]:
beta = 1.0; lr=1e-4; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(1000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 50 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [29]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [30]:
beta = 0.0; lr=1e-4; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(1000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 50 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [31]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [32]:
beta = 0.1; lr=1e-4; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(1000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 50 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [33]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [34]:
beta = 0.1; lr=1e-3; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(1000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 50 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [35]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [36]:
img[0]

(tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],
 
         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],
 
         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]]]),
 0)

In [37]:
img[0][0]

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [38]:
img[0][0].shape

torch.Size([3, 256, 256])

In [39]:
img[17][0].shape

torch.Size([3, 256, 256])

In [40]:
img[17][0].unsqueeze(0)

tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 0.9961,  ..., 0.5176, 0.3490, 0.3020],
          [1.0000, 1.0000, 0.9961,  ..., 0.4667, 0.3294, 0.2941],
          [1.0000, 1.0000, 0.9961,  ..., 0.4510, 0.3255, 0.2902]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 0.9961,  ..., 0.4392, 0.2941, 0.2510],
          [1.0000, 1.0000, 0.9961,  ..., 0.3922, 0.2745, 0.2431],
          [1.0000, 1.0000, 0.9961,  ..., 0.3804, 0.2706, 0.2392]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1

In [41]:
model.cpu()

model.encoder(img[17][0].unsqueeze(0))

tensor([[ -0.8147,  -4.2963,   1.9796,   2.2398,  -3.4094,  -0.1408,  -0.7454,
          -4.8866,  -3.5917,  -0.0475,   2.4621,   0.3947,   1.5482,  -5.5894,
           4.3102,  -8.2950,  -8.8949,  -7.1701,   0.9033,  -0.5024,   0.2268,
           1.8123,   5.1527,   1.5640,   7.2273,  -2.5669,   0.7683,   6.9144,
          -2.6649,  -2.8633,   6.5264,   0.5516,  -4.5259,  -4.4192,   1.0600,
           0.8273,  -0.1298,   3.9295,  -3.3490,  -0.7514,  -1.9961,   1.8203,
           0.5205,  -1.8626,  -5.0708,  -2.4133,  -0.4544,   1.3296,   1.2036,
          -2.6726,   0.5506,   0.2406,   2.3767,   1.0173,  -1.1693,  -0.5004,
          -1.7596,   1.6321, -10.0739,   1.5267,   0.0278,  -6.4631,  -3.1417,
          -3.0257,  -2.6074,  -3.0530,  -3.3341,  -3.2207,  -3.3812,  -2.6282,
          -3.1419,  -2.4923,  -2.3608,  -2.2267,  -3.0264,  -2.1589,  -2.9918,
          -2.7435,  -3.1926,  -2.4976,  -2.7894,  -1.3035,  -2.8470,  -2.3606,
          -1.5444,  -2.9351,  -3.7645,  -2.6623,  -3

In [42]:
model.cpu()

model.encoder(img[17][0].unsqueeze(0))[:model.z_dim]

tensor([[ -0.8147,  -4.2963,   1.9796,   2.2398,  -3.4094,  -0.1408,  -0.7454,
          -4.8866,  -3.5917,  -0.0475,   2.4621,   0.3947,   1.5482,  -5.5894,
           4.3102,  -8.2950,  -8.8949,  -7.1701,   0.9033,  -0.5024,   0.2268,
           1.8123,   5.1527,   1.5640,   7.2273,  -2.5669,   0.7683,   6.9144,
          -2.6649,  -2.8633,   6.5264,   0.5516,  -4.5259,  -4.4192,   1.0600,
           0.8273,  -0.1298,   3.9295,  -3.3490,  -0.7514,  -1.9961,   1.8203,
           0.5205,  -1.8626,  -5.0708,  -2.4133,  -0.4544,   1.3296,   1.2036,
          -2.6726,   0.5506,   0.2406,   2.3767,   1.0173,  -1.1693,  -0.5004,
          -1.7596,   1.6321, -10.0739,   1.5267,   0.0278,  -6.4631,  -3.1417,
          -3.0257,  -2.6074,  -3.0530,  -3.3341,  -3.2207,  -3.3812,  -2.6282,
          -3.1419,  -2.4923,  -2.3608,  -2.2267,  -3.0264,  -2.1589,  -2.9918,
          -2.7435,  -3.1926,  -2.4976,  -2.7894,  -1.3035,  -2.8470,  -2.3606,
          -1.5444,  -2.9351,  -3.7645,  -2.6623,  -3

In [43]:
model.cpu()

model.encoder(img[17][0].unsqueeze(0))[:model.z_dim].detach()

tensor([[ -0.8147,  -4.2963,   1.9796,   2.2398,  -3.4094,  -0.1408,  -0.7454,
          -4.8866,  -3.5917,  -0.0475,   2.4621,   0.3947,   1.5482,  -5.5894,
           4.3102,  -8.2950,  -8.8949,  -7.1701,   0.9033,  -0.5024,   0.2268,
           1.8123,   5.1527,   1.5640,   7.2273,  -2.5669,   0.7683,   6.9144,
          -2.6649,  -2.8633,   6.5264,   0.5516,  -4.5259,  -4.4192,   1.0600,
           0.8273,  -0.1298,   3.9295,  -3.3490,  -0.7514,  -1.9961,   1.8203,
           0.5205,  -1.8626,  -5.0708,  -2.4133,  -0.4544,   1.3296,   1.2036,
          -2.6726,   0.5506,   0.2406,   2.3767,   1.0173,  -1.1693,  -0.5004,
          -1.7596,   1.6321, -10.0739,   1.5267,   0.0278,  -6.4631,  -3.1417,
          -3.0257,  -2.6074,  -3.0530,  -3.3341,  -3.2207,  -3.3812,  -2.6282,
          -3.1419,  -2.4923,  -2.3608,  -2.2267,  -3.0264,  -2.1589,  -2.9918,
          -2.7435,  -3.1926,  -2.4976,  -2.7894,  -1.3035,  -2.8470,  -2.3606,
          -1.5444,  -2.9351,  -3.7645,  -2.6623,  -3

In [44]:
model.cpu()

mu = model.encoder(img[17][0].unsqueeze(0))[:model.z_dim].detach()

In [45]:
mu.shape

torch.Size([1, 128])

In [46]:
model.cpu()

mu = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()

In [47]:
mu.shape

torch.Size([1, 64])

In [48]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

In [49]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

mu_avg = (mu1 + mu2) / 2

In [50]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

mu_avg = (mu1 + mu2) / 2

xrecon = model.decoder(mu_avg)

In [51]:
xrecon.shape

torch.Size([1, 3, 256, 256])

In [52]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

mu_avg = (mu1 + mu2) / 2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)

In [53]:
plt.imshow(xrecon)

<matplotlib.image.AxesImage at 0x7f11f25e1250>

In [54]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

mu_avg = (mu1 + mu2) / 2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
plt.imshow(xrecon)

<matplotlib.image.AxesImage at 0x7f1201cab810>

In [55]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

mu_avg = (mu1 + mu2) / 2

xrecon = model.decoder(mu1).detach()[0].transpose(0, -1).transpose(0, 1)
plt.imshow(xrecon)

<matplotlib.image.AxesImage at 0x7f1201cc48d0>

In [56]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

mu_avg = (mu1 + mu2) / 2

xrecon = model.decoder(mu1).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)

<matplotlib.image.AxesImage at 0x7f11f24d4c90>

In [57]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

mu_avg = (mu1 + mu2) / 2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)

<matplotlib.image.AxesImage at 0x7f11f2545590>

In [58]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

mu_avg = (mu1 + mu2) / 2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [59]:
beta = 1; lr=1e-3; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(1000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 50 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [60]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [61]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

mu_avg = (mu1 + mu2) / 2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [62]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

alpha = 0.5
mu_avg = alpha * mu1 + (1-alpha) * mu2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [63]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

alpha = 0.1
mu_avg = alpha * mu1 + (1-alpha) * mu2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [64]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

alpha = 0.0
mu_avg = alpha * mu1 + (1-alpha) * mu2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [65]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

alpha = 0.2
mu_avg = alpha * mu1 + (1-alpha) * mu2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [66]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

alpha = 0.3
mu_avg = alpha * mu1 + (1-alpha) * mu2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [67]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

alpha = 0.4
mu_avg = alpha * mu1 + (1-alpha) * mu2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [68]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

alpha = 0.5
mu_avg = alpha * mu1 + (1-alpha) * mu2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [69]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

alpha = 0.6
mu_avg = alpha * mu1 + (1-alpha) * mu2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [70]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

alpha = 0.7
mu_avg = alpha * mu1 + (1-alpha) * mu2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [71]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

alpha = 0.8
mu_avg = alpha * mu1 + (1-alpha) * mu2

xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon)
plt.show()

In [72]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
plt.show()

In [73]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [74]:
model.cpu()

mu1 = model.encoder(img[25][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [75]:
model.cpu()

mu1 = model.encoder(img[25][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [76]:
model.cpu()

mu1 = model.encoder(img[23][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [77]:
model.cpu()

mu1 = model.encoder(img[0][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [78]:
model.cpu()

mu1 = model.encoder(img[0][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[2][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [79]:
model.cpu()

mu1 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[2][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [80]:
model.cpu()

mu1 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[3][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [81]:
model.cpu()

mu1 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[4][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [82]:
model.cpu()

mu1 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[5][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [83]:
model.cpu()

mu1 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[6][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [84]:
model.cpu()

mu1 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[7][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [85]:
model.cpu()

mu1 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[8][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [86]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [87]:
beta = 1; lr=1e-3; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(2000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 100 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [88]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [89]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [90]:
beta = 2; lr=1e-3; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(2000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 100 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [91]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [92]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [93]:
beta = 2; lr=1e-3; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(10000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 100 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [94]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [95]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [96]:
rec_loss, total_kld

(tensor(1077.6022, device='cuda:0', grad_fn=<DivBackward0>),
 tensor([107.0838], device='cuda:0', grad_fn=<MeanBackward1>))

In [97]:
beta = 5; lr=1e-3; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(10000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 100 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [98]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [99]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [100]:
model.cpu()

mu1 = model.encoder(img[0][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [101]:
model.cpu()

mu1 = model.encoder(img[0][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[2][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [102]:
model.cpu()

mu1 = model.encoder(img[2][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[0][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [103]:
model.cpu()

mu1 = model.encoder(img[2][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [104]:
model.cpu()

mu1 = model.encoder(img[3][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [105]:
model.cpu()

mu1 = model.encoder(img[4][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [106]:
model.cpu()

mu1 = model.encoder(img[5][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [107]:
model.cpu()

mu1 = model.encoder(img[6][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [108]:
model.cpu()

mu1 = model.encoder(img[7][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [109]:
model.cpu()

mu1 = model.encoder(img[8][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [110]:
model.cpu()

mu1 = model.encoder(img[9][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [111]:
model.cpu()

mu1 = model.encoder(img[10][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [112]:
model.cpu()

mu1 = model.encoder(img[11][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [113]:
model.cpu()

mu1 = model.encoder(img[12][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [114]:
model.cpu()

mu1 = model.encoder(img[13][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [115]:
model.cpu()

mu1 = model.encoder(img[14][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [116]:
model.cpu()

mu1 = model.encoder(img[15][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [117]:
model.cpu()

mu1 = model.encoder(img[16][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [118]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [119]:
model.cpu()

mu1 = model.encoder(img[18][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [120]:
model.cpu()

mu1 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [121]:
model.cpu()

mu1 = model.encoder(img[20][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [122]:
model.cpu()

mu1 = model.encoder(img[21][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [123]:
model.cpu()

mu1 = model.encoder(img[22][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [124]:
model.cpu()

mu1 = model.encoder(img[23][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [125]:
model.cpu()

mu1 = model.encoder(img[23][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [126]:
model.cpu()

mu1 = model.encoder(img[24][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [127]:
model.cpu()

mu1 = model.encoder(img[25][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [128]:
from PIL import Image

def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGBA')  # force alpha channel
        background = Image.new('RGBA', img.size, (255, 255, 255))
        alpha_composite = Image.alpha_composite(background, img).convert('RGB')
    return alpha_composite

# 图像增强：https://pytorch.org/vision/main/auto_examples/plot_transforms.html#random-transforms
transform = T.Compose([T.Resize((256, 256)),
                       T.RandomInvert(p=1),
                       T.RandomHorizontalFlip(),
                       T.RandomAffine(degrees=(-5, 5), translate=(0.1,0.1), scale=(0.8, 1.2), interpolation=T.InterpolationMode.BILINEAR),
                       T.RandomInvert(p=1),
                       #T.ColorJitter(hue=0.5, saturation=0.1, contrast=0.2),
                       T.ToTensor()])
'''transform = T.Compose([T.Resize((256, 256)),
                       T.ToTensor()])'''

img = ImageFolder(root='dataset', loader = pil_loader_rgba, transform=transform)

In [129]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [130]:
beta = 5; lr=1e-3; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

for i in range(10000):
    loss_count = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
    if i % 100 == 0:
        print(f'epco {i}: {loss_count / len(gidata)}')

In [131]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [132]:
model.cpu()

mu1 = model.encoder(img[25][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[1][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [133]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [134]:
import wandb
import copy

class Reporter:
    def __init__(self, dt, local=False):
        self.dt = dt
        self.loss_count = {}
        self.k = 0.0
        self.local = local
        self.record = []
    
    def report(self):
        if self.k > 0:
            for k in self.loss_count.keys():
                self.loss_count[k] /= self.k
            if self.local:
                self.record.append(copy.deepcopy(self.loss_count))
                for k,v in self.loss_count.items():
                    print(f'{k}: {v}', end='; ')
                print('.')
            else:
                wandb.log(self.loss_count)
    
    def step(self, loss_dict):
        self.k += 1
        for k, v in loss_dict.items():
            if not (k in self.loss_count):
                self.loss_count[k] = 0.0
            self.loss_count[k] += v
        if self.k >= self.dt:
            self.report()
            self.k = 0
            for k in self.loss_count.keys():
                self.loss_count[k] = 0.0

In [135]:
beta = 5; lr=1e-3; dim=64

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

reporter = Reporter(dt=10)

for i in range(10000):
    loss_count = 0
    k = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
        reporter.step({'epco':i+k/len(gidata), 'loss':loss.item(), 'loss_rec':rec_loss.item(), 'kld':total_kld.item()})
        k += 1

In [136]:
beta = 5; lr=1e-3; dim=64

wandb.init(config={'beta':beta, 'lr':lr, 'dim':dim}, project="Genshin VAE")  # upload args

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

reporter = Reporter(dt=10)

for i in range(10000):
    loss_count = 0
    k = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
        reporter.step({'epco':i+k/len(gidata), 'loss':loss.item(), 'loss_rec':rec_loss.item(), 'kld':total_kld.item()})
        k += 1

In [137]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [138]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [139]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [140]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [141]:
beta = 5; lr=1e-2; dim=64

wandb.init(config={'beta':beta, 'lr':lr, 'dim':dim}, project="Genshin VAE")  # upload args

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

reporter = Reporter(dt=10)

for i in range(10000):
    loss_count = 0
    k = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
        reporter.step({'epco':i+k/len(gidata), 'loss':loss.item(), 'loss_rec':rec_loss.item(), 'kld':total_kld.item()})
        k += 1

In [142]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [143]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [144]:
beta = 5; lr=1e-2; dim=64

wandb.init(config={'beta':beta, 'lr':lr, 'dim':dim}, project="Genshin VAE")  # upload args

gidata = data.DataLoader(img, batch_size=32, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

reporter = Reporter(dt=10)

for i in range(10000):
    loss_count = 0
    k = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
        reporter.step({'epco':i+k/len(gidata), 'loss':loss.item(), 'loss_rec':rec_loss.item(), 'kld':total_kld.item()})
        k += 1

In [145]:
beta = 5; lr=1e-3; dim=64

wandb.init(config={'beta':beta, 'lr':lr, 'dim':dim}, project="Genshin VAE")  # upload args

gidata = data.DataLoader(img, batch_size=32, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

reporter = Reporter(dt=10)

for i in range(10000):
    loss_count = 0
    k = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
        reporter.step({'epco':i+k/len(gidata), 'loss':loss.item(), 'loss_rec':rec_loss.item(), 'kld':total_kld.item()})
        k += 1

In [146]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [147]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [148]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [149]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [150]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [151]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [152]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [153]:
from PIL import Image

def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGBA')  # force alpha channel
        background = Image.new('RGBA', img.size, (255, 255, 255))
        alpha_composite = Image.alpha_composite(background, img).convert('RGB')
    return alpha_composite

# 图像增强：https://pytorch.org/vision/main/auto_examples/plot_transforms.html#random-transforms
transform = T.Compose([T.Resize((256, 256)),
                       T.RandomInvert(p=1),
                       T.RandomHorizontalFlip(),
                       T.RandomAffine(degrees=(-5, 5), translate=(0.1,0.1), interpolation=T.InterpolationMode.BILINEAR),
                       T.RandomInvert(p=1),
                       #T.ColorJitter(hue=0.5, saturation=0.1, contrast=0.2),
                       T.ToTensor()])
'''transform = T.Compose([T.Resize((256, 256)),
                       T.ToTensor()])'''

img = ImageFolder(root='dataset', loader = pil_loader_rgba, transform=transform)

In [154]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [155]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [156]:
from PIL import Image

def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGBA')  # force alpha channel
        background = Image.new('RGBA', img.size, (255, 255, 255))
        alpha_composite = Image.alpha_composite(background, img).convert('RGB')
    return alpha_composite

# 图像增强：https://pytorch.org/vision/main/auto_examples/plot_transforms.html#random-transforms
transform = T.Compose([T.Resize((256, 256)),
                       T.RandomInvert(p=1),
                       T.RandomHorizontalFlip(),
                       T.RandomAffine(translate=(0.1,0.1), interpolation=T.InterpolationMode.BILINEAR),
                       T.RandomInvert(p=1),
                       #T.ColorJitter(hue=0.5, saturation=0.1, contrast=0.2),
                       T.ToTensor()])
'''transform = T.Compose([T.Resize((256, 256)),
                       T.ToTensor()])'''

img = ImageFolder(root='dataset', loader = pil_loader_rgba, transform=transform)

In [157]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [158]:
from PIL import Image

def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGBA')  # force alpha channel
        background = Image.new('RGBA', img.size, (255, 255, 255))
        alpha_composite = Image.alpha_composite(background, img).convert('RGB')
    return alpha_composite

# 图像增强：https://pytorch.org/vision/main/auto_examples/plot_transforms.html#random-transforms
transform = T.Compose([T.Resize((256, 256)),
                       T.RandomInvert(p=1),
                       T.RandomHorizontalFlip(),
                       T.RandomAffine(degrees=0, translate=(0.1,0.1), interpolation=T.InterpolationMode.BILINEAR),
                       T.RandomInvert(p=1),
                       #T.ColorJitter(hue=0.5, saturation=0.1, contrast=0.2),
                       T.ToTensor()])
'''transform = T.Compose([T.Resize((256, 256)),
                       T.ToTensor()])'''

img = ImageFolder(root='dataset', loader = pil_loader_rgba, transform=transform)

In [159]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [160]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [161]:
from PIL import Image

def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGBA')  # force alpha channel
        background = Image.new('RGBA', img.size, (255, 255, 255))
        alpha_composite = Image.alpha_composite(background, img).convert('RGB')
    return alpha_composite

# 图像增强：https://pytorch.org/vision/main/auto_examples/plot_transforms.html#random-transforms
transform = T.Compose([T.Resize((256, 256)),
                       T.RandomInvert(p=1),
                       T.RandomHorizontalFlip(),
                       T.RandomAffine(degrees=0, translate=(0.1,0.1), interpolation=T.InterpolationMode.BILINEAR),
                       T.RandomInvert(p=1),
                       T.ColorJitter(hue=0.5, saturation=0.1, contrast=0.2),
                       T.ToTensor()])
'''transform = T.Compose([T.Resize((256, 256)),
                       T.ToTensor()])'''

img = ImageFolder(root='dataset', loader = pil_loader_rgba, transform=transform)

In [162]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [163]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [164]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [165]:
beta = 5; lr=1e-3; dim=64

wandb.init(config={'beta':beta, 'lr':lr, 'dim':dim}, project="Genshin VAE")  # upload args

gidata = data.DataLoader(img, batch_size=32, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

reporter = Reporter(dt=10)

for i in range(10000):
    loss_count = 0
    k = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
        reporter.step({'epco':i+k/len(gidata), 'loss':loss.item(), 'loss_rec':rec_loss.item(), 'kld':total_kld.item()})
        k += 1

In [166]:
model.eval()
xrecon, mu, logvar = model(arr.unsqueeze(0).cuda())
xrecon = torch.sigmoid(xrecon)
plt.imshow(xrecon[0].transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()
plt.imshow(arr.transpose(0, -1).transpose(0, 1).cpu().detach())
plt.show()

In [167]:
model.cpu()

mu1 = model.encoder(img[17][0].unsqueeze(0))[:, :model.z_dim].detach()
mu2 = model.encoder(img[19][0].unsqueeze(0))[:, :model.z_dim].detach()

fig = plt.figure(figsize=(20,3))
i = 0
for alpha in [0, 0.25, 0.5, 0.75, 1]:
    #alpha = 0.8
    ax1 = plt.subplot(151+i)
    i += 1
    mu_avg = alpha * mu1 + (1-alpha) * mu2

    xrecon = model.decoder(mu_avg).detach()[0].transpose(0, -1).transpose(0, 1)
    xrecon = torch.sigmoid(xrecon)
    plt.imshow(xrecon)
    plt.title(f'alpha={alpha}')
plt.show()

In [168]:
from PIL import Image

def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGBA')  # force alpha channel
        background = Image.new('RGBA', img.size, (255, 255, 255))
        alpha_composite = Image.alpha_composite(background, img).convert('RGB')
    return alpha_composite

# 图像增强：https://pytorch.org/vision/main/auto_examples/plot_transforms.html#random-transforms
transform = T.Compose([T.Resize((256, 256)),
                       T.RandomInvert(p=1),
                       T.RandomHorizontalFlip(),
                       T.RandomAffine(degrees=0, translate=(0.1,0.1), interpolation=T.InterpolationMode.BILINEAR),
                       T.RandomInvert(p=1),
                       #T.ColorJitter(hue=0.5, saturation=0.1, contrast=0.2),
                       T.ToTensor()])
'''transform = T.Compose([T.Resize((256, 256)),
                       T.ToTensor()])'''

img = ImageFolder(root='dataset', loader = pil_loader_rgba, transform=transform)

In [169]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [170]:
import wandb
import copy

class Reporter:
    def __init__(self, dt, local=False):
        self.dt = dt
        self.loss_count = {}
        self.k = 0.0
        self.local = local
        self.record = []
    
    def report(self):
        if self.k > 0:
            for k in self.loss_count.keys():
                self.loss_count[k] /= self.k
            if self.local:
                self.record.append(copy.deepcopy(self.loss_count))
                for k,v in self.loss_count.items():
                    print(f'{k}: {v}', end='; ')
                print('.')
            else:
                wandb.log(self.loss_count)
    
    def step(self, loss_dict):
        self.k += 1
        for k, v in loss_dict.items():
            if not (k in self.loss_count):
                self.loss_count[k] = 0.0
            self.loss_count[k] += v
        if self.k >= self.dt:
            self.report()
            self.k = 0
            for k in self.loss_count.keys():
                self.loss_count[k] = 0.0

In [171]:
beta = 5; lr=1e-3; dim=64

wandb.init(config={'beta':beta, 'lr':lr, 'dim':dim}, project="Genshin VAE")  # upload args

gidata = data.DataLoader(img, batch_size=32, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)

reporter = Reporter(dt=10)

for i in range(10000):
    loss_count = 0
    k = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
        reporter.step({'epco':i+k/len(gidata), 'loss':loss.item(), 'loss_rec':rec_loss.item(), 'kld':total_kld.item()})
        k += 1