<a href="https://colab.research.google.com/github/KazutoYamada/VAEs/blob/main/VAEs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**VAEs**

以下の基本的なVAEの学習用ノートブック（PyTorch）です。


1.   VAE（Variational Autoencoder）
2.   Convolutional VAE
3.   Conditional VAE

学習データは、MNIST(28,28,1)、CIFAR10(32,32,3)、ラベル付きオリジナルデータ(resize_to_32,32,3)を利用可能です。 オリジナルデータを使用の際、ファイル名かデータ名のどちらにラベルされているかで仕様が異なります。


**0. 学習データの準備**

※Convolutional VAEの際はtransformの一次元化をコメントアウトにする必要あり。

※画像サイズに合わせてモデルのパラメータ変更必要。

In [None]:
#コラボのマウント
from google.colab import drive
drive.mount('/content/drive/')

In [None]:
#ライブラリのインポート
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import skimage
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import datasets, transforms

0.1. MNISTなどダウンロードデータを使用の際

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Lambda(lambda x: x.view(-1))])

#MNIST_data
train_data =  datasets.MNIST('./data/MNIST', train=True, download=True, transform = transform)
val_data =  datasets.MNIST('./data/MNIST', train=False, download=True, transform = transform)

#CIFAR10

"""train_data = datasets.CIFAR10('/data/CIFAR10', train=True, download=True, transform = transform)
val_data = datasets.CIFAR10('/data/CIFAR10', train=False, download=True, transform = transform)"""


train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

0.2.オリジナルデータを使用の際（データ名ごとにラベル付けしている場合）

In [None]:
dir_name = "/content/drive/My Drive/"
image_name_list = os.listdir(dir_name)

#データ前処理

class dataset(Dataset):
    def __init__(self, image_name_list):
        self.image_name_list = image_name_list
        self.transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Lambda(lambda x: x.view(-1))])
    def __len__(self):
        return len(self.image_name_list)

    def __getitem__(self, idx):
        image_name = self.image_name_list[idx]
        image = skimage.io.imread(dir_name + image_name)
        image = skimage.transform.resize(image, (32,32))
        trans_image = self.transform(image[:,:,:3])
        label = int(image_name.split(".")[0])
        return trans_image, torch.Tensor([label])

#データ読み込み
train_val_dataset = dataset(image_name_list)

train_rate = 0.7

train_size = int(train_rate * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size

train_data, val_data = random_split(train_val_dataset,[train_size, val_size]])

train_loader = DataLoader(dataset, batch_size = 32, shuffle= True, drop_last = False)
val_loader = DataLoader(dataset, batch_size = 32, shuffle= False, drop_last = False)

0.3.オリジナルデータを使用の際（ファイル名ごとにラベル付けしている場合）

In [None]:
dir_name = "/content/drive/My Drive/"
image_name_list = os.listdir(dir_name)

#データ前処理
data_transform = transforms.Compose([transforms.Resize((32,32)),
                                    transforms.ToTensor(),
                                     transforms.Lambda(lambda x: x.view(-1))])

#データ読み込み
train_val_dataset = datasets.ImageFolder(dir_name, data_transform)

#あとは0.2と同じ

1. VAE

In [None]:
#ハイパーパラメータの設定
epochs = 5
lr = 0.001
beta = 1
z_dim = 200
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)
batch_size = 32

In [None]:
#モデル設定

def torch_log(x):
    return torch.log(torch.clamp(x, min=1e-10))

class VAE(nn.Module):
    def __init__(self, z_dim):
        super().__init__()

        #Encoder
        self.ln_en = nn.Linear(32*32*3,1000)
        self.ln_en_mean = nn.Linear(1000,z_dim)
        self.ln_en_std = nn.Linear(1000,z_dim)

        #Decoder
        self.ln_de1 = nn.Linear(z_dim,1000)
        self.ln_de2 = nn.Linear(1000, 32*32*3)

    def encoder(self, x):
        x = F.relu(self.ln_en(x))
        mean = self.ln_en_mean(x)
        std = F.softplus(self.ln_en_std(x))
        return mean, std

    def re_trick(self, mean, std):
        if self.training:
            epsilon = torch.randn(mean.shape).to(device)
            return mean + std * epsilon
        else:
            return mean

    def decoder(self, z):
        x = F.relu(self.ln_de1(z))
        x = torch.sigmoid(self.ln_de2(x))
        return x

    def forward(self, x):
        mean, std = self.encoder(x)
        z = self.re_trick(mean, std)
        x = self.decoder(z)
        return x, z

    def loss(self, x):
        mean, std = self.encoder(x)
        KL = -0.5 * torch.mean(torch.sum(1+torch_log(std**2) - mean**2 - std**2, dim = 1))

        z = self.re_trick(mean, std)
        y = self.decoder(z)
        #recon = -torch.mean(torch.sum(x*torch_log(y)+(1-x)*torch_log(1-y), dim=1)) #ベルヌーイ分布のとき
        recon = torch.mean(torch.sum((y-x)**2, dim=1))
        return KL, recon

