In [6]:
from __future__ import  print_function
import torch
from torch import nn, optim
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision
from torchvision import transforms
import cv2
from torchvision.utils import save_image

In [2]:
class MyDataset(Dataset):
    def __init__(self, data_npz):
        x_npz = np.load(data_npz)
        x_ww = x_npz['arr_0']
        self.len = x_ww.shape[0]
        self.dim = x_ww.shape[1]
        self.data = torch.from_numpy(x_ww)
    
    def __getitem__(self,index):
        return self.data[index]
            
    def __len__(self):
        return self.len

In [3]:
datset = MyDataset('../x_ww_bw_50176_pre-processed.npz')

In [4]:
validation_split = .2
shuffle_dataset = True
random_seed = 42

In [5]:
dataset_size = len(datset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

In [7]:
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

In [8]:
train_loader = DataLoader(dataset=datset, batch_size=32, sampler=train_sampler, num_workers=2)

In [9]:
test_loader = DataLoader(dataset=datset, batch_size=32, sampler=val_sampler, num_workers=2)

In [10]:
device = torch.device('cuda')

In [11]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        self.fc1 = nn.Linear(50176, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 50176)
        
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
    
    def reparametrize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.rand_like(std)
        return eps.mul(std).add_(mu)
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1,50176))
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar

In [12]:
model = VAE().to(device)

In [13]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [14]:
def loss_func(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 50176), reduction='sum')
    
    KLD = -0.5* torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return BCE + KLD

In [17]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader, 0):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_func(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('train epoch: {} [{}/{} ({:.2f}%)]\tLoss:{:.3f}'.format(
                    epoch, batch_idx*len(data), len(train_loader.dataset),
                    100.* batch_idx/len(train_loader), loss.item()/len(data)))
            
#        print('=====> Epoch:{} Avarage Loss: {:.4f}'.format(
#                epoch, train_loss/ len(train_loader.dataset)))

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, data in enumerate(train_loader, 0):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_func(recon_batch, data, mu, logvar).item()
            if batch_idx ==0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data.view(32,1,224,224)[:n], recon_batch.view(32,1,224,224)[:n]])
                save_image
            

In [18]:
for epoch in range(1,10):
    train(epoch)

