In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

from torchvision.datasets import ImageFolder
import torchvision.transforms as T

import matplotlib.pyplot as plt
%matplotlib inline

# https://towardsdatascience.com/beginners-guide-to-loading-image-data-with-pytorch-289c60b7afec
# β-VAE: https://github.com/1Konny/Beta-VAE

In [2]:
from PIL import Image

def pil_loader_rgba(path: str) -> Image.Image:
    with open(path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGBA')  # force alpha channel
        background = Image.new('RGBA', img.size, (255, 255, 255))
        alpha_composite = Image.alpha_composite(background, img).convert('RGB')
    return alpha_composite

# 图像增强：https://pytorch.org/vision/main/auto_examples/plot_transforms.html#random-transforms
transform = T.Compose([T.Resize((128, 128)),
                       T.RandomInvert(p=1),
                       T.RandomHorizontalFlip(),
                       T.RandomAffine(degrees=0, translate=(0.1,0.1), interpolation=T.InterpolationMode.BILINEAR),
                       T.RandomInvert(p=1),
                       #T.ColorJitter(hue=0.5, saturation=0.1, contrast=0.2),
                       T.ToTensor()])
'''transform = T.Compose([T.Resize((256, 256)),
                       T.ToTensor()])'''

img = ImageFolder(root='dataset', loader = pil_loader_rgba, transform=transform)

In [3]:
plt.figure(figsize=(6,6))
for i in [16, 17, 18, 19]:
    ax = plt.subplot(221 + i - 16)
    arr, cls = img[i]

    plt.imshow(arr.transpose(0,-1).transpose(0,1), vmin=0, vmax=1)
plt.show()

In [4]:
from model import BetaVAE_H as VAE
from Solver import reconstruction_loss, kl_divergence

model = VAE(nc=3)
model.eval()

xrecon, mu, logvar = model(arr.unsqueeze(0))

In [5]:
import wandb
import copy

class Reporter:
    def __init__(self, dt, local=False):
        self.dt = dt
        self.loss_count = {}
        self.k = 0.0
        self.local = local
        self.record = []
    
    def report(self):
        if self.k > 0:
            for k in self.loss_count.keys():
                self.loss_count[k] /= self.k
            if self.local:
                self.record.append(copy.deepcopy(self.loss_count))
                for k,v in self.loss_count.items():
                    print(f'{k}: {v}', end='; ')
                print('.')
            else:
                wandb.log(self.loss_count)
    
    def step(self, loss_dict):
        self.k += 1
        for k, v in loss_dict.items():
            if not (k in self.loss_count):
                self.loss_count[k] = 0.0
            self.loss_count[k] += v
        if self.k >= self.dt:
            self.report()
            self.k = 0
            for k in self.loss_count.keys():
                self.loss_count[k] = 0.0

In [6]:
beta = 5; lr=1e-3; dim=64

wandb.init(config={'beta':beta, 'lr':lr, 'dim':dim}, project="Genshin VAE")  # upload args

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr)

reporter = Reporter(dt=10)

def max_weight(model):
    max_ = -1
    for para in model.parameters():
        para_max = max(abs(para)).item()
        if para_max > max_:
            max_ = para_max
    return max_

for i in range(10000):
    loss_count = 0
    k = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
        reporter.step({'epco':i+k/len(gidata), 'loss':loss.item(), 'loss_rec':rec_loss.item(), 'kld':total_kld.item(), 'max_w':max_weight(model)})
        k += 1
    if i % 1000 == 0:
        torch.save(model, f'./models/model_{i}.pth')

In [7]:
model.encoder[0].weight

