# CNN solver

In [240]:
import torch
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from torch import nn
import torch.nn.functional as F
from time import time
import contextlib

@contextlib.contextmanager
def timer(msg='timer'):
    tic = time()
    yield
    return print(f"{msg}: {time() - tic:.2f}")

# Neural net solver class

In [326]:
class NeuralNetSolver(nn.Module):
    def __init__(self, T_ini, K, s):
        super(NeuralNetSolver,self).__init__()
        self.nx = T_ini.shape[2]
        self.ny = T_ini.shape[3]
        self.T = nn.Parameter(T_ini[:,:,1:-1,1:-1])
        self.K = K
        self.s = s
        dx = 1./(self.nx-1)
        dy = 1./(self.ny-1)
        self.grad_x = torch.tensor([-1./dx, 1./dx]).reshape((1,1,2,1))
        self.grad_y = torch.tensor([ 1./dy,-1./dy]).reshape((1,1,1,2))
        
        self.avg_x = torch.tensor([.5,.5]).reshape((1,1,2,1))
        self.avg_y = torch.tensor([.5,.5]).reshape((1,1,1,2))
        
        self.BC_left = torch.zeros((1,1,self.nx-2,1))
        self.BC_right = torch.zeros((1,1,self.nx-2,1))
        
        self.BC_top = torch.zeros((1,1,1,self.ny))
        self.BC_bot = torch.zeros((1,1,1,self.ny))
        
    def forward(self):
        T = self.T
        K = self.K
        s = self.s
        
        # Add boundary conditions as padding
        T = self.pad_T()
        
#         print(self.grad_x)
        dT_dx = F.conv2d(T, self.grad_x)
        dT_dy = F.conv2d(T, self.grad_y)
        
        
        
        K_avg_x = F.conv2d(K, self.avg_x)
        K_avg_y = F.conv2d(K, self.avg_y)
        
        
        
        K_d2T_dx2 = F.conv2d(K_avg_x*dT_dx, self.grad_x)
        K_d2T_dy2 = F.conv2d(K_avg_y*dT_dy, self.grad_y)
        
#         print(K_d2T_dx2.shape)
#         print(K_d2T_dy2.shape)
        
        return K_d2T_dx2[:,:,:,1:-1] + K_d2T_dy2[:,:,1:-1,:] + s[:,:,1:-1,1:-1]
    def pad_T(self):
        T = self.T
        T = torch.cat([self.BC_left, T, self.BC_right],dim=3)
        T = torch.cat([self.BC_bot, T, self.BC_top],dim=2)
        return T

In [327]:
def loss_fn(y_hat):
    # note: expected outcome is 0
    return torch.mean((y_hat)**2)

# Training function

In [332]:
def train(net,loss_fn, optimizer, abs_loss_limit, rel_loss_limit):
    
    def closure():
        out = net()
        loss = loss_fn(out)
        optimizer.zero_grad()
        loss.backward()
        return loss

      
    last_loss = 0.0
    with timer("solve"):
        print(f"epoch  |  absolute loss  |  relative loss")
        print( "---------------------------------------------")
        for i in range(500):
            loss = optimizer.step(closure)
            if (i%20)==0:
                with torch.no_grad():
                    loss = loss.item()
                    print(f" {i:04d}       {loss:.2e}          {abs(loss-last_loss):.2e}")
        #         
        #             plt.title(f"epoch = {i:.0f}, loss = {loss:.2e}")
        #             plt.imshow(net.T.reshape(net.T.shape[2:]))
        #             display(fig)
        #             clear_output(wait = True)
                if loss<abs_loss_limit:
                    print(f"Stop! absolute loss target reached ({abs_loss_limit:.2e})")
                    break
                elif abs(loss-last_loss)<rel_loss_limit:
                    print(f"Stop! relative loss target reached ({rel_loss_limit:.2e})")
                    break

                last_loss = loss
    return net.pad_T()

# Simple solver

