# CNN to solve a multiband linear equation system

In [2]:
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# import torch.nn.functional as F
# from torch.utils.data import Dataset, DataLoader

# from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

## Samples Generator

In [90]:
def tensor2matrix(t):

    pass

def make_matrix(number_generator,
                args={},
                symmetric=True,
                off_diagonal_abs_mean=0.5,
                grid_size=(5, 5),
                as_tensor=True):

    #Generate values
    ni, nj = grid_size
    n_diag_1 = (ni-1)*nj
    n_diag_ni = ni*(nj-1)

    diag_1_u = [number_generator(**args) for _ in range(n_diag_1)]
    diag_ni_u = [number_generator(**args) for _ in range(n_diag_ni)]

    if symmetric:
        diag_1_l = diag_1_u
        diag_ni_l = diag_ni_u
    else:
        diag_1_l = [number_generator(**args) for _ in range(n_diag_1)]
        diag_ni_l = [number_generator(**args) for _ in range(n_diag_ni)]

    #To tensor
    diag_1_u = torch.tensor(diag_1_u).float()
    diag_1_l = torch.tensor(diag_1_l).float()
    diag_ni_u = torch.tensor(diag_ni_u).float()
    diag_ni_l = torch.tensor(diag_ni_l).float()

    #Scale off main diagonal
    off_diagonal = torch.cat([diag_ni_u, diag_1_u, diag_1_l, diag_ni_l])
    off_diagonal = torch.abs(off_diagonal)
    mean_abs = torch.mean(off_diagonal)
    alpha = off_diagonal_abs_mean / mean_abs
    diag_1_u = torch.mul(diag_1_u, alpha)
    diag_1_l = torch.mul(diag_1_l, alpha)
    diag_ni_u = torch.mul(diag_ni_u, alpha)
    diag_ni_l = torch.mul(diag_ni_l, alpha)

    print(f'ni_u = {diag_ni_u}')
    print(f'1_u = {diag_1_u}')
    print(f'1_l = {diag_1_l}')
    print(f'ni_l = {diag_ni_l}')

    #Adjust zeroes
    diag_1_u = torch.reshape(diag_1_u, (ni-1,nj))
    diag_1_u = torch.cat([torch.zeros(1,nj), diag_1_u], dim=0)
    diag_1_u = torch.flatten(diag_1_u.transpose(0,1))

    diag_1_l = torch.reshape(diag_1_l, (ni-1,nj))
    diag_1_l = torch.cat([diag_1_l, torch.zeros(1,nj)], dim=0)
    diag_1_l = torch.flatten(diag_1_l.transpose(0,1))

    diag_ni_l = torch.reshape(diag_ni_l, (ni,nj-1))
    diag_ni_l = torch.cat([diag_ni_l, torch.zeros(ni,1)], dim=1)
    diag_ni_l = torch.flatten(diag_ni_l.transpose(0,1))

    diag_ni_u = torch.reshape(diag_ni_u, (ni,nj-1))
    diag_ni_u = torch.cat([torch.zeros(ni,1), diag_ni_u], dim=1)
    diag_ni_u = torch.flatten(diag_ni_u.transpose(0,1))

    if as_tensor:
        t = torch.cat([diag_ni_u.reshape((nj,ni)).transpose(0,1).unsqueeze(0),
                       diag_1_u.reshape((nj,ni)).transpose(0,1).unsqueeze(0),
                       diag_1_l.reshape((nj,ni)).transpose(0,1).unsqueeze(0),
                       diag_ni_l.reshape((nj,ni)).transpose(0,1).unsqueeze(0)],
                       dim=0)
        return t

    n_diag = ni*nj
    diag = [1.]*n_diag
    diag = torch.tensor(diag).float()

    m = torch.zeros((n_diag,n_diag))
    m = m + torch.diag_embed(diag)
    m[1:,:-1] = m[1:,:-1] + torch.diag_embed(diag_1_l[:-1])
    m[:-1,1:] = m[:-1,1:] + torch.diag_embed(diag_1_u[1:])
    m[ni:,:-ni] = m[ni:,:-ni] + torch.diag_embed(diag_ni_l[:-ni])
    m[:-ni,ni:] = m[:-ni,ni:] + torch.diag_embed(diag_ni_u[ni:])

    return m

def one():
    return 1

m = make_matrix(
    number_generator=random.uniform,
    args={'a':-1,'b':3},
    # one,
    grid_size=(2,3),
    off_diagonal_abs_mean=2,
    as_tensor=True,
    symmetric=False)
m

ni_u = tensor([ 3.7700, -1.0961,  2.8185,  1.8514])
1_u = tensor([0.6850, 2.9814, 3.5380])
1_l = tensor([ 0.2580, -1.0260,  3.3171])
ni_l = tensor([-0.9402,  1.9439, -0.2893,  3.4852])
tensor([[[ 0.0000,  3.7700, -1.0961],
         [ 0.0000,  2.8185,  1.8514]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.6850,  2.9814,  3.5380]],

        [[ 0.2580, -1.0260,  3.3171],
         [ 0.0000,  0.0000,  0.0000]],

        [[-0.9402,  1.9439,  0.0000],
         [-0.2893,  3.4852,  0.0000]]])


tensor([[ 1.0000,  0.6850,  3.7700,  0.0000,  0.0000,  0.0000],
        [ 0.2580,  1.0000,  0.0000,  2.8185,  0.0000,  0.0000],
        [-0.9402,  0.0000,  1.0000,  2.9814, -1.0961,  0.0000],
        [ 0.0000, -0.2893, -1.0260,  1.0000,  0.0000,  1.8514],
        [ 0.0000,  0.0000,  1.9439,  0.0000,  1.0000,  3.5380],
        [ 0.0000,  0.0000,  0.0000,  3.4852,  3.3171,  1.0000]])