In [None]:
import numpy as np
import time
import torch
import random
import sys
import os
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn, optim, autograd
from scipy.integrate import odeint

from google.colab import drive
drive.mount("/content/drive")

main_path = "/content/drive/My Drive/Workspace/Fourier_PINN/" # ENZE marked: you need to change your main_path if it's not here
sys.path.append(main_path)

from utils import draw_two_dimension, MultiSubplotDraw

Mounted at /content/drive


In [None]:
class Parameters:
    # line 1
    kscln = 0.2
    kdcln = 0.2
    kasbf_ = 1
    kasbf = 10
    kisbf = 25
    Jsbf = 1
    # line 2
    ksclbs = 0.15
    kdclbs_ = 0.1
    kdclbs = 0.05
    # line 3
    ksnrm1 = 0.05
    kdnrm1 = 0.1
    MBFtot = 0.5
    kass_ = 1
    kdiss_ = 0.001
    Jmbf = 0.01
    # line 4
    ksclbm_ = 0.01
    ksclbm = 0.01
    kdclbm_ =0.01
    kdclbm = 1
    Jclbm = 0.05
    # n = 2
    # line 5
    kspolo = 0.01
    kdpolo_ = 0.01
    kdpolo = 1
    kacdc14 = 1
    kicdc14 = 0.25
    Jcdc14 = 0.01
    # line 6
    kssic_ = 0.02
    kdsic_ = 0.01
    kdsic = 2
    Jsic1 = 0.01
    Kdiss = 0.05
    # line 7
    kacdh1_ = 1
    kacdh1 = 10
    kicdh1_ = 0.2
    kicdh1 = 10
    Jcdh1 = 0.01
    # line 8
    ndClbM = 0

class TrainArgs:
    iteration = 1000000
    epoch_step = 100
    test_step = 1000
    initial_lr = 0.0001
    main_path = "/content/drive/My Drive/Workspace/Fourier_PINN/"


class Start:
    Cln = 0.05
    ClbSt = 0.045
    MBF = 0.02
    Nrm1t = 0.54
    ClbMt = 0.9
    Polo = 0.25
    Sic1t = 0.01
    SBF = 0.031
    Cdh1 = 0.0005
    Cdc14 = 0.1
    all = [Cln, ClbSt, MBF, Nrm1t, ClbMt, Polo, Sic1t, SBF, Cdh1, Cdc14]

