<a href="https://colab.research.google.com/github/Sibylse/ZeroShades/blob/master/SARAH_Clustering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code from the [proximal SGD blogpost](http://pmelchior.net/blog/proximal-matrix-factorization-in-pytorch.html)

From the [alternating minimization pytorch](https://medium.com/@rinabuoy13/explicit-recommender-system-matrix-factorization-in-pytorch-f3779bb55d74) blog:

In [7]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import collections
rng = np.random.default_rng()
cuda = torch.cuda.is_available()
dev = torch.device("cuda") if cuda else torch.device("cpu")
dev

device(type='cuda')

In [8]:
def generateBinaryFactor(n, r, q):
  B=np.zeros((n,r))
  l= int(np.ceil(n*0.01)) #lower bound of uniquely assigned ones per cluster
  t = r*l # end of block-diagonal part
  for s in range(r):
    B[s*l:(s+1)*l,s]=1 #create block for diagonal
    B[t:,s]= rng.binomial(1, q-0.01, n-t)
  return B

Generate a diagonal-dominant middle factor matrix C.

In [9]:
r=3
m=300 
n=200
Y = generateBinaryFactor(m, r, 0.2)
X = generateBinaryFactor(n, r, 0.2)
C = np.round(rng.uniform(0,5,(r,r)),2) 
# Make diagonal-dominant matrix:
# C= np.round(rng.uniform(0,0.5,(r,r)),2) + np.diag(rng.normal(1,0.1,r))
C

array([[1.64, 4.07, 0.47],
       [0.55, 0.23, 4.34],
       [3.33, 4.03, 4.98]])

In [10]:
D = Y@C@X.T + np.round(rng.normal(0,0.1,(m,n)),2)

In [11]:
np.mean((D-Y@C@X.T)**2)

0.009987398333333335

Custom dataset on the basis of [this](https://www.kaggle.com/pinocookie/pytorch-dataset-and-dataloader) code.

In [12]:
class DataMatrix(Dataset):
        
    def __init__(self, D):
        self.data_ = torch.from_numpy(D).double()
        self.m_ = D.shape[0]
        self.n_ = D.shape[1]

    def __len__(self):
        return self.m_*self.n_
    
    def __getitem__(self, index):
        j = int(index/self.n_)
        i = index%self.n_       
        return torch.tensor([j,i]), self.data_[j,i]

In [13]:
int(n*m*2.5/100)

1500

In [14]:
bs = int(n*m*5/100)
train_loader = DataLoader(DataMatrix(D), batch_size=bs, shuffle=True)
test_loader = DataLoader(DataMatrix(Y@C@X.T), batch_size=1000)
I_b=torch.eye(bs)
print(bs, len(train_loader))

3000 20


Models: Matrix Factorization

In [15]:
class GaussianNoise(nn.Module):
    """Gaussian noise regularizer.

    Args:
        sigma (float, optional): relative standard deviation used to generate the
            noise. Relative means that it will be multiplied by the magnitude of
            the value your are adding the noise to. This means that sigma can be
            the same regardless of the scale of the vector.
        is_relative_detach (bool, optional): whether to detach the variable before
            computing the scale of the noise. If `False` then the scale of the noise
            won't be seen as a constant but something to optimize: this will bias the
            network to generate vectors with smaller values.
    """

    def __init__(self, sigma=0.1, is_relative_detach=True):
        super().__init__()
        self.sigma = sigma
        self.is_relative_detach = is_relative_detach
        self.noise = torch.tensor(0).to(dev)

    def forward(self, x):
        if self.training and self.sigma != 0:
            scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
            sampled_noise = self.noise.repeat(*x.size()).float().normal_() * scale
            x = x + sampled_noise
        return x 

In [16]:
class MatrixFactorization(torch.nn.Module):
    
    def __init__(self, m, n, r=20,sigma=0):
        super().__init__()
        self.X = nn.Sequential(collections.OrderedDict([
          ('fact', torch.nn.Embedding(n, r).double()),
          ('noise', GaussianNoise(sigma))
        ]))
        self.Y = nn.Sequential(collections.OrderedDict([
          ('fact', torch.nn.Embedding(m, r).double()),
          ('noise', GaussianNoise(sigma))
        ]))
        self.C = nn.Sequential(collections.OrderedDict([
          ('fact', torch.nn.Embedding(r, r).double()),
          ('noise', GaussianNoise(sigma))
        ]))
        #self.Y = torch.nn.Embedding(m, r).double()
        #self.X = torch.nn.Embedding(n, r).double()
        #self.C = torch.nn.Embedding(r, r).double()
        torch.nn.init.uniform_(self.Y.fact.weight)
        torch.nn.init.uniform_(self.X.fact.weight)
        torch.nn.init.uniform_(self.C.fact.weight)
        
    def forward(self, idx):
      #j and i are torch tensors, denoting a set of indices i and j
      YC =torch.matmul(self.Y(idx[:,0]),self.C.fact.weight) 
      YCX_batch = (YC * self.X(idx[:,1])).sum(1,keepdim=True)
      return YCX_batch.squeeze() #reduces every 1x dimension for tensor

In [17]:
model = MatrixFactorization(n, m, r=2)
model.to(dev)

MatrixFactorization(
  (X): Sequential(
    (fact): Embedding(300, 2)
    (noise): GaussianNoise()
  )
  (Y): Sequential(
    (fact): Embedding(200, 2)
    (noise): GaussianNoise()
  )
  (C): Sequential(
    (fact): Embedding(2, 2)
    (noise): GaussianNoise()
  )
)

Get batch-Lipschitz constants:

In [18]:
def L_X(batch_Y,batch_i):
  with torch.no_grad():
    unique_i, inverse_i = torch.unique(batch_i, return_inverse=True)
    I1hot =I_b[inverse_i].to(dev)
    W = torch.einsum('js,jt,ju->stu', batch_Y , batch_Y, I1hot)
    L = torch.sqrt((W**2).sum([0,1]).max()).item()
    L= L/bs*2
    return max(L,0.001)
def L_C(batch_X,batch_Y):
  with torch.no_grad():
    d = torch.einsum('js,jt->j', batch_Y , batch_X)
    W = torch.matmul(torch.transpose(batch_X,0,1)*d, batch_Y)
    L = torch.sqrt((W**2).sum()).item()
    L = L/bs*2
    return max(L,0.001)
def stepsizeX(model,batch):
  with torch.no_grad():
    batch_Y = model.Y.fact(batch[:,0]).float().to(dev)
    batch_Y = torch.matmul(batch_Y,model.C.fact.weight.float()).to(dev)
    L= L_X(batch_Y,batch[:,1])
    return 1/4.1/L
def stepsizeY(model,batch):
  with torch.no_grad():
    batch_X = model.X.fact(batch[:,1]).float().to(dev)
    batch_X = torch.matmul(batch_X,torch.transpose(model.C.fact.weight.float(),0,1)).to(dev) 
    L= L_X(batch_X,batch[:,0])
    return 1/4.1/L
def stepsizeC(model,batch):
  with torch.no_grad():
    batch_X = model.X.fact(batch[:,1]).float().to(dev)
    batch_Y = model.Y.fact(batch[:,0]).float().to(dev)
    L= L_C(batch_X,batch_Y)
    return 1/4.1/L

Prox operators

In [19]:
def phi(x):
    return 1-torch.abs(1-2*x)
def prox_binary_(x,lambdas,lr,alpha=-1e-8):
  with torch.no_grad():
    idx_up = x>0.5
    idx_down = x<=0.5
    x[idx_up] += 2*lr*lambdas[idx_up]
    x[idx_down] -= 2*lr*lambdas[idx_down]
    x[x>1] = 1 
    x[x<0] = 0 
    lambdas.add_(phi(x)-1,alpha=alpha)

In [20]:
def prox_pos_(X,weight,lr,alpha=0):
  with torch.no_grad():
    X[X<0] = 0

Optimization

In [21]:
import torch
from torch.optim.optimizer import Optimizer,required

class SAGA(Optimizer):
    r"""Implements SAGA TODO.
    """

    def __init__(self, params, prox, L, reg_weight=0.1, saga=False):

        defaults = dict(reg_weight=reg_weight,saga=saga)
        super(Spring, self).__init__(params, defaults)
        self.prox = prox
        self.L=L

    def __setstate__(self, state):
        super(Spring, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('saga', False)
            group.setdefault('reg_weight', 0.1)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        """
        loss = None


        for group in self.param_groups:
            reg_weight = group['reg_weight'] #lambda
            estimator = group['estimator'] #SAGA, SARAH or SGD?

            for p in group['params']:
                if p.grad is None:
                    continue
                # p.grad gradient has same shape like the parameter but 
                # zeros at the positions which are not updated.
                # grad_p has same size as batch times r 
                if estimator == "SAGA": 
                    
                    param_state = self.state[p]
                    batch_idx = p.grad.sum(1) != 0
                    grad_p = p.grad[batch_idx,:]
                    if 'grad_buffer' not in param_state:
                        G = grad_p
                        param_state['grad_buffer'] = torch.clone(p.grad).detach() 
                    else:
                        G = param_state['grad_buffer'][batch_i,:]
                        G.mul_(-1).add_(grad_p, alpha=1) #G = grad_t - grad_t-1
                        param_state['grad_buffer'][batch_i,:] = torch.clone(grad_p).detach() #memory?
                    if 'avg_buffer' not in param_state:
                        grad_p = grad_p.add(buf, alpha=momentum)
                    else:
                        param_state['avg_buffer']
                p.add_(p.grad, alpha=-group['lr'])

        return loss

In [22]:
import torch
from torch.optim.optimizer import Optimizer,required

class SARAH(Optimizer):
    r"""Implements SARAH
    """
    def __init__(self, params, params_prev, lr=required, weight_decay=0):

        defaults = dict(lr=lr, weight_decay=weight_decay)
        super(SARAH, self).__init__(params, defaults)
        self.grad_buff = None
        self.params_prev = params_prev

    def __setstate__(self, state):
        super(SARAH, self).__setstate__(state)
    
    def zero_grad(self):
        super(SARAH, self).zero_grad()
        for p in self.params_prev:
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.
        """
        loss = None

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            for x, prev_x in zip(group['params'],self.params_prev):
                if x.grad is None:
                    continue
                if weight_decay != 0:
                    x.grad.add_(x, alpha=weight_decay)
                # x.grad gradient has same shape like the parameter but 
                # zeros at the positions which are not updated.
                # grad_p has same size as batch times r 
                #param_state = self.state[x]
                if self.grad_buff is None: #do full gradient update
                    self.grad_buff = torch.clone(x.grad).detach() 
                else:
                    self.grad_buff.add_(x.grad, alpha=1)
                    self.grad_buff.add_(prev_x.grad, alpha=-1) #g_t = g_t-1  +  grad_t - grad_t-1
                prev_x.mul_(0).add_(torch.clone(x),alpha=1)  
                x.add_(x.grad, alpha=-group['lr'])

        return loss

In [23]:
n*5

1000

In [24]:
def train(epoch,alpha):
  model.train()
  model_prev.train()
  cum_loss = 0.
  for batch_idx, (data, target) in enumerate(train_loader):
      lr_mean, lambda_mean =0,0
      if cuda:
        data, target = data.cuda(), target.cuda()
      data, target = Variable(data), Variable(target.double())
      
      for group in param_list:
        optimizer = group["optimizer"]
        optParam = optimizer.param_groups[0]
        stepsize = group["step"]
        optimizer.zero_grad()
        output = model(data)
        output_prev = model_prev(data)
        loss = loss_func(output, target)
        loss_prev = loss_func(output_prev, target)
        loss.backward()
        loss_prev.backward()
        #print("grad:",optParam['params'][0].grad)
        #print("grad nonzero:",(optParam['params'][0].grad !=0)*reg_weight*optParam['lr'])
        optParam['lr'] =stepsize(model,data)/2
        lr_mean += optParam['lr']
        optimizer.step()
        if "prox" in group:
            prox = group["prox"]
            #The whole factor matrix is proxed for one batch? 
            #Most of the time, this is ok because not often there is no tupel for one row/column.
            prox(optParam['params'][0].data,group["lambda"],optParam['lr'], alpha=alpha) 
            lambda_mean += torch.mean(group["lambda"])
        cum_loss+=loss.item()/3
        if batch_idx % 2 +1 == 0:
            print('Train Epoch:\t\t\t {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader),
                100. * batch_idx / len(train_loader), loss.item()))
  if epoch % 3 == 0:
    print('==Train Epoch:\t\t\t {} \tLoss: {:.6f}\t lambda: {:.3e}\t lr: {:.3f}'.format(
        epoch, cum_loss/len(train_loader),lambda_mean/2,lr_mean/3))

#
# Train full grad
#
def train_full_grad(epoch,alpha):
  model.train()
  cum_loss = 0.
  lambda_mean, lr_mean=0,0
  #for every factor matrix - optimizer
  for group in param_list: 
      optimizer = group["optimizer"]
      optimizer.grad_buff =None #Sign that we do a full grad update
      optimizer.zero_grad()
      optParam = optimizer.param_groups[0]
      stepsize = group["step"]
      for batch_idx, (data, target) in enumerate(train_loader):
          if cuda:
            data, target = data.cuda(), target.cuda()
          data, target = Variable(data), Variable(target.double())
            
          output = model(data)
          # loss is mean squared error over a batch 
          loss = loss_func(output, target)*bs/n/m
          loss.backward()
          cum_loss+=loss.item()
      # gamma = 1/(2L), L is normalized with bs but this is a full grad update  
      optParam['lr'] =stepsize(model,data)/2 #TODO stepsize is computed for one batch! 
      lr_mean+= optParam['lr']
      optimizer.step()
      if "prox" in group:
          prox = group["prox"]
          prox(optParam['params'][0].data,group["lambda"],optParam['lr'], alpha=alpha) 
          lambda_mean += torch.mean(group["lambda"])
      if batch_idx % 2 +1 == 0:
          print('Train Full Grad Batch:\t {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader),
                100. * batch_idx / len(train_loader), loss.item()))
  if epoch % 3 == 0:
    print('==Train Full Grad Epoch:\t {} \tLoss: {:.6f}\t lambda: {:.3e}\t lr: {:.3f}'.format(
        epoch, cum_loss/3,lambda_mean/2,lr_mean/3))


In [25]:
def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)

        output = model(data)
        # sum up batch loss
        test_loss += loss_func(output, target).item()

    test_loss /= len(test_loader)
    print('\nTest set: Average loss: {:.4f}\n'.format(test_loss))

