# Neural Network for sparse basis transformation B

## Definitions 

In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib as mpl
import warnings
from datetime import datetime
import torchvision
import torchvision.transforms as transforms
from PIL import Image, ImageSequence
time = datetime.now()
!mkdir plots

warnings.simplefilter(action='ignore', category=FutureWarning)
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
torch.manual_seed(12345)
np.random.seed(12345)

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
    print('Not connected to a GPU')
else:
    print(gpu_info)

# ----------------------------------------------------------------------------- 
# -------------------------- Fully Connected Network --------------------------
# ----------------------------------------------------------------------------- 
class FCN(nn.Module): 
    def __init__(self, N_INPUT, N_OUTPUT, N_HIDDEN, N_LAYERS):
        super().__init__()
        activation = nn.Tanh #Specify the used activation function
        self.fc1 = nn.Sequential(*[nn.Linear(N_INPUT, N_HIDDEN), activation()]) #Input to first hidden layer
        self.fc2 = nn.Sequential(*[nn.Sequential(*[nn.Linear(N_HIDDEN, N_HIDDEN), activation()]) for _ in range(N_LAYERS-1)]) #Going through the remaining hidden layers
        self.fc3 = nn.Linear(N_HIDDEN, N_OUTPUT) #Last hidden layer to output layer

    def forward(self, *args):
        if len(args) == 1: #When multiple initial conditions are specified, this will provide the correct shape. 
            x = args[0]
        elif len(np.shape(args[0])) <= 1:
            x = torch.FloatTensor([*args]).T
        else:
            x = torch.FloatTensor(torch.cat([*args], 1))

        x = self.fc1(x) #Going through the layers
        x = self.fc2(x)
        x = self.fc3(x)
        return x

# ----------------------------------------------------------------------------- 
# -------------------------- Creating GIF animations -------------------------- 
# ----------------------------------------------------------------------------- 
def save_gif(outfile, files, fps=5, loop=0):
    imgs = [Image.open(file) for file in files]
    imgs[0].save(fp=outfile, format='GIF', append_images=imgs[1:], save_all=True, duration=int(1000/fps), loop=loop)

