# Rank-1 Introduction

# Import packages

### Import python packages

In [1]:
import re,os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython.display import clear_output
import pandas as pd
import math
import random
from tqdm import trange
import matplotlib.colors as mcolors
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
tfk = tf.keras
tfkl = tf.keras.layers
tfkltd= tf.keras.layers.TimeDistributed
clear_output()
os.environ["CUDA_VISIBLE_DEVICES"]="7"

2022-08-09 23:33:16.398670: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.

KeyboardInterrupt



### Import custom packages

In [None]:
from generate import generate
from plot_training import plot_training

# Generate the Dataset

### Set parameters

In [None]:
n1 = 10
n2 = 100
nx = 50
nt = 20
p = 1
q = 10
states = np.arange(n1)
nuisances = np.arange(n2)*n1
#states = np.array([3,5,7,11,19,29,61,137,313,503])
#nuisances = np.array([1,2,4,8,16,32,64,128,256,512])
path = './p=%d,q=%d'%(p,q)

### Generate X

In [None]:
data = generate(states,nuisances,nx,nt,replace=0)
X = data.X
X_states = data.X_states
X_nuisances = data.X_nuisances

In [None]:
plt.matshow(data.D.T)

### What blocks in D is selected by X?

In [None]:
waiting_samples = data.waiting_samples
D_selected = np.ones((n1,n2))
for i in range(n1):
    for j in range(n2):
        D_selected[i][waiting_samples[i]] = 0    
plt.figure(figsize=(5,5),dpi=100)
plt.matshow(data.selected_times.T, fignum=1)
plt.ylabel('nuisances');plt.xlabel('states')
plt.xticks(range(n1));plt.yticks(range(n2))
plt.gca().set_xticks([x - 0.5 for x in plt.gca().get_xticks()][1:], minor='true')
plt.gca().set_yticks([y - 0.5 for y in plt.gca().get_yticks()][1:], minor='true')
plt.grid(which='minor')
plt.title('How many times has each block in D been selected')
plt.colorbar()
plt.show()

# Set SymAE

### Set layers

In [None]:
class SymmetricEncoderDense0D(tf.keras.Model):
    def __init__(self, p, nt):
        super(SymmetricEncoderDense0D, self).__init__(name='sym_encoder')
        self.nt_out=p
        self.nt_in=nt
        self.d1=tfkl.Dense(p)
    def call(self, x, training=False):
        x=tf.math.reduce_mean(x,axis=1)
        x=self.d1(x)
        return x
# 你是在让 nuisance encoder 记住states
class NuisanceEncoderDense0D(tf.keras.Model):
    def __init__(self, q, nt):
        super(NuisanceEncoderDense0D, self).__init__(name='nui_encoder')
        self.d1=tfkl.Dense(10000, activation=tf.keras.layers.LeakyReLU())
        self.d2=tfkl.Dense(q)
    def call(self, x, training=False):
        x=self.d1(x)
        x=self.d2(x)
        return x
class DistributeZsym(tf.keras.Model):
    def __init__(self, ntau, nz0, nzi):
        super(DistributeZsym, self).__init__(name='dist')
        self.nz0=nz0
        self.nzi=nzi
        self.ntau=ntau
        self.ri=tfkl.Reshape(target_shape=(ntau,nzi))
        self.repeat=tfkl.RepeatVector(ntau)
    def call(self, z, training=False):
        z0,zi=tf.split(z,[self.nz0, self.ntau*self.nzi],axis=1)
        zi=self.ri(zi)
        z0=self.repeat(z0)
        out=tfkl.concatenate([z0, zi],axis=2)
        return out
class MixerDense0D(tf.keras.Model):
    def __init__(self, n_out=1, n_in=p+q):
        super(MixerDense0D, self).__init__(name='mixer')
        self.d1=tfkl.Dense(10000, activation=tf.keras.layers.LeakyReLU())
        self.d2=tfkl.Dense(1)
    def call(self, x, training=False):
        #x=self.d1(x)
        x=tf.math.reduce_sum(x, axis=-1, keepdims=True)
        return x