In [343]:
nx = 200
ny = 200
T_ini = torch.zeros((1,1,nx,ny))
K     = torch.ones((1,1,nx,ny))
s     = torch.ones((1,1,nx,ny))
net = NeuralNetSolver(T_ini, K,s)

optimizer = torch.optim.LBFGS(net.parameters())
abs_loss = 1e-10
rel_loss = 1e-100

T_padded = train(net, loss_fn, optimizer, abs_loss, rel_loss)
fig = plt.figure(figsize=[8,8])                
with torch.no_grad():
    plt.title(f"epoch = {i:.0f}, loss = {loss:.2e}")
    plt.imshow(T_padded[0,0,:,:])
    plt.colorbar()
    display(fig)
# clear_output(wait = True)

epoch  |  absolute loss  |  relative loss
---------------------------------------------
 0000       1.00e+00          1.00e+00
 0020       8.78e-01          1.22e-01
 0040       8.20e-01          5.79e-02
 0060       7.79e-01          4.13e-02
 0080       7.34e-01          4.46e-02
 0100       7.16e-01          1.83e-02
 0120       6.73e-01          4.30e-02
 0140       6.29e-01          4.38e-02
 0160       6.07e-01          2.18e-02
 0180       5.97e-01          1.07e-02
 0200       5.86e-01          1.07e-02
 0220       5.80e-01          6.58e-03
 0240       5.66e-01          1.40e-02
 0260       5.27e-01          3.82e-02
 0280       5.15e-01          1.28e-02
 0300       4.45e-01          6.99e-02
 0320       2.05e-01          2.39e-01


KeyboardInterrupt: 

# Multigrid

In [346]:
nx, ny = 200, 200
n_level = 4 # each level reduces resolution by a factor two
reduction_fac = [2**(n_level-1-level) for level in range(n_level)]
print(reduction_fac)
# Material properties should be defined on the finest grid
K_ref     = torch.ones((1,1,nx,ny))
s_ref     = torch.ones((1,1,nx,ny))
print(nx, ny)

n_level = 3
for level, red_fac in enumerate(reduction_fac):
    
    with torch.no_grad():
        K = torch.nn.functional.avg_pool2d(K_ref, red_fac, stride=red_fac)
        s = torch.nn.functional.avg_pool2d(s_ref, red_fac, stride=red_fac)

    

    
    
    if level==0:
        T_ini = torch.zeros((1,1,int(nx/red_fac),int(ny/red_fac)))
    else:
        
        T_ini = F.interpolate(T_previous,scale_factor=2,mode='bilinear',align_corners=True)
    print(T_ini.shape)
                
            
        

    net = NeuralNetSolver(T_ini, K,s)

    optimizer = torch.optim.LBFGS(net.parameters())
    abs_loss = 1e-6
    rel_loss = 1e-7

    
    T_previous = train(net, loss_fn, optimizer, abs_loss, rel_loss)

[8, 4, 2, 1]
200 200
torch.Size([1, 1, 25, 25])
epoch  |  absolute loss  |  relative loss
---------------------------------------------
 0000       1.00e+00          1.00e+00
 0020       7.64e-10          1.00e+00
Stop! absolute loss target reached (1.00e-06)
solve: 0.41
torch.Size([1, 1, 50, 50])
epoch  |  absolute loss  |  relative loss
---------------------------------------------
 0000       3.25e-01          3.25e-01
 0020       9.65e-06          3.25e-01
 0040       9.65e-06          4.88e-09
Stop! relative loss target reached (1.00e-07)
solve: 1.37
torch.Size([1, 1, 100, 100])
epoch  |  absolute loss  |  relative loss
---------------------------------------------
 0000       3.28e-01          3.28e-01
 0020       1.74e-05          3.28e-01
 0040       1.58e-05          1.59e-06
 0060       1.57e-05          1.14e-07
 0080       1.56e-05          4.33e-08
Stop! relative loss target reached (1.00e-07)
solve: 5.16
torch.Size([1, 1, 200, 200])
epoch  |  absolute loss  |  relative lo