def split(array, nrows, ncols):
    r, h = array.shape
    return (array.reshape(h//nrows, nrows, -1, ncols)
                 .swapaxes(1, 2)
                 .reshape(-1, nrows, ncols))

# Option 1: Load uploaded image

In [None]:
images = []
img_files = ['amyloid64.png','checker64.png','zehner64.png','astrocyte64.png']
for img_file in img_files:
    im = Image.open(img_file)
    im = im.convert('L')
    A = []
    for image in ImageSequence.Iterator(im):
        A.append(np.array(image).flatten())
    A = np.row_stack(A)
    input = (A-np.min(A))/(np.max(A)-np.min(A))
    input = input[0]
    pixelsize = int(np.sqrt(len(input)))
    fig = plt.figure(figsize=(20,15))
    ax = fig.add_subplot(151)
    ax.title.set_text('$x$')
    ax.imshow(input.reshape(pixelsize,pixelsize), interpolation='none', cmap=cm.Greys_r, vmin=0, vmax=1)
    images.append(input)

# Option 2: Load CIFAR10 images

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 300
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# functions to show an image
def imshow(img):
    img = img/2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

imgs = torchvision.transforms.functional.rgb_to_grayscale(images)
A = []
for image in imgs:
    A.append(np.array(image).flatten())
A = np.row_stack(A)
images = (A-np.min(A))/(np.max(A)-np.min(A))
pixelsize = int(np.sqrt(len(images[0])))

# Training the network

In [None]:
# ----------------------------------------------------------------------------- 
# ------------------------------- NN parameters ------------------------------- 
# ----------------------------------------------------------------------------- 
lr = 1e-1                     #Learning rate
INPUT = pixelsize**2          #Amount of input values
N_HIDDEN = 16                 #Amount of neurons in hidden layers
N_LAYERS = 3                  #Amount of hidden layers
OUTPUT = pixelsize**2         #Amount of output values
epsilon = 10e-8               #Threshold to approximate L0 norm
im_number = 250               #Amount of images used
image_iterations = 250        #Amount of iterations executed per image

time = datetime.now()         #Keep track of time

model1 = FCN(INPUT,OUTPUT,N_HIDDEN,N_LAYERS) #Create a network for C
optimizer1 = torch.optim.Adam(model1.parameters(),lr=lr)   #Adam optimizer
model2 = FCN(INPUT,OUTPUT,N_HIDDEN,N_LAYERS) #Create a network for D
optimizer2 = torch.optim.Adam(model2.parameters(),lr=lr)   #Adam optimizer

files = []
counter = 0
for i in range(1,int(image_iterations*im_number)):
    if i%image_iterations == 0:
        counter += 1
    input = images[counter]

    optimizer1.zero_grad()
    optimizer2.zero_grad()
    C = torch.reshape(model1(torch.Tensor(input)), (pixelsize,pixelsize)) 
    D = torch.reshape(model2(torch.Tensor(input)), (pixelsize,pixelsize))
    B = torch.kron(C, D)
    bx = torch.linalg.solve(B, torch.Tensor(input)) # Solve B(B_inv x)=x for 'B_inv x'
    
    loss_sparse = sum(torch.sqrt((bx)**2 + epsilon))
    loss_orth_C = (10e5)*torch.mean((torch.matmul(C, torch.t(C)) - torch.Tensor(np.identity(pixelsize)))**2)
    loss_orth_D = (10e5)*torch.mean((torch.matmul(D, torch.t(D)) - torch.Tensor(np.identity(pixelsize)))**2)
    loss = loss_sparse + loss_orth_C + loss_orth_D

    loss.backward()
    def closure(): return loss
    optimizer1.step(closure)
    optimizer2.step(closure)
    loss = loss.detach()

    # Adaptive learning rate 
    if i%(image_iterations//2) == 0:
        lr = 10e-3 # Decrease when half of iterations passed per image
    if i%image_iterations == 0:
        lr = 10e-1 # Reset to larger lr

    if i%50 == 0:
        C = C.detach().numpy()
        D = D.detach().numpy()
        C_inv = np.linalg.inv(C)
        D_inv = np.linalg.inv(D)
        B = np.kron(C,D)
        B_inv = np.kron(C_inv,D_inv)

        threshold = 1e-15
        sparsity1 = (input < threshold).sum()
        sparsity2 = (np.matmul(B_inv, input) < threshold).sum()
        print('Sparsity before:', sparsity1)
        print('Sparsity after:', sparsity2)

        fig = plt.figure(figsize=(20,15))
        ax1 = fig.add_subplot(141)
        ax1.title.set_text('$x$')
        ax1.imshow(input.reshape(pixelsize,pixelsize), interpolation='none', cmap=cm.Greys_r, vmin=0, vmax=1)
        ax2 = fig.add_subplot(142)
        ax2.title.set_text('$B^{-1}$')
        ax2.imshow(B_inv, interpolation='none', cmap=cm.Greys_r, vmin=0, vmax=1)
        ax3 = fig.add_subplot(143)
        ax3.title.set_text('$BB^{T}x$')
        ax3.imshow(np.matmul(np.matmul(B, B.transpose()), input).reshape(pixelsize,pixelsize), interpolation='none', cmap=cm.Greys_r, vmin=0, vmax=1)
        ax4 = fig.add_subplot(144)
        ax4.title.set_text('$B^{-1}x$')
        ax4.imshow(np.matmul(B_inv, input).reshape(pixelsize,pixelsize), interpolation='none', cmap=cm.Greys_r, vmin=0, vmax=1)
        
        file = "plots/nn_%.8i.png"%(i)
        plt.annotate("Iteration: %i"%(i),xy=(1.05, 0.87),xycoords='axes fraction',fontsize="x-large",color="k")
        plt.annotate("Sparsity $x$: %i"%(sparsity1),xy=(1.05, 0.77),xycoords='axes fraction',fontsize="x-large",color="k")
        plt.annotate("Sparsity $B^{-1}x$: %i"%(sparsity2),xy=(1.05, 0.67),xycoords='axes fraction',fontsize="x-large",color="k")
        plt.savefig(file, bbox_inches='tight', pad_inches=0.1, dpi=100, facecolor="white")
        files.append(file)
        plt.show()
        print(i, datetime.now() - time, loss_sparse, loss_orth_C, loss_orth_D)

save_gif("nn.gif", files, fps=20, loop=0)

B = B.detach().numpy().reshape(INPUT,INPUT)
print('error',np.mean((np.matmul(B,np.matmul(np.linalg.inv(B),input))-input)**2))
np.savetxt("B.csv", B, delimiter=",")
print('Sparsity before:', (input < 1e-10).sum())
print('Sparsity after:', (np.matmul(B, input) < 1e-10).sum())

# Obtaining samples for GPSR

In [None]:
#Obtain 5 training data samples
for i in range(5):
    index = np.random.randint(im_number)
    I = images[index].reshape(pixelsize,pixelsize)
    I8 = (((I-I.min())/(I.max()-I.min()))*255.9).astype(np.uint8)
    img = Image.fromarray(I8)
    img.save(f"img{index}.png")
#Obtain 5 test data samples
for i in range(5):
    index = np.random.randint(im_number,batch_size)
    I = images[index].reshape(pixelsize,pixelsize)
    I8 = (((I-I.min())/(I.max()-I.min()))*255.9).astype(np.uint8)
    img = Image.fromarray(I8)
    img.save(f"img{index}.png")