In [None]:
class LatentCat(tf.keras.Model):
    def __init__(self, alpha=1.0):
        super(LatentCat, self).__init__(name='latentcat')
        self.drop = tfkl.Dropout(alpha)
    def call(self, zsym, znuisance, training=False):
        znuisance = self.drop(znuisance, training=training)
        znuisance = tfkl.Flatten()(znuisance)
        z = tfkl.concatenate([zsym, znuisance])
        return z

### model

In [None]:
class SymAE(tf.keras.Model):
    def __init__(self, N, nt, p, q, dropout_rate): 
        super(SymAE, self).__init__()
        # Build symmetric encoder
        sym_encoder = SymmetricEncoderDense0D(p,nt)
        self.sym_encoder=sym_encoder
        # Build nuisance encoder
        nui_encoder = NuisanceEncoderDense0D(q,nt)
        self.nui_encoder = nui_encoder
        #Build latentcat
        latentcat = LatentCat(alpha=dropout_rate)
        self.latentcat = latentcat  
        # Build distribute in decoder
        distzsym = DistributeZsym(nt, p, q)
        self.distzsym = distzsym
        #Build mixer in decoder
        mixer = MixerDense0D(1,p+q)
        self.mixer = mixer
        # Build encoder
        encoder_input = tfk.Input(shape=(nt,1), dtype='float32', name='encoder_input')
        znuisance = nui_encoder(encoder_input)
        zsym = sym_encoder(encoder_input)
        encoder_output=latentcat(zsym,znuisance)
        encoder=tfk.Model(encoder_input, encoder_output, name="encoder")
        self.encoder=encoder
        # Build decoder
        decoder_input = tfk.Input(shape=(p+q*nt), name='latentcode')
        decoder_output=mixer(distzsym(decoder_input))
        decoder=tfk.Model(decoder_input,decoder_output, name="decoder") 
        self.decoder=decoder
    def call(self, input_tensor, training=False):
        return self.decoder(self.encoder(input_tensor))

### Initialize SymAE

In [None]:
model = SymAE(nx,nt,p,q,0.5)
clear_output()

In [None]:
try:
    model.load_weights(path+'/checkpoint')
    print("weight exists")
except:
    print("weight doesn't exist")

In [None]:
#model.latentcat.drop.rate = 0.0

### Select optimizer

In [None]:
Adam = tf.keras.optimizers.Adam(learning_rate=0.001,beta_1=0.9,beta_2=0.999,epsilon=1e-07)
SGD = tf.keras.optimizers.SGD(learning_rate=0.001,momentum=0.0,nesterov=False)
model.compile(loss='mse',optimizer=Adam)
clear_output()

In [None]:
mse = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001,beta_1=0.9,beta_2=0.999,epsilon=1e-07)

In [None]:
M=10000
epochs=range(M)
losses=[np.nan]*M
redata=[np.nan]*M
sample_size = 10
redatum_list = list(zip(np.random.randint(0,nx,sample_size),np.random.randint(0,nx,sample_size)))

### Plot

In [None]:
#tf.keras.utils.plot_model(model)

# Train

### From scratch

In [None]:
@tf.function
def redatum(X1, X2):
    return model.decoder(model.latentcat(model.sym_encoder(X1), model.nui_encoder(X2)))
@tf.function
def redatum_loss(X):
    s = []
    for (i1,i2) in redatum_list:
        X1 = X[i1:i1+1]
        X2 = X[i2:i2+1]
        s.append(mse(redatum(X1,X2), states[X_states[i1]]+nuisances[X_nuisances[i2,:]]))
    return sum(s)/sample_size
@tf.function
def reconstruction_loss(model, x, training=False):
    x_hat = model(x, training=True)
    return mse(x, x_hat)
@tf.function
def train_step(model, x, training=True):
    with tf.GradientTape() as tape:
        loss_rec = reconstruction_loss(model, x, training)
        loss = loss_rec
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss_rec

In [None]:
for epoch in range(10000):
    loss_rec = train_step(model, X)
    print("For epoch {:d}, reconstruction loss is {:f}."
          .format(epoch, loss_rec))
    losses[epoch]=loss_rec
    redata[epoch]=redatum_loss(X)
    clear_output(wait=True)

### Plot

