In [1]:
import sys
import os
from pathlib import Path

# Add parent directory temporarily to sys.path
sys.path.insert(0, str(Path(os.getcwd()).resolve().parent))

In [2]:
from __future__ import annotations

import warnings
from itertools import chain
from collections import OrderedDict

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from tqdm import tqdm

from lib.Simulation import Simulation
from lib.Simulation_gpu import Simulation as Simulation_GPU
from lib.Loader import Loader
from lib.nn.helper import SaveLoad, BatchProcessing
from lib.nn.nets import FCN, MShuffle
from lib.nn.netsdd import LinearDD, SoftplusDD, FCNDD, MShuffleDD
warnings.filterwarnings('ignore')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

path = "main"

KeyboardInterrupt: 

In [None]:
np.random.seed(20)

dt, dx, dy = (0.03, 90 / 600, 90 / 600)
Nt, Nx, Ny = (200, 600, 600)
St, Sx, Sy = (100,1,1)
Lt, Lx, Ly = Nt*dt, Nx*dx, Ny*dy 
myu_size = (5, 8, 8)
S = Simulation_GPU(
        d = (dt, dx, dy),
        N = (Nt, Nx, Ny),
        s = (St, Sx, Sy),
        myu_size = myu_size,
        myu_mstd = (5.4, 0.8)
)
A, myu = S.compute()
S.check_properties(A,myu)
u = A

In [None]:
import torch
from lib.Loader import Loader

class DatasetLoader(Loader):
    def __init__(self, X, Y, T, A):
        self.X = torch.tensor(X, dtype=torch.float32).view(-1).to(device)
        self.Y = torch.tensor(Y, dtype=torch.float32).view(-1).to(device)
        self.T = torch.tensor(T, dtype=torch.float32).view(-1).to(device)
        self.XYT = torch.stack((self.X, self.Y, self.T), dim=1)
        
        A_real = torch.tensor(A.real, dtype=torch.float32).view(-1).to(device)
        A_imag = torch.tensor(A.imag, dtype=torch.float32).view(-1).to(device)
        self.A = torch.stack((A_real, A_imag), dim=1).to(device)
        self._setmyu = False
        
    def set_myu(self, myu):
        self._setmyu = True
        self.myu = myu

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, slice):
        if self._setmyu: return self.XYT[slice], self.A[slice], self.myu[slice]
        return self.XYT[slice], self.A[slice]

In [None]:
class DatasetLoaderDD(DatasetLoader):
    def __getitem__(self, slice):
        result = super().__getitem__(slice)
        (xyt, *rest) = result
        dxyt = torch.tensor([[0,0,1],[1,0,0],[0,1,0]], dtype = torch.float32).unsqueeze(1).repeat(1, self.batch_size, 1).to(device)
        ddxyt = torch.tensor([[0,0,0],[0,0,0],[0,0,0]], dtype = torch.float32).unsqueeze(1).repeat(1, self.batch_size, 1).to(device)
        return ((xyt,dxyt,ddxyt), *rest)

In [None]:
x = np.linspace(-Lx, Lx, Nx).flatten()[:, None]/2
y = np.linspace(-Ly, Ly, Ny).flatten()[:, None]/2
t = np.linspace(0, Lt, Nt).flatten()[:, None]

X, T, Y = np.meshgrid(x, t, y)
dldd = DatasetLoaderDD(X,Y,T,A)

In [None]:
class PCNNDD(nn.Module, SaveLoad, BatchProcessing):
    def __init__(self, device = 'cpu'):
        super(PCNNDD, self).__init__()
        self.encoder = FCNDD(layers_list = [3,512]).to(device)
        self.decoder = FCNDD(layers_list = [512,2]).to(device)
        self.shuffler = MShuffleDD(exp_size = 9, n_depth = 4 ).to(device)
        self.optimizer = torch.optim.Adam(params = chain(
            self.encoder._Wtmx.parameters(),
            self.shuffler._Wtmx.parameters(),
            self.decoder._Wtmx.parameters(),
            ), lr=0.01)
        
    def forward(self, x):
        E = self.encoder.forward(x)
        S = self.shuffler.forward(E)
        return self.decoder.forward(S)

In [None]:
class PINNDD(PCNNDD):
    def __init__(self, *args, device = 'cpu', **kwargs):
        super(PINNDD, self).__init__(*args, device = device, **kwargs)
        self.device = device 
        mNt,mNx,mNy = myu_size
        self.mshape = (Nt, mNx, mNy)
        self.scale = Nx//mNx
        self.myureset()
    

    def myureset(self):
        myu = torch.abs(torch.randn(*self.mshape, dtype=torch.float32).to(self.device))
        myu = nn.Parameter(myu)
        self.myuparam = myu
        
    @property
    def myu(self):
        myu = F.interpolate(self.myuparam.unsqueeze(0), scale_factor=self.scale, mode='nearest').squeeze()
        return myu
    
    def save_myu(self,filename):
        myu = self.myuparam.cpu().detach().numpy()
        np.save(filename, myu)
        
    def load_myu(self,filename):
        myu = np.load(filename)
        myuparam = torch.tensor(myu, dtype=torch.float32).to(self.device)
        self.myuparam = nn.Parameter(myuparam)



