In [1]:
from utils import device, generate_A_H_sol, decompose_matrix
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

Code run on : cpu


In [2]:
# # ! Tout ca devrait être des variables
total_itr = 25  # Total number of iterations (multiple of "itr")
n = 300  # Number of rows # ? suppose to be a variable ?
m = 600  # Number of columns # ? suppose to be a variable ?
bs = 10000  # Mini-batch size (samples)
num_batch = 500  # Number of mini-batches
lr_adam = 0.002  # Learning rate of optimizer
init_val_SORNet = 1.1  # Initial value of omega for SORNet
init_val_SOR_CHEBY_Net_omega = 0.6  # Initial value of omega for SOR_CHEBY_Net
init_val_SOR_CHEBY_Net_gamma = 0.8  # Initial value of gamma for SOR_CHEBY_Net
init_val_SOR_CHEBY_Net_alpha = 0.9  # Initial value of alpha for SOR_CHEBY_Net
init_val_AORNet_r = 0.9  # Initial value of r for AORNet
init_val_AORNet_omega = 1.5  # Initial value of omega for AORNet
init_val_RINet = 0.1  # Initial value of omega for RINet

# Generate A and H
seed = 12

A, H, W, solution, y = generate_A_H_sol(n=n, m=m, seed=seed, bs=bs)
A, D, L, U, Dinv, Minv = decompose_matrix(A)

Condition number, min. and max. eigenvalues of A:
30.828752475024576 5.678370598467175 0.18419073567986302


In [6]:
class SORNet(nn.Module):
    """Deep unfolded SOR with a constant step size."""

    def __init__(self, init_val_SORNet, A, H, bs, y, device=device):
        """
        Initialize the SORNet model.

        Args:
            num_itr (int): Number of iterations.
            init_val_SORNet (float): Initial value for inv_omega.
            D (torch.Tensor): Diagonal matrix D.
            L (torch.Tensor): Lower triangular matrix L.
            U (torch.Tensor): Upper triangular matrix U.
            H (torch.Tensor): Matrix H.
            bs (int): Batch Size
            y (toch.Tensor): Solution
            device (str): Device to run the model on ('cpu' or 'cuda').

        """
        super(SORNet, self).__init__()
        self.device = device
        self.inv_omega = nn.Parameter(torch.tensor(init_val_SORNet, device=device))
        
        A, D, L, U, _, _ = decompose_matrix(A)
        
        self.A = A.to(device)
        self.D = D.to(device)
        self.L = L.to(device)
        self.U = U.to(device)
        self.H = H.to(device)
        self.Dinv = torch.linalg.inv(D).to(device)
        self.bs = bs
        self.y = y.to(device)

    def forward(self, num_itr):
        """
        Perform forward pass of the SORNet model.

        Args:
            bs (int): Batch size.
            y (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
            list: List of intermediate results.

        """
        traj = []

        invM = torch.linalg.inv(self.inv_omega * self.D + self.L)
        s = torch.zeros(self.bs, self.H.size(1), device=self.device)
        traj.append(s)
        yMF = torch.matmul(self.y, self.H.T)
        s = torch.matmul(yMF, self.Dinv)

        for _ in range(num_itr):
            temp = torch.matmul(s, (self.inv_omega - 1) * self.D - self.U) + yMF
            s = torch.matmul(temp, invM)
            traj.append(s)

        return s, traj

# ========================================================================================================