In [None]:
# df1 = pd.DataFrame(tf.reshape(model.nui_encoder(X),[-1,q]), columns = ['latent nuisance'])
# df2 = pd.DataFrame(nuisances[X_nuisances.reshape(-1,1)],columns=['nuisance'])
# df3 = pd.DataFrame(np.repeat(states[X_states],nt,axis=0),columns=['state'])
# df4 = pd.DataFrame(np.zeros(df1.shape[0]),columns=['0'])
# df = pd.concat([df1,df2,df3,df4],axis=1)
# fig = px.scatter(df, x='latent nuisance', color='nuisance', y='0')
# fig.update_layout(title="For epoch {:d}, reconstruction loss is {:f}."
#       .format(epoch, loss_rec))
# fig.show()

### Loss-epoch graph

In [None]:
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(epochs,losses,'C0')
ax2.plot(epochs,redata,'C1')
ax1.set_xlabel('epoch')
ax1.set_ylabel('loss', color='C0')
ax2.set_ylabel('redatum loss', color='C1')
plt.savefig(path+'train.pdf', format='pdf', bbox_inches='tight')
plt.show()

### Save weights

In [None]:
#model.save_weights('./checkpoint/'+datetime.now().strftime("%B%d"))
model.save_weights(path+'/checkpoint')
print("weights saved")

# Visualize training loss

### Creat a dict mapping from subscripts of D to subscripts of X

This map is the inverse of X_states and X_nuisances

In [None]:
subscript_map = {}
for i in range(nx):
    for j in range(nt):
        subscript_map[(X_states[i],X_nuisances[i,j])] = (i,j)

### Plot D and SymAE(X)

In [None]:
D = data.D
X_hat = model.predict(X)[:,:,0]
X_converted_max = np.empty((n1,n2))
X_converted_max.fill(-np.inf)
X_converted_min = np.empty((n1,n2))
X_converted_min.fill(np.inf)
for i in range(nx):
    for j in range(nt):
        i_D = X_states[i]
        j_D = X_nuisances[i,j]
        X_converted_max[i_D,j_D] = max(X_hat[i,j],X_converted_max[i_D,j_D])
        X_converted_min[i_D,j_D] = min(X_hat[i,j],X_converted_max[i_D,j_D])
def plot_reconstruct(D,X_converted):
    fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(10,5))
    norm = mcolors.Normalize(vmin=np.amin(D), vmax=np.amax(D))
    pcm = axs[0].matshow(D.T, norm = norm)
    axs[0].set_ylabel('nuisances')
    axs[0].set_xlabel('states')
    axs[0].set_xticks(range(n1))
    axs[0].set_yticks(range(n2))
    axs[0].set_title('D')
    pcm = axs[1].matshow(X_converted.T, norm = norm)
    axs[1].set_ylabel('nuisances')
    axs[1].set_xlabel('states')
    axs[1].set_xticks(range(n1))
    axs[1].set_yticks(range(n2))
    axs[1].set_title('$\hat{X}$')
    fig.colorbar(pcm,ax=axs)
    return fig
fig = plot_reconstruct(D,X_converted_max)
fig.suptitle('Training loss',fontsize=15)
plt.show()
#fig = plot_reconstruct(D,X_converted_min)
#fig.suptitle('Training loss X_min',fontsize=15)
#plt.show()

# Visualize performance of redatuming

### Evaluate latent code

In [None]:
Cs = model.sym_encoder.predict(X, verbose=0)
Ns = model.nui_encoder.predict(X, verbose=0)
print(Cs.shape)
print(Ns.shape)

### Creat a dict, mapping from coordinates of missing blocks (in D) to coordinates i, i', j (in X)

In [None]:
missing_map = {}
for i_D in range(n1):
    for j_D in range(n2):
        state_candidates = np.argwhere(X_states==i_D) 
        nuisance_candidates = np.argwhere(X_nuisances==j_D) 
        i_s, = state_candidates[np.random.choice(state_candidates.shape[0])]
        i_n, j_n = nuisance_candidates[np.random.choice(nuisance_candidates.shape[0])]
        missing_map[i_D,j_D] = (i_n,i_s,j_n)