## Training the model


In [None]:
def fmse_mse_batch_train(self, dataloader, lr=0.01, verbose=1, device="cpu"):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(self.parameters(), lr=lr)
    
    all_losses = []

    for inputs, outputs, myu in dataloader:
        optimizer.zero_grad()

        net_myu, real_myu = myu.T
        net_myu, real_myu = net_myu.view(-1,1), real_myu.view(-1,1)
        
        u, (u_t, y_x, u_y), (u_tt, u_xx, u_yy) = self(inputs)
        
        loss = criterion(u, outputs)


        pref =  u_t - u_xx - u_yy + torch.pow(torch.abs(u), 2).sum(dim=1, keepdim=True) * u 
        netf = pref - u * net_myu
        realf = pref - u * real_myu
        netfloss = torch.mean(netf**2)
        realfloss = torch.mean(realf**2)

        loss.backward()
        optimizer.step()
            

        all_losses.append((loss.item(),netfloss.item(),realfloss.item()))
    
    return all_losses

In [None]:
torch.manual_seed(1)
net = PINNDD(device = device)
net.device = device
dldd.set(epochs = 20000, batch_size = 4000, shuffle=True, verbose = 2, device = device)

In [None]:
M = torch.stack((
net.myu.view(-1),
torch.tensor(myu).view(-1).to(device)))

In [None]:
dldd.set_myu(M.T)

In [None]:
lr = 1e-3
L = fmse_mse_batch_train(net, dataloader = dldd, verbose = 1, device = device, lr = lr)


In [None]:
plt.plot(L, label = ("NMSE", "FMSE(Network)", "FMSE(Real)"))
plt.legend()
plt.yscale('log')
plt.xlabel('epochs')
plt.ylabel('Custom Loss')
plt.title(f'Training of the PCNN \n lr={lr}')
plt.savefig(f'{path}_{lr}.png')

## Get the myus

In [None]:
import numpy as np

def loader_predict(self, dataloader, complex = True):
    All = {
        'u': [],
        'u_t': [],
        'u_x': [],
        'u_y': [],
        'u_tt': [],
        'u_xx': [],
        'u_yy': [],
        'net_myu': [],
        'real_myu': []
    }

    for inputs, outputs, myu in dataloader:
        net_myu, real_myu = myu.T
        net_myu, real_myu = net_myu.view(-1, 1), real_myu.view(-1, 1)
        
        u, (u_t, u_x, u_y), (u_tt, u_xx, u_yy) = self(inputs)

        All['u'].append(u.cpu().detach().numpy())
        All['u_t'].append(u_t.cpu().detach().numpy())
        All['u_x'].append(u_x.cpu().detach().numpy())
        All['u_y'].append(u_y.cpu().detach().numpy())
        All['u_tt'].append(u_tt.cpu().detach().numpy())
        All['u_xx'].append(u_xx.cpu().detach().numpy())
        All['u_yy'].append(u_yy.cpu().detach().numpy())
        All['net_myu'].append(net_myu.cpu().detach().numpy())
        All['real_myu'].append(real_myu.cpu().detach().numpy())
    
    # Convert lists to numpy arrays
    All = {key: np.concatenate(value, axis=0) for key, value in All.items()}

    if complex:
        # Convert to complex numbers after converting to numpy arrays
        All['u'] = All['u'][:, 0] + All['u'][:, 1] * 1j
        All['u_t'] = All['u_t'][:, 0] + All['u_t'][:, 1] * 1j
        All['u_x'] = All['u_x'][:, 0] + All['u_x'][:, 1] * 1j
        All['u_y'] = All['u_y'][:, 0] + All['u_y'][:, 1] * 1j
        All['u_tt'] = All['u_tt'][:, 0] + All['u_tt'][:, 1] * 1j
        All['u_xx'] = All['u_xx'][:, 0] + All['u_xx'][:, 1] * 1j
        All['u_yy'] = All['u_yy'][:, 0] + All['u_yy'][:, 1] * 1j
    
    return (All['u'],(All['u_t'], All['u_x'], All['u_y']),(All['u_tt'], All['u_xx'], All['u_yy']),(All['net_myu'], All['real_myu']))