class SOR_CHEBY_Net(nn.Module):
    """Deep unfolded SOR with Chebyshev acceleration."""

    def __init__(self, num_itr, init_val_SOR_CHEBY_Net_omega, init_val_SOR_CHEBY_Net_gamma, init_val_SOR_CHEBY_Net_alpha, A, H, bs, y, device=device):
        """
        Initialize the SOR_CHEBY_Net model.

        Args:
            num_itr (int): Number of iterations.
            init_val_SOR_CHEBY_Net_omega (float): Initial value for omega.
            init_val_SOR_CHEBY_Net_gamma (float): Initial value for gamma.
            init_val_SOR_CHEBY_Net_alpha (float): Initial value for inv_omega.
            D (torch.Tensor): Diagonal matrix D.
            L (torch.Tensor): Lower triangular matrix L.
            U (torch.Tensor): Upper triangular matrix U.
            H (torch.Tensor): Matrix H.
            bs (int): Batch Size
            y (torch.Tensor): Solution of the linear equation
            device (str): Device to run the model on ('cpu' or 'cuda').

        """
        super(SOR_CHEBY_Net, self).__init__()
        self.device = device
        self.gamma = nn.Parameter(init_val_SOR_CHEBY_Net_gamma * torch.ones(num_itr, device=device))
        self.omega = nn.Parameter(init_val_SOR_CHEBY_Net_omega * torch.ones(num_itr, device=device))
        self.inv_omega = nn.Parameter(torch.tensor(init_val_SOR_CHEBY_Net_alpha, device=device))
        
        A, D, L, U, _, _ = decompose_matrix(A)
        self.A = A
        self.D = D.to(device)
        self.L = L.to(device)
        self.U = U.to(device)
        self.H = H.to(device)
        self.Dinv = torch.linalg.inv(D).to(device)
        self.bs = bs
        self.y = y.to(device)

    def forward(self, num_itr):
        """
        Perform forward pass of the SOR_CHEBY_Net model.

        Args:
            bs (int): Batch size.
            y (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
            list: List of intermediate results.

        """
        traj = []

        invM = torch.linalg.inv(self.inv_omega * self.D + self.L)
        s = torch.zeros(self.bs, self.H.size(1), device=self.device)
        s_new = torch.zeros(self.bs, self.H.size(1), device=self.device)
        traj.append(s)
        yMF = torch.matmul(self.y, self.H.T)
        s = torch.matmul(yMF, self.Dinv)
        s_present = s
        s_old = torch.zeros_like(s_present)

        for i in range(num_itr):
            temp = torch.matmul(s, (self.inv_omega - 1) * self.D - self.U) + yMF
            s = torch.matmul(temp, invM)

            s_new = self.omega[i] * (self.gamma[i] * (s - s_present) + (s_present - s_old)) + s_old
            s_old = s
            s_present = s_new
            traj.append(s_new)

        return s_new, traj

# =====================================================================================

class AORNet(nn.Module):
    """Deep unfolded AOR with a constant step size."""

    def __init__(self, init_val_AORNet_r, init_val_AORNet_omega, A, H, bs, y, device=device):
        """
        Initialize the AORNet model.

        Args:
            init_val_AORNet_r (float): Initial value for r.
            init_val_AORNet_omega (float): Initial value for omega.
            D (torch.Tensor): Diagonal matrix D.
            L (torch.Tensor): Lower triangular matrix L.
            U (torch.Tensor): Upper triangular matrix U.
            H (torch.Tensor): Matrix H.
            device (str): Device to run the model on ('cpu' or 'cuda').
        """
        super(AORNet, self).__init__()
        self.device = device
        self.r = nn.Parameter(torch.tensor(init_val_AORNet_r, device=device))
        self.omega = nn.Parameter(torch.tensor(init_val_AORNet_omega, device=device))
        
        A, D, L, U, _, _ = decompose_matrix(A)
        self.A = A.to(device)
        self.D = D.to(device)
        self.L = L.to(device)
        self.U = U.to(device)
        self.H = H.to(device)
        self.Dinv = torch.linalg.inv(D).to(device)
        self.bs = bs
        self.y = y

    def forward(self, num_itr):
        """
        Perform forward pass of the AORNet model.

        Args:
            num_itr (int): Number of iterations.
            bs (int): Batch size.
            y (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
            list: List of intermediate results.
        """
        traj = []

        invM = torch.linalg.inv(self.L - self.r * self.D)
        N = (1 - self.omega) * self.D + (self.omega - self.r) * self.L + self.omega * self.U
        s = torch.zeros(self.bs, self.H.size(1), device=self.device)
        traj.append(s)
        yMF = torch.matmul(self.y, self.H.T)
        s = torch.matmul(yMF, self.Dinv)

        for _ in range(num_itr):
            s = torch.matmul(s, torch.matmul(invM, N)) + torch.matmul(yMF, invM)
            traj.append(s)

        return s, traj