for i_D in range(n1):
    for j_D in range(n2):
        i, i_prime, j = missing_map[i_D,j_D] 
        assert X_states[i_prime] == i_D
        assert X_nuisances[i,j] == j_D

### Define a function dec

In [None]:
def dec(latent_code):
    tem = latent_code[np.newaxis, np.newaxis, :]
    tem = np.repeat(tem, 1, axis=1)
    tem = model.mixer.predict(tem, verbose=0)
    return tem[0,0,0]

### Fill out X_redatum

Then we evaluate $\hat{X}_{i_n \to i_s}[j_n]$ and put it at X_redatum[i_D,j_D].  
In the code below, i=i_n, i_prime=i_s, j=j_n.  

In [None]:
X_redatum = np.copy(X_converted_max)
coherent_i_prime = np.empty((1,p))
nuisance_i_j = np.empty((1,q*nt))
for (i_D,j_D) in missing_map.keys():
    i,i_prime,j = missing_map[i_D,j_D]
    coherent_i_prime = Cs[i_prime,:]
    nuisance_i_j = Ns[i,j]
    merger = np.concatenate([coherent_i_prime,nuisance_i_j])
    X_redatum[i_D,j_D] = dec(merger)  
clear_output()

### Plot

In [None]:
fig = plot_reconstruct(D,X_redatum)
fig.suptitle('Redatuming',fontsize=15)
plt.show()

# PCA

In [None]:
C_reshaped = list(map(tuple, np.repeat(Cs, nt, axis=0))) 
N_reshaped = list(map(tuple, Ns.reshape(-1,q)))
state_reshaped = states[np.repeat(X_states, nt, axis=0)]
nuisance_reshaped = nuisances[X_nuisances.reshape(-1,1)[:,0]]
data_dict = {'latent state':C_reshaped,
             'latent nuisance':N_reshaped,
             'true state':state_reshaped,
             'true nuisance':nuisance_reshaped}
C_pca = list(map(tuple,PCA(min(3,p)).fit_transform(C_reshaped)))
N_pca = list(map(tuple,PCA(min(3,q)).fit_transform(N_reshaped)))
data_dict['PCA latent state'] = C_pca
data_dict['PCA latent nuisance'] = N_pca
df = pd.DataFrame(data_dict)
df.shape

### p-space (latent coherent space)

In [None]:
if p==1:
    df = pd.DataFrame(Cs, columns = ['1st'])
    df = pd.concat([df,pd.DataFrame(X_states,columns=['state'])],axis=1)
    fig = px.scatter(df, x='1st', color='state')
    fig.show()
elif p==2:
    pca = PCA(n_components=2)
    pca_C = pca.fit_transform(Cs)
    df = pd.DataFrame(pca_C, columns = ['1st','2nd'])
    tem = pd.DataFrame(X_states,columns=['state'])
    df = pd.concat([df,tem],axis=1)
    df.sort_values('state',inplace=True)
    fig = px.scatter(df, x='1st', y='2nd', color='state')
    fig.update_layout(title_text='p space PCA')
    fig.show()
elif p>=3:
    pca = PCA(n_components=3)
    pca_C = pca.fit_transform(Cs)
    df = pd.DataFrame(pca_C, columns = ['1st','2nd','3rd'])
    tem = pd.DataFrame(X_states,columns=['state'])
    df = pd.concat([df,tem],axis=1)
    df.sort_values('state',inplace=True)
    fig = px.scatter_3d(df, x='1st', y='2nd', z='3rd', color='state')
    fig.update_layout(title_text='p space PCA')
    fig.show()

### q space (latent nuisance space)

In [None]:
pca_N = Ns.reshape(-1,q)
df = pd.DataFrame(pca_N, columns = np.arange(q))
tem = pd.DataFrame(X_nuisances.reshape(-1,1),columns=['nuisance'])
df = pd.concat([df,tem],axis=1)
tem = pd.DataFrame(np.repeat(X_states,nt,axis=0),columns=['state'])
df = pd.concat([df,tem],axis=1)
fig = px.scatter(df, x='nuisance', y=4, color='state')
fig.update_layout(title_text='q space PCA')
fig.update_traces(textposition='top center')
fig.show()