In [0]:
## Importing necessary libraries
import torch as tor
import torch.nn as nn
import torch.utils.data
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt
import os
import math

from sklearn.datasets import make_sparse_coded_signal

In [0]:
class listaUnit(nn.Module):

    """
        The basic unit of the lista network
    """

    def __init__(self,activation):

        super().__init__()

        self.activation = activation

    def forward(self,inp,B,S,theta):

        C = B + tor.matmul(S,inp)
        Z = self.activation(C,theta)

        return Z

In [0]:
class listanet(nn.Module):

    """
        This is the lista network, that takes as input a signal vector and outputs a sparse code
    """

    def __init__(self,activation,n_features,n_components,device):

        super().__init__()

        self.n = n_features
        self.m = n_components

        self.activation = activation

        ## defining learnable parameters
        self.We = nn.Parameter(tor.rand(self.m, self.n).to(device) * 0.001, requires_grad = True)
        self.S = nn.Parameter(tor.rand(self.m, self.m).to(device) * 0.001, requires_grad = True)
        self.theta = nn.Parameter(tor.rand(self.m, 1).to(device) * 0.001, requires_grad = True)
        
        self.l1 = listaUnit(self.activation)
        self.l2 = listaUnit(self.activation)
        self.l3 = listaUnit(self.activation)

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.We,a = math.sqrt(5))
        nn.init.kaiming_uniform_(self.S,a = math.sqrt(5))
        nn.init.uniform_(self.theta,-0.001,0.001)

    def forward(self,x):

        B =  tor.matmul(self.We,x)

        Z0 = self.activation(B,self.theta)

        Z1 = self.l1(Z0,B, self.S, self.theta)
        Z2 = self.l2(Z1,B, self.S, self.theta)
        Z3 = self.l3(Z2,B, self.S, self.theta)

        return Z3

In [0]:
def soft_threshold(v,theta):

    """
        The soft-thresholding function
    """

    v = tor.sign(v) * tor.max(tor.abs(v) - theta,tor.tensor(0.0).to(device))

    return v

In [0]:
def batch_soft_threshold(v,theta):

    """
        perform soft-thresholding on the entire batch
    """

    batch_size,i,j = v.size()
    v1 = v.clone()

    for k in range(batch_size):

        v1[k,:,:] = soft_threshold(v[k,:,:],theta)

    return v1

In [0]:
def getdata(n_features,n_components,n_samples,n_nonzero_coefs,random_state = None,train_size = 0.7):

    """
        Obtain the data of signal, the corresponding dictionary and the sparse codes
    """ 

    X,Wd,Z = make_sparse_coded_signal(n_samples,n_components,n_features,n_nonzero_coefs,random_state)

    mid1 = int(X.shape[1] * (train_size))
    mid2 = mid1 + int(X.shape[1] * (1 - train_size))//2

    Xtrain = X[:,:mid1]
    Xval = X[:,mid1:mid2]
    Xtest = X[:,mid2:]
    Ztrain = Z[:,:mid1]
    Zval = Z[:,mid1:mid2]
    Ztest = Z[:,mid2:]

    Xtrain,Xtest,Ztrain,Ztest = tor.from_numpy(Xtrain),tor.from_numpy(Xtest),tor.from_numpy(Ztrain),tor.from_numpy(Ztest)
    Xval,Zval = tor.from_numpy(Xval),tor.from_numpy(Zval)

    return Xtrain,Xval,Xtest,Wd,Ztrain,Zval,Ztest

In [0]:
class dataset(torch.utils.data.Dataset):

    """
        Dataset class
    """

    def __init__(self, phase = "train"):

        super().__init__()

        if(phase == "train"):
            self.Z = Ztrain
            self.X = Xtrain
        elif(phase == "val"):
            self.X = Xval
            self.Z = Zval
        elif(phase == "test"):
            self.X = Xtest
            self.Z = Ztest


        self.data_size = self.X.size(1)

    def __len__(self):
        return self.data_size

    def __getitem__(self,idx):

        Zi = self.Z[:,idx].unsqueeze(1).type(tor.FloatTensor)
        Xi = self.X[:,idx].unsqueeze(1).type(tor.FloatTensor)

        return Zi,Xi

