# Automatically run through different p and q

### import packages

In [1]:
import re,os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from IPython.display import clear_output
import pandas as pd
import math
import random
from datetime import datetime
from scipy.stats import linregress
from tqdm import trange
tfk = tf.keras
tfkl = tf.keras.layers
clear_output()

### import my files

In [1]:
from parameters import *
from MRA_generate import MRA_generate
from symae_model import SymAE
from plot_training import plot_training
from plot_redatuming import plot_redatuming
from plot_save import plot_save
from CustomCallback import CustomCallback
from redatuming import redatuming
from latent import latent 

Num GPUs Available:  2
TensorFlow Version:  2.9.1


### Generate the Dataset

In [2]:
def g(n,x):
    if n==0:
        return math.e**(-9*x**2)
    elif n==1:
        return int(x<0.5)
    elif n==2:
        if x<0.3:
            return x
        elif x<0.6:
            return 0.6-x
        else:
            return 0
    elif n==3:
        return math.cos(2*math.pi*x)
    elif n==4:
        return math.e**(-9*(x-0.5)**2)
    else:
        return np.inf
MRA_data=MRA_generate(d,nt,N,sigma,ne,g,replace=False,outer_replace=True)
X=MRA_data.X
print("Numbers of States:")
pd.DataFrame(MRA_data.states).value_counts()

Numbers of States:


1    27
3    26
0    25
2    22
dtype: int64

In [3]:
def find_state(state):
    MRA_data=MRA_generate(d,nt,1,sigma,ne,g)
    for i in range(100):
        MRA_data.generate_default()
        if MRA_data.states[0]==state:
            return MRA_data
def alter_parameters(p,q):
    line_storage=''
    with open('parameters.py','r+') as f:
        for line in f.readlines():
            if(line.find('p') == 0):
                line = 'p=%d\n'%p
            if(line.find('q') == 0):
                line = 'q=%d\n'%q
            line_storage += line
    with open('parameters.py','r+') as f:
        f.writelines(line_storage)

In [4]:
auto_path = './pq_checkpoint/status.txt'
auto_status = 'empty'

In [5]:
def everything(p,q):
    alter_parameters(p,q)
    model=SymAE(N,nt,d,p,q,kernel_size,filters,dropout_rate)
    Adam=tf.keras.optimizers.Adam(learning_rate=0.001,beta_1=0.9,beta_2=0.999,epsilon=1e-07)
    model.compile(loss='mse',optimizer=Adam)
    try: 
        model.load_weights('./pq_checkpoint/'+'p=%d,q=%d'%(p,q))
    except:
        print('create '+'p=%d,q=%d'%(p,q))
    M=500
    epochs=range(M)
    losses=[0.0]*M
    slope=0.0; intercept=0.0; std_err=0.0
    loop_time = 0
    class CustomCallback(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            clear_output(wait=True)
            losses[epoch]=logs["loss"]
            print("p={:d}, q={:d}. For epoch {:d}, loss is {:f}.".format(p,q,epoch,logs["loss"]))
            if epoch%100==0:
                auto_status="p={:d}, q={:d}. model_reset_times={:d}. loop_time={:d}. For epoch {:d}, loss is {:f}.".format(p,q,model_reset_times,loop_time,epoch,logs["loss"])
                auto_f = open(auto_path,'w')
                print(auto_status,file=auto_f)
                auto_f.close()
    while True:
        history=model.fit(X,X,epochs=M,verbose=0,callbacks=[CustomCallback()])
        model.save_weights('./pq_checkpoint/'+'p=%d,q=%d'%(p,q))
        loop_time += 1
        slope, intercept, _, _, std_err = linregress(epochs, losses)
        min_first_half = min(losses[0:M//2])
        min_second_half = min(losses[M//2:M])
        # Epoch-Loss
        plt.figure(figsize=(6,4),dpi=150)
        plt.plot(epochs,losses)
        plt.plot(epochs,intercept+slope*np.array(epochs))
        plt.scatter(np.argmin(losses[0:M//2]),min_first_half,c=['C2'])
        plt.scatter(M//2+np.argmin(losses[M//2:M]),min_second_half,c=['C2'])
        plt.xlabel('epoch')
        plt.ylabel('loss')
        plt.tight_layout()
        fig = plt.gcf()
        fig.savefig('./pq_checkpoint/p=%d,q=%d,epoch-loss.png'%(p,q))
        # Redatuming
        MRA1=find_state(0)
        MRA2=find_state(3)
        redatum=redatuming(model,MRA1,MRA2,1,p,q)
        fig=plot_redatuming(redatum,p,q)
        plt.tight_layout()
        fig.savefig('./pq_checkpoint/p=%d,q=%d,cos_tri.png'%(p,q))
        # Redatuming
        MRA1=find_state(0)
        MRA2=find_state(1)
        redatum=redatuming(model,MRA1,MRA2,1,p,q)
        fig=plot_redatuming(redatum,p,q)
        plt.tight_layout()
        fig.savefig('./pq_checkpoint/p=%d,q=%d,squ_tri.png'%(p,q))
        # Termination
        if min_second_half > min_first_half - 0.001:         
            return True
        if loop_time > 5 and min(losses) > 0.1 :
            return False

### Main loop

In [None]:
model_reset_times=0
for p in range(3,20,1):
    for q in range(1,20,1):
        print('traning p=%d q=%d\n'%(p,q))
        model_reset_times=0
        flag=everything(p,q)
        while flag==False:
            model_reset_times+=1
            flag=everything(p,q)
            if model_reset_times>2: 
                break

p=9, q=8. For epoch 320, loss is 0.037999.
