In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
import torch
import math
import matplotlib.pyplot as plt
from torch.autograd import Variable
import torch.nn as nn
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import pickle
from scipy.io import loadmat
import os
import matplotlib.cm as cm


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device to train")


all_path="/content/drive/MyDrive/Colab_Notebooks/2023_mep/0120/"
path_m1="/content/drive/MyDrive/Colab_Notebooks/2023_mep/0120/d4_m1_k008/model_MEP_find_minimals_d4_m1_k008.pth"
path_p1="/content/drive/MyDrive/Colab_Notebooks/2023_mep/0120/d4_p1_k008/model_MEP_find_minimals_d4_p1_k008.pth"


mycase="ac_4d_6"
casenum="NBC_k008_beta10lg_6"

path_lg=all_path+casenum+"/"+"lgloss"+".pkl"

plot_use_all=[]

torch.set_printoptions(precision=10)


model_name= all_path+casenum+"/"+"model_MEP_"+mycase+"_"+casenum+".pth"
model_Parameters_name=model_name


kappa=0.08
load_model = True
batches = 20000
beta = 10
learning_rate = 1e-4
dimension = 1
dn=100
dnt=28

x_dim=4



alpha1=1
alpha4=0
alpha3=0.001
alpha5=1



def mkdir(path):
  folder = os.path.exists(path)
  if not folder:
    os.makedirs(path)
    print("---  new folder...  ---")
    print("---  OK  ---")
  else:
    print("---  There is this folder!  ---")


def fig_loss_batch(plt_batch,loss_batch):
    plt.figure(figsize=(5,4))
    plt.plot(plt_batch,loss_batch,'b-')
    plt.xlabel("Batches")
    plt.ylabel("Loss")

def fig_cos(plt_batch,lg_batch,j):
    fig = plt.figure(figsize=(5,4))
    plt.plot(plt_batch,lg_batch,'b-')
    plt.xlabel("Batches")
    plt.ylabel("$l_g$")
    if j == batches-500:
      plt.savefig(all_path+casenum+"/"+"lgloss"+".jpg")

def fig_countour(model):
    x_plot_use=[0.25,0.5]
    t_plot_use=[0.0,0.25,0.5,0.75,1.0]

    x1 = np.linspace(0.001,0.999,dn,endpoint=True)
    x2 = np.linspace(0.001,0.999,dn,endpoint=True)

    xx1,xx2=np.meshgrid(x1,x2)
    xx1=xx1.reshape((-1,1))
    xx2=xx2.reshape((-1,1))

    xx=np.hstack((xx1,xx2))

    xx1=xx[:,0]
    xx2=xx[:,1]

    xx1=xx1.reshape(dn,dn)
    xx2=xx2.reshape(dn,dn)


    t_all=[r"$s=0$",r"$s=\frac{1}{4}$",r"$s=\frac{1}{2}$",r"$s=\frac{3}{4}$",r"$s=1$"]


    for t in t_plot_use:
      fig2 = plt.figure(figsize=(6,6))
      for k in range(len(x_plot_use)):
        for l in range(len(x_plot_use)):
          title_all=r"$(x_1,x_2,$"
          x2=x_plot_use[k]*np.ones((dn*dn,1))
          x3=x_plot_use[l]*np.ones((dn*dn,1))
          tt=t*np.ones((dn*dn,1))
          xxx=np.hstack((xx,x2))
          xxxx=np.hstack((xxx,x3))
          txxxx=np.hstack((tt,xxxx))
          txxxx_tensor=Variable(torch.from_numpy(txxxx),requires_grad=True).to(device)
          phi_tensor=model(txxxx_tensor)

          phi=phi_tensor.cpu().detach().numpy()

          phi_plot_sq=phi.reshape(dn,dn)

          fig2.add_subplot(len(x_plot_use),len(x_plot_use),k*len(x_plot_use)+l+1)

          picori=phi_plot_sq.reshape(dn,dn)
          pic=np.zeros((dn,dn,3))

          B=(np.sign(picori)+1)/2*picori
          C=(np.sign(picori)-1)/2*picori

          pic[:,:,0]=B
          pic[:,:,-1]=C


          plt.imshow(pic)
          plt.xticks([])
          plt.yticks([])

          if k==0:
            title_all=title_all+r"$\frac{1}{4},$"
          elif k==1:
            title_all=title_all+r"$\frac{1}{2},$"

          if l==0:
            title_all=title_all+r"$\frac{1}{4})$"
          elif l==1:
            title_all=title_all+r"$\frac{1}{2})$"

          plt.title(title_all)
          plt.xlabel(r"$x_1$")
          plt.ylabel(r"$x_2$")


      plt.savefig("dim4_k008_"+"t_"+str(t)+".pdf")


class NeuralNetwork_minimum(nn.Module):
    def __init__(self):
        super(NeuralNetwork_minimum, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_tanh_stack = nn.Sequential(
            nn.Linear(4,64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, 1)
        )

    def forward(self, s):
        phi_pred = self.linear_tanh_stack(s)
        return phi_pred




class NeuralNetwork(nn.Module):
    def __init__(self,p1,m1):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_tanh_stack = nn.Sequential(
            nn.Linear(x_dim+1,100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 100),
            nn.Tanh(),
            nn.Linear(100, 1)
        )
        self.p1=p1
        self.m1=m1


    def forward(self, s):
        x_pred = self.linear_tanh_stack(s)
        ss=s[:,0]
        ss=ss.reshape(-1,1)
        xx=s[:,1:]

        out=ss*(1-ss)*x_pred + (1-ss)*self.p1(xx) + ss*self.m1(xx)

        return out


def train(model):
    loss_batch=[]
    plt_batch=[]
    lg_batch=[]

    if load_model:
        model.load_state_dict(torch.load(model_Parameters_name))

    fig_countour(model)
    plt.show()

    return




if __name__=='__main__':

    net_p1=NeuralNetwork_minimum().to(device)
    net_p1=net_p1.double()
    net_p1.load_state_dict(torch.load(path_p1))

    for p in net_p1.parameters():
        p.requires_grad=False


    net_m1=NeuralNetwork_minimum().to(device)
    net_m1=net_m1.double()
    net_m1.load_state_dict(torch.load(path_m1))


    for p in net_m1.parameters():
        p.requires_grad=False


    mkdir(all_path+casenum+"_energyfig/")


    model = NeuralNetwork(net_p1,net_m1).to(device)
    model = model.double()

    #model = train_pre(model)
    train(model)

