[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Albly/Sparce_image_reconstruction/blob/master/notebooks/learn_lista.ipynb)

In [None]:
from IPython import get_ipython
import os 
import sys
from pathlib import Path

%load_ext autoreload
%autoreload 2

# check if we use colab or local machine
if 'google.colab' in str(get_ipython()):
    IS_COLAB = True
    print('Running on colab')
else:
    IS_COLAB = False
    print('Running on local machine')

if IS_COLAB:
    git_root = !git rev-parse --show-toplevel
    already_in_repo = os.path.exists(git_root[0])

    if not already_in_repo:
        !git clone https://github.com/Albly/Sparce_image_reconstruction $repo_dir

    sys.path.append('Sparce_image_reconstruction')

else:
    os.chdir(Path().absolute().parent)

In [None]:
import torch
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm, trange
from scipy import sparse
import numpy as np

from metrics import *
from utils import *
from recoverers import activations as act
from recoverers.classical import *
from recoverers.lernable import Lista


In [None]:
from sklearn.datasets import fetch_openml
# Load data from https://www.openml.org/d/554
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)

try: 
    X = X.to_numpy()
except Exception:
    pass

In [None]:
USE_WANDB = True

if USE_WANDB:
    import wandb


def train_net(Images, model, optimizer, scheduler,A, SNR):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('Starting Training using device : ', device)

    model.to(device)
    criterion = MSE
    #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20,30], gamma = 0.1)  
    n_epochs = 100
    
    train_mses = []
    test_mses = []
    best_test_mse = 10**7

    for epoch in range(n_epochs):
        model.train()
        current_mse = 0
        optimizer.zero_grad() 

        # training 
        for n_img in range(300):
            x = torch.tensor(Images[n_img,:], dtype= torch.complex128).to(device)
            #A = get_partial_FFT(392,784)
            n = get_noise(x.real, SNR_dB = SNR)
            n = n +1j*0
            y = A.to(device)@(x.to(device)+n.to(device))
            x_hat = model(y)
            loss = criterion(x, x_hat)
            current_mse += loss
        
        current_mse.backward()
        optimizer.step()
        train_mses.append(current_mse.item())

        model.eval()
        test_mse = 0
        Images_rec  = torch.tensor(())
        Images_real = torch.tensor(())

        for n_img in range(300,350):
            x = torch.tensor(Images[n_img,:], dtype= torch.complex128).to(device)
            n = get_noise(x.real, SNR_dB = 40)
            n = n +1j*0
            y = A.to(device)@(x.to(device)+n.to(device))
            x_hat = model(y)
            loss = criterion(x, x_hat)
            test_mse += loss

            if n_img < 305:
                a = (torch.linalg.pinv(A.to(device))@y).real.reshape(28,28).detach().cpu()
                Images_real = torch.cat((Images_real, a), dim = 1)
                a = (x_hat).real.reshape(28,28).detach().cpu()
                Images_rec = torch.cat((Images_rec, a), dim = 1)


        test_mses.append(test_mse.item())
        scheduler.step(test_mse)

        if USE_WANDB:
            wandb.log({'Train_sum_MSE': current_mse.item()})
            wandb.log({'Test_sum_MSE': test_mse.item()})

            if epoch%5 == 0: 
                images = wandb.Image(Images_real, caption= 'Epoch = {}'.format(epoch))
                wandb.log({"Noisy": images})
                images = wandb.Image(Images_rec, caption= 'Epoch = {}'.format(epoch))
                wandb.log({"Recovered": images})


        if test_mse < best_test_mse:
            best_test_mse = test_mse
            torch.save(model.state_dict(), './'+'Trained'+'_best_params.pt')

        print('Epoch: {0}. Train Loss : {1}. Test Loss: {2}'.format(epoch, current_mse, test_mse))

    return model, train_mses, test_mses

In [None]:
from recoverers.lernable import Lista

M = 392
N = 784
SNR = 40


A = get_partial_FFT(M,N)
lista = Lista(A = A, layers = 16)
optimizer = torch.optim.Adam(lista.parameters(), lr = 1.0e+0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor = 0.1, patience = 4, min_lr = 1.0e-6 )


if USE_WANDB:
    import wandb

    wandb.init(project="Lista")
    wandb.watch(lista, log=None)
    wandb.config.model_name = "Lista"
    wandb.config.measurement_type = "Partial DFT"
    wandb.config.M = M
    wandb.config.SNR = SNR
    

In [None]:
model, train_loss, test_loss = train_net(X, lista, optimizer, A, SNR)