<a href="https://colab.research.google.com/github/Sibylse/Diss/blob/master/SARAH_Biclustering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code from the [proximal SGD blogpost](http://pmelchior.net/blog/proximal-matrix-factorization-in-pytorch.html)

From the [alternating minimization pytorch](https://medium.com/@rinabuoy13/explicit-recommender-system-matrix-factorization-in-pytorch-f3779bb55d74) blog:

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import collections
rng = np.random.default_rng()
cuda = torch.cuda.is_available()
dev = torch.device("cuda") if cuda else torch.device("cpu")
dev

device(type='cuda')

In [None]:
def generateBinaryFactor(n, r, q):
  B=np.zeros((n,r))
  l= int(np.ceil(n*0.01)) #lower bound of uniquely assigned ones per cluster
  t = r*l # end of block-diagonal part
  for s in range(r):
    B[s*l:(s+1)*l,s]=1 #create block for diagonal
    B[t:,s]= rng.binomial(1, q-0.01, n-t)
  return B

Generate a uniformly distributed or diagonal-dominant middle factor matrix C.

In [None]:
r=3
m=300 
n=200
Y = generateBinaryFactor(m, r, 0.2)
X = generateBinaryFactor(n, r, 0.2)
# Make uniformly distributed matrix:
C = np.round(rng.uniform(0,5,(r,r)),2) 
# Make diagonal-dominant matrix:
# C= np.round(rng.uniform(0,0.5,(r,r)),2) + np.diag(rng.normal(1,0.1,r))
C

array([[4.21, 4.23, 0.5 ],
       [1.01, 3.82, 2.7 ],
       [4.78, 1.05, 2.86]])

The data matrix $D=YCX^\top+N$ where $N$ is the noise matrix.

In [None]:
D = Y@C@X.T + np.round(rng.normal(0,0.1,(m,n)),2)

The loss of the ground truth is:

In [None]:
np.mean((D-Y@C@X.T)**2)

0.010031036666666666

Custom dataset class on the basis of [this](https://www.kaggle.com/pinocookie/pytorch-dataset-and-dataloader) code.

In [None]:
class DataMatrix(Dataset):
        
    def __init__(self, D):
        self.data_ = torch.from_numpy(D).double()
        self.m_ = D.shape[0]
        self.n_ = D.shape[1]

    def __len__(self):
        return self.m_*self.n_
    
    def __getitem__(self, index):
        j = int(index/self.n_)
        i = index%self.n_       
        return torch.tensor([j,i]), self.data_[j,i]

Initialize batch loaders which partition the set of all tupels $(j,i)$ which are indices of the data matrix ($1\leq i\leq n, 1\leq j\leq m$) every epoch into batches containing $\texttt{bs}$ elements.

In [None]:
bs = int(n*m*30/100) #originally 10%
train_loader = DataLoader(DataMatrix(D), batch_size=bs, shuffle=True)
test_loader = DataLoader(DataMatrix(Y@C@X.T), batch_size=1000)
#I_b=torch.eye(bs)
print(bs, len(train_loader))

18000 4


Models: Matrix Factorization

In [None]:
class MatrixFactorization(torch.nn.Module):
    
    def __init__(self, m, n, r=20):
        super().__init__()
        self.Y = torch.nn.Embedding(m, r).double()
        self.X = torch.nn.Embedding(n, r).double()
        self.C = torch.nn.Embedding(r, r).double()
        torch.nn.init.uniform_(self.Y.weight)
        torch.nn.init.uniform_(self.X.weight)
        torch.nn.init.uniform_(self.C.weight)
        
    def forward(self, idx):
      #j and i are torch tensors, denoting a set of indices i and j
      YC =torch.matmul(self.Y(idx[:,0]),self.C.weight) 
      YCX_batch = (YC * self.X(idx[:,1])).sum(1,keepdim=True)
      return YCX_batch.squeeze() #reduces every 1x dimension for tensor

In [None]:
model = MatrixFactorization(n, m, r=2)
model.to(dev)

MatrixFactorization(
  (Y): Embedding(200, 2)
  (X): Embedding(300, 2)
  (C): Embedding(2, 2)
)

Get batch-stepsizes by Lipschitz constants of the gradients:

In [None]:
def stepsizeX(model,batch):
  with torch.no_grad():
    if batch is None:
      YC = torch.matmul(model.Y.weight,model.C.weight)
      L = 2*torch.sqrt((torch.matmul(torch.transpose(YC,0,1),YC)**2).sum())/n/m
    else:
      n_array.fill_(0)
      YC = torch.matmul(model.Y.weight,model.C.weight)
      y = (YC**2).sum(1)[batch[:,0]]
      n_array.index_add_(0,batch[:,1],y)
      L= n_array.max().item()/bs*2
    return 1/2/max(L,0.001)
def stepsizeY(model,batch):
  with torch.no_grad():
    if batch is None:
      XCt = torch.matmul(model.X.weight,torch.transpose(model.C.weight,0,1))
      L = 2*torch.sqrt((torch.matmul(torch.transpose(XCt,0,1),XCt)**2).sum())/n/m
    else:
      m_array.fill_(0)
      XCt = torch.matmul(model.X.weight,torch.transpose(model.C.weight,0,1))
      x = (XCt**2).sum(1)[batch[:,1]]
      m_array.index_add_(0,batch[:,0],x)
      L= m_array.max().item()/bs*2
    return 1/2/max(L,0.001)
def stepsizeC(model,batch):
  with torch.no_grad():
    if batch is None:
      L = 2*(model.X.weight**2).sum()*(model.Y.weight**2).sum()/n/m
    else:
      x = (model.X.weight**2).sum(1)[batch[:,1]]
      y = (model.Y.weight**2).sum(1)[batch[:,0]]
      L = torch.dot(x,y).item()
      L = L/bs*2
    return 1/2/max(L,0.001)

Prox operators

In [None]:
def phi(x):
    return 1-torch.abs(1-2*x)
def prox_binary_(x,lambdas,lr,alpha=-1e-8):
  with torch.no_grad():
    idx_up = x>0.5
    idx_down = x<=0.5
    x[idx_up] += 2*lr*lambdas[idx_up]
    x[idx_down] -= 2*lr*lambdas[idx_down]
    x[x>1] = 1 
    x[x<0] = 0 
    lambdas.add_(phi(x)-1,alpha=alpha)

In [None]:
def prox_pos_(X,weight,lr,alpha=0):
  with torch.no_grad():
    X[X<0] = 0

Optimization

In [None]:
import torch
from torch.optim.optimizer import Optimizer,required

class SARAH(Optimizer):
    r"""Implements SARAH
    """
    def __init__(self, params, params_prev, lr=required, weight_decay=0):

        defaults = dict(lr=lr, weight_decay=weight_decay)
        super(SARAH, self).__init__(params, defaults)
        self.grad_buff = None
        self.params_prev = params_prev

    def __setstate__(self, state):
        super(SARAH, self).__setstate__(state)
    
    def zero_grad(self):
        super(SARAH, self).zero_grad()
        for p in self.params_prev:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        """
        loss = None

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            for x, prev_x in zip(group['params'],self.params_prev):
                if x.grad is None:
                    continue
                if weight_decay != 0:
                    x.grad.add_(x, alpha=weight_decay)
                # x.grad gradient has same shape like the parameter but 
                # zeros at the positions which are not updated.
                # grad_p has same size as batch times r 
                #param_state = self.state[x]
                if self.grad_buff is None: #do full gradient update
                    self.grad_buff = torch.clone(x.grad).detach() 
                else:
                    self.grad_buff.add_(x.grad, alpha=1)
                    self.grad_buff.add_(prev_x.grad, alpha=-1) #g_t = g_t-1  +  grad_t - grad_t-1
                prev_x.mul_(0).add_(torch.clone(x),alpha=1)  
                x.add_(x.grad, alpha=-group['lr'])

        return loss

In [None]:
def train(epoch,alpha):
  model.train()
  model_prev.train()
  cum_loss = 0.
  for batch_idx, (data, target) in enumerate(train_loader):
      lr_mean, lambda_mean =0,0
      if cuda:
        data, target = data.cuda(), target.cuda()
      data, target = Variable(data), Variable(target.double())
      
      for group in param_list:
        optimizer = group["optimizer"]
        optParam = optimizer.param_groups[0]
        stepsize = group["step"]
        optimizer.zero_grad()
        output = model(data)
        output_prev = model_prev(data)
        loss = loss_func(output, target)
        loss_prev = loss_func(output_prev, target)
        loss.backward()
        loss_prev.backward()
        #print("grad:",optParam['params'][0].grad)
        #print("grad nonzero:",(optParam['params'][0].grad !=0)*reg_weight*optParam['lr'])
        optParam['lr'] =stepsize(model,data)#/2
        lr_mean += optParam['lr']
        optimizer.step()
        if "prox" in group:
            prox = group["prox"]
            #The whole factor matrix is proxed for one batch? 
            #Most of the time, this is ok because not often there is no tupel for one row/column.
            prox(optParam['params'][0].data,group["lambda"],optParam['lr'], alpha=alpha) 
            lambda_mean += torch.mean(group["lambda"])
        cum_loss+=loss.item()/3
        if batch_idx % 2 +1 == 0:
            print('Train Epoch:\t\t\t {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader),
                100. * batch_idx / len(train_loader), loss.item()))
  if epoch % 3 == 0:
    print('==Train Epoch:\t\t\t {} \tLoss: {:.6f}\t lambda: {:.3e}\t lr: {:.3f}'.format(
        epoch, cum_loss/len(train_loader),lambda_mean/2,lr_mean/3))

#
# Train full grad
#
def train_full_grad(epoch,alpha):
  model.train()
  cum_loss = 0.
  lambda_mean, lr_mean=0,0
  #for every factor matrix - optimizer
  for group in param_list: 
      optimizer = group["optimizer"]
      optimizer.grad_buff =None #Sign that we do a full grad update
      optimizer.zero_grad()
      optParam = optimizer.param_groups[0]
      stepsize = group["step"]
      for batch_idx, (data, target) in enumerate(train_loader):
          if cuda:
            data, target = data.cuda(), target.cuda()
          data, target = Variable(data), Variable(target.double())
            
          output = model(data)
          # loss is mean squared error over a batch 
          loss = loss_func_full(output, target)/n/m
          loss.backward()
          cum_loss+=loss.item()
      # gamma = 1/(2L), L is normalized with bs but this is a full grad update  
      optParam['lr'] =stepsize(model,None)#/2 #TODO stepsize is computed for one batch! 
      lr_mean+= optParam['lr']
      optimizer.step()
      if "prox" in group:
          prox = group["prox"]
          prox(optParam['params'][0].data,group["lambda"],optParam['lr'], alpha=alpha) 
          lambda_mean += torch.mean(group["lambda"])
      if batch_idx % 2 +1 == 0:
          print('Train Full Grad Batch:\t {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader),
                100. * batch_idx / len(train_loader), loss.item()))
  if epoch % 3 == 0:
    print('==Train Full Grad Epoch:\t {} \tLoss: {:.6f}\t lambda: {:.3e}\t lr: {:.3f}'.format(
        epoch, cum_loss/3,lambda_mean/2,lr_mean/3))


In [None]:
def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)

        output = model(data)
        # sum up batch loss
        test_loss += loss_func(output, target).item()

    test_loss /= len(test_loader)
    print('\nTest set: Average loss: {:.4f}\n'.format(test_loss))

Can L be 0? If not, there might be an error! I've seen that L_C is zero

In [None]:
def is_binary(A):
  return ((A<1) *(A>0)).sum().item() == 0

In [None]:
n_array = torch.zeros(n).double().to(dev)
m_array = torch.zeros(m).double().to(dev)


In [None]:
model = MatrixFactorization(m, n, r=r)
model.to(dev)
model_prev = MatrixFactorization(m, n, r=r)
model_prev.to(dev)
optimizerY = SARAH([model.Y.weight],[model_prev.Y.weight], lr=0.1) # learning rate
optimizerX = SARAH([model.X.weight],[model_prev.X.weight], lr=0.1) # learning rate
optimizerC = SARAH([model.C.weight],[model_prev.C.weight], lr=0.1)
lambdasY = torch.zeros_like(model.Y.weight) #TODO is prev_model also proxed?
lambdasX = torch.zeros_like(model.X.weight)
param_list = [{'optimizer': optimizerX, 'step': stepsizeX, 'prox':prox_binary_, 'lambda':lambdasX},
              {'optimizer': optimizerY, 'step': stepsizeY, 'prox':prox_binary_, 'lambda':lambdasY},
              {'optimizer': optimizerC, 'step': stepsizeC, 'prox':prox_pos_, 'lambda':torch.tensor([0.0])}]
loss_func = torch.nn.MSELoss()
loss_func_full = torch.nn.MSELoss(reduction='sum')
epoch = 1
test()
alpha=-1e-8
#alpha = -1/n/m/len(train_loader)/10
thresh=0.1
full_grad_prob=0.3
#while not is_binary(model.Y.fact.weight.data) or not is_binary(model.X.fact.weight.data):
while not is_binary(model.Y.weight.data) or not is_binary(model.X.weight.data):
    full_batch_grad = np.random.binomial(1,full_grad_prob)
    if full_batch_grad:
      train_full_grad(epoch,alpha)
    else:
      train(epoch,alpha)
    if epoch % 6 == 0:
      #phiX, phiY = torch.mean(phi(model.X.fact.weight.data)), torch.mean(phi(model.Y.fact.weight.data))
      phiX, phiY = torch.mean(phi(model.X.weight.data)), torch.mean(phi(model.Y.weight.data))
      print('--\t\t\tphi(X):\t {:.3f} \tphi(Y): {:.3f}'.format(phiX,phiY))
      if max(phiX,phiY) < thresh:
        alpha*=2
        thresh/=2
    epoch+=1
    if epoch % 50 == 0:
      test()


Test set: Average loss: 4.7010

==Train Epoch:			 3 	Loss: 0.642183	 lambda: 7.969e-08	 lr: 10.663
==Train Epoch:			 6 	Loss: 0.521559	 lambda: 1.447e-07	 lr: 8.140
--			phi(X):	 0.272 	phi(Y): 0.274
==Train Epoch:			 9 	Loss: 0.467683	 lambda: 2.317e-07	 lr: 7.031
==Train Epoch:			 12 	Loss: 0.444347	 lambda: 2.752e-07	 lr: 7.029
--			phi(X):	 0.277 	phi(Y): 0.273
==Train Epoch:			 15 	Loss: 0.448108	 lambda: 3.620e-07	 lr: 6.492
==Train Full Grad Epoch:	 18 	Loss: 0.452143	 lambda: 4.269e-07	 lr: 4.167
--			phi(X):	 0.283 	phi(Y): 0.277
==Train Epoch:			 21 	Loss: 0.428145	 lambda: 4.917e-07	 lr: 6.701
==Train Epoch:			 24 	Loss: 0.422547	 lambda: 5.564e-07	 lr: 6.409
--			phi(X):	 0.284 	phi(Y): 0.279
==Train Epoch:			 27 	Loss: 0.421124	 lambda: 6.426e-07	 lr: 5.883
==Train Full Grad Epoch:	 30 	Loss: 0.421714	 lambda: 6.857e-07	 lr: 3.788
--			phi(X):	 0.284 	phi(Y): 0.279
==Train Epoch:			 33 	Loss: 0.416888	 lambda: 7.504e-07	 lr: 5.998
==Train Full Grad Epoch:	 36 	Loss: 0.415

In [None]:
batch = next(iter(train_loader))[0]
batch = batch.cuda()
batch_Y = model.Y(batch[:,0]).float().to(dev)
batch_i = batch[:,1]
unique_i, inverse_i, count_i = torch.unique(batch_i, return_inverse=True, return_counts=True)
I1hot =I_b[inverse_i].to(dev)
W = torch.einsum('js,jt,ju->stu', batch_Y.float() , batch_Y.float(), I1hot)

In [None]:
n_array = torch.zeros(n).float().to(dev)
n_array.index_add_(0,batch_i,(batch_Y**2).sum(1))

tensor([54., 50., 51., 52., 45., 43., 58., 38., 42., 64., 49., 42., 50., 46.,
        35., 61., 50., 40., 46., 46., 57., 62., 51., 52., 51., 47., 61., 54.,
        69., 45., 48., 49., 51., 39., 47., 63., 54., 50., 52., 45., 47., 48.,
        56., 48., 49., 57., 49., 43., 55., 48., 49., 57., 55., 55., 34., 58.,
        53., 45., 64., 50., 44., 38., 53., 54., 53., 47., 59., 45., 45., 38.,
        43., 43., 32., 48., 50., 52., 43., 49., 59., 64., 48., 54., 58., 54.,
        46., 67., 56., 53., 51., 61., 46., 61., 50., 56., 52., 62., 41., 56.,
        48., 51., 51., 49., 50., 46., 56., 52., 45., 61., 37., 49., 39., 53.,
        48., 53., 51., 47., 54., 56., 60., 46., 47., 49., 48., 48., 52., 43.,
        42., 45., 49., 45., 46., 50., 47., 42., 50., 53., 57., 48., 53., 51.,
        44., 43., 52., 42., 37., 36., 50., 64., 43., 30., 43., 39., 61., 46.,
        45., 43., 41., 49., 46., 56., 47., 42., 41., 62., 43., 56., 55., 51.,
        46., 43., 48., 44., 48., 38., 51., 48., 46., 62., 51., 5

In [None]:
n_array.max().item()

69.0

In [None]:
batch[:,1]

tensor([39, 38, 62,  ..., 67, 95, 32], device='cuda:0')

In [None]:
Y[0:10,:]

array([[1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 0.]])

In [None]:
model.Y.weight[0:10,:]

tensor([[0., 0., 1.],
        [0., 0., 1.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 0.]], device='cuda:0', dtype=torch.float64,
       grad_fn=<SliceBackward>)

In [None]:
model.Y.weight.sum(0)

tensor([50., 59., 54.], device='cuda:0', dtype=torch.float64,
       grad_fn=<SumBackward1>)

In [None]:
Y.sum(0)

array([54., 50., 59.])

In [None]:
C

array([[4.21, 4.23, 0.5 ],
       [1.01, 3.82, 2.7 ],
       [4.78, 1.05, 2.86]])

In [None]:
model.C.weight

Parameter containing:
tensor([[1.0055, 3.8218, 2.7007],
        [4.7808, 1.0474, 2.8617],
        [4.2134, 4.2298, 0.4993]], device='cuda:0', dtype=torch.float64,
       requires_grad=True)

In [None]:
X[0:10,:]

array([[1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 1.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 1.]])

In [None]:
model.X.fact.weight[0:10,:]

ModuleAttributeError: ignored

In [None]:
optParam = optimizerX.param_groups[0]
print((optParam['params'][0].grad !=0).sum())
gradX = optParam['params'][0].grad
batch_i = gradX.sum(1) != 0
gradX[batch_i,:]

In [None]:
batch = next(iter(train_loader))[0]
#batch = batch.to(dev)
batch

In [None]:
len(np.unique(batch[:,1]))

In [None]:
torch.round(torch.matmul(torch.matmul(model.Y.fact.weight,model.C.fact.weight),torch.transpose(model.X.fact.weight,0,1))*100)/100

In [None]:
Y@C@X.T

In [None]:
D