In [None]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import os
import numpy as np
from torchsummary import summary #for summary
import matplotlib.pyplot as plt

if not os.path.exists('./dc_img'):
    os.mkdir('./dc_img')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
np.random.seed = 23

In [None]:
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x
  

num_epochs = 100
batch_size = 32
learning_rate = 1e-3

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5), (0.5,0.5))
])

dataset = MNIST('./data', transform=img_transform, download=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
class encoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.ReLU()
        )
        
    def forward(self, data):
        return self.encoder(data)
    
class decoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(decoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.ReLU()
        )
    
    def forward(self,data):
        return self.decoder(data)
    

class VariationalAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(VariationalAutoencoder, self).__init__()
        self.input_size = 100
        self.output_size = 8
        self.encoder_model = encoder(input_size, hidden_size, output_size).to(device)
        self.decoder_model  = decoder(self.output_size, hidden_size, input_size).to(device)
        self.fc1 = nn.Linear(self.input_size, self.output_size).to(device)
        self.fc2 = nn.Linear(self.input_size, self.output_size).to(device)
        
    def forward(self, data):
        x = self.encoder_model(data)
        encoder_out = x.to(device)
        self.mean_out = self.fc1(encoder_out)
        self.log_sigma = self.fc1(encoder_out)
        self.std_div = torch.exp(self.log_sigma)
        noraml_values = torch.from_numpy(np.random.normal(0,1,size=self.std_div.size())).float().to(device)
        latent_out = self.std_div*Variable(noraml_values, requires_grad = False) + self.mean_out
        return self.decoder_model(latent_out.float().to(device))
    

# encoder_model = encoder(28*28, 100, 100).to(device)
# decoder_model = decoder(100,100,28*28).to(device)

# summary(encoder_model, (1,28*28))
# summary(decoder_model, (100,100))

vae = VariationalAutoencoder(28*28, 100, 100).to(device)

In [16]:
def latent_loss(mean, std_div, beta):
    mean_sq = mean*mean
    std_div_sq = std_div*std_div
    return beta*torch.mean(mean_sq + std_div_sq + torch.log(std_div) - 1)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)

In [17]:
# test = img[0].reshape([1,1,28,28]).to(device)
# output = vae.forward(test)
# plt.imshow(output.reshape([28,28]).cpu().detach().numpy())

In [18]:
num_epochs = 100
beta = 0.5
input_dim = 28*28

for epoch in range(num_epochs):
    for data in dataloader:
        img, _ = data
        img = Variable(img.resize_(batch_size, input_dim)).to(device)
        
        # ===================forward=====================
        output = vae.forward(img)
        ll = latent_loss(vae.mean_out, vae.std_div, beta)
        loss = criterion(output, img) + ll
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch+1, num_epochs, loss.item()))
    if epoch % 2 == 0:
        pic = to_img(output.cpu().data)
        save_image(pic, './dc_img/image_{}.png'.format(epoch))

epoch [1/100], loss:0.4409
epoch [2/100], loss:0.4390
epoch [3/100], loss:0.4442
epoch [4/100], loss:0.4430
epoch [5/100], loss:0.4393
epoch [6/100], loss:0.4428
epoch [7/100], loss:0.4435
epoch [8/100], loss:0.4414
epoch [9/100], loss:0.4467
epoch [10/100], loss:0.4437
epoch [11/100], loss:0.4446
epoch [12/100], loss:0.4414
epoch [13/100], loss:0.4458
epoch [14/100], loss:0.4493
epoch [15/100], loss:0.4451
epoch [16/100], loss:0.4429
epoch [17/100], loss:0.4443
epoch [18/100], loss:0.4461
epoch [19/100], loss:0.4479


KeyboardInterrupt: 

In [None]:
data, _ = next(iter(dataloader))
