## Overview of NN implementation details for 2D problem 

- No changes to training strategy as established in "OnlineImplementation.ipynb" 


- Have found both time to solution and  n$th$ iteration error to be good metrics for adding data to training set


- Can use fewer number of Epochs in training set up


- Altered pytorch strutures to deal with 2D inputs. (Can also use 1D data to work with 2D problems, see final point)


- Can continue to ensure "spread" of data by checking orthogonality of flattended 2D solutions


- Identified 2 neural network architetectures that work exceptionally well 

#### Modification to Data Filter for Adding Data to Training Set 

- Originally, the only metric that was used to add data to the training set was the time-to-solution. This worked well since the time-to-solution encodes information regarding the rate of convergence that the error of a single iteration does not. However, on a distributed system, run time for the same problem can vary depending on system loads. To fix this issue, we add an additional filter that depends on the error of an iterate.  Now, we essentially only add data to the training set when the time-to-solution is above average **AND** the $n$th iterate error is above average. This additional situation prevents the addition of data to the training set that is converging faster than average, but the time-to-solution decreased below average due to system loads. **The key change is bolded in the code below:**



            Initial_set=5
            IterTime_AVG=0.0
            IterErr10_AVG=0.0
            
            # Check if we are in first GMRES e1 tolerance run. If so, we compute prediction, and check the prediction is "good" before moving forward. 
            if func.predictor.is_trained and refine==False:
                pred_x0 = func.predictor.predict(b_flat)
                target_test=GMRES(A, b, x0, e, 2,1, False)
                IterErr_test = resid(A, target_test, b)
                print('size',len(IterErr_test))
                print(IterErr_test[-1],max(Err_list))
                if (IterErr_test[-1]>max(Err_list)): 
                    print('poor prediction,using initial x0')
                    pred_x0 = x0
            else:
                pred_x0 = x0


            #Time GMRES function 
            tic = time.perf_counter()
            target  = func(A, b,b_flat, pred_x0, e, nmax_iter,ML_GMRES_Time_list,ProbCount,restart,debug,refine,blist,reslist,Err_list,ML_GMRES_Time_list2, *eargs)
            toc = time.perf_counter()

            res = target[-1]
            res_flat=np.reshape(res.T,(1,-1),order='C').squeeze(0)


            # Check if we are in first e tolerance loop
            if refine==False :
                IterErr = resid(A, target, b)
                ML_GMRES_Time_list.append((toc-tic))
                Err_list.append(IterErr[2])  
                if ProbCount<=Initial_set:
                    func.predictor.add_init(b_flat, res_flat)
                if ProbCount==Initial_set:
                    timeLoop=func.predictor.retrain_timed()
                    print('Initial Training')
            else :
                ML_GMRES_Time_list2.append((toc-tic))


            # Compute moving averages used to filter data
            if ProbCount>Initial_set:
                IterTime_AVG=moving_average(np.asarray(ML_GMRES_Time_list),ProbCount)
                IterErr10_AVG=moving_average(np.asarray(Err_list),ProbCount)
                print(ML_GMRES_Time_list[-1],IterTime_AVG,Err_list[-1],IterErr10_AVG)


            # Filter for data to be added to training set
######            if (ML_GMRES_Time_list[-1]>IterTime_AVG and Err_list[-1]>IterErr10_AVG ) and  refine==True and ProbCount>Initial_set : 
                

                blist.append(b_flat)
                reslist.append(res_flat)
                
                # check orthogonality of 3 solutions that met training set critera
                if   len(blist)==3 :
                    resMat=np.asarray(reslist)
                    resMat_square=resMat**2
                    row_sums = resMat_square.sum(axis=1,keepdims=True)
                    resMat= resMat/np.sqrt(row_sums)
                    InnerProd=np.dot(resMat,resMat.T)
                    print('InnerProd',InnerProd)
                    func.predictor.add(np.asarray(blist[0]), np.asarray(reslist[0]))
                    cutoff=0.8
                    
                    # Picking out sufficiently orthogonal subset of 3 solutions gathered
                    if np.abs(InnerProd[0,1]) and np.abs(InnerProd[0,2])<cutoff :
                        if np.abs(InnerProd[1,2])<cutoff :
                            func.predictor.add(np.asarray(blist[1]), np.asarray(reslist[1]))
                            func.predictor.add(np.asarray(blist[2]), np.asarray(reslist[2]))
                        elif np.abs(InnerProd[1,2])>=cutoff: 
                            func.predictor.add(np.asarray(blist[1]), np.asarray(reslist[1]))
                    elif np.abs(InnerProd[0,1])<cutoff :
                        func.predictor.add(np.asarray(blist[1]), np.asarray(reslist[1]))
                    elif np.abs(InnerProd[0,2])<cutoff :
                        func.predictor.add(np.asarray(blist[2]), np.asarray(reslist[2]))
                    
                    if func.predictor.counter>=retrain_freq:
                        if func.debug:
                            print("retraining")
                            print(func.predictor.counter)
                            timeLoop=func.predictor.retrain_timed()
                            trainTime=float(timeLoop[-1])
                            blist=[]
                            reslist=[]
            return target,ML_GMRES_Time_list,trainTime,blist,reslist,Err_list,ML_GMRES_Time_list2

#### Adjustments to Training Loop Pytorch Tensors 