Can L be 0? If not, there might be an error! I've seen that L_C is zero

In [26]:
def is_binary(A):
  return ((A<1) *(A>0)).sum().item() == 0

In [None]:
model = MatrixFactorization(m, n, r=r,sigma=0)
model.to(dev)
model_prev = MatrixFactorization(m, n, r=r,sigma=0)
model_prev.to(dev)
optimizerY = SARAH([model.Y.fact.weight],[model_prev.Y.fact.weight], lr=0.1) # learning rate
optimizerX = SARAH([model.X.fact.weight],[model_prev.X.fact.weight], lr=0.1) # learning rate
optimizerC = SARAH([model.C.fact.weight],[model_prev.C.fact.weight], lr=0.1)
lambdasY = torch.zeros_like(model.Y.fact.weight) #TODO is prev_model also proxed?
lambdasX = torch.zeros_like(model.X.fact.weight)
param_list = [{'optimizer': optimizerX, 'step': stepsizeX, 'prox':prox_binary_, 'lambda':lambdasX},
              {'optimizer': optimizerY, 'step': stepsizeY, 'prox':prox_binary_, 'lambda':lambdasY},
              {'optimizer': optimizerC, 'step': stepsizeC, 'prox':prox_pos_, 'lambda':torch.tensor([0.0])}]
