Torch implementation of NMF - When done should be moved to a .py file to allow easy implementation, or alternatively allow saving of model parameters

In [1]:
import torch
import numpy as np
import pandas as pd

from helpers.data import X
from helpers.callbacks import earlyStop
from helpers.losses import VolLoss

Non negative matrix factorization works by finding matrices W and H such that X = WH

In [2]:
class torchNMF(torch.nn.Module):
    def __init__(self, X, rank):
        super(torchNMF, self).__init__()
        
        #Shape of Matrix for reproduction
        n_row, n_col = X.shape
        self.X = torch.tensor(X)
        
        self.softmax = torch.nn.Softmax(dim = 0)
        self.softplus = torch.nn.Softplus()
        
        #Initialization of Tensors/Matrices a and b with size NxR and RxM
        # W is the basis matrix
        self.W = torch.nn.Parameter(torch.rand(n_row, rank, requires_grad=True))
        # H is the encoding matrix
        self.H = torch.nn.Parameter(torch.rand(rank, n_col, requires_grad=True))
        
    def forward(self):
        
        #Implementation of NMF - F(W, H) = ||X - WH||^2
        self.WH = torch.matmul(self.softmax(self.W), self.softplus(self.H))
        x = self.WH
        
        return self.softmax(self.W), self.softplus(self.H), x

In [3]:
torch.manual_seed(0)

X = torch.tensor(X)

nmf = torchNMF(X, 3)

#optimizer for modifying learning rate, ADAM chosen because of https://machinelearningmastery.com/adam-optimization-algorithm-for-deep-learning/
optimizer = torch.optim.Adam(nmf.parameters(), lr=0.3)

#early stopping
es = earlyStop(patience=10, offset=-0.00000001)

running_loss = []

while (not es.trigger()):
    #zero optimizer gradient
    optimizer.zero_grad()

    #forward
    w_out, h_out, x_out = nmf()
    #backward
    loss = VolLoss(X)
    loss = loss.forward(w_out, h_out, x_out)
    loss.backward()
    
    #Update A and B
    optimizer.step()
    
    running_loss.append(loss.item())
    es.count(loss.item())
    
    #print loss
    print(f"epoch: {len(running_loss)}, Loss: {loss.item()}", end='\r')

  self.X = torch.tensor(X)


epoch: 2932, Loss: 773.42848274616069

In [6]:
W, H = list(nmf.parameters())

W = nmf.softmax(W)
H = nmf.softplus(H)

print(W.size())

print(W[:,2].sum())

W = W.detach().numpy()
H = H.detach().numpy()

rec = np.dot(W, H)

rec = rec.T

rec_frame = pd.DataFrame(rec)
rec_frame.columns = rec_frame.columns.astype(str)


rec_frame.to_parquet("recons_x_nmf_vol.parquet",
                     engine = 'fastparquet')

torch.Size([29, 3])
tensor(1.0000, grad_fn=<SumBackward0>)
