In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from torchvision import datasets, transforms
import torchvision.transforms as transforms
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision.utils import save_image

from torchsummary import summary

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import math
from PIL import Image
from IPython.display import display
import glob

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
class VAE(nn.Module):
    def __init__(self, latent_size=15):
        super(VAE, self).__init__()
        
        self.latent_size = latent_size
        
        self.l1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=2, padding=1)
        self.l2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=4, stride=2, padding=1)
        self.l3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=4, stride=2, padding=1)
        
        self.l31 = nn.Linear(6*4*20*20, self.latent_size)
        self.l32 = nn.Linear(6*4*20*20, self.latent_size)
        
        self.f = nn.Linear(self.latent_size, 6*4*20*20)
        
        self.l4 = nn.ConvTranspose2d(in_channels=24, out_channels=12, kernel_size=4, stride=2, padding=1)
        self.l5 = nn.ConvTranspose2d(in_channels=12, out_channels=6, kernel_size=4, stride=2, padding=1)
        self.l6 = nn.ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=4, stride=2, padding=1)
        
    def encoder(self, x_in):
        h = F.relu(self.l1(x_in))
        h = F.relu(self.l2(h))
        h = F.relu(self.l3(h))
        
        h = h.view(h.size(0), -1)
        
        return self.l31(h), self.l32(h)
    
    def decoder(self, z):
        z = self.f(z)
        z = z.view(z.size(0), 6*4, 20, 20)
        
        z = F.relu(self.l4(z))
        z = F.relu(self.l5(z))
        z = torch.sigmoid(self.l6(z))
        
        return z
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return torch.add(eps.mul(std), mu)
    
    def forward(self, x_in):
        mu, log_var = self.encoder(x_in)
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

In [4]:
vae = VAE()

In [5]:
vae.load_state_dict(torch.load('./model_weight/wf.pt'))

<All keys matched successfully>