- The main change to the training loop is that the number of Epochs has been reduced from 2000 to 500 with excellent results. Furthermore, the pytorch tensors are now structured with the following shapes (N,n_x,n_y) where N is the batch size, $n_x$ is the number of x-grid points, and $n_y$ is the number of y grid points. These changes have been consistently implemented in the relevant code so that 2D data can be handled correctly. However we note here that the data can also be treated using a fully flattened representation throughout. We do this as well with one of the neural network architectures we have developed. 

 def retrain_timed(self):

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.xNew = self.xNew.to(device)
        self.yNew = self.yNew.to(device)

        self.loss_val = list()  # clear loss val history
        self.loss_val.append(10.0)

        batch_size=32
#####        numEpochs=500
        e1=1e-3
        epoch=0
        
        while self.loss_val[-1]> e1 and epoch<numEpochs:
            permutation = torch.randperm(self.x.size()[0])
            for t in range(0,self.x.size()[0],batch_size):
                
                indices = permutation[t:t+batch_size]

                batch_x, batch_y = self.x[indices],self.y[indices]

                # Adding new data to each batch
                # Note: only adding at most 3 data points to each batch
                batch_xMix=torch.cat((batch_x,self.xNew)) 
                batch_yMix=torch.cat((batch_y,self.yNew))

                # Forward pass: Compute predicted y by passing x to the model
                y_pred = self.model(batch_xMix)

                # Compute and print loss
                loss = self.criterion(y_pred, batch_yMix)
                self.loss_val.append(loss.item())

                # Zero gradients, perform a backward pass, and update the weights.
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                epoch=epoch+1
                
        print('Final loss:',loss.item())
        self.loss_val.append(loss.item())

        self.x=torch.cat((self.x,self.xNew))
        self.y=torch.cat((self.y,self.yNew))
#####        self.xNew = torch.empty(0, self.D_in,self.D_in)
#####        self.yNew = torch.empty(0, self.D_out,self.D_out)

        numparams=sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print('parameters',numparams)

        self.is_trained = True

#### Ensuring the "Spread" of 2D Data 

Good performance in our methodology has been attained by flattening solution data and taking the inner product of the data as a means to ensure  that data is sufficiently "different". This was originally implemented for 1D data, and has worked well (if not better) for flattended 2D data. Note that the code changes very little since we essentially only compare vectors of data :


                blist.append(b)
                reslist.append(res)
                reslist_flat.append(np.reshape(res,(1,-1),order='C').squeeze(0))   
                
                # check orthogonality of 3 solutions that met training set critera
                if   len(blist)==3 :
                    resMat=np.asarray(reslist_flat)
                    resMat_square=resMat**2
                    row_sums = resMat_square.sum(axis=1,keepdims=True)
                    resMat= resMat/np.sqrt(row_sums)
                    InnerProd=np.dot(resMat,resMat.T)
                    print('InnerProd',InnerProd)
                    func.predictor.add(np.asarray(blist)[0], np.asarray(reslist)[0])
                    cutoff=0.8
                    
                    # Picking out sufficiently orthogonal subset of 3 solutions gathered
                    if np.abs(InnerProd[0,1]) and np.abs(InnerProd[0,2])<cutoff :
                        if np.abs(InnerProd[1,2])<cutoff :
                            func.predictor.add(np.asarray(blist)[1], np.asarray(reslist)[1])
                            func.predictor.add(np.asarray(blist)[2], np.asarray(reslist)[2])
                        elif np.abs(InnerProd[1,2])>=cutoff: 
                            func.predictor.add(np.asarray(blist)[1], np.asarray(reslist)[1])
                    elif np.abs(InnerProd[0,1])<cutoff :
                        func.predictor.add(np.asarray(blist)[1], np.asarray(reslist)[1])
                    elif np.abs(InnerProd[0,2])<cutoff :
                        func.predictor.add(np.asarray(blist)[2], np.asarray(reslist)[2])
                    
                    if func.predictor.counter>=retrain_freq:
                        if func.debug:
                            print("retraining")
                            print(func.predictor.counter)
                            timeLoop=func.predictor.retrain_timed()
                            trainTime=float(timeLoop[-1])
                            blist=[]
                            reslist=[]
                            reslist_flat=[]

#### Neural Network Architectures 

-Two simple neural network architectures have been found to be optimal in our experiments.  Note that optimal here means both a network that does not take too long to train (i.e not too deep or wide) and that quickly provides a speed-up to GMRES

##### 2D CNN 
- Two  single channel 2D convolutional layers(varying kernel sizes) + 1 fully connected linear output layer
    - The clear inspiration for this is the fact that the solution to the Poisson Equation can be expressed as a convolution of the RHS with the corresponding greens function. For a general linear operator $L$ we have
    
    $$Lu=f \implies u=G*f$$

##### 1D CNN 

- $n^2$ channels 1D convolution(kernel size=n) + 1 Fully connected output layer
    - The inspiration for this network can be understood when considering the underlying linear algebra problem for the discrete laplacian. In particular, we can think of the first layer as something not too different from the action of the matrix inverse on the RHS since we can express matrix multiplication of an $n\times n$ with a vector of length $n$ as a 1D convolution with $n^2$ channels where the kernel for every channel is of length $n$.
    
    $$Ax=b \implies x=A^{-1}b$$