In [1]:
import torch
from torch import Tensor                  
import torch.nn as nn                     

import matplotlib.pyplot as plt

import numpy as np
import time


torch.set_default_tensor_type(torch.DoubleTensor)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from tqdm import tqdm
print(device)
if device == 'cuda':
    print(torch.cuda.get_device_name()) 

cuda


In [2]:
# for complex roots
class NNCR(nn.Module):
    
    def __init__(self,layers):
        super().__init__() 
              
        self.activation = nn.Tanh()
        self.loss_function = nn.MSELoss(reduction ='mean')
        self.linears = nn.ModuleList([nn.Linear(layers[i], layers[i+1]) for i in range(len(layers)-1)]).to(device)
        
        for i in range(len(layers)-1):
            nn.init.xavier_normal_(self.linears[i].weight.data, gain=1.0)
            nn.init.zeros_(self.linears[i].bias.data)
            
    def forward(self,x):
            if torch.is_tensor(x) !=True:
                x= torch.from_numpy(x).to(device)
            sigma = x.type(torch.DoubleTensor).to(device)
            for i in range(len(layers)-2):
                z = self.linears[i](sigma)
                sigma = self.activation(z)
            sigma = self.linears[-1](sigma)
            return sigma

    
    #Modify the loss function as per the problem
    def loss_func(self, x_train1):                     
        g = x_train1.clone()
                        
        g.requires_grad = True
        
        u = self.forward(g)
        real,imag = com_eqn(u,g)        
        return (real**2).mean()+(imag**2).mean()
        
                                           
    def closure(self,steps,eps=1e-8,lr=1e-1,show=True):
            start = time.time()
            optimizer = torch.optim.Adam(self.parameters(),lr=lr)
            for i in tqdm(range(steps)):
                loss = self.loss_func(x_train1)
                self.mse = loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                #Learning rate scheduling. It performs better using this even for Adam.
                if i%(steps/4)==0:
                    lr=lr/5
                    optimizer = torch.optim.Adam(self.parameters(),lr=lr)
                    if show==True:
                        with torch.no_grad():
                            print('Iter: ',i,'Loss: ',loss.detach().cpu().numpy(),' lr: ',lr)
                if self.mse<=eps:
                    print('Converged !')
                    break
            print('MSE Loss: ',float(loss.detach().cpu().numpy()))
            print('total time: ',time.time()-start)

In [3]:
#Quadratic Equations
def com_eqn(u,x_train1):
    x = u[:,[0]]
    y = u[:,[1]]
    b = x_train1[:,[0]]
    c = x_train1[:,[1]]
    return (x**2-y**2+b*x+c),(2*x*y+b*y)

In [4]:
# Make input data
b = np.linspace(-5,5,500)
c = np.linspace(-5,5,500)
B,C = np.meshgrid(b,c)
x_train1 = torch.from_numpy(np.hstack(( B.flatten()[:,None],C.flatten()[:,None]))).to(device)
x_train1.shape

torch.Size([250000, 2])

In [5]:
steps=10000
layers = np.array([2,30,30,30,30,2])
Root_NN2 = NNCR(layers)
Root_NN2.to(device)
Root_NN2.closure(steps=steps,show=True)

  0%|                                         | 7/10000 [00:00<06:28, 25.70it/s]

Iter:  0 Loss:  12.375783004329111  lr:  0.02


 25%|█████████▌                            | 2509/10000 [00:43<02:09, 57.99it/s]

Iter:  2500 Loss:  0.006408895225553652  lr:  0.004


 50%|███████████████████                   | 5011/10000 [01:26<01:26, 57.83it/s]

Iter:  5000 Loss:  0.0017655809424622707  lr:  0.0008


 75%|████████████████████████████▌         | 7507/10000 [02:09<00:42, 58.12it/s]

Iter:  7500 Loss:  0.0007208401733932557  lr:  0.00016


100%|█████████████████████████████████████| 10000/10000 [02:51<00:00, 58.16it/s]

MSE Loss:  0.0005287277058215822
total time:  171.94635891914368





In [6]:
#torch.save(Root_NN2,'models/Complex_quad.pth')
test=np.array([[-4,5],[-2,4],[1,3],[4,2],[1.5,4.5]])[None,:]
print(Root_NN2(test))

tensor([[[ 1.9869e+00, -1.0042e+00],
         [ 1.0000e+00, -1.7314e+00],
         [-4.9917e-01, -1.6562e+00],
         [-5.8320e-01,  4.4085e-04],
         [-7.5040e-01, -1.9853e+00]]], device='cuda:0',
       grad_fn=<ViewBackward0>)