class Config:
    def __init__(self):
        self.model_name = "BYCC_Fourier"
        self.curve_names = ["Cln", "ClbSt", "MBF", "Nrm1t", "ClbMt","Polo","Sic1t","SBF","Cdh1","Cdc14"]
        self.params = Parameters
        self.args = TrainArgs
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.seed = 0

        self.T = 100
        self.T_unit = 1e-2
        self.T_N = int(self.T / self.T_unit)

        self.prob_dim = 10
        self.y0 = np.asarray(Start.all)
        self.t = np.asarray([i * self.T_unit for i in range(self.T_N)])
        self.t_torch = torch.tensor(self.t, dtype=torch.float32).to(self.device)
        self.x = torch.tensor(np.asarray([[[i * self.T_unit] * self.prob_dim for i in range(self.T_N)]]), dtype=torch.float32).to(self.device)
        # print(self.x.shape)
        self.truth = odeint(self.pend, self.y0, self.t)

        self.modes = 64  # Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.width = 16
        self.fc_map_dim = 128
    
    def pend(self, y, t):
        Cln, ClbSt, MBF, Nrm1t, ClbMt, Polo, Sic1t, SBF, Cdh1, Cdc14 = y[0], y[1], y[2], y[3], y[4], y[5], y[6], y[7], y[8], y[9] 
        params = {
            "Cln": y[0],
            "ClbSt": y[1],
            "MBF": y[2],
            "Nrm1t": y[3],
            "ClbMt": y[4],
            "Polo": y[5],
            "Sic1t": y[6],
            "SBF": y[7],
            "Cdh1": y[8],
            "Cdc14": y[9]
        }

        Cln = y[0]
        ClbSt = y[1]
        MBF = y[2]
        Nrm1t = y[3]
        ClbMt = y[4]
        Polo = y[5]
        Sic1t = y[6]
        SBF = y[7]
        Cdh1 = y[8]
        Cdc14 = y[9]

        # sqrt = np.sqrt
        # MBFa = params["MBF"] * params["Cln"] / (self.params.Jmbf + params["Cln"])
        # Clbt = params["ClbSt"] + params["ClbMt"] + self.params.ndClbM
        # BB = params["Sic1t"] + Clbt + self.params.Kdiss
        # Sic1Clb = 2 * params["Sic1t"] * Clbt / (BB + sqrt(BB ** 2 - 4 * params["Sic1t"] * Clbt))
        # Clb = Clbt - Sic1Clb
        # ClbM = (params["ClbMt"] + self.params.ndClbM) * (Clbt - Sic1Clb) / Clbt

        self.config = self

        nonzero_offset = 1e-20
        abs_func = np.abs
        sqrt = np.sqrt
        MBFa = MBF * Cln / (abs_func(self.config.params.Jmbf + Cln) + nonzero_offset)
        Clbt = ClbSt + ClbMt + self.config.params.ndClbM
        BB = Sic1t + Clbt + self.config.params.Kdiss
        Sic1Clb = 2 * Sic1t * Clbt / (BB + sqrt(abs_func(BB ** 2 - 4 * Sic1t * Clbt) + nonzero_offset))
        Clb = Clbt - Sic1Clb
        ClbM = (ClbMt + self.config.params.ndClbM) * (Clbt - Sic1Clb) / (abs_func(Clbt)+nonzero_offset)

        # Cln_ = self.params.kscln * params["SBF"] - self.params.kdcln * params["Cln"]
        # ClbSt_ = self.params.ksclbs * MBFa - (self.params.kdclbs_ + self.params.kdclbs * params["Cdh1"]) * params["ClbSt"]
        # MBF_ = self.params.kdiss_ * (self.params.MBFtot - params["MBF"]) - self.params.kass_ * params["MBF"] * (params["Nrm1t"] - (self.params.MBFtot - params["MBF"]))
        # Nrm1t_ = self.params.ksnrm1 * MBFa - self.params.kdnrm1 * params["Cdh1"] * params["Nrm1t"]
        # ClbMt_ = self.params.ksclbm_ + self.params.ksclbm * ClbM ** 2 / (self.params.Jclbm ** 2 + ClbM ** 2) - (self.params.kdclbm_ + self.params.kdclbm * params["Cdh1"]) * params["ClbMt"]
        # Polo_ = self.params.kspolo * ClbM - (self.params.kdpolo_ + self.params.kdpolo * params["Cdh1"]) * params["Polo"]
        # Sic1t_ = self.params.kssic_ - (self.params.kdsic_ + self.params.kdsic * Clb * (params["Cln"] + Clb) / (self.params.Jsic1 + params["Cln"] + Clb)) * params["Sic1t"]
        # SBF_ = (self.params.kasbf_ + self.params.kasbf * params["Cln"]) * (1 - params["SBF"])/(self.params.Jsbf + 1 - params["SBF"]) - self.params.kisbf * ClbM * params["SBF"]/(self.params.Jsbf + params["SBF"])
        # Cdh1_ = (self.params.kacdh1_ + self.params.kacdh1 * params["Cdc14"])*(1 - params["Cdh1"])/(self.params.Jcdh1 + 1 - params["Cdh1"]) - (self.params.kicdh1_ * params["Cln"] + self.params.kicdh1 * Clb) * params["Cdh1"] / (self.params.Jcdh1 + params["Cdh1"])
        # Cdc14_ = self.params.kacdc14 * params["Polo"] * (1 - params["Cdc14"])/(self.params.Jcdc14 + 1 - params["Cdc14"]) - self.params.kicdc14 * params["Cdc14"] / (self.params.Jcdc14 + params["Cdc14"])

        Cln_t_target = self.config.params.kscln * SBF - self.config.params.kdcln * Cln
        ClbSt_t_target = self.config.params.ksclbs * MBFa - (self.config.params.kdclbs_ + self.config.params.kdclbs * Cdh1) * ClbSt
        MBF_t_target = self.config.params.kdiss_ * (self.config.params.MBFtot - MBF) - self.config.params.kass_ * MBF * (Nrm1t - (self.config.params.MBFtot - MBF))
        Nrm1t_t_target = self.config.params.ksnrm1 * MBFa - self.config.params.kdnrm1 * Cdh1 * Nrm1t
        ClbMt_t_target = self.config.params.ksclbm_ + self.config.params.ksclbm * ClbM ** 2 / (abs_func(self.config.params.Jclbm ** 2 + ClbM ** 2) + nonzero_offset) - (self.config.params.kdclbm_ + self.config.params.kdclbm * Cdh1) * ClbMt
        Polo_t_target = self.config.params.kspolo * ClbM - (self.config.params.kdpolo_ + self.config.params.kdpolo * Cdh1) * Polo
        Sic1t_t_target = self.config.params.kssic_ - (self.config.params.kdsic_ + self.config.params.kdsic * Clb * (Cln + Clb) / (abs_func(self.config.params.Jsic1 + Cln + Clb) + nonzero_offset)) * Sic1t
        SBF_t_target = (self.config.params.kasbf_ + self.config.params.kasbf * Cln) * (1 - SBF)/(abs_func(self.config.params.Jsbf + 1 - SBF) + nonzero_offset) - self.config.params.kisbf * ClbM * SBF/(self.config.params.Jsbf + SBF)
        Cdh1_t_target = (self.config.params.kacdh1_ + self.config.params.kacdh1 * Cdc14)*(1 - Cdh1)/(abs_func(self.config.params.Jcdh1 + 1 - Cdh1) + nonzero_offset) - (self.config.params.kicdh1_ * Cln + self.config.params.kicdh1 * Clb) * Cdh1 / (abs_func(self.config.params.Jcdh1 + Cdh1) + nonzero_offset)
        Cdc14_t_target = self.config.params.kacdc14 * Polo * (1 - Cdc14)/(abs_func(self.config.params.Jcdc14 + 1 - Cdc14)+nonzero_offset) - self.config.params.kicdc14 * Cdc14 / (abs_func(self.config.params.Jcdc14 + Cdc14) + nonzero_offset)

        dydt = np.asarray([Cln_t_target, ClbSt_t_target, MBF_t_target, Nrm1t_t_target, ClbMt_t_target, Polo_t_target, Sic1t_t_target, SBF_t_target, Cdh1_t_target, Cdc14_t_target])
        # dydt = np.asarray([Cln_, ClbSt_, MBF_, Nrm1t_, ClbMt_, Polo_, Sic1t_, SBF_, Cdh1_, Cdc14_])
        return dydt