loss_func = torch.nn.MSELoss()
epoch = 1
test()
alpha=-1e-8
thresh=0.01
full_grad_prob=0.3
while not is_binary(model.Y.fact.weight.data) or not is_binary(model.X.fact.weight.data):
    full_batch_grad = np.random.binomial(1,full_grad_prob)
    if full_batch_grad:
      train_full_grad(epoch,alpha)
    else:
      train(epoch,alpha)
    if epoch % 6 == 0:
      phiX, phiY = torch.mean(phi(model.X.fact.weight.data)), torch.mean(phi(model.Y.fact.weight.data))
      print('--\t\t\tphi(X):\t {:.3f} \tphi(Y): {:.3f}'.format(phiX,phiY))
      if max(phiX,phiY) < thresh:
        alpha*=2
        thresh/=2
    epoch+=1
    if epoch % 50 == 0:
      test()


Test set: Average loss: 4.4215

==Train Epoch:			 3 	Loss: 2.256917	 lambda: 1.229e-07	 lr: 1.643
==Train Epoch:			 6 	Loss: 0.893256	 lambda: 4.191e-07	 lr: 1.060
--			phi(X):	 0.253 	phi(Y): 0.232
==Train Full Grad Epoch:	 9 	Loss: 0.743900	 lambda: 5.873e-07	 lr: 0.878
==Train Epoch:			 12 	Loss: 0.621190	 lambda: 1.052e-06	 lr: 0.740
--			phi(X):	 0.260 	phi(Y): 0.188
==Train Epoch:			 15 	Loss: 0.534607	 lambda: 1.517e-06	 lr: 0.620
==Train Epoch:			 18 	Loss: 0.481876	 lambda: 1.980e-06	 lr: 0.517
--			phi(X):	 0.261 	phi(Y): 0.198
==Train Epoch:			 21 	Loss: 0.458119	 lambda: 2.295e-06	 lr: 0.516
==Train Full Grad Epoch:	 24 	Loss: 0.437350	 lambda: 2.609e-06	 lr: 0.511
--			phi(X):	 0.263 	phi(Y): 0.207
==Train Epoch:			 27 	Loss: 0.427172	 lambda: 2.922e-06	 lr: 0.473
==Train Epoch:			 30 	Loss: 0.416427	 lambda: 3.234e-06	 lr: 0.479
--			phi(X):	 0.269 	phi(Y): 0.215
==Train Epoch:			 33 	Loss: 0.404738	 lambda: 3.687e-06	 lr: 0.467
==Train Full Grad Epoch:	 36 	Loss: 0.4003

In [None]:
Y[0:10,:]

In [None]:
model.Y.fact.weight[0:10,:]

In [None]:
model.Y.fact.weight.sum(0)

In [None]:
Y.sum(0)

In [None]:
C

In [None]:
model.C.fact.weight

In [None]:
X[0:10,:]

In [None]:
model.X.fact.weight[0:10,:]

In [None]:
optParam = optimizerX.param_groups[0]
print((optParam['params'][0].grad !=0).sum())
gradX = optParam['params'][0].grad
batch_i = gradX.sum(1) != 0
gradX[batch_i,:]

In [None]:
batch = next(iter(train_loader))[0]
#batch = batch.to(dev)
batch

In [None]:
len(np.unique(batch[:,1]))

In [None]:
torch.round(torch.matmul(torch.matmul(model.Y.fact.weight,model.C.fact.weight),torch.transpose(model.X.fact.weight,0,1))*100)/100

In [None]:
Y@C@X.T

In [None]:
D