### Divide the dataset into train, validation and test sets

In [0]:
Xtrain,Xval,Xtest,Wd,Ztrain,Zval,Ztest = getdata(100,100,50000,60)

### saving the input data

In [0]:
task = "load"

if(task == "save"):
    np.save("Xtrain.npy",Xtrain)
    np.save("Xval.npy",Xval)
    np.save("Xtest.npy",Xtest)
    np.save("Wd.npy",Wd)
    np.save("Ztrain.npy",Ztrain)
    np.save("Zval.npy",Zval)
    np.save("Ztest.npy",Ztest)
elif(task == "load"):
    Xtrain = np.load("Xtrain.npy")
    Xval = np.load("Xval.npy")
    Xtest = np.load("Xtest.npy")
    # Wd = np.load("Wd.npy")
    Ztrain = np.load("Ztrain.npy")
    Zval = np.load("Zval.npy")
    Ztest = np.load("Ztest.npy")

    Xtrain,Xtest,Ztrain,Ztest = tor.from_numpy(Xtrain),tor.from_numpy(Xtest),tor.from_numpy(Ztrain),tor.from_numpy(Ztest)
    Xval,Zval = tor.from_numpy(Xval),tor.from_numpy(Zval)

    print("data loaded")

data loaded


### Create the training set and validation set 

In [0]:
trainset = dataset()
trainloader = tor.utils.data.DataLoader(trainset,batch_size = 128)

valset = dataset(phase = "val")
valloader = tor.utils.data.DataLoader(valset,batch_size = 128)

### Set up the network and device configuration

In [0]:
device = tor.device("cuda:0" if tor.cuda.is_available() else "cpu")
print("using: ",device)

net = listanet(batch_soft_threshold,100,100,device)
# criterion = nn.MSELoss()

using:  cuda:0


In [0]:
def train(net,epochs,lr,reg,dataloaders,reset,save = False):

    """
        Method that performs training given a models
    """

    if(reset):
        net.reset_parameters()
        print("/////////////////////// weights reset \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\")

    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr = 1e-3,weight_decay = reg)

    trainloader,valloader = dataloaders

    epoch_losses = []

    # with tor.autograd.set_detect_anomaly(True):
    for epoch in range(epochs):

        batch_losses = []

        for batch_idx,(z,x) in enumerate(trainloader):

            z,x = z.to(device),x.to(device)

            optimizer.zero_grad()

            out = net(x)
            loss = criterion(out,z)

            # print(out)
            # print(z)
            # print(loss)

            loss.backward()
            optimizer.step()

            batch_losses.append(loss.item())
        
        print("epoch: ",epoch,"epoch loss: ",np.mean(batch_losses))

        net.eval()

        with tor.no_grad():

            batch_losses = []
            for batch_idx,(z,x) in enumerate(valloader):
                
                z,x = z.to(device),x.to(device)

                out = net(x)

                loss = criterion(out,z)

                batch_losses.append(loss.item())

            print("val loss: ",np.mean(batch_losses))

        net.train()

        print("-------------------------------------------------------------------------------------")

        if(save):
            state = net.state_dict()
            tor.save(state,"/content/drive/My Drive/datasets/model_chk.pth.tar")
            print("******************** saving the model ****************************")
    return epoch_losses

### Loading a pre-existing model

In [0]:
state = tor.load("/content/drive/My Drive/datasets/model_chk.pth.tar")
net.load_state_dict(state)

<All keys matched successfully>

### training the model given a set of hyper-parameters

In [0]:
epochs = 100
lr = 1.38 * 1e-6
dataloaders = [trainloader,valloader]
reset = False

state = tor.load("/content/drive/My Drive/datasets/model_chk.pth.tar")
net.load_state_dict(state)
_ = train(net,epochs,lr,dataloaders,reset,save = True)

### Hyper-parameter tuning

In [0]:
epochs = 35
rounds = 10

reset = True
dataloaders = [trainloader,valloader]

for count in range(rounds):

    lr = 10 ** np.random.uniform(-6,-1)
    reg = 10 ** np.random.uniform(-5,-1)

    print("learning rate: ",lr,"reg: ",reg,"---> ",count+1,"/",rounds)

    _ = train(net,epochs,lr,reg,dataloaders,reset)