In [None]:
#学習
model  = VAE(z_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

epoch_train_losses = []
epoch_val_losses = []

for epoch in range(epochs):

    batch_train_loss = []
    batch_val_loss = []

    model.train()
    for x, _ in train_loader: #train_loader
        x = x.to(device)
        model.zero_grad()
        KL, recon = model.loss(x)
        loss = beta * KL + recon
        loss.backward()
        optimizer.step()
        batch_train_loss.append(loss.cpu().detach().numpy())
    epoch_train_losses.append(np.average(batch_train_loss))

    model.eval()
    for x, _ in val_loader: #val_loader
        x = x.to(device)
        KL, recon = model.loss(x)
        loss = beta * KL + recon
        batch_val_loss.append(loss.cpu().detach().numpy())
    epoch_val_losses.append(np.average(batch_val_loss))

    print("{} EPOCH: Train_Loss -> {}, Val_Loss -> {}".format(epoch, np.average(batch_train_loss), np.average(batch_val_loss)))

plt.plot(range(len(epoch_train_losses)), epoch_train_losses, label="Train_Loss")
plt.plot(range(len(epoch_val_losses)), epoch_val_losses, label="Val_Loss")
plt.legend()

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()


In [None]:
#画像確認

#元画像
fig = plt.figure(figsize=(3,3))
for i in range(9):
    x, _ = val_data[i] #val_data
    im = x.view(-1,28, 28).permute(1,2,0).squeeze().numpy()

    ax = fig.add_subplot(3,3,i+1)
    ax.imshow(im)

#再構成画像

model.eval()

fig = plt.figure(figsize=(3,3))
for i in range(9):
    x, _ = val_data[i] #val_data
    x = x.unsqueeze(0).to(device)
    y, z = model(x)
    im = y.view(-1, 28, 28).permute(1,2,0).detach().cpu().squeeze().numpy()

    ax = fig.add_subplot(3,3,i+1)
    ax.imshow(im)

In [None]:
#潜在空間の確認

model.eval()

z_list = []
t_list = []

for x, t in val_data:
    t_list.append(t)
    x = x.to(device).unsqueeze(0)
    _, z = model(x)
    z_list.append(z.cpu().detach().numpy()[0])

z_val = np.stack(z_list)

#線形なPCAでの次元削減

z_reduc = PCA(n_components = 2).fit_transform(z_val).T

colors = ['khaki', 'lightgreen', 'cornflowerblue', 'violet', 'sienna', 'darkturquoise', 'slateblue', 'orange', 'darkcyan', 'tomato']

plt.figure(figsize=(5,5))
plt.scatter(*z_reduc, s = 0.7, c =[colors[t] for t in t_list])

for i in range(10):
    plt.scatter([],[], c = colors[i], label=i)
plt.legend()

#非線形なtSNEでの次元削減
z_reduc = TSNE(n_components = 2).fit_transform(z_val).T

colors = ['khaki', 'lightgreen', 'cornflowerblue', 'violet', 'sienna', 'darkturquoise', 'slateblue', 'orange', 'darkcyan', 'tomato']

plt.figure(figsize=(5,5))
plt.scatter(*z_reduc, s = 0.7, c =[colors[t] for t in t_list])

for i in range(10):
    plt.scatter([],[], c = colors[i], label=i)
plt.legend()

**2. Convolutional VAE**

In [None]:
#モデル設定

def torch_log(x):
    return torch.log(torch.clamp(x, min=1e-10))

class VAE(nn.Module):
    def __init__(self, z_dim):
        super().__init__()

        #Encoder
        self.convblock1 = nn.Sequential(
            nn.Conv2d(3,64,3,padding = 1), #(32,32,3)の画像を想定
            nn.ReLU(),
            nn.Conv2d(64,64,3,padding = 1),
            nn.ReLU(),
        )
        self.convblock2 = nn.Sequential(
            nn.Conv2d(64,128,3,padding = 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128,128,3,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.conv2d1 = nn.Conv2d(128,128, 3)

        self.maxpooling = nn.MaxPool2d(2)

        self.ln_en = nn.Linear(6*6*128,1000)
        self.ln_en_mean = nn.Linear(1000,z_dim)
        self.ln_en_std = nn.Linear(1000,z_dim)

        #Decoder
        self.ln_de1 = nn.Linear(z_dim,1000)
        self.ln_de2 = nn.Linear(1000, 9*9*128)

        self.conv2d2 = nn.Conv2d(128,128, 3, padding = 1)

        self.convblock3 = nn.Sequential(
            nn.Conv2d(128,128,3,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128,64,3,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.convblock4 = nn.Sequential(
            nn.Conv2d(64,64,3),
            nn.ReLU(),
            nn.Conv2d(64,3,3),
            nn.Sigmoid()
        )
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear")

    def encoder(self, x):
        x = self.convblock1(x)
        x = self.maxpooling(x)
        x = self.convblock2(x)
        x = self.maxpooling(x)
        x = self.conv2d1(x)
        x = x.view(-1,6*6*128)
        x = self.ln_en(x)
        mean = self.ln_en_mean(x)
        std = F.softplus(self.ln_en_std(x))
        return mean, std

    def re_trick(self, mean, std):
        if self.training:
            epsilon = torch.randn(mean.shape).to(device)
            return mean + std * epsilon
        else:
            return mean

    def decoder(self, z):
        x = F.relu(self.ln_de1(z))
        x = F.relu(self.ln_de2(x))
        x = x.view(-1,128,9,9)
        x = self.conv2d2(x)
        x = self.upsample(x)
        x = self.convblock3(x)
        x = self.upsample(x)
        x = self.convblock4(x)
        return x


    def forward(self, x):
        mean, std = self.encoder(x)
        z = self.re_trick(mean, std)
        x = self.decoder(z)
        return x, z

    def loss(self, x):
        mean, std = self.encoder(x)
        KL = -0.5 * torch.sum(1+torch_log(std**2) - mean**2 - std**2)/ batch_size

        z = self.re_trick(mean, std)
        y = self.decoder(z)
        #recon = -torch.mean(torch.sum(x*torch_log(y)+(1-x)*torch_log(1-y), dim=1))
        recon = torch.sum((y-x)**2) / batch_size
        return KL, recon

In [None]:
#学習
model  = VAE(z_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

epoch_train_losses = []
epoch_val_losses = []

for epoch in range(epochs):

    batch_train_loss = []
    batch_val_loss = []


    model.train()
    for x, _ in dataloader: #train_loader
        x = x.to(device)
        model.zero_grad()
        KL, recon = model.loss(x)
        loss = beta * KL + recon
        loss.backward()
        optimizer.step()
        batch_train_loss.append(loss.cpu().detach().numpy())
    epoch_train_losses.append(np.average(batch_train_loss))


    model.eval()
    for x, _ in dataloader: #val_loader
        x = x.to(device)
        KL, recon = model.loss(x)
        loss = beta * KL + recon
        batch_val_loss.append(loss.cpu().detach().numpy())
    epoch_val_losses.append(np.average(batch_val_loss))

    print("{} EPOCH: Train_Loss -> {}, Val_Loss -> {}".format(epoch, np.average(batch_train_loss), np.average(batch_val_loss)))

plt.plot(range(len(epoch_train_losses)), epoch_train_losses, label="Train_Loss")
plt.plot(range(len(epoch_val_losses)), epoch_val_losses, label="Val_Loss")
plt.legend()

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()


In [None]:
#画像確認

fig = plt.figure(figsize=(3,3))
for i in range(9):
    x, _ = dataset[i] #train_data
    im = x.view(-1, 32, 32).permute(1,2,0).squeeze().numpy()

    ax = fig.add_subplot(3,3,i+1)
    ax.imshow(im)

model.eval()

fig = plt.figure(figsize=(3,3))
for i in range(9):
    x, _ = dataset[i] #val_data
    x = x.unsqueeze(0).to(device)
    y, z = model(x)
    im = y.view(-1, 32, 32).permute(1,2,0).detach().cpu().squeeze().numpy()

    ax = fig.add_subplot(3,3,i+1)
    ax.imshow(im)

**3.Conditional VAE**

In [None]:
#モデル設定

def torch_log(x):
    return torch.log(torch.clamp(x, min=1e-10))

class VAE(nn.Module):
    def __init__(self, z_dim):
        super().__init__()

        #Encoder
        self.ln_en1 = nn.Linear(28*28+class_size,1000)
        #self.ln_en2 = nn.Linear(1000, 200)
        self.ln_en_mean = nn.Linear(1000,z_dim)
        self.ln_en_std = nn.Linear(1000,z_dim)

        #Decoder
        self.ln_de1 = nn.Linear(z_dim+class_size,1000)
        #self.ln_de2 = nn.Linear(200, 1000)
        self.ln_de3 = nn.Linear(1000, 28*28)

    def encoder(self, x):
        x = F.relu(self.ln_en1(x))
        #x = F.relu(self.ln_en2(x))
        mean = self.ln_en_mean(x)
        std = F.softplus(self.ln_en_std(x))
        return mean, std

    def re_trick(self, mean, std):
        if self.training:
            epsilon = torch.randn(mean.shape).to(device)
            return mean + std * epsilon
        else:
            return mean

    def decoder(self, z):
        x = F.relu(self.ln_de1(z))
        #x = F.relu(self.ln_de2(x))
        x = torch.sigmoid(self.ln_de3(x))
        return x

    def forward(self, x, label):
        one_hot_label = F.one_hot(label, num_classes=class_size).to(torch.float32)
        x_label = torch.cat((x, one_hot_label), dim = 1)
        mean, std = self.encoder(x_label)
        z = self.re_trick(mean, std)
        z_label = torch.cat((z,one_hot_label), dim=1)
        x = self.decoder(z_label)
        return x, z

    def loss(self, x, label):
        one_hot_label = F.one_hot(label, num_classes=class_size).to(torch.float32)
        x_label = torch.cat((x, one_hot_label), dim = 1)

        mean, std = self.encoder(x_label)
        KL = -0.5 * torch.mean(torch.sum(1+torch_log(std**2) - mean**2 - std**2, dim = 1))

        z = self.re_trick(mean, std)

        z_label = torch.cat((z,one_hot_label), dim=1)
        y = self.decoder(z_label)

        recon = torch.mean(torch.sum((y-x)**2, dim=1))
        #recon = -torch.mean(torch.sum(x*torch_log(y)+(1-x)*torch_log(1-y), dim=1))
        return KL, recon

In [None]:
#学習
model  = VAE(z_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

epoch_train_losses = []
epoch_val_losses = []

for epoch in range(epochs):

    batch_train_loss = []
    batch_val_loss = []

    model.train()
    for x, label in train_loader: #train_loader
        x = x.to(device)
        label = label.to(device)

        model.zero_grad()
        KL, recon = model.loss(x, label)
        loss = beta * KL + recon
        loss.backward()
        optimizer.step()
        batch_train_loss.append(loss.cpu().detach().numpy())
    epoch_train_losses.append(np.average(batch_train_loss))

    model.eval()
    for x, label in val_loader: #val_loader
        x = x.to(device)
        label = label.to(device)

        KL, recon = model.loss(x, label)
        loss = beta * KL + recon
        batch_val_loss.append(loss.cpu().detach().numpy())
    epoch_val_losses.append(np.average(batch_val_loss))

    print("{} EPOCH: Train_Loss -> {}, Val_Loss -> {}".format(epoch, np.average(batch_train_loss), np.average(batch_val_loss)))

plt.plot(range(len(epoch_train_losses)), epoch_train_losses, label="Train_Loss")
plt.plot(range(len(epoch_val_losses)), epoch_val_losses, label="Val_Loss")
plt.legend()

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()


In [None]:
#画像確認

#元画像
fig = plt.figure(figsize=(3,3))
for i in range(9):
    x, _ = val_data[i] #val_data
    im = x.view(-1,28, 28).permute(1,2,0).squeeze().numpy()

    ax = fig.add_subplot(3,3,i+1)
    ax.imshow(im)

#再構成画像
model.eval()

fig = plt.figure(figsize=(3,3))
for i in range(9):
    x, label = val_data[i] #val_data
    label = torch.tensor([label]).to(device)
    x = x.unsqueeze(0).to(device)
    y, z = model(x, label)
    im = y.view(-1, 28, 28).permute(1,2,0).detach().cpu().squeeze().numpy()

    ax = fig.add_subplot(3,3,i+1)
    ax.imshow(im)

label = torch.tensor([5])
one_hot_label = F.one_hot(label, num_classes=class_size).to(torch.float32)

#ラベル指定画像生成
model.eval()

fig = plt.figure(figsize=(3,3))
for i in range(9):
    z = torch.randn((z_dim), device = device)
    z = z.unsqueeze(0).to(device)
    one_hot_label = one_hot_label.to(device)
    z_label = torch.cat((z,one_hot_label), dim=1)
    y = model.decoder(z_label)
    im = y.view(-1, 28, 28).permute(1,2,0).detach().cpu().squeeze().numpy()

    ax = fig.add_subplot(3,3,i+1)
    ax.imshow(im)