### De-Amortize

In [None]:
#!/usr/bin/env python
# coding: utf-8

# ### De-Amortize

# In[1]:


import numpy as np
import scipy
import matplotlib.pyplot as plt
import os
import logging
import sys
from sklearn import preprocessing

import torch
import torch.nn as nn
import torch.nn.functional as F


import random
import numpy as np
import pickle
import time


# In[2]:


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# In[3]:


SMALL_SIZE =  14
MEDIUM_SIZE = 16
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


# In[4]:


RUN             = '1'

seeds           = {'1':5123, '2':879, '3':9981, '4': 20075, '5': 66, '6': 276, '7': 936664}

conditional     = False
cond_on         = 'type'

root_model_data = '/global/cscratch1/sd/vboehm/Datasets/sdss/by_model/'
root_models     = '/global/cscratch1/sd/vboehm/Models/SDSS_AE/'
root_encoded    = '/global/cscratch1/sd/vboehm/Datasets/encoded/sdss/'
root_decoded    = '/global/cscratch1/sd/vboehm/Datasets/decoded/sdss/'


wlmin, wlmax    = (3388,8318)
fixed_num_bins  = 1000
min_SN          = 50
min_z           = 0.1
max_z           = 0.36
label           = 'galaxies_quasars_bins%d_wl%d-%d'%(fixed_num_bins,wlmin,wlmax)
label_          = label+'_minz%s_maxz%s_minSN%d'%(str(int(min_z*100)).zfill(3),str(int(max_z*100)).zfill(3),min_SN)
label_2         = label_+'_10_fully_connected_mean_div'

plotpath        = '/global/homes/v/vboehm/codes/SDSS_PAE/figures'


if conditional:
    label_2='conditional_%s'%cond_on+label_2
    
upsampling      = 'SMOTE'
fac             = 10


load_path = '/global/homes/v/vboehm/codes/SDSS_PAE/notebooks/'

_,_,train_, le = pickle.load(open(os.path.join(root_model_data,'combined_%s.pkl'%label_),'rb'))
_,_,encoded_train = np.load(os.path.join(root_encoded,'encoded_%s_RUN%s.npy'%(label_2,RUN)), allow_pickle=True)
_,_,train, mean, std = np.load(os.path.join(root_decoded,'decoded_%s_RUN%s.npy'%(label_2,RUN)), allow_pickle=True)



enc_weights = np.load(os.path.join(load_path,'encoder_weights.npy'), allow_pickle=True)
dec_weights = np.load(os.path.join(load_path,'decoder_weights.npy'), allow_pickle=True)


# In[7]:


for ii in range(len(enc_weights)):
    print(enc_weights[ii].shape)


# In[8]:


latent_dim      = 10
lr_final        = 1.3e-05
lr_init         = 7e-4
out_features    = [1000,590]


# In[9]:


def hidden_init(layer):
    fan_in = layer.weight.data.size()[0]
    lim = 1. / np.sqrt(fan_in)
    return (-lim, lim)  


# In[10]:


