In [None]:
import torch
import numpy as np

from data import X

In [63]:
class torchAA(torch.nn.Module):
    def __init__(self, X, rank):
        super(torchAA, self).__init__()
        
        #Shape of Matrix for reproduction
        n_row, n_col = X.shape
        self.X = torch.tensor(X)
        
        #softmax layer
        self.softmax = torch.nn.Softmax(dim = 0)
        
        #Initialization of Tensors/Matrices S and C with size Col x Rank and Rank x Col
        # NxM (X) * MxD (C) = NxD (XC)
        # NxD (XC) * DxM (S) = NxM (XCS)
        
        self.C = torch.nn.Parameter(torch.rand(n_col, rank, requires_grad=True))
        self.S = torch.nn.Parameter(torch.rand(rank, n_col, requires_grad=True))
        
        
    def forward(self):
        
        #Implementation of AA - F(C, S) = ||X - XCS||^2
        
        #first matrix Multiplication with softmax
        self.XC = torch.matmul(self.X.double(),
                               self.softmax(self.C.double()))
        
        #Second matrix multiplication with softmax
        self.XCS = torch.matmul(self.XC.double(),
                                self.softmax(self.S.double()))
        
        x = self.X - self.XCS
        
        return x

In [56]:
#Defining frobenius Loss
class frobeniusLoss(torch.nn.Module):
    def __init__(self):
        super(frobeniusLoss, self).__init__()
        self.loss = torch.linalg.matrix_norm
    
    def forward(self, input):
        return self.loss(input, ord='fro')

In [None]:
nmf = torchAA(X, 10)

batch_size = 10

#optimizer for modifying learning rate, ADAM chosen because of https://machinelearningmastery.com/adam-optimization-algorithm-for-deep-learning/
optimizer = torch.optim.Adam(nmf.parameters(), lr=0.5)

for epoch in range(100):
    
    #zero optimizer gradient
    optimizer.zero_grad()

    #forward
    output = nmf()
    
    #backward
    loss = frobeniusLoss()
    loss = loss.forward(output)
    loss.backward()
    
    #Update C and S
    optimizer.step()
        
    print(loss.item())
