In [1]:
 
from __future__ import print_function

import os
import sys
import glob
import h5py
import numpy as np
import math


import torch
from torch import nn
from torch.autograd import Variable
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset , DataLoader
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

import tensorboard
import tensorboardX
from torch.utils.tensorboard import SummaryWriter


from log import Logger
#from data import   trainlabelDataset_reduced,testlabelDataset_reduced
from util import r2, mse, rmse, mae, pp_mse, pp_rmse, pp_mae
#from model import autoencoder_999, autoencoder_333_2,autoencoder_1015

def to_img(x):   # image size 
    x = x.view(x.size(0), 1, 64, 64)
    return x


if not os.path.exists('./gal_img1107'):
    os.mkdir('./gal_img1107')
    

if not os.path.exists('./run1107'):
    os.mkdir('./run1107')    


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
class trainlabelDataset_cae(Dataset):
    'Characterizes a dataset for PyTorch'

    def __init__(self):
                
        f = h5py.File('train.h5','r')
        image = f['img'][:]/(1.5e4)    # max-min normalize the image 
        gal_flux = f['gal_flux'][:]/(1.5e4)  
        bulge_re = f['bulge_re'][:]
        bulge_n = f['bulge_n'][:]
        gal_q = f['gal_q'][:]
        gal_beta = f['gal_beta'][:]
        f.close()
        
        g1= (1-gal_q )/(1+gal_q) *np.cos(2 * gal_beta) +1
        g2= (1-gal_q )/(1+gal_q) *np.sin(2 * gal_beta) +1
        
        image.astype('float32')
        gal_flux.astype('float32')
        bulge_re.astype('float32')
        bulge_n.astype('float32')
        g1.astype('float32') 
        g2.astype('float32')
        
        self.len = image.shape[0]
        self.image= torch.from_numpy(image[:])
        self.gal_flux = (torch.from_numpy(gal_flux[:]))
        self.bulge_re = torch.from_numpy(bulge_re[:])
        self.bulge_n = torch.from_numpy(bulge_n[:])

        self.g1 = torch.from_numpy(g1[:])
        self.g2 = torch.from_numpy(g2[:])
    def __len__(self):
        return self.len
    
    
    def __getitem__(self, index):
        return self.image[index], np.asarray([self.gal_flux[index],self.bulge_n[index],self.bulge_re[index],self.g1[index],self.g2[index]])
    

    
    
class testlabelDataset_cae(Dataset):
    'Characterizes a dataset for PyTorch'

    def __init__(self):
                
        f= h5py.File('test.h5','r')
        image = f['img'][:]/(1.5e4)    # max-min normalize the image 
        gal_flux = f['gal_flux'][:]/(1.5e4)  
        bulge_re = f['bulge_re'][:]
        bulge_n = f['bulge_n'][:]
        gal_q = f['gal_q'][:]
        gal_beta = f['gal_beta'][:]
        f.close()
        
        g1= (1-gal_q )/(1+gal_q) *np.cos(2 * gal_beta) +1
        g2= (1-gal_q )/(1+gal_q) *np.sin(2 * gal_beta) +1
        
        image.astype('float32')
        gal_flux.astype('float32')
        bulge_re.astype('float32')
        bulge_n.astype('float32')
        g1.astype('float32') 
        g2.astype('float32')
        
        self.len = image.shape[0]
        self.image= torch.from_numpy(image[:])
        self.gal_flux = (torch.from_numpy(gal_flux[:]))
        self.bulge_re = torch.from_numpy(bulge_re[:])
        self.bulge_n = torch.from_numpy(bulge_n[:])

        self.g1 = torch.from_numpy(g1[:])
        self.g2 = torch.from_numpy(g2[:])
    def __len__(self):
        return self.len
    
    
    def __getitem__(self, index):
        return self.image[index], np.asarray([self.gal_flux[index],self.bulge_n[index],self.bulge_re[index],self.g1[index],self.g2[index]])
    