class Encoder(nn.Module):

    def __init__(self, seed):
        super(Encoder, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(fixed_num_bins, out_features[0])
        self.fc2 = nn.Linear(out_features[0],out_features[1])
        self.fc3 = nn.Linear(out_features[1], latent_dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.fc1.weight.data.uniform_(*hidden_init(self.fc1))
        self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
        self.fc3.weight.data.uniform_(*hidden_init(self.fc3))

    def forward(self, state):
        x = torch.nn.LeakyReLU(negative_slope=0.3)(self.fc1(state))
        x = torch.nn.LeakyReLU(negative_slope=0.3)(self.fc2(x))
        x = self.fc3(x)
        return x

encoder=Encoder(time.time()).to(device)


# In[11]:


encoder.fc1.weight.data=torch.from_numpy(np.transpose(enc_weights[0])).to(device)
encoder.fc1.bias.data=torch.from_numpy(enc_weights[1]).to(device)
encoder.fc2.weight.data=torch.from_numpy(np.transpose(enc_weights[2])).to(device)
encoder.fc2.bias.data=torch.from_numpy(enc_weights[3]).to(device)
encoder.fc3.weight.data=torch.from_numpy(np.transpose(enc_weights[4])).to(device)
encoder.fc3.bias.data=torch.from_numpy(enc_weights[5]).to(device)


# In[12]:


pred = encoder.forward(torch.from_numpy(train_['spec'][:,:,0]).float().to(device))


# In[13]:


pred.device


# In[14]:


class Decoder(nn.Module):
    """Decoder"""

    def __init__(self, seed):

        super(Decoder, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(latent_dim, out_features[1])
        self.fc2 = nn.Linear(out_features[1],out_features[0])
        self.fc3 = nn.Linear(out_features[0],fixed_num_bins)
        self.reset_parameters()

    def reset_parameters(self):
        self.fc1.weight.data.uniform_(*hidden_init(self.fc1))
        self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
        self.fc3.weight.data.uniform_(*hidden_init(self.fc3))

    def forward(self, state):
        """Build an actor (policy) network that maps states -> actions."""
        x = torch.nn.LeakyReLU(negative_slope=0.3)(self.fc1(state))
        x = torch.nn.LeakyReLU(negative_slope=0.3)(self.fc2(x))
        x = self.fc3(x)
        return x

decoder=Decoder(time.time()).to(device)


# In[15]:


decoder.fc1.weight.data=torch.from_numpy(np.transpose(dec_weights[0])).to(device)
decoder.fc1.bias.data=torch.from_numpy(dec_weights[1]).to(device)
decoder.fc2.weight.data=torch.from_numpy(np.transpose(dec_weights[2])).to(device)
decoder.fc2.bias.data=torch.from_numpy(dec_weights[3]).to(device)
decoder.fc3.weight.data=torch.from_numpy(np.transpose(dec_weights[4])).to(device)
decoder.fc3.bias.data=torch.from_numpy(dec_weights[5]).to(device)


# In[79]:


for param in decoder.parameters():
    param.requires_grad = False
    
for param in encoder.parameters():
    param.requires_grad = False


# In[80]:


dec = decoder.forward(pred)


# In[81]:


z_init = torch.Tensor(np.squeeze(encoded_train)).to(device)


# In[82]:


z = torch.autograd.Variable(torch.Tensor(np.squeeze(encoded_train)).to(device),requires_grad=True).to(device)




def neg_log_posterior(z,x,noise,mask,decoder, dens, class_, prior=False):
    pred  = decoder.forward(z)
    ll    = torch.sum(0.5*(pred-x)*(pred-x)*noise*mask,axis=1)
    if prior:
        ll = ll -dens(z, class_)
    return torch.mean(ll,axis=0)


sys.path.append('/global/u2/v/vboehm/codes/SIG_GIS/')
from sig_gis import *
from sig_gis.GIS import *


model = torch.load(os.path.join(root_models,'conditional_SINF_%s_%s_%d_AE1'%(label_2,upsampling,fac)))
model = model.to(device)

for param in model.parameters():
    param.requires_grad = False


optim = torch.optim.Adam([z], lr=1e-2)


num_iter = 2#int(len(z)/500)
resid    = len(z)%500
print(resid)

a = time.time()
neg_logs =[]
zs = []
#stepping through dataset
for jj in range(2):
    # minimization steps
    for ii in range(70):
        if ii>40:
            for param_group in optim.param_groups:
                param_group['lr'] = 1e-3
        if jj<num_iter:
            if 0<ii<30:
                neg_log = neg_log_posterior(z[jj*500:(jj+1)*500],torch.from_numpy(np.squeeze(train_['spec'])[jj*500:(jj+1)*500]).to(device),torch.from_numpy(np.squeeze(train_['noise'])[jj*500:(jj+1)*500]).to(device),torch.from_numpy(np.squeeze(train_['mask'])[jj*500:(jj+1)*500]).to(device),decoder,model.evaluate_density, torch.from_numpy(train_['subclass'][jj*500:(jj+1)*500]).to(device),prior=False)
            else:
                neg_log = neg_log_posterior(z[jj*500:(jj+1)*500],torch.from_numpy(np.squeeze(train_['spec'])[jj*500:(jj+1)*500]).to(device),torch.from_numpy(np.squeeze(train_['noise'])[jj*500:(jj+1)*500]).to(device),torch.from_numpy(np.squeeze(train_['mask'])[jj*500:(jj+1)*500]).to(device),decoder,model.evaluate_density, torch.from_numpy(train_['subclass'][jj*500:(jj+1)*500]).to(device), prior=True)

        else:
            if 0<ii<30:
                neg_log = neg_log_posterior(z[jj*500::],torch.from_numpy(np.squeeze(train_['spec'])[jj*500::]).to(device),torch.from_numpy(np.squeeze(train_['noise'])[jj*500::]).to(device),torch.from_numpy(np.squeeze(train_['mask'])[jj*500::]).to(device),decoder,model.evaluate_density, torch.from_numpy(train_['subclass'][jj*500::]).to(device),prior=False)
            else:
                neg_log = neg_log_posterior(z[jj*500::],torch.from_numpy(np.squeeze(train_['spec'])[jj*500::]).to(device),torch.from_numpy(np.squeeze(train_['noise'])[jj*500::]).to(device),torch.from_numpy(np.squeeze(train_['mask'])[jj*500::]).to(device),decoder,model.evaluate_density, torch.from_numpy(train_['subclass'][jj*500::]).to(device), prior=True)
        optim.zero_grad()
        neg_log.backward()
        optim.step()
        neg_logs.append(neg_log.cpu().detach().numpy())
    zs.append(z[jj*500:(jj+1)*500].cpu().detach().numpy())
end = time.time()


zs = np.reshape(np.asarray(zs),(-1,10))
#np.save(os.path.join(root_encoded,'encoded_MAP_test_%s_RUN%s.npy'%(label_2,RUN)),zs)

print((end -a)/60/60)

In [1]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
import os
import logging
import sys
from sklearn import preprocessing

import torch
import torch.nn as nn
import torch.nn.functional as F


import random
import numpy as np
import pickle
import time

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
SMALL_SIZE =  14
MEDIUM_SIZE = 16
BIGGER_SIZE = 18

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


In [4]:
RUN             = '1'

seeds           = {'1':5123, '2':879, '3':9981, '4': 20075, '5': 66, '6': 276, '7': 936664}

conditional     = False
cond_on         = 'type'

root_model_data = '/global/cscratch1/sd/vboehm/Datasets/sdss/by_model/'
root_models     = '/global/cscratch1/sd/vboehm/Models/SDSS_AE/'
root_encoded    = '/global/cscratch1/sd/vboehm/Datasets/encoded/sdss/'
root_decoded    = '/global/cscratch1/sd/vboehm/Datasets/decoded/sdss/'


wlmin, wlmax    = (3388,8318)
fixed_num_bins  = 1000
min_SN          = 50
min_z           = 0.1
max_z           = 0.36
label           = 'galaxies_quasars_bins%d_wl%d-%d'%(fixed_num_bins,wlmin,wlmax)
label_          = label+'_minz%s_maxz%s_minSN%d'%(str(int(min_z*100)).zfill(3),str(int(max_z*100)).zfill(3),min_SN)
label_2         = label_+'_10_fully_connected_mean_div'

plotpath        = '/global/homes/v/vboehm/codes/SDSS_PAE/figures'


if conditional:
    label_2='conditional_%s'%cond_on+label_2
    
upsampling      = 'SMOTE'
fac             = 10

In [5]:
train_,valid_,test_,le = pickle.load(open(os.path.join(root_model_data,'combined_%s.pkl'%label_),'rb'))
encoded_train, encoded_valid, encoded_test = np.load(os.path.join(root_encoded,'encoded_%s_RUN%s.npy'%(label_2,RUN)), allow_pickle=True)
train, valid, test, mean, std = np.load(os.path.join(root_decoded,'decoded_%s_RUN%s.npy'%(label_2,RUN)), allow_pickle=True)



In [6]:
enc_weights = np.load('encoder_weights.npy', allow_pickle=True)
dec_weights = np.load('decoder_weights.npy', allow_pickle=True)

In [7]:
for ii in range(len(enc_weights)):
    print(enc_weights[ii].shape)


(1000, 1000)
(1000,)
(1000, 590)
(590,)
(590, 10)
(10,)


In [8]:
latent_dim      = 10
lr_final        = 1.3e-05
lr_init         = 7e-4
out_features    = [1000,590]

In [9]:
def hidden_init(layer):
    fan_in = layer.weight.data.size()[0]
    lim = 1. / np.sqrt(fan_in)
    return (-lim, lim)  

In [10]:
class Encoder(nn.Module):

    def __init__(self, seed):
        super(Encoder, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(fixed_num_bins, out_features[0])
        self.fc2 = nn.Linear(out_features[0],out_features[1])
        self.fc3 = nn.Linear(out_features[1], latent_dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.fc1.weight.data.uniform_(*hidden_init(self.fc1))
        self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
        self.fc3.weight.data.uniform_(*hidden_init(self.fc3))

    def forward(self, state):
        x = torch.nn.LeakyReLU(negative_slope=0.3)(self.fc1(state))
        x = torch.nn.LeakyReLU(negative_slope=0.3)(self.fc2(x))
        x = self.fc3(x)
        return x

encoder=Encoder(time.time()).to(device)

In [11]:
encoder.fc1.weight.data=torch.from_numpy(np.transpose(enc_weights[0])).to(device)
encoder.fc1.bias.data=torch.from_numpy(enc_weights[1]).to(device)
encoder.fc2.weight.data=torch.from_numpy(np.transpose(enc_weights[2])).to(device)
encoder.fc2.bias.data=torch.from_numpy(enc_weights[3]).to(device)
encoder.fc3.weight.data=torch.from_numpy(np.transpose(enc_weights[4])).to(device)
encoder.fc3.bias.data=torch.from_numpy(enc_weights[5]).to(device)

In [12]:
pred = encoder.forward(torch.from_numpy(train_['spec'][:,:,0]).float().to(device))

In [13]:
pred.device

device(type='cuda', index=0)

In [14]:
class Decoder(nn.Module):
    """Decoder"""

    def __init__(self, seed):

        super(Decoder, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.fc1 = nn.Linear(latent_dim, out_features[1])
        self.fc2 = nn.Linear(out_features[1],out_features[0])
        self.fc3 = nn.Linear(out_features[0],fixed_num_bins)
        self.reset_parameters()

    def reset_parameters(self):
        self.fc1.weight.data.uniform_(*hidden_init(self.fc1))
        self.fc2.weight.data.uniform_(*hidden_init(self.fc2))
        self.fc3.weight.data.uniform_(*hidden_init(self.fc3))

    def forward(self, state):
        """Build an actor (policy) network that maps states -> actions."""
        x = torch.nn.LeakyReLU(negative_slope=0.3)(self.fc1(state))
        x = torch.nn.LeakyReLU(negative_slope=0.3)(self.fc2(x))
        x = self.fc3(x)
        return x

decoder=Decoder(time.time()).to(device)

In [15]:
decoder.fc1.weight.data=torch.from_numpy(np.transpose(dec_weights[0])).to(device)
decoder.fc1.bias.data=torch.from_numpy(dec_weights[1]).to(device)
decoder.fc2.weight.data=torch.from_numpy(np.transpose(dec_weights[2])).to(device)
decoder.fc2.bias.data=torch.from_numpy(dec_weights[3]).to(device)
decoder.fc3.weight.data=torch.from_numpy(np.transpose(dec_weights[4])).to(device)
decoder.fc3.bias.data=torch.from_numpy(dec_weights[5]).to(device)

In [79]:
for param in decoder.parameters():
    param.requires_grad = False
    
for param in encoder.parameters():
    param.requires_grad = False

In [80]:
dec = decoder.forward(pred)

In [81]:
z_init = torch.Tensor(np.squeeze(encoded_train)).to(device)

In [82]:
z = torch.autograd.Variable(torch.Tensor(np.squeeze(encoded_train)).to(device),requires_grad=True).to(device)

In [97]:
def neg_log_posterior(z,x,noise,mask,decoder, dens, class_, prior=False):
    pred  = decoder.forward(z)
    ll    = torch.sum(0.5*(pred-x)*(pred-x)*noise*mask,axis=1)
    if prior:
        ll = ll -dens(z, class_)
    return torch.mean(ll,axis=0)

In [98]:
sys.path.append('/global/u2/v/vboehm/codes/SIG_GIS/')
from sig_gis import *
from sig_gis.GIS import *

In [99]:
model = torch.load(os.path.join(root_models,'conditional_SINF_%s_%s_%d_AE1'%(label_2,upsampling,fac)))
model = model.to(device)

for param in model.parameters():
    param.requires_grad = False

In [100]:
optim = torch.optim.Adam([z], lr=1e-2)

In [102]:
a = time.time()
neg_logs =[]
zs = []
for jj in range(1):
    print(jj)
    for ii in range(60):
        if ii>40:
            for param_group in optim.param_groups:
                param_group['lr'] = 1e-3
        if ii%10==0:
            print(ii)
        if jj<538:
            if 0<ii<30:
                neg_log = neg_log_posterior(z[jj*500:(jj+1)*500],torch.from_numpy(np.squeeze(train_['spec'])[jj*500:(jj+1)*500]).to(device),torch.from_numpy(np.squeeze(train_['noise'])[jj*500:(jj+1)*500]).to(device),torch.from_numpy(np.squeeze(train_['mask'])[jj*500:(jj+1)*500]).to(device),decoder,model.evaluate_density, torch.from_numpy(train_['subclass'][jj*500:(jj+1)*500]).to(device),prior=False)
            else:
                neg_log = neg_log_posterior(z[jj*500:(jj+1)*500],torch.from_numpy(np.squeeze(train_['spec'])[jj*500:(jj+1)*500]).to(device),torch.from_numpy(np.squeeze(train_['noise'])[jj*500:(jj+1)*500]).to(device),torch.from_numpy(np.squeeze(train_['mask'])[jj*500:(jj+1)*500]).to(device),decoder,model.evaluate_density, torch.from_numpy(train_['subclass'][jj*500:(jj+1)*500]).to(device), prior=True)

        else:
            neg_log = neg_log_posterior(z[jj*500::],torch.from_numpy(np.squeeze(train_['spec'])[jj*500::]).to(device),torch.from_numpy(np.squeeze(train_['noise'])[jj*500::]).to(device),torch.from_numpy(np.squeeze(train_['mask'])[jj*500::]).to(device),decoder,model.evaluate_density, torch.from_numpy(train_['subclass'][jj*500::]).to(device))            
        optim.zero_grad()
        neg_log.backward()
        optim.step()
        neg_logs.append(neg_log.cpu().detach().numpy())
        print(neg_logs[-1])
    zs.append(z.cpu().detach().numpy())
end = time.time()

0
0
568.8936066867742
563.0592233470292
561.8220430475985
561.235402184114
560.6292078741234
560.0169584898983
559.4918379761167
559.0636817957812
558.7036932938506
558.3770510786653
10
558.0616825293624
557.7584776637384
557.4789212125698
557.2329226686647
557.0314399435201
556.8562380323206
556.6688485786419
556.4706651450545
556.2979158244719
556.1617649585437
20
556.0384016660194
555.9073609174055
555.7706716981852
555.644911568783
555.5369902268005
555.437177779751
555.3373097646258
555.2402599572752
555.1495969699603
555.0644880946152
30
563.2744658691438
563.1917897257603
563.0739602555844
562.9346819797395
562.7860590271636
562.6381943229947
562.4961113513822
562.358535157655
562.2243088516549
562.09909230735
40
561.983274700259
561.874405362947
561.8629659336483
561.85027310221
561.8367102966263
561.8226554638816
561.8083418122113
561.7939550345936
561.7797074979188
561.7657930553144
50
561.7520815082769
561.7384156384447
561.7247235416542
561.7110781438056
561.6974370160831
5

In [103]:
zs = np.reshape(np.asarray(zs),(-1,10))

In [104]:
np.save(os.path.join(root_encoded,'encoded_MAP_train_%s_RUN%s.npy'%(label_2,RUN)),zs)


In [105]:
(end -a)/60/60*len(train_['spec'])/500

4.48136737956312