# =====================================================================================

class RINet(nn.Module):
    """Deep unfolded Richardson iteration."""

    def __init__(self, init_val_RINet, A, H, bs, y, device=device):
        """
        Initialize the RINet model.

        Args:
            num_itr (int): Number of iterations.

        """
        super(RINet, self).__init__()
        self.inv_omega = nn.Parameter(torch.tensor(init_val_RINet, device=device))
        
        A, D, L, U, _, _ = decompose_matrix(A)
        self.A = A.to(device)
        self.D = D.to(device)
        self.L = L.to(device)
        self.U = U.to(device)
        self.H = H.to(device)
        self.Dinv = torch.linalg.inv(D).to(device)
        self.bs = bs
        self.y = y

    def forward(self, num_itr):
        """
        Perform forward pass of the RINet model.

        Args:
            num_itr (int): Number of iterations.
            bs (int): Batch size.
            y (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
            list: List of intermediate results.

        """
        traj = []

        s = torch.zeros(self.bs, A.shape[0]).to(device)
        traj.append(s)
        yMF = torch.matmul(self.y, self.H.T)
        s = torch.matmul(yMF, self.Dinv)

        for _ in range(num_itr):
            s = s + torch.mul(self.inv_omega[0], (yMF - torch.matmul(s, self.A)))
            traj.append(s)

        return s, traj

In [7]:
loss_func = nn.MSELoss()

# Models
model_SorNet = SORNet(init_val_SORNet, A, H, bs, y, device=device)
model_Sor_Cheby_Net = SOR_CHEBY_Net(total_itr, init_val_SOR_CHEBY_Net_omega, init_val_SOR_CHEBY_Net_gamma, init_val_SOR_CHEBY_Net_alpha, A, H, bs, y, device=device)
model_AorNet = AORNet(init_val_AORNet_r, init_val_AORNet_omega, A, H, bs, y, device=device)
model_RINet = RINet( init_val_RINet, A, H, bs, y, device=device)

# Optimizers
opt_SORNet = optim.Adam(model_SorNet.parameters(), lr=lr_adam)
opt_SORNet_Cheby = optim.Adam(model_Sor_Cheby_Net.parameters(), lr=lr_adam)
opt_AORNet = optim.Adam(model_AorNet.parameters(), lr=lr_adam)
opt_RINet = optim.Adam(model_RINet.parameters(), lr=lr_adam)

In [8]:
loss_gen=[]
for gen in range(total_itr):
    """
    Training process of SORNet.

    Args:
        gen (int): Generation number.

    """
    for i in range(num_batch):
        opt_SORNet.zero_grad()
        solution = torch.normal(0.0 * torch.ones(bs, n), 1.0).to(device)
        y = solution @ H
        x_hat, _ = model_SorNet(gen + 1)
        loss = loss_func(x_hat, solution)
        loss.backward()
        opt_SORNet.step()
        
        if i % 200 == 0:
            print("generation:", gen + 1, " batch:", i, "\t MSE loss:", loss.item())

    loss_gen.append(loss.item())
## training process of SOR_CHEBY_Net
# it takes about several minutes on Google Colaboratory

generation: 1  batch: 0 	 MSE loss: 1.791316032409668
generation: 1  batch: 200 	 MSE loss: 1.7642379999160767
generation: 1  batch: 400 	 MSE loss: 1.762949824333191
generation: 2  batch: 0 	 MSE loss: 1.7931795120239258
generation: 2  batch: 200 	 MSE loss: 1.7605116367340088
generation: 2  batch: 400 	 MSE loss: 1.7603600025177002
generation: 3  batch: 0 	 MSE loss: 1.7796863317489624
generation: 3  batch: 200 	 MSE loss: 1.7681561708450317
generation: 3  batch: 400 	 MSE loss: 1.7641069889068604
generation: 4  batch: 0 	 MSE loss: 1.777740478515625
generation: 4  batch: 200 	 MSE loss: 1.7655141353607178


KeyboardInterrupt: 