In [3]:
class autoencoder_1015_con(nn.Module):   # 
    def __init__(self):            #  1x 64 x 64 
        
        super(autoencoder_1015_con, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=1, padding=1),  # 64 * 64 * 64  
            nn.BatchNorm2d(64), 
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),   # 64 * 31 * 31   (31 -2)/2 +1 
            
            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # 128 *16 * 16 
            nn.ReLU(True),
            nn.BatchNorm2d(128), 
            nn.MaxPool2d(2, stride=1),  # 128 * 15 * 15 
            
            nn.Conv2d(128, 64, 3, stride=1, padding=1),  # b, 64 * 15 * 15 
            nn.ReLU(True),
            nn.BatchNorm2d(64), 
            nn.MaxPool2d(2, stride=1),  # b, 64, 14, 14 
            
            nn.Conv2d(64, 1, 3, stride=1, padding=1),  # b, 1  x 14, 14 
        #    nn.Linear(14*14, 14*14)

            
        #    nn.BatchNorm2d(1), 
         #   nn.ReLU(True)
        )
        
        self.lin_1= nn.Linear(14*14+1, 14*14+1)
        self.lin_2= nn.Linear(14*14+1, 14*14)


        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1, 64, 3, stride=2),  # b, 64,  
            nn.BatchNorm2d(64), 
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 128, 3, stride=2),  # b, 64, 55, 55 
            nn.BatchNorm2d(128), 
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 3, stride=1),  # 128, 
            nn.BatchNorm2d(64), 
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, stride=1),  # b, 64,  

        )

    def forward(self, x):
        z_1 = self.encoder(x)
        # add the average pixel value as the last neuron in the fully connected layer 
        #img 64x64, normalize by 1.5e6
        avg=  x.sum(dim=3).sum(dim=2).sum(dim=1)/64/64/1.5e6
        
        trained_=z_1.view(z_1.size(0),14*14)
        avg_flux= avg.view(x.size(0),1)
    
        z_2= torch.cat((trained_,avg_flux),dim=1)
                       
        z = self.lin_1(z_2)
        
        z_3 = self.lin_2(z)
        
        x=self.decoder(z_3.view(z.size(0),1,14,14))
        return x,z   
    


In [4]:
dataset= trainlabelDataset_cae()
dataloader= DataLoader(dataset=dataset, batch_size=64,shuffle=True,drop_last=True,)

test_dataset = testlabelDataset_cae()
test_dataloader= DataLoader(dataset=test_dataset, batch_size=64,shuffle=True,drop_last=True)




writer = SummaryWriter("run1107/rec400_1")  ################################################### change name 

num_epochs =40000
batch_size = 64
learning_rate = 1e-2
#weight=1e3
reg=1e-3



model = autoencoder_1015_con().cuda()   ############################################################## AE model 
#model.load_state_dict(torch.load('gal_img1001/rec200_6_3_36500.pth'))    ###

criterion_mean = nn.L1Loss(reduction='mean')
criterion_none = nn.L1Loss(reduction='none')
criterion_none_mse = nn.MSELoss(reduction='mean')



#scheduler 
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[500,3000,10000,20000,30000], gamma=0.1)



In [None]:






for epoch in range(num_epochs):
    total_loss = 0.0
    total_mse = 0.0 
    total_recon=0.0
    total_latent=0.0
    
    num_examples = 0.0
    test_num_examples=0.0
    
    
    test_total_loss = 0.0    
    test_total_mse=0.0
    test_total_recon=0.0
    test_total_latent=0.0
    model.train()
    for data in dataloader:
        img,label= [x.type(torch.float32).cuda() for x in data]
        img = img.view(img.size(0), 1,64,64)

       # print(img.shape)
       # print("",img.sum())
       # print("",img[0].sum())
        # forward
        output, z = model(img)
        z=z.view(z.size(0),14*14+1)
       # print("output ",output.shape)
       # print("z ",z.shape)
    ################################################## Loss function with regularizing Z ########################
        flux_for_scaling=img.sum(dim=3).sum(dim=2).sum(dim=1)
        
        loss=  (criterion_none(output, img).sum(dim=3).sum(dim=2).sum(dim=1)/ flux_for_scaling).mean() +  ( criterion_none(z[:,:5], label).sum(dim=1)* flux_for_scaling).mean()  +  reg*torch.norm(z[:,5:],p=1)

        loss_recon=(criterion_none(output, img).sum(dim=3).sum(dim=2).sum(dim=1)/ flux_for_scaling).mean()
        loss_latent=(criterion_none(z[:,:5], label)).mean() 
        
        
        MSE_loss = nn.MSELoss()(output, img)
        batch_size = img.size(0)
        total_loss += loss.item() * batch_size
        total_mse += MSE_loss.item() * batch_size
        total_recon+= loss_recon.item() * batch_size
        total_latent+= loss_latent.item() * batch_size

        num_examples += batch_size

        
        optimizer.zero_grad()
    # backward
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    if epoch % 10 == 0:
        torch.save(model.state_dict(), './gal_img1107/rec400_1_train_{}.pth'.format(epoch))     
        
    model.eval()
    for data in test_dataloader:
        test_img,test_label= [x.type(torch.float32).cuda() for x in data]

        test_img = test_img.view(test_img.size(0), 1,64,64)
       # print(img.shape)
        

        # forward
        test_output,test_z = model(test_img)
        test_z=test_z.view(test_z.size(0),14*14+1)
        test_flux_for_scaling= test_img.sum(dim=3).sum(dim=2).sum(dim=1)
                
       #print("output ",output.shape)
       # test_loss = criterion(test_output, test_img) + criterion(z[:,:7], label) /(1e)  #  + 1e-5*  criterion(test_z[:,:7], test_label) 
        test_loss=  (criterion_none(test_output, test_img).sum(dim=3).sum(dim=2).sum(dim=1)/ test_flux_for_scaling).mean() +  ( criterion_none(test_z[:,:5], test_label).sum(dim=1)* test_flux_for_scaling).mean()   +  reg*torch.norm(test_z[:,5:],p=1)


        test_loss_recon=(criterion_none(test_output, test_img).sum(dim=3).sum(dim=2).sum(dim=1)/ test_flux_for_scaling).mean()
        test_loss_latent= (criterion_none(test_z[:,:5], test_label)).mean()  
        test_MSE_loss = nn.MSELoss()(test_output, test_img)
        batch_size = test_img.size(0)
        test_total_loss += test_loss.item() * batch_size
        test_total_mse += test_MSE_loss.item() * batch_size
        test_total_recon+= test_loss_recon.item() * batch_size
        test_total_latent+= test_loss_latent.item() * batch_size


        test_num_examples += batch_size
        
        #print("haha")

    writer.add_scalar('Loss/train',total_loss / num_examples,epoch)
    writer.add_scalar('Mse/train', total_mse / num_examples,epoch)   
    writer.add_scalar('Recon/train', total_recon / num_examples,epoch)        
    writer.add_scalar('Latent/train', total_latent / num_examples,epoch)        
    writer.add_scalar('Loss/test',test_total_loss / test_num_examples,epoch)
    writer.add_scalar('Mse/test', test_total_mse / test_num_examples,epoch)
    writer.add_scalar('Recon/test', test_total_recon / test_num_examples,epoch)        
    writer.add_scalar('Latent/test', test_total_latent / test_num_examples,epoch)
    
    
    if epoch % 10 == 0:
        x = to_img(img.cpu().data)
        x_hat = to_img(output.cpu().data)
        test_x = to_img(test_img.cpu().data)    ########## change name 
        test_x_hat = to_img(test_output.cpu().data)
        torch.save(x, './gal_img1107/rec400_1_x_{}.pt'.format(epoch))
        torch.save(x_hat, './gal_img1107/rec400_1_x_hat_{}.pt'.format(epoch))
        torch.save(test_x, './gal_img1107/rec400_1_test_x_{}.pt'.format(epoch))
        torch.save(test_x_hat, './gal_img1107/rec400_1_test_x_hat_{}.pt'.format(epoch))
        torch.save(model.state_dict(), './gal_img1107/rec400_1_{}.pth'.format(epoch))     