Parameter containing:
tensor([[[[ 0.1351,  0.2457,  0.0470, -0.0712],
          [-0.0396,  0.0321, -0.0349,  0.0991],
          [ 0.4036, -0.2128,  0.0321,  0.1213],
          [-0.0840,  0.0621, -0.3310, -0.3118]],

         [[ 0.2368, -0.2397, -0.0726,  0.0765],
          [ 0.1789,  0.0852,  0.0208,  0.0166],
          [-0.1838, -0.4289, -0.0909, -0.0076],
          [-0.0275, -0.1189,  0.2796, -0.1264]],

         [[-0.1004, -0.0921,  0.1749,  0.0333],
          [-0.1626, -0.1281, -0.1731,  0.2038],
          [ 0.0268, -0.0474,  0.1778, -0.1087],
          [-0.2194,  0.1509,  0.2749, -0.2196]]],


        [[[-0.1895, -0.0515,  0.1415,  0.2240],
          [ 0.4438, -0.2663, -0.1410,  0.0297],
          [ 0.4714,  0.0717,  0.2180,  0.0187],
          [-0.0660,  0.3439,  0.4332, -0.0726]],

         [[ 0.1908, -0.1847, -0.0679,  0.2698],
          [-0.3016,  0.0480,  0.2567,  0.0652],
          [-0.0749,  0.0280, -0.0883, -0.3398],
          [ 0.2700, -0.3424,  0.0724, -0.3297]],

      

In [8]:
abs(model.encoder[0].weight)

tensor([[[[0.1351, 0.2457, 0.0470, 0.0712],
          [0.0396, 0.0321, 0.0349, 0.0991],
          [0.4036, 0.2128, 0.0321, 0.1213],
          [0.0840, 0.0621, 0.3310, 0.3118]],

         [[0.2368, 0.2397, 0.0726, 0.0765],
          [0.1789, 0.0852, 0.0208, 0.0166],
          [0.1838, 0.4289, 0.0909, 0.0076],
          [0.0275, 0.1189, 0.2796, 0.1264]],

         [[0.1004, 0.0921, 0.1749, 0.0333],
          [0.1626, 0.1281, 0.1731, 0.2038],
          [0.0268, 0.0474, 0.1778, 0.1087],
          [0.2194, 0.1509, 0.2749, 0.2196]]],


        [[[0.1895, 0.0515, 0.1415, 0.2240],
          [0.4438, 0.2663, 0.1410, 0.0297],
          [0.4714, 0.0717, 0.2180, 0.0187],
          [0.0660, 0.3439, 0.4332, 0.0726]],

         [[0.1908, 0.1847, 0.0679, 0.2698],
          [0.3016, 0.0480, 0.2567, 0.0652],
          [0.0749, 0.0280, 0.0883, 0.3398],
          [0.2700, 0.3424, 0.0724, 0.3297]],

         [[0.0154, 0.0673, 0.1052, 0.0231],
          [0.0840, 0.1351, 0.1442, 0.0221],
          [0.1285, 0

In [9]:
max(abs(model.encoder[0].weight))

In [10]:
beta = 5; lr=1e-3; dim=64

wandb.init(config={'beta':beta, 'lr':lr, 'dim':dim}, project="Genshin VAE")  # upload args

gidata = data.DataLoader(img, batch_size=16, shuffle=True)
model = VAE(nc=3, z_dim=dim)
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=lr)

reporter = Reporter(dt=10)

def max_weight(model):
    max_ = -1
    for para in model.parameters():
        para_max = abs(para.data).max().item()
        if para_max > max_:
            max_ = para_max
    return max_

for i in range(10000):
    loss_count = 0
    k = 0
    for x, cls in gidata:
        optimizer.zero_grad()
        x = x.cuda()
        xrecon, mu, logvar = model(x)
        rec_loss = reconstruction_loss(x, xrecon, distribution='gaussian')
        total_kld, dimension_wise_kld, mean_kld = kl_divergence(mu, logvar)

        loss = rec_loss + beta * total_kld
        loss_count += loss.item()
        loss.backward()
        optimizer.step()
        reporter.step({'epco':i+k/len(gidata), 'loss':loss.item(), 'loss_rec':rec_loss.item(), 'kld':total_kld.item(), 'max_w':max_weight(model)})
        k += 1
    if i % 1000 == 0:
        torch.save(model, f'./models/model_{i}.pth')