In [None]:
config=Config()

In [None]:
class SpectralConv1d(nn.Module):
    def __init__(self, config):
        super(SpectralConv1d, self).__init__()
        self.config = config
        self.in_channels = self.config.width
        self.out_channels = self.config.width
        self.scale = 1 / (self.in_channels * self.out_channels)
        self.weights = nn.Parameter(self.scale * torch.rand(self.in_channels, self.out_channels, self.config.modes, dtype=torch.cfloat))

    def compl_mul1d(self, input, weights):
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        x_ft = torch.fft.rfft(x)
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1) // 2 + 1, dtype=torch.cfloat).to(self.config.device) 
        out_ft[:, :, :self.config.modes] = self.compl_mul1d(x_ft[:, :, :self.config.modes], self.weights)
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x


class FourierModel(nn.Module):
    def __init__(self, config):
        super(FourierModel, self).__init__()
        self.time_string = time.strftime("%Y%m%d_%H%M%S", time.localtime(time.time()))
        self.config = config
        self.setup_seed(self.config.seed)

        self.step_2_flag = False
        
        self.fc0 = nn.Linear(self.config.prob_dim, self.config.width)  # input channel is 2: (a(x), x)

        self.conv0 = SpectralConv1d(self.config)
        self.conv1 = SpectralConv1d(self.config)
        self.conv2 = SpectralConv1d(self.config)
        self.conv3 = SpectralConv1d(self.config)
        self.w0 = nn.Conv1d(self.config.width, self.config.width, 1)
        self.w1 = nn.Conv1d(self.config.width, self.config.width, 1)
        self.w2 = nn.Conv1d(self.config.width, self.config.width, 1)
        self.w3 = nn.Conv1d(self.config.width, self.config.width, 1)

        self.fc1 = nn.Linear(self.config.width, self.config.fc_map_dim)
        self.fc2 = nn.Linear(self.config.fc_map_dim, self.config.fc_map_dim)
        self.fc3 = nn.Linear(self.config.fc_map_dim, self.config.fc_map_dim)
        self.fc4 = nn.Linear(self.config.fc_map_dim, self.config.prob_dim)

        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight.data,mode='fan_out', nonlinearity='relu')

        self.criterion = torch.nn.MSELoss("sum").to(self.config.device)

        self.y_tmp = None
        self.epoch_tmp = None
        self.loss_record_tmp = None

        self.figure_save_path_folder = "{0}/figure/{1}_{2}/".format(self.config.args.main_path, self.config.model_name, self.time_string)
        if not os.path.exists(self.figure_save_path_folder):
            os.makedirs(self.figure_save_path_folder)
        self.default_colors = ["red", "blue", "green", "pink", "cyan", "lime", "pink", "indigo", "brown", "grey"]

        print("using {}".format(str(self.config.device)))
        print("iteration = {}".format(self.config.args.iteration))
        print("epoch_step = {}".format(self.config.args.epoch_step))
        print("test_step = {}".format(self.config.args.test_step))
        print("model_name = {}".format(self.config.model_name))
        print("time_string = {}".format(self.time_string))
        self.truth_loss()


    def forward(self, x):
        # print("cp1", x.shape)
        x = self.fc0(x)
        # print("cp2", x.shape)
        x = x.permute(0, 2, 1)
        # print("cp3", x.shape)

        x1 = self.conv0(x)
        # print("cp4", x1.shape)
        x2 = self.w0(x)
        # print("cp5", x2.shape)
        x = x1 + x2
        x = F.gelu(x)
        # print("cp6", x.shape)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2
        # print("cp7", x.shape)
        x = x.permute(0, 2, 1)
        # print("cp8", x.shape)
        x = self.fc1(x)
        # print("cp9", x.shape)
        x = F.gelu(x)
        # print("cp10", x.shape)
        x = self.fc2(x)
        x = F.gelu(x)

        x = self.fc3(x)
        x = F.gelu(x)

        x = self.fc4(x)
        x = torch.sigmoid(x)
        # print("cp11", x.shape)

        # print(x.shape)
        return x

    
    @staticmethod
    def setup_seed(seed):
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.deterministic = True
    
    def set_flag(self, flag):
        self.step_2_flag = flag

    def ode_gradient(self, x, y):
        Cln = y[0, :,0]
        ClbSt = y[0, :,1]
        MBF = y[0, :,2]
        Nrm1t = y[0, :,3]
        ClbMt = y[0, :,4]
        Polo = y[0, :,5]
        Sic1t = y[0, :,6]
        SBF = y[0, :,7]
        Cdh1 = y[0, :,8]
        Cdc14 = y[0, :,9]

        Cln_t = torch.gradient(Cln, spacing=(self.config.t_torch,))[0]
        ClbSt_t = torch.gradient(ClbSt, spacing=(self.config.t_torch,))[0]
        MBF_t = torch.gradient(MBF, spacing=(self.config.t_torch,))[0]
        Nrm1t_t = torch.gradient(Nrm1t, spacing=(self.config.t_torch,))[0]
        ClbMt_t = torch.gradient(ClbMt, spacing=(self.config.t_torch,))[0]
        Polo_t = torch.gradient(Polo, spacing=(self.config.t_torch,))[0]
        Sic1t_t = torch.gradient(Sic1t, spacing=(self.config.t_torch,))[0]
        SBF_t = torch.gradient(SBF, spacing=(self.config.t_torch,))[0]
        Cdh1_t = torch.gradient(Cdh1, spacing=(self.config.t_torch,))[0]
        Cdc14_t = torch.gradient(Cdc14, spacing=(self.config.t_torch,))[0]

        d_y = torch.cat((
            Cln_t.reshape([self.config.T_N, 1]),
            ClbSt_t.reshape([self.config.T_N, 1]),
            MBF_t.reshape([self.config.T_N, 1]),
            Nrm1t_t.reshape([self.config.T_N, 1]),
            ClbMt_t.reshape([self.config.T_N, 1]),
            Polo_t.reshape([self.config.T_N, 1]),
            Sic1t_t.reshape([self.config.T_N, 1]),
            SBF_t.reshape([self.config.T_N, 1]),
            Cdh1_t.reshape([self.config.T_N, 1]),
            Cdc14_t.reshape([self.config.T_N, 1])), 1)

        # print(Cln_t.shape)
        # print(ClbSt_t.shape)
        # print(MBF_t.shape)
        # print(Nrm1t_t.shape)
        # print(ClbMt_t.shape)
        # print(Polo_t.shape)
        # print(Sic1t_t.shape)
        # print(SBF_t.shape)
        # print(Cdh1_t.shape)
        # print(Cdc14_t.shape)
        nonzero_offset = 1e-20
        sqrt = torch.sqrt
        MBFa = MBF * Cln / (torch.abs(self.config.params.Jmbf + Cln) + nonzero_offset)
        Clbt = ClbSt + ClbMt + self.config.params.ndClbM
        BB = Sic1t + Clbt + self.config.params.Kdiss
        Sic1Clb = 2 * Sic1t * Clbt / (BB + sqrt(torch.abs(BB ** 2 - 4 * Sic1t * Clbt) + nonzero_offset))
        Clb = Clbt - Sic1Clb
        ClbM = (ClbMt + self.config.params.ndClbM) * (Clbt - Sic1Clb) / (torch.abs(Clbt)+nonzero_offset)

        # Cln_t_target  = kscln*SBF - kdcln*Cln
        # ClbSt_t_target = ksclbs*(MBF*Cln/(torch.abs(Jmbf + Cln)+1e-12)) - (kdclbs1 + kdclbs*Cdh1)*ClbSt
        # MBF_t_target  = kdiss1*(MBFtot - MBF) - kass1*MBF*(Nrm1t - (MBFtot - MBF))
        # Nrm1t_t_target = ksnrm1*(MBF*Cln/(torch.abs(Jmbf + Cln)+1e-12)) - kdnrm1*Cdh1*Nrm1t
        # ClbMt_t_target = ksclbm1 + ksclbm*((ClbMt+ndClbM)*((ClbSt + ClbMt + ndClbM) - (2*Sic1t*(ClbSt + ClbMt + ndClbM)/(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss) + torch.sqrt(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss)**2 -4*Sic1t*(ClbSt + ClbMt + ndClbM))+1e-12))+1e-12)))/(torch.abs(ClbSt + ClbMt + ndClbM)+1e-12))**n/(torch.abs(Jclbm**n + ((ClbMt + ndClbM)*((ClbSt + ClbMt + ndClbM) - (2*Sic1t*(ClbSt + ClbMt + ndClbM)/(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss) + torch.sqrt(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss)**2 -4*Sic1t*(ClbSt + ClbMt + ndClbM))+1e-12))+1e-12)))/(torch.abs(ClbSt + ClbMt + ndClbM)+1e-12))**n)+1e-12) - (kdclbm1 + kdclbm*Cdh1)*ClbMt
        # Polo_t_target  = kspolo*((ClbMt+ndClbM)*((ClbSt + ClbMt + ndClbM) - (2*Sic1t*(ClbSt + ClbMt + ndClbM)/(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss) + torch.sqrt(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss)**2 -4*Sic1t*(ClbSt + ClbMt + ndClbM))+1e-12))+1e-12)))/(torch.abs(ClbSt + ClbMt + ndClbM)+1e-12)) - (kdpolo1 + kdpolo*Cdh1)*Polo
        # Sic1t_t_target = kssic1 - (kdsic1 + kdsic*((ClbSt + ClbMt + ndClbM) - (2*Sic1t*(ClbSt + ClbMt + ndClbM)/(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss) + torch.sqrt(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss)**2 -4*Sic1t*(ClbSt + ClbMt + ndClbM))+1e-12))+1e-12)))*(Cln+((ClbSt + ClbMt + ndClbM) - (2*Sic1t*(ClbSt + ClbMt + ndClbM)/(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss) + torch.sqrt(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss)**2 -4*Sic1t*(ClbSt + ClbMt + ndClbM))+1e-12))+1e-12))))/(torch.abs(Jsic1+Cln+((ClbSt + ClbMt + ndClbM) - (2*Sic1t*(ClbSt + ClbMt + ndClbM)/(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss) + torch.sqrt(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss)**2 -4*Sic1t*(ClbSt + ClbMt + ndClbM))+1e-12))+1e-12))))+1e-12))*Sic1t
        # SBF_t_target  = (kasbf1 + kasbf*Cln)*(1-SBF)/(torch.abs(Jsbf + 1 - SBF)+1e-12) - kisbf*((ClbMt+ndClbM)*((ClbSt + ClbMt + ndClbM) - (2*Sic1t*(ClbSt + ClbMt + ndClbM)/(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss) + torch.sqrt(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss)**2 -4*Sic1t*(ClbSt + ClbMt + ndClbM))+1e-12))+1e-12)))/(torch.abs(ClbSt + ClbMt + ndClbM)+1e-12))*SBF/(torch.abs(Jsbf + SBF)+1e-12)
        # Cdh1_t_target  = (kacdh11 + kacdh1*Cdc14)*(1 - Cdh1)/(torch.abs(Jcdh1 + 1 - Cdh1)+1e-12) - (kicdh11*Cln + kicdh1*((ClbSt + ClbMt + ndClbM) - (2*Sic1t*(ClbSt + ClbMt + ndClbM)/(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss) + torch.sqrt(torch.abs((Sic1t + (ClbSt + ClbMt + ndClbM) + Kdiss)**2 -4*Sic1t*(ClbSt + ClbMt + ndClbM))+1e-12))+1e-12))))*Cdh1/(torch.abs(Jcdh1 + Cdh1)+1e-12)
        # Cdc14_t_target = kacdc14*Polo*(1 - Cdc14)/(torch.abs(Jcdc14 + 1 - Cdc14)+1e-12) - kicdc14*Cdc14/(torch.abs(Jcdc14 + Cdc14)+1e-12)

        Cln_t_target = self.config.params.kscln * SBF - self.config.params.kdcln * Cln
        ClbSt_t_target = self.config.params.ksclbs * MBFa - (self.config.params.kdclbs_ + self.config.params.kdclbs * Cdh1) * ClbSt
        MBF_t_target = self.config.params.kdiss_ * (self.config.params.MBFtot - MBF) - self.config.params.kass_ * MBF * (Nrm1t - (self.config.params.MBFtot - MBF))
        Nrm1t_t_target = self.config.params.ksnrm1 * MBFa - self.config.params.kdnrm1 * Cdh1 * Nrm1t
        ClbMt_t_target = self.config.params.ksclbm_ + self.config.params.ksclbm * ClbM ** 2 / (torch.abs(self.config.params.Jclbm ** 2 + ClbM ** 2) + nonzero_offset) - (self.config.params.kdclbm_ + self.config.params.kdclbm * Cdh1) * ClbMt
        Polo_t_target = self.config.params.kspolo * ClbM - (self.config.params.kdpolo_ + self.config.params.kdpolo * Cdh1) * Polo
        Sic1t_t_target = self.config.params.kssic_ - (self.config.params.kdsic_ + self.config.params.kdsic * Clb * (Cln + Clb) / (torch.abs(self.config.params.Jsic1 + Cln + Clb) + nonzero_offset)) * Sic1t
        SBF_t_target = (self.config.params.kasbf_ + self.config.params.kasbf * Cln) * (1 - SBF)/(torch.abs(self.config.params.Jsbf + 1 - SBF) + nonzero_offset) - self.config.params.kisbf * ClbM * SBF/(self.config.params.Jsbf + SBF)
        Cdh1_t_target = (self.config.params.kacdh1_ + self.config.params.kacdh1 * Cdc14)*(1 - Cdh1)/(torch.abs(self.config.params.Jcdh1 + 1 - Cdh1) + nonzero_offset) - (self.config.params.kicdh1_ * Cln + self.config.params.kicdh1 * Clb) * Cdh1 / (torch.abs(self.config.params.Jcdh1 + Cdh1) + nonzero_offset)
        Cdc14_t_target = self.config.params.kacdc14 * Polo * (1 - Cdc14)/(torch.abs(self.config.params.Jcdc14 + 1 - Cdc14)+nonzero_offset) - self.config.params.kicdc14 * Cdc14 / (torch.abs(self.config.params.Jcdc14 + Cdc14) + nonzero_offset)

        f_Cln = Cln_t  - Cln_t_target
        f_ClbSt = ClbSt_t - ClbSt_t_target
        f_MBF = MBF_t  - MBF_t_target
        f_Nrm1t = Nrm1t_t - Nrm1t_t_target
        f_ClbMt = ClbMt_t - ClbMt_t_target
        f_Polo = Polo_t - Polo_t_target 
        f_Sic1t = Sic1t_t - Sic1t_t_target
        f_SBF = SBF_t  - SBF_t_target
        f_Cdh1 = Cdh1_t - Cdh1_t_target 
        f_Cdc14 = Cdc14_t - Cdc14_t_target
        
        f_y = torch.cat((
            f_Cln.reshape([self.config.T_N, 1]),
            f_ClbSt.reshape([self.config.T_N, 1]),
            f_MBF.reshape([self.config.T_N, 1]),
            f_Nrm1t.reshape([self.config.T_N, 1]),
            f_ClbMt.reshape([self.config.T_N, 1]),
            f_Polo.reshape([self.config.T_N, 1]),
            f_Sic1t.reshape([self.config.T_N, 1]),
            f_SBF.reshape([self.config.T_N, 1]),
            f_Cdh1.reshape([self.config.T_N, 1]),
            f_Cdc14.reshape([self.config.T_N, 1])), 1)

        g_Cln = torch.gradient(f_Cln, spacing=(self.config.t_torch,))[0]
        g_ClbSt = torch.gradient(f_ClbSt, spacing=(self.config.t_torch,))[0]
        g_MBF = torch.gradient(f_MBF, spacing=(self.config.t_torch,))[0]
        g_Nrm1t = torch.gradient(f_Nrm1t, spacing=(self.config.t_torch,))[0]
        g_ClbMt = torch.gradient(f_ClbMt, spacing=(self.config.t_torch,))[0]
        g_Polo = torch.gradient(f_Polo, spacing=(self.config.t_torch,))[0]
        g_Sic1t = torch.gradient(f_Sic1t, spacing=(self.config.t_torch,))[0]
        g_SBF = torch.gradient(f_SBF, spacing=(self.config.t_torch,))[0]
        g_Cdh1 = torch.gradient(f_Cdh1, spacing=(self.config.t_torch,))[0]
        g_Cdc14 = torch.gradient(f_Cdc14, spacing=(self.config.t_torch,))[0]

        g_y = torch.cat((
            g_Cln.reshape([self.config.T_N, 1]),
            g_ClbSt.reshape([self.config.T_N, 1]),
            g_MBF.reshape([self.config.T_N, 1]),
            g_Nrm1t.reshape([self.config.T_N, 1]),
            g_ClbMt.reshape([self.config.T_N, 1]),
            g_Polo.reshape([self.config.T_N, 1]),
            g_Sic1t.reshape([self.config.T_N, 1]),
            g_SBF.reshape([self.config.T_N, 1]),
            g_Cdh1.reshape([self.config.T_N, 1]),
            g_Cdc14.reshape([self.config.T_N, 1])), 1)

        # S = y[0, :, 0]
        # I = y[0, :, 1]
        # R = y[0, :, 2]
        # S_t = torch.gradient(S, spacing=(self.config.t_torch,))[0]
        # I_t = torch.gradient(I, spacing=(self.config.t_torch,))[0]
        # R_t = torch.gradient(R, spacing=(self.config.t_torch,))[0]
        # f_S = S_t - (- self.config.params.beta * S * I )
        # f_I = I_t - (self.config.params.beta * S * I  - self.config.params.gamma * I)
        # f_R = R_t - (self.config.params.gamma * I)

        return f_y, d_y, g_y

    def loss(self, y):
        y0_pred = y[0, 0, :]
        y0_true = torch.tensor(self.config.y0, dtype=torch.float32).to(self.config.device)

        ode_f, ode_d, ode_g = self.ode_gradient(self.config.x, y)
        zeros_1D = torch.zeros([self.config.T_N]).to(self.config.device)
        zeros_nD = torch.zeros([self.config.T_N, self.config.prob_dim]).to(self.config.device)
        # print(ode.shape)
        # print(zeros_nD.shape)

        mse_cost_function1 = torch.nn.MSELoss(reduction='mean').to(self.config.device)  # Mean squared error 均方误差求
        mse_cost_function2 = torch.nn.MSELoss(reduction='sum').to(self.config.device)  # Mean squared error 均方误差求
        mse_cost_function2 = self.criterion

        lambda_1 = 1
        lambda_2 = 0.1
        lambda_3 = 0.1
        lambda_4 = 0.01
        lambda_5 = 0.1

        loss1 = lambda_1 * sum([self.criterion(y0_pred[i], y0_true[i]) for i in range(self.config.prob_dim)])
        loss2 = lambda_2 * sum([self.criterion(ode_f[:, i], zeros_1D) for i in range(self.config.prob_dim)])
        loss3 = lambda_3 * (self.criterion(torch.abs(y - 0), y - 0) + self.criterion(torch.abs(1 - y), 1 - y)) #+ self.criterion(torch.abs(self.config.params.N - y), self.config.params.N - y)
        loss4 = lambda_4 * sum([mse_cost_function2(0.1 / mse_cost_function2(ode_d[:, i], zeros_1D), torch.tensor(0.00).to(self.config.device)) for i in range(self.config.prob_dim)])
        loss5 = lambda_5 * sum([mse_cost_function2(ode_g[:, i], zeros_1D) for i in range(self.config.prob_dim)]) #mse_cost_function2(g_Cln, pt_all_zeros_1)
        # loss4 = self.criterion(y[0, :, 0] + y[0, :, 1] + y[0, :, 2] - self.config.params.N, zeros_1D)
        # loss4 = self.criterion(1 / u_0, pt_all_zeros_3)
        # loss5 = self.criterion(torch.abs(u_0 - v_0), u_0 - v_0)
        if not self.step_2_flag:
            loss = loss1 + loss2 + loss3 + loss4 
            loss_list = [loss1, loss2, loss3, loss4]
        else:
            loss = loss1 + loss2 + loss3 
            loss_list = [loss1, loss2, loss3]
        return loss, loss_list
    
    def truth_loss(self):
        y_truth = torch.tensor(self.config.truth.reshape([1, self.config.T_N ,self.config.prob_dim])).to(self.config.device)
        tl, tl_list = self.loss(y_truth)
        loss_print_part = " ".join(["Loss_{0:d}:{1:.8f}".format(i + 1, loss_part.item()) for i, loss_part in enumerate(tl_list)])
        print("Ground truth has loss: Loss:{0:.8f} {1}".format(tl.item(), loss_print_part))

    def train_model(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.args.initial_lr, weight_decay=0)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda e: 1 / (e / 10000 + 1))
        self.train()

        start_time = time.time()
        start_time_0 = start_time
        loss_record = []
        
        for epoch in range(1, self.config.args.iteration + 1):
            optimizer.zero_grad()

            y = self.forward(self.config.x)
            loss, loss_list = self.loss(y)
            loss_record.append(loss.item())

            loss.backward()
            optimizer.step()
            scheduler.step()

            if epoch % self.config.args.epoch_step == 0:
                now_time = time.time()
                loss_print_part = " ".join(["Loss_{0:d}:{1:.6f}".format(i + 1, loss_part.item()) for i, loss_part in enumerate(loss_list)])
                print("Epoch [{0:05d}/{1:05d}] Loss:{2:.6f} {3} Lr:{4:.6f} Time:{5:.6f}s ({6:.2f}min in total, {7:.2f}min remains)".format(epoch, self.config.args.iteration, loss.item(), loss_print_part, optimizer.param_groups[0]["lr"], now_time - start_time, (now_time - start_time_0) / 60.0, (now_time - start_time_0) / 60.0 / epoch * (self.config.args.iteration - epoch)))
                start_time = now_time

                if epoch % self.config.args.test_step == 0:
                    self.y_tmp = y
                    self.epoch_tmp = epoch
                    self.loss_record_tmp = loss_record
                    self.test_model()
                    if not self.step_2_flag:
                        torch.save(self, main_path + '/saves/BYCC_step1.pt')
                    else:
                        torch.save(self, main_path + '/saves/BYCC_step2.pt')
    
    def test_model(self):
        y_draw = self.y_tmp[0].cpu().detach().numpy().swapaxes(0, 1)
        x_draw = self.config.t
        y_draw_truth = self.config.truth.swapaxes(0, 1)
        save_path = "{}/{}_{}_epoch={}.png".format(self.figure_save_path_folder, self.config.model_name, self.time_string, self.epoch_tmp)
        # draw_two_dimension(
        #     y_lists=np.concatenate([y_draw, y_draw_truth], axis=0),
        #     x_list=x_draw,
        #     color_list=self.default_colors[: self.config.prob_dim] + ["grey"] * self.config.prob_dim,
        #     legend_list=self.config.curve_names + ["{}_true".format(item) for item in self.config.curve_names],
        #     line_style_list=["solid"] * self.config.prob_dim + ["dashed"] * self.config.prob_dim,
        #     fig_title="{}_{}_epoch={}".format(self.config.model_name, self.time_string, self.epoch_tmp),
        #     fig_size=(8, 6),
        #     show_flag=True,
        #     save_flag=True,
        #     save_path=save_path,
        #     save_dpi=300,
        #     legend_loc="center right",
        # )
        m = MultiSubplotDraw(row=2, col=5, fig_size=(20, 6), tight_layout_flag=True, show_flag=True, save_flag=True,
                             save_path=save_path)
        for name, item, item_target, color in zip(self.config.curve_names, y_draw, y_draw_truth, self.default_colors[:self.config.prob_dim]):
            m.add_subplot(
                y_lists=[item, item_target],
                x_list=x_draw,
                color_list=[color, "black"],
                legend_list=["pred", "true"],
                line_style_list=["solid", "dashed"],
                fig_title=name,
            )
        m.draw()
        print("Figure is saved to {}".format(save_path))
        self.draw_loss_multi(self.loss_record_tmp, [1.0, 0.5, 0.25])
    
    @staticmethod
    def draw_loss_multi(loss_list, last_rate_list):
        m = MultiSubplotDraw(row=1, col=len(last_rate_list), fig_size=(8 * len(last_rate_list), 6), tight_layout_flag=True, show_flag=True, save_flag=False, save_path=None)
        for one_rate in last_rate_list:
            # print(loss_list[-int(len(loss_list) * one_rate):])
            # print(range(len(loss_list) - int(len(loss_list) * one_rate) + 1, len(loss_list) + 1))
            m.add_subplot(
                y_lists=[loss_list[-int(len(loss_list) * one_rate):]],
                x_list=range(len(loss_list) - int(len(loss_list) * one_rate) + 1, len(loss_list) + 1),
                color_list=["blue"],
                line_style_list=["solid"],
                fig_title="Loss - lastest ${}$% - epoch ${}$ to ${}$".format(int(100 * one_rate), len(loss_list) - int(len(loss_list) * one_rate) + 1, len(loss_list)),
                fig_x_label="epoch",
                fig_y_label="loss")
        m.draw()




 


In [None]:
config = Config()
model = FourierModel(config).to(config.device)
model.train_model()

In [None]:
model = torch.load(main_path + '/saves/BYCC_step1.pt').to(model.config.device)
model.set_flag(True)
model.truth_loss()
model.train_model()

In [None]:
config = Config()
model = FourierModel(config).to(config.device)
model.train_model()