In [4]:
from __future__ import print_function

import os
import sys
import glob
import h5py
import numpy as np
import math


import torch
from torch import nn
from torch.autograd import Variable
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset , DataLoader
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

import tensorboard
import tensorboardX
from torch.utils.tensorboard import SummaryWriter


from log import Logger
from data import  trainDataset, testDataset, trainlabelDataset,testlabelDataset
from util import r2, mse, rmse, mae, pp_mse, pp_rmse, pp_mae
from model import autoencoder_999, autoencoder_333

def to_img(x):   # image size 
    x = x.view(x.size(0), 1, 64, 64)
    return x

In [7]:
class vae_501(nn.Module):   # 
    def __init__(self):            #  1x 64 x 64 
        
        super(vae_501, self).__init__()
        
        self.fc11 = nn.Linear(14*14, 14)
        self.fc12 = nn.Linear(14*14, 14)

        self.fc21 = nn.Linear(14, 14*14)
        
        self.enc = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=1, padding=1),  # 64 * 64 * 64  
            nn.BatchNorm2d(64), 
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),   # 64 * 31 * 31 
            
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # 128 *16 * 16 
            nn.ReLU(True),
            nn.BatchNorm2d(128), 
            nn.MaxPool2d(2, stride=1),  # 128 * 15 * 15 
            
            nn.Conv2d(128, 64, 3, stride=1, padding=1),  # b,  * 15 * 15 
            nn.ReLU(True),
            nn.BatchNorm2d(64), 
            nn.MaxPool2d(2, stride=1),  # b, 256, 14, 14 
            
            nn.Conv2d(64, 1, 3, stride=1, padding=1),  # b, 1  x 14, 14 
            nn.BatchNorm2d(1), 
            nn.ReLU(True)
        )
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(1, 64, 3, stride=2),  # b, 2,  
            nn.BatchNorm2d(64), 
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 128, 3, stride=2),  # b, 8, 55, 55 
            nn.BatchNorm2d(128), 
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 3, stride=1),  # b, 16, 
            nn.BatchNorm2d(64), 
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, stride=1),  # b, 1,  64 x 64 
        )

        
        
    def encoder(self, x):
        h1=self.enc(x)
      #  print("h1", h1.shape)
        h2=h1.view(-1,14*14)
        
        return  self.fc11(h2), self.fc12(h2)
        
    def reparametrize(self, mu, logvar):  # mu, sigma --> mu + sigma * N(0,1)
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)
    
        
        
    def decoder (self, z):
        
        h3= self.fc21(z)
        h4=h3.view(-1,1,14,14)

        return self.dec(h4)
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparametrize(mu, logvar)
        return self.decoder(z), mu, logvar
    
    
    
    
    

In [9]:
vae = vae_501().cuda()

In [10]:
a = torch.randn(2, 1, 64, 64).cuda()

In [11]:
b = vae(a)

In [15]:
b[2].size()

torch.Size([2, 14])

In [None]:

if not os.path.exists('./gal_img1001'):
    os.mkdir('./gal_img1001')

    
dataset= trainlabelDataset()
dataloader= DataLoader(dataset=dataset, batch_size=64,shuffle=True)

test_dataset = testlabelDataset()
test_dataloader= DataLoader(dataset=test_dataset, batch_size=64,shuffle=True)


writer = SummaryWriter("run1001/exp350",)  ################################################### change name 

num_epochs =20000
batch_size = 64
learning_rate = 5e-1




model = vae_501().cuda()   ############################################################## AE model 

reconstruction_function = nn.L1Loss()

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    BCE = reconstruction_function(recon_x, x)  # mse loss
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return BCE + KLD

#scheduler 
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1000,4000], gamma=0.1)




for epoch in range(num_epochs):
    total_loss = 0.0
    total_mse = 0.0 
    num_examples = 0.0
    test_num_examples=0.0
    
    
    test_total_loss = 0.0    
    test_total_mse=0.0
    
    model.train()
    for data in dataloader:
        img,label= [x.type(torch.float32).cuda() for x in data]
        img = img.view(img.size(0), 1,64,64)

       # print(img.shape)
        # forward
        
        recon_batch, mu, logvar = model(img)
        
        loss = loss_function(recon_batch, img, mu, logvar)

        
        
    #    MSE_loss = nn.MSELoss()(output, img)
        batch_size = img.size(0)
        total_loss += loss.item() * batch_size
     #   total_mse += MSE_loss.item() * batch_size
        num_examples += batch_size

        
        optimizer.zero_grad()
    # backward
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    model.eval()
    for data in test_dataloader:
        test_img,test_label= [x.type(torch.float32).cuda() for x in data]

        test_img = test_img.view(test_img.size(0), 1,64,64)
       # print(img.shape)
        

        # forward
       # test_output = model(test_img)
       # test_z=test_z.view(test_z.size(0),14*14)
                
       #print("output ",output.shape)
     #   test_loss = criterion(test_output, test_img)  #  + 1e-5*  criterion(test_z[:,:7], test_label) 
        
        
    #    test_MSE_loss = nn.MSELoss()(test_output, test_img)
    #    batch_size = test_img.size(0)
    #    test_total_loss += test_loss.item() * batch_size
    #    test_total_mse += test_MSE_loss.item() * batch_size
    #    test_num_examples += batch_size

        
 #  writer.add_scalar('Loss/train',total_loss / num_examples,epoch)
 #   writer.add_scalar('Mse/train', total_mse / num_examples,epoch)        
 #   writer.add_scalar('Loss/test',test_total_loss / test_num_examples,epoch)
 #   writer.add_scalar('Mse/test', test_total_mse / test_num_examples,epoch)
    
    print("hah")
 #   print('epoch [{}/{}], loss:{:.4f}, MSE_loss:'
  #        .format(epoch + 1, num_epochs, total_loss / num_examples))  
 #   
    
'''

    print(' epoch [{}/{}],test_loss:{:.4f}, test_MSE_loss:{:.4f}'
          .format(epoch + 1, num_epochs, test_total_loss / test_num_examples, test_total_mse/ test_num_examples))

    if epoch % 10 == 0:
        x = to_img(img.cpu().data)
        x_hat = to_img(output.cpu().data)
        test_x = to_img(test_img.cpu().data)    ########## change name 
        test_x_hat = to_img(test_output.cpu().data)
        torch.save(x, './gal_img1001/exp350_x_{}.pt'.format(epoch))
        torch.save(x_hat, './gal_img1001/exp350_x_hat_{}.pt'.format(epoch))
        torch.save(test_x, './gal_img1001/exp350_test_x_{}.pt'.format(epoch))
        torch.save(test_x_hat, './gal_img1001/exp350_test_x_hat_{}.pt'.format(epoch))
        torch.save(model.state_dict(), './gal_img1001/exp350_{}.pth'.format(epoch))       
           
           
''' 
    



hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah
hah


In [None]:
a=np.zeros((2,2))


In [None]:
1+1

In [61]:
a.sum(axis=0)

array([3., 3.])