In [None]:
batch_size = 10000
dldd.set(epochs = len(dldd)//batch_size, batch_size=batch_size,shuffle=False, verbose=1)


In [None]:
u, (u_t, u_x, u_y), (u_tt, u_xx, u_yy), (net_myu, real_myu) = loader_predict(net,dldd)


In [None]:
A_laplase = np.fft.ifft2(np.fft.fft2(A) * S.q.get())
u_laplase = u_xx + u_yy

In [None]:
np.mean(np.abs((u_laplase - A_laplase.reshape(-1))/ np.abs(u_laplase)))

In [None]:
i=0
L = []
pref = u_t - u_xx - u_yy + (np.abs(u)**2) * u
pref = torch.tensor(pref)
U = torch.tensor(u)
M = net.myu.view(-1).to("cpu")

for lr in [10,1,0.1,0.01,0.001]:
    
    myuoptimizer = torch.optim.Adam( params = [net.myuparam], lr=lr ) 
    myuoptimizer.zero_grad()

    FL = []
    for _ in tqdm(range(40)):

        f =  pref - U * M
        FMSE = (torch.abs(f)**2).mean()
        FMSE.backward(retain_graph=True)
        FL.append(FMSE.cpu().detach().numpy())
        myuoptimizer.step()
        myuoptimizer.zero_grad()
         
    plt.plot(range(i,i+len(FL)), FL, label=f'lr={lr}')
    i+=len(FL)-1

plt.yscale('log')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('FMSE')
plt.title('MYU Training')
plt.tight_layout()
plt.savefig(f'{path}_myutraining.png')
plt.show()

In [None]:
if not os.path.exists(path + ".pt"):
    net.save_model(path + ".pt")


### 3rd Stage

In [None]:
if os.path.exists(path + ".pt"):
    net.load_model(path + ".pt")


In [None]:
def fmse_mse_batch_trainv2(self, dataloader, lr=0.01, verbose=1, device="cpu"):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(self.parameters(), lr=lr)
    
    all_losses = []

    for inputs, outputs, myu in dataloader:
        optimizer.zero_grad()

        net_myu, real_myu = myu.T
        net_myu, real_myu = net_myu.view(-1,1), real_myu.view(-1,1)
        
        u, (u_t, y_x, u_y), (u_tt, u_xx, u_yy) = self(inputs)
        
        loss = criterion(u, outputs)


        pref =  u_t - u_xx - u_yy + torch.pow(torch.abs(u), 2).sum(dim=1, keepdim=True) * u 
        netf = pref - u * net_myu
        realf = pref - u * real_myu
        netfloss = torch.mean(netf**2)
        realfloss = torch.mean(realf**2)
#        loss.backward()
        optimizer.step()
            

        all_losses.append((loss.item(),netfloss.item(),realfloss.item()))
    
    return all_losses

In [None]:
lr = 2e-5
dldd.set(epochs = 200, batch_size = 4000, shuffle=True, verbose = 2, device = device)
L = fmse_mse_batch_trainv2(net, dataloader = dldd, verbose = 1, device = device, lr = lr)


In [None]:
plt.plot(L, label = ("NMSE", "FMSE(Network)", "FMSE(Real)"))
plt.legend()
plt.yscale('log')
plt.xlabel('epochs')
plt.ylabel('Custom Loss')
plt.title(f'Training of the PCNN \n lr={lr}')
plt.savefig(f'{path}_L_{lr}.png')

## Visualizing and saving plot gifs

In [None]:
class PCNN(nn.Module, SaveLoad, BatchProcessing):
    def __init__(self, device = 'cpu'):
        super(PCNN, self).__init__()
        self.encoder = FCN(layers_list = [3,512]).to(device)
        self.decoder = FCN(layers_list = [512,2]).to(device)
        self.shuffler = MShuffle(exp_size = 9, n_depth = 4 ).to(device)
        self.optimizer = torch.optim.Adam(params = chain(
            self.encoder._Wtmx.parameters(),
            self.shuffler._Wtmx.parameters(),
            self.decoder._Wtmx.parameters(),
            ), lr=0.01)
        
    def forward(self, x):
        E = self.encoder.forward(x)
        S = self.shuffler.forward(E)
        return self.decoder.forward(S)

In [None]:
lightnet = PCNN().to(device)
sd = net.state_dict()
del sd['myuparam']
lightnet.load_state_dict(sd)

In [None]:
module_sq = lambda a: np.real(a)**2 + np.imag(a)**2
phase = lambda a: np.arcsin(np.real(a)/np.sqrt(np.real(a)**2+np.imag(a)**2))
real_imag = lambda a: np.real(a)*np.imag(a)
real = lambda a: np.real(a)
imag = lambda a: np.imag(a)

funlist_name = ["module_sq", "phase","real_imag","real","imag"]
funlist = [module_sq, phase, real_imag, real, imag]

myupred = net.myu.cpu().detach().numpy()
A = A
A_pred = lightnet.batch_predict(dldd.XYT).reshape(A.shape)

In [None]:
ATenzor = np.array([np.stack([
                    np.stack([fun(A) for fun in funlist]),
                    np.stack([fun(Ap) for fun in funlist])
                    ]) for A, Ap in zip(A, A_pred)])
MTenzor = np.array([np.array([m,mp]) for m, mp in zip(myu,myupred)])
AMTenzor = np.concatenate((MTenzor[:, :, np.newaxis, :, :], ATenzor), axis=2)

In [None]:
from lib.Video import create_video
create_video(AMTenzor, titles=[['myu'] + funlist_name] * 2, videotitle = f'{path}_after_